diff --git a/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala b/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala index 40692ec4..f902adca 100644 --- a/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala +++ b/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala @@ -1,10 +1,12 @@ package com.github.swagger.scala.converter +import java.lang.annotation.Annotation import java.lang.reflect.ParameterizedType import java.util.Iterator import com.fasterxml.jackson.databind.JavaType import com.fasterxml.jackson.databind.`type`.ReferenceType +import com.fasterxml.jackson.module.scala.introspect.{BeanIntrospector, PropertyDescriptor} import com.fasterxml.jackson.module.scala.{DefaultScalaModule, JsonScalaEnumeration} import io.swagger.v3.core.converter._ import io.swagger.v3.core.jackson.ModelResolver @@ -27,6 +29,8 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte private val SetClass = classOf[scala.collection.Set[_]] private val BigDecimalClass = classOf[BigDecimal] private val BigIntClass = classOf[BigInt] + private val ProductClass = classOf[Product] + private val AnyClass = classOf[Any] override def resolve(`type`: AnnotatedType, context: ModelConverterContext, chain: Iterator[ModelConverter]): Schema[_] = { val javaType = _mapper.constructType(`type`.getType) @@ -40,6 +44,8 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte resolve(nextType(baseType, `type`, javaType), context, chain) } else if (!annotatedOverrides.headOption.getOrElse(true)) { resolve(nextType(new AnnotatedTypeForOption(), `type`, javaType), context, chain) + } else if (isCaseClass(cls)) { + caseClassSchema(cls, `type`, context, chain).getOrElse(None.orNull) } else if (chain.hasNext) { val nextResolved = Option(chain.next().resolve(`type`, context, chain)) nextResolved match { @@ -68,14 +74,41 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte } } - private def getRequiredSettings(`type`: AnnotatedType): Seq[Boolean] = `type` match { - case _: AnnotatedTypeForOption => Seq.empty - case _ => { - nullSafeList(`type`.getCtxAnnotations).collect { - case p: Parameter => p.required() - case s: SchemaAnnotation => s.required() - case a: ArraySchema => a.arraySchema().required() + private def caseClassSchema(cls: Class[_], `type`: AnnotatedType, context: ModelConverterContext, + chain: Iterator[ModelConverter]): Option[Schema[_]] = { + if (chain.hasNext) { + Option(chain.next().resolve(`type`, context, chain)).map { schema => + val introspector = BeanIntrospector(cls) + introspector.properties.foreach { property => + getPropertyAnnotations(property) match { + case Seq() => { + val propertyClass = getPropertyClass(property) + if (!isOption(propertyClass)) addRequiredItem(schema, property.name) + } + case annotations => { + val required = getRequiredSettings(annotations).headOption + .getOrElse(!isOption(getPropertyClass(property))) + if (required) addRequiredItem(schema, property.name) + } + } + } + schema } + } else { + None + } + } + + private def getRequiredSettings(annotatedType: AnnotatedType): Seq[Boolean] = annotatedType match { + case _: AnnotatedTypeForOption => Seq.empty + case _ => getRequiredSettings(nullSafeList(annotatedType.getCtxAnnotations)) + } + + private def getRequiredSettings(annotations: Seq[Annotation]): Seq[Boolean] = { + annotations.collect { + case p: Parameter => p.required() + case s: SchemaAnnotation => s.required() + case a: ArraySchema => a.arraySchema().required() } } @@ -183,8 +216,63 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte else None } + private def getPropertyClass(property: PropertyDescriptor): Class[_] = { + property.param match { + case Some(constructorParameter) => { + val types = constructorParameter.constructor.getParameterTypes + if (constructorParameter.index > types.size) { + AnyClass + } else { + types(constructorParameter.index) + } + } + case _ => property.field match { + case Some(field) => field.getType + case _ => property.setter match { + case Some(setter) if setter.getParameterCount == 1 => { + setter.getParameterTypes()(0) + } + case _ => property.beanSetter match { + case Some(setter) if setter.getParameterCount == 1 => { + setter.getParameterTypes()(0) + } + case _ => AnyClass + } + } + } + } + } + + private def getPropertyAnnotations(property: PropertyDescriptor): Seq[Annotation] = { + property.param match { + case Some(constructorParameter) => { + val types = constructorParameter.constructor.getParameterAnnotations + if (constructorParameter.index > types.size) { + Seq.empty + } else { + types(constructorParameter.index).toSeq + } + } + case _ => property.field match { + case Some(field) => field.getAnnotations.toSeq + case _ => property.setter match { + case Some(setter) if setter.getParameterCount == 1 => { + setter.getAnnotations().toSeq + } + case _ => property.beanSetter match { + case Some(setter) if setter.getParameterCount == 1 => { + setter.getAnnotations().toSeq + } + case _ => Seq.empty + } + } + } + } + } + private def isOption(cls: Class[_]): Boolean = cls == OptionClass private def isIterable(cls: Class[_]): Boolean = IterableClass.isAssignableFrom(cls) + private def isCaseClass(cls: Class[_]): Boolean = ProductClass.isAssignableFrom(cls) private def nullSafeList[T](array: Array[T]): List[T] = Option(array) match { case None => List.empty[T] diff --git a/src/test/scala/com/github/swagger/scala/converter/ModelPropertyParserTest.scala b/src/test/scala/com/github/swagger/scala/converter/ModelPropertyParserTest.scala index 6273aef1..4bf523e9 100644 --- a/src/test/scala/com/github/swagger/scala/converter/ModelPropertyParserTest.scala +++ b/src/test/scala/com/github/swagger/scala/converter/ModelPropertyParserTest.scala @@ -347,8 +347,7 @@ class ModelPropertyParserTest extends AnyFlatSpec with Matchers with OptionValue val1Field shouldBe a [IntegerSchema] val val2Field = model.value.getProperties.get("val2") val2Field shouldBe a [IntegerSchema] - //TODO try to fix this - //model.value.getRequired().asScala shouldEqual Seq("val1", "val2") + model.value.getRequired().asScala shouldEqual Seq("val1", "val2") } private def findModel(schemas: Map[String, Schema[_]], name: String): Option[Schema[_]] = { @@ -367,7 +366,7 @@ class ModelPropertyParserTest extends AnyFlatSpec with Matchers with OptionValue val schemas = converter.readAll(classOf[ModelWStringSeq]).asScala.toMap val model = findModel(schemas, "ModelWStringSeq") model should be(defined) - nullSafeList(model.value.getRequired) shouldEqual Seq() + nullSafeList(model.value.getRequired) shouldBe empty } it should "process Array-Model with forced required Scala Option Seq" in {