Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package com.github.swagger.scala.converter

import java.lang.annotation.Annotation
import java.lang.reflect.ParameterizedType
import java.util.Iterator

import com.fasterxml.jackson.databind.JavaType
import com.fasterxml.jackson.databind.`type`.ReferenceType
import com.fasterxml.jackson.module.scala.introspect.{BeanIntrospector, PropertyDescriptor}
import com.fasterxml.jackson.module.scala.{DefaultScalaModule, JsonScalaEnumeration}
import io.swagger.v3.core.converter._
import io.swagger.v3.core.jackson.ModelResolver
Expand All @@ -27,6 +29,8 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte
private val SetClass = classOf[scala.collection.Set[_]]
private val BigDecimalClass = classOf[BigDecimal]
private val BigIntClass = classOf[BigInt]
private val ProductClass = classOf[Product]
private val AnyClass = classOf[Any]

override def resolve(`type`: AnnotatedType, context: ModelConverterContext, chain: Iterator[ModelConverter]): Schema[_] = {
val javaType = _mapper.constructType(`type`.getType)
Expand All @@ -40,6 +44,8 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte
resolve(nextType(baseType, `type`, javaType), context, chain)
} else if (!annotatedOverrides.headOption.getOrElse(true)) {
resolve(nextType(new AnnotatedTypeForOption(), `type`, javaType), context, chain)
} else if (isCaseClass(cls)) {
caseClassSchema(cls, `type`, context, chain).getOrElse(None.orNull)
} else if (chain.hasNext) {
val nextResolved = Option(chain.next().resolve(`type`, context, chain))
nextResolved match {
Expand Down Expand Up @@ -68,14 +74,41 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte
}
}

private def getRequiredSettings(`type`: AnnotatedType): Seq[Boolean] = `type` match {
case _: AnnotatedTypeForOption => Seq.empty
case _ => {
nullSafeList(`type`.getCtxAnnotations).collect {
case p: Parameter => p.required()
case s: SchemaAnnotation => s.required()
case a: ArraySchema => a.arraySchema().required()
private def caseClassSchema(cls: Class[_], `type`: AnnotatedType, context: ModelConverterContext,
chain: Iterator[ModelConverter]): Option[Schema[_]] = {
if (chain.hasNext) {
Option(chain.next().resolve(`type`, context, chain)).map { schema =>
val introspector = BeanIntrospector(cls)
introspector.properties.foreach { property =>
getPropertyAnnotations(property) match {
case Seq() => {
val propertyClass = getPropertyClass(property)
if (!isOption(propertyClass)) addRequiredItem(schema, property.name)
}
case annotations => {
val required = getRequiredSettings(annotations).headOption
.getOrElse(!isOption(getPropertyClass(property)))
if (required) addRequiredItem(schema, property.name)
}
}
}
schema
}
} else {
None
}
}

private def getRequiredSettings(annotatedType: AnnotatedType): Seq[Boolean] = annotatedType match {
case _: AnnotatedTypeForOption => Seq.empty
case _ => getRequiredSettings(nullSafeList(annotatedType.getCtxAnnotations))
}

private def getRequiredSettings(annotations: Seq[Annotation]): Seq[Boolean] = {
annotations.collect {
case p: Parameter => p.required()
case s: SchemaAnnotation => s.required()
case a: ArraySchema => a.arraySchema().required()
}
}

Expand Down Expand Up @@ -183,8 +216,63 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte
else None
}

private def getPropertyClass(property: PropertyDescriptor): Class[_] = {
property.param match {
case Some(constructorParameter) => {
val types = constructorParameter.constructor.getParameterTypes
if (constructorParameter.index > types.size) {
AnyClass
} else {
types(constructorParameter.index)
}
}
case _ => property.field match {
case Some(field) => field.getType
case _ => property.setter match {
case Some(setter) if setter.getParameterCount == 1 => {
setter.getParameterTypes()(0)
}
case _ => property.beanSetter match {
case Some(setter) if setter.getParameterCount == 1 => {
setter.getParameterTypes()(0)
}
case _ => AnyClass
}
}
}
}
}

private def getPropertyAnnotations(property: PropertyDescriptor): Seq[Annotation] = {
property.param match {
case Some(constructorParameter) => {
val types = constructorParameter.constructor.getParameterAnnotations
if (constructorParameter.index > types.size) {
Seq.empty
} else {
types(constructorParameter.index).toSeq
}
}
case _ => property.field match {
case Some(field) => field.getAnnotations.toSeq
case _ => property.setter match {
case Some(setter) if setter.getParameterCount == 1 => {
setter.getAnnotations().toSeq
}
case _ => property.beanSetter match {
case Some(setter) if setter.getParameterCount == 1 => {
setter.getAnnotations().toSeq
}
case _ => Seq.empty
}
}
}
}
}

private def isOption(cls: Class[_]): Boolean = cls == OptionClass
private def isIterable(cls: Class[_]): Boolean = IterableClass.isAssignableFrom(cls)
private def isCaseClass(cls: Class[_]): Boolean = ProductClass.isAssignableFrom(cls)

private def nullSafeList[T](array: Array[T]): List[T] = Option(array) match {
case None => List.empty[T]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,7 @@ class ModelPropertyParserTest extends AnyFlatSpec with Matchers with OptionValue
val1Field shouldBe a [IntegerSchema]
val val2Field = model.value.getProperties.get("val2")
val2Field shouldBe a [IntegerSchema]
//TODO try to fix this
//model.value.getRequired().asScala shouldEqual Seq("val1", "val2")
model.value.getRequired().asScala shouldEqual Seq("val1", "val2")
}

private def findModel(schemas: Map[String, Schema[_]], name: String): Option[Schema[_]] = {
Expand All @@ -367,7 +366,7 @@ class ModelPropertyParserTest extends AnyFlatSpec with Matchers with OptionValue
val schemas = converter.readAll(classOf[ModelWStringSeq]).asScala.toMap
val model = findModel(schemas, "ModelWStringSeq")
model should be(defined)
nullSafeList(model.value.getRequired) shouldEqual Seq()
nullSafeList(model.value.getRequired) shouldBe empty
}

it should "process Array-Model with forced required Scala Option Seq" in {
Expand Down