From 909a190ac79b57e0dba37c12c78ad9b34a33ae2f Mon Sep 17 00:00:00 2001 From: SamTheisens <1911436+SamTheisens@users.noreply.github.com> Date: Fri, 5 Aug 2022 14:34:27 +0800 Subject: [PATCH 1/5] Save optional primitive types from erasure --- build.sbt | 9 ++--- .../SwaggerScalaModelConverter.scala | 34 +++++++++++++++++-- .../converter/ModelPropertyParserTest.scala | 7 ++-- .../scala/converter/ScalaModelTest.scala | 2 +- 4 files changed, 42 insertions(+), 10 deletions(-) diff --git a/build.sbt b/build.sbt index 698a5c0e..a05fbc94 100644 --- a/build.sbt +++ b/build.sbt @@ -49,6 +49,7 @@ libraryDependencies ++= Seq( "org.slf4j" % "slf4j-api" % "1.7.36", "io.swagger.core.v3" % "swagger-core-jakarta" % "2.2.2", "com.fasterxml.jackson.module" %% "jackson-module-scala" % "2.13.3", + "org.scala-lang" % "scala-reflect" % scalaVersion.value, "org.scalatest" %% "scalatest" % "3.2.11" % Test, "org.slf4j" % "slf4j-simple" % "1.7.36" % Test ) @@ -63,10 +64,10 @@ licenses := Seq(("Apache License 2.0", new URL("http://www.apache.org/licenses/L pomExtra := { pomExtra.value ++ Group( - - github - https://github.com/swagger-api/swagger-scala-module/issues - + + github + https://github.com/swagger-api/swagger-scala-module/issues + fehguy diff --git a/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala b/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala index aa1046b0..1263b2d3 100644 --- a/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala +++ b/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala @@ -9,12 +9,14 @@ import com.fasterxml.jackson.module.scala.introspect.{BeanIntrospector, Property import com.fasterxml.jackson.module.scala.{DefaultScalaModule, JsonScalaEnumeration} import io.swagger.v3.core.converter._ import io.swagger.v3.core.jackson.ModelResolver -import io.swagger.v3.core.util.{Json, PrimitiveType} +import io.swagger.v3.core.util.{Json, PrimitiveType, ReflectionUtils} import io.swagger.v3.oas.annotations.Parameter import io.swagger.v3.oas.annotations.media.{ArraySchema, Schema => SchemaAnnotation} import io.swagger.v3.oas.models.media.Schema import org.slf4j.LoggerFactory +import java.util +import scala.reflect.runtime.universe import scala.util.Try import scala.util.control.NonFatal @@ -37,6 +39,30 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte private val ProductClass = classOf[Product] private val AnyClass = classOf[Any] + private def erasedOptionalPrimitives(cls: Class[_]) = { + val mirror = universe.runtimeMirror(cls.getClassLoader) + val sym = mirror.staticClass(cls.getName) // obtain class symbol for `c` + val properties = sym.selfType.members + .filterNot(_.isMethod) + .filterNot(_.isClass) + .map(prop => prop.name.toString.trim -> prop.typeSignature).toMap + + properties.view.mapValues { typeSignature => + if (mirror.runtimeClass(typeSignature.typeSymbol.asClass) != OptionClass) { + None + } else { + val typeArg = typeSignature.typeArgs.headOption + typeArg.flatMap { signature => + if (signature.typeSymbol.isClass) { + val clazz = mirror.runtimeClass(signature.typeSymbol.asClass) + Option(PrimitiveType.fromType(clazz)).map(_.createProperty()) + } else None + } + } + }.collect { case (k, Some(v)) => k -> v }.toMap + } + + override def resolve(`type`: AnnotatedType, context: ModelConverterContext, chain: Iterator[ModelConverter]): Schema[_] = { val javaType = _mapper.constructType(`type`.getType) val cls = javaType.getRawClass @@ -71,7 +97,9 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte } private def caseClassSchema(cls: Class[_], `type`: AnnotatedType, context: ModelConverterContext, - chain: Iterator[ModelConverter]): Option[Schema[_]] = { + chain: util.Iterator[ModelConverter]): Option[Schema[_]] = { + val erasedProperties = erasedOptionalPrimitives(cls) + if (chain.hasNext) { Option(chain.next().resolve(`type`, context, chain)).map { schema => val introspector = BeanIntrospector(cls) @@ -79,6 +107,8 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte getPropertyAnnotations(property) match { case Seq() => { val propertyClass = getPropertyClass(property) + erasedProperties.get(property.name).foreach(schema.addProperty(property.name, _)) + val optionalFlag = isOption(propertyClass) if (optionalFlag && schema.getRequired != null && schema.getRequired.contains(property.name)) { schema.getRequired.remove(property.name) diff --git a/src/test/scala/com/github/swagger/scala/converter/ModelPropertyParserTest.scala b/src/test/scala/com/github/swagger/scala/converter/ModelPropertyParserTest.scala index 56c761e0..04fb268b 100644 --- a/src/test/scala/com/github/swagger/scala/converter/ModelPropertyParserTest.scala +++ b/src/test/scala/com/github/swagger/scala/converter/ModelPropertyParserTest.scala @@ -107,9 +107,10 @@ class ModelPropertyParserTest extends AnyFlatSpec with Matchers with OptionValue val model = schemas.get("ModelWOptionInt") model should be (defined) model.value.getProperties should not be (null) - val optInt = model.value.getProperties().get("optInt") + val optInt = model.value.getProperties.get("optInt") optInt should not be (null) - optInt shouldBe a [Schema[_]] + optInt shouldBe a [IntegerSchema] + optInt.asInstanceOf[IntegerSchema].getFormat shouldEqual "int32" nullSafeList(model.value.getRequired) shouldBe empty } @@ -134,7 +135,7 @@ class ModelPropertyParserTest extends AnyFlatSpec with Matchers with OptionValue model.value.getProperties should not be (null) val optLong = model.value.getProperties().get("optLong") optLong should not be (null) - optLong shouldBe a [Schema[_]] + optLong shouldBe a [IntegerSchema] nullSafeList(model.value.getRequired) shouldBe empty } diff --git a/src/test/scala/com/github/swagger/scala/converter/ScalaModelTest.scala b/src/test/scala/com/github/swagger/scala/converter/ScalaModelTest.scala index 38f9e461..b00fc919 100644 --- a/src/test/scala/com/github/swagger/scala/converter/ScalaModelTest.scala +++ b/src/test/scala/com/github/swagger/scala/converter/ScalaModelTest.scala @@ -70,7 +70,7 @@ class ScalaModelTest extends AnyFlatSpec with Matchers { val date = userSchema.getProperties().get("date") date shouldBe a [DateTimeSchema] - //date.getDescription should be ("the birthdate") +// date.getDescription should be ("the birthdate") } it should "read a model with vector property" in { From b7fa29df280533c78fe5c77762f4eae888076484 Mon Sep 17 00:00:00 2001 From: SamTheisens <1911436+SamTheisens@users.noreply.github.com> Date: Fri, 5 Aug 2022 16:23:22 +0800 Subject: [PATCH 2/5] Override implementation when optional primitive but leave all other annotation properties in tact branch: --- build.sbt | 6 +++++ .../SwaggerScalaModelConverter.scala | 24 ++++++++++++------- .../converter/ModelPropertyParserTest.scala | 13 ++++++++++ src/test/scala/models/ModelWOptionInt.scala | 4 +++- 4 files changed, 38 insertions(+), 9 deletions(-) diff --git a/build.sbt b/build.sbt index a05fbc94..c034cbdb 100644 --- a/build.sbt +++ b/build.sbt @@ -53,6 +53,12 @@ libraryDependencies ++= Seq( "org.scalatest" %% "scalatest" % "3.2.11" % Test, "org.slf4j" % "slf4j-simple" % "1.7.36" % Test ) +libraryDependencies ++= { + CrossVersion.partialVersion(Keys.scalaVersion.value) match { + case Some((3, _)) => Seq() + case _ => Seq("org.scala-lang" % "scala-reflect" % scalaVersion.value) + } +} homepage := Some(new URL("https://github.com/swagger-akka-http/swagger-scala-module")) diff --git a/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala b/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala index 1263b2d3..339797f5 100644 --- a/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala +++ b/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala @@ -9,7 +9,7 @@ import com.fasterxml.jackson.module.scala.introspect.{BeanIntrospector, Property import com.fasterxml.jackson.module.scala.{DefaultScalaModule, JsonScalaEnumeration} import io.swagger.v3.core.converter._ import io.swagger.v3.core.jackson.ModelResolver -import io.swagger.v3.core.util.{Json, PrimitiveType, ReflectionUtils} +import io.swagger.v3.core.util.{Json, PrimitiveType} import io.swagger.v3.oas.annotations.Parameter import io.swagger.v3.oas.annotations.media.{ArraySchema, Schema => SchemaAnnotation} import io.swagger.v3.oas.models.media.Schema @@ -39,23 +39,23 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte private val ProductClass = classOf[Product] private val AnyClass = classOf[Any] - private def erasedOptionalPrimitives(cls: Class[_]) = { + private def erasedOptionalPrimitives(cls: Class[_]): Map[String, Schema[_]] = { val mirror = universe.runtimeMirror(cls.getClassLoader) - val sym = mirror.staticClass(cls.getName) // obtain class symbol for `c` + val sym = mirror.staticClass(cls.getName) val properties = sym.selfType.members .filterNot(_.isMethod) .filterNot(_.isClass) .map(prop => prop.name.toString.trim -> prop.typeSignature).toMap - properties.view.mapValues { typeSignature => + properties.mapValues { typeSignature => if (mirror.runtimeClass(typeSignature.typeSymbol.asClass) != OptionClass) { None } else { val typeArg = typeSignature.typeArgs.headOption typeArg.flatMap { signature => if (signature.typeSymbol.isClass) { - val clazz = mirror.runtimeClass(signature.typeSymbol.asClass) - Option(PrimitiveType.fromType(clazz)).map(_.createProperty()) + val runtimeClass = mirror.runtimeClass(signature.typeSymbol.asClass) + Option(PrimitiveType.fromType(runtimeClass)).map(_.createProperty()) } else None } } @@ -104,11 +104,12 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte Option(chain.next().resolve(`type`, context, chain)).map { schema => val introspector = BeanIntrospector(cls) introspector.properties.foreach { property => + erasedProperties.get(property.name).foreach { schemaOverride => + overrideImplementationOfAnnotationSchema(schema, schemaOverride, property.name) + } getPropertyAnnotations(property) match { case Seq() => { val propertyClass = getPropertyClass(property) - erasedProperties.get(property.name).foreach(schema.addProperty(property.name, _)) - val optionalFlag = isOption(propertyClass) if (optionalFlag && schema.getRequired != null && schema.getRequired.contains(property.name)) { schema.getRequired.remove(property.name) @@ -130,6 +131,13 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte } } + private def overrideImplementationOfAnnotationSchema(schema: Schema[_], schemaOverride: Schema[_], propertyName: String) = { + val prop = schema.getProperties.get(propertyName) + val propAsString = objectMapper.writeValueAsString(prop) + val parsedSchema = objectMapper.readValue(propAsString, schemaOverride.getClass) + schema.addProperty(propertyName, parsedSchema) + } + private def getRequiredSettings(annotatedType: AnnotatedType): Seq[Boolean] = annotatedType match { case _: AnnotatedTypeForOption => Seq.empty case _ => getRequiredSettings(nullSafeList(annotatedType.getCtxAnnotations)) diff --git a/src/test/scala/com/github/swagger/scala/converter/ModelPropertyParserTest.scala b/src/test/scala/com/github/swagger/scala/converter/ModelPropertyParserTest.scala index 04fb268b..11ad02de 100644 --- a/src/test/scala/com/github/swagger/scala/converter/ModelPropertyParserTest.scala +++ b/src/test/scala/com/github/swagger/scala/converter/ModelPropertyParserTest.scala @@ -124,9 +124,22 @@ class ModelPropertyParserTest extends AnyFlatSpec with Matchers with OptionValue optInt should not be (null) optInt shouldBe a [IntegerSchema] optInt.asInstanceOf[IntegerSchema].getFormat shouldEqual "int32" + optInt.getDescription shouldBe "This is an optional int" nullSafeList(model.value.getRequired) shouldBe empty } + it should "allow annotation to override required with Scala Option Int" in { + val converter = ModelConverters.getInstance() + val schemas = converter.readAll(classOf[ModelWOptionIntSchemaOverrideForRequired]).asScala.toMap + val model = schemas.get("ModelWOptionIntSchemaOverrideForRequired") + model should be(defined) + model.value.getProperties should not be (null) + val optInt = model.value.getProperties().get("optInt") + optInt should not be (null) + optInt shouldBe an [IntegerSchema] + nullSafeList(model.value.getRequired) shouldEqual Seq("optInt") + } + it should "process Model with Scala Option Long" in { val converter = ModelConverters.getInstance() val schemas = converter.readAll(classOf[ModelWOptionLong]).asScala.toMap diff --git a/src/test/scala/models/ModelWOptionInt.scala b/src/test/scala/models/ModelWOptionInt.scala index 6c05dd26..58702c4f 100644 --- a/src/test/scala/models/ModelWOptionInt.scala +++ b/src/test/scala/models/ModelWOptionInt.scala @@ -4,4 +4,6 @@ import io.swagger.v3.oas.annotations.media.Schema case class ModelWOptionInt(optInt: Option[Int]) -case class ModelWOptionIntSchemaOverride(@Schema(implementation = classOf[Int]) optInt: Option[Int]) +case class ModelWOptionIntSchemaOverride(@Schema(description = "This is an optional int") optInt: Option[Int]) + +case class ModelWOptionIntSchemaOverrideForRequired(@Schema(required = true) optInt: Option[Int]) From bd0034ca89ec1be130da4b870854519bb37b1777 Mon Sep 17 00:00:00 2001 From: SamTheisens <1911436+SamTheisens@users.noreply.github.com> Date: Sat, 6 Aug 2022 11:34:30 +0800 Subject: [PATCH 3/5] Ensure that project compiles on Scala 3 by introducing a Scala 3 stub implementation of `erasedOptionalPrimitives`. Tests still fail --- .github/workflows/ci.yml | 6 +++- build.sbt | 4 +-- .../scala/converter/ErasureHelper.scala | 32 +++++++++++++++++++ .../scala/converter/ErasureHelper.scala | 11 +++++++ .../SwaggerScalaModelConverter.scala | 27 +--------------- 5 files changed, 51 insertions(+), 29 deletions(-) create mode 100644 src/main/scala-2/com/github/swagger/scala/converter/ErasureHelper.scala create mode 100644 src/main/scala-3/com/github/swagger/scala/converter/ErasureHelper.scala diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e746758e..cb2f3d27 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -59,9 +59,13 @@ jobs: run: sbt '++${{ matrix.scala }}' coverage test coverageReport - name: Scala build - if: '!startsWith(matrix.scala, ''2.13'')' + if: '!startsWith(matrix.scala, ''2.13'') && !startsWith(matrix.scala, ''3.0'')' run: sbt '++${{ matrix.scala }}' test + - name: Scala compile + if: startsWith(matrix.scala, '3.0') + run: sbt '++${{ matrix.scala }}' compile + - name: Publish to Codecov.io if: startsWith(matrix.scala, '2.13') uses: codecov/codecov-action@v2 diff --git a/build.sbt b/build.sbt index c034cbdb..66d4c54f 100644 --- a/build.sbt +++ b/build.sbt @@ -49,7 +49,6 @@ libraryDependencies ++= Seq( "org.slf4j" % "slf4j-api" % "1.7.36", "io.swagger.core.v3" % "swagger-core-jakarta" % "2.2.2", "com.fasterxml.jackson.module" %% "jackson-module-scala" % "2.13.3", - "org.scala-lang" % "scala-reflect" % scalaVersion.value, "org.scalatest" %% "scalatest" % "3.2.11" % Test, "org.slf4j" % "slf4j-simple" % "1.7.36" % Test ) @@ -91,7 +90,8 @@ pomExtra := { ThisBuild / githubWorkflowBuild := Seq( WorkflowStep.Sbt(List("coverage", "test", "coverageReport"), name = Some("Scala 2.13 build"), cond = Some("startsWith(matrix.scala, '2.13')")), - WorkflowStep.Sbt(List("test"), name = Some("Scala build"), cond = Some("!startsWith(matrix.scala, '2.13')")), + WorkflowStep.Sbt(List("test"), name = Some("Scala build"), cond = Some("!startsWith(matrix.scala, '2.13') && !startsWith(matrix.scala, '3.0')")), + WorkflowStep.Sbt(List("compile"), name = Some("Scala compile"), cond = Some("startsWith(matrix.scala, '3.0')")), ) ThisBuild / githubWorkflowJavaVersions := Seq(JavaSpec(Zulu, "8")) diff --git a/src/main/scala-2/com/github/swagger/scala/converter/ErasureHelper.scala b/src/main/scala-2/com/github/swagger/scala/converter/ErasureHelper.scala new file mode 100644 index 00000000..f8e2c62d --- /dev/null +++ b/src/main/scala-2/com/github/swagger/scala/converter/ErasureHelper.scala @@ -0,0 +1,32 @@ +package com.github.swagger.scala.converter + +import io.swagger.v3.core.util.PrimitiveType +import io.swagger.v3.oas.models.media.Schema + +object ErasureHelper { + + def erasedOptionalPrimitives(cls: Class[_]): Map[String, Schema[_]] = { + import scala.reflect.runtime.universe + val mirror = universe.runtimeMirror(cls.getClassLoader) + val sym = mirror.staticClass(cls.getName) + val properties = sym.selfType.members + .filterNot(_.isMethod) + .filterNot(_.isClass) + .map(prop => prop.name.toString.trim -> prop.typeSignature).toMap + + properties.mapValues { typeSignature => + if (typeSignature.typeSymbol.isClass && mirror.runtimeClass(typeSignature.typeSymbol.asClass) != classOf[scala.Option[_]]) { + None + } else { + val typeArg = typeSignature.typeArgs.headOption + typeArg.flatMap { signature => + if (signature.typeSymbol.isClass) { + val runtimeClass = mirror.runtimeClass(signature.typeSymbol.asClass) + Option(PrimitiveType.fromType(runtimeClass)).map(_.createProperty()) + } else None + } + } + }.collect { case (k, Some(v)) => k -> v }.toMap + } + +} diff --git a/src/main/scala-3/com/github/swagger/scala/converter/ErasureHelper.scala b/src/main/scala-3/com/github/swagger/scala/converter/ErasureHelper.scala new file mode 100644 index 00000000..bb0d2bbb --- /dev/null +++ b/src/main/scala-3/com/github/swagger/scala/converter/ErasureHelper.scala @@ -0,0 +1,11 @@ +package com.github.swagger.scala.converter + +import io.swagger.v3.oas.models.media.Schema + + +object ErasureHelper { + + def erasedOptionalPrimitives(cls: Class[_]): Map[String, Class[_]] = Map.empty[String, Class[_]] + +} + diff --git a/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala b/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala index 339797f5..870eff83 100644 --- a/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala +++ b/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala @@ -16,7 +16,6 @@ import io.swagger.v3.oas.models.media.Schema import org.slf4j.LoggerFactory import java.util -import scala.reflect.runtime.universe import scala.util.Try import scala.util.control.NonFatal @@ -39,30 +38,6 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte private val ProductClass = classOf[Product] private val AnyClass = classOf[Any] - private def erasedOptionalPrimitives(cls: Class[_]): Map[String, Schema[_]] = { - val mirror = universe.runtimeMirror(cls.getClassLoader) - val sym = mirror.staticClass(cls.getName) - val properties = sym.selfType.members - .filterNot(_.isMethod) - .filterNot(_.isClass) - .map(prop => prop.name.toString.trim -> prop.typeSignature).toMap - - properties.mapValues { typeSignature => - if (mirror.runtimeClass(typeSignature.typeSymbol.asClass) != OptionClass) { - None - } else { - val typeArg = typeSignature.typeArgs.headOption - typeArg.flatMap { signature => - if (signature.typeSymbol.isClass) { - val runtimeClass = mirror.runtimeClass(signature.typeSymbol.asClass) - Option(PrimitiveType.fromType(runtimeClass)).map(_.createProperty()) - } else None - } - } - }.collect { case (k, Some(v)) => k -> v }.toMap - } - - override def resolve(`type`: AnnotatedType, context: ModelConverterContext, chain: Iterator[ModelConverter]): Schema[_] = { val javaType = _mapper.constructType(`type`.getType) val cls = javaType.getRawClass @@ -98,7 +73,7 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte private def caseClassSchema(cls: Class[_], `type`: AnnotatedType, context: ModelConverterContext, chain: util.Iterator[ModelConverter]): Option[Schema[_]] = { - val erasedProperties = erasedOptionalPrimitives(cls) + val erasedProperties = ErasureHelper.erasedOptionalPrimitives(cls) if (chain.hasNext) { Option(chain.next().resolve(`type`, context, chain)).map { schema => From 2cd53ca799a516b76d0485f549b997926e5ea773 Mon Sep 17 00:00:00 2001 From: SamTheisens <1911436+SamTheisens@users.noreply.github.com> Date: Sat, 6 Aug 2022 12:34:12 +0800 Subject: [PATCH 4/5] Prepare for supporting collections branch: --- .../scala/converter/ErasureHelper.scala | 25 ++++++------------- .../SwaggerScalaModelConverter.scala | 25 ++++++++++++------- .../converter/ModelPropertyParserTest.scala | 17 +++++++++++++ src/test/scala/models/ModelWSeqInt.scala | 3 +++ 4 files changed, 44 insertions(+), 26 deletions(-) create mode 100644 src/test/scala/models/ModelWSeqInt.scala diff --git a/src/main/scala-2/com/github/swagger/scala/converter/ErasureHelper.scala b/src/main/scala-2/com/github/swagger/scala/converter/ErasureHelper.scala index f8e2c62d..2e76f9dd 100644 --- a/src/main/scala-2/com/github/swagger/scala/converter/ErasureHelper.scala +++ b/src/main/scala-2/com/github/swagger/scala/converter/ErasureHelper.scala @@ -1,32 +1,23 @@ package com.github.swagger.scala.converter -import io.swagger.v3.core.util.PrimitiveType -import io.swagger.v3.oas.models.media.Schema - object ErasureHelper { - def erasedOptionalPrimitives(cls: Class[_]): Map[String, Schema[_]] = { + def erasedOptionalPrimitives(cls: Class[_]): Map[String, Class[_]] = { import scala.reflect.runtime.universe val mirror = universe.runtimeMirror(cls.getClassLoader) val sym = mirror.staticClass(cls.getName) val properties = sym.selfType.members .filterNot(_.isMethod) .filterNot(_.isClass) - .map(prop => prop.name.toString.trim -> prop.typeSignature).toMap - properties.mapValues { typeSignature => - if (typeSignature.typeSymbol.isClass && mirror.runtimeClass(typeSignature.typeSymbol.asClass) != classOf[scala.Option[_]]) { - None - } else { - val typeArg = typeSignature.typeArgs.headOption - typeArg.flatMap { signature => - if (signature.typeSymbol.isClass) { - val runtimeClass = mirror.runtimeClass(signature.typeSymbol.asClass) - Option(PrimitiveType.fromType(runtimeClass)).map(_.createProperty()) - } else None - } + properties.flatMap { prop => + val maybeClass: Option[Class[_]] = prop.typeSignature.typeArgs.headOption.flatMap { signature => + if (signature.typeSymbol.isClass) { + Option(mirror.runtimeClass(signature.typeSymbol.asClass)) + } else None } - }.collect { case (k, Some(v)) => k -> v }.toMap + maybeClass.map(prop.name.toString.trim -> _) + }.toMap } } diff --git a/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala b/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala index 870eff83..e3b454c0 100644 --- a/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala +++ b/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala @@ -79,12 +79,14 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte Option(chain.next().resolve(`type`, context, chain)).map { schema => val introspector = BeanIntrospector(cls) introspector.properties.foreach { property => - erasedProperties.get(property.name).foreach { schemaOverride => - overrideImplementationOfAnnotationSchema(schema, schemaOverride, property.name) + val propertyClass = getPropertyClass(property) + erasedProperties.get(property.name).foreach { erasedType => + if (isOption(propertyClass)) { + overrideImplementationOfAnnotationSchema(schema, erasedType, property.name) + } } getPropertyAnnotations(property) match { case Seq() => { - val propertyClass = getPropertyClass(property) val optionalFlag = isOption(propertyClass) if (optionalFlag && schema.getRequired != null && schema.getRequired.contains(property.name)) { schema.getRequired.remove(property.name) @@ -94,7 +96,7 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte } case annotations => { val required = getRequiredSettings(annotations).headOption - .getOrElse(!isOption(getPropertyClass(property))) + .getOrElse(!isOption(propertyClass)) if (required) addRequiredItem(schema, property.name) } } @@ -106,11 +108,16 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte } } - private def overrideImplementationOfAnnotationSchema(schema: Schema[_], schemaOverride: Schema[_], propertyName: String) = { - val prop = schema.getProperties.get(propertyName) - val propAsString = objectMapper.writeValueAsString(prop) - val parsedSchema = objectMapper.readValue(propAsString, schemaOverride.getClass) - schema.addProperty(propertyName, parsedSchema) + private def overrideImplementationOfAnnotationSchema(schema: Schema[_], erasedType: Class[_], propertyName: String) = { + val primitiveType = PrimitiveType.fromType(erasedType) + if (primitiveType == null) { + schema + } else { + val prop = schema.getProperties.get(propertyName) + val propAsString = objectMapper.writeValueAsString(prop) + val parsedSchema = objectMapper.readValue(propAsString, primitiveType.createProperty().getClass) + schema.addProperty(propertyName, parsedSchema) + } } private def getRequiredSettings(annotatedType: AnnotatedType): Seq[Boolean] = annotatedType match { diff --git a/src/test/scala/com/github/swagger/scala/converter/ModelPropertyParserTest.scala b/src/test/scala/com/github/swagger/scala/converter/ModelPropertyParserTest.scala index 11ad02de..acc1ed67 100644 --- a/src/test/scala/com/github/swagger/scala/converter/ModelPropertyParserTest.scala +++ b/src/test/scala/com/github/swagger/scala/converter/ModelPropertyParserTest.scala @@ -338,6 +338,23 @@ class ModelPropertyParserTest extends AnyFlatSpec with Matchers with OptionValue nullSafeList(arraySchema.getRequired()) shouldBe empty } + it should "process Model with Scala Seq Int" in { + val converter = ModelConverters.getInstance() + val schemas = converter.readAll(classOf[ModelWSeqInt]).asScala.toMap + val model = findModel(schemas, "ModelWSeqInt") + model should be(defined) + model.value.getProperties should not be (null) + + val stringsField = model.value.getProperties.get("ints") + + stringsField shouldBe a[ArraySchema] + val arraySchema = stringsField.asInstanceOf[ArraySchema] + arraySchema.getUniqueItems() shouldBe (null) + arraySchema.getItems shouldBe a[ObjectSchema] + nullSafeMap(arraySchema.getProperties()) shouldBe empty + nullSafeList(arraySchema.getRequired()) shouldBe empty + } + it should "process Model with Scala Set" in { val converter = ModelConverters.getInstance() val schemas = converter.readAll(classOf[ModelWSetString]).asScala.toMap diff --git a/src/test/scala/models/ModelWSeqInt.scala b/src/test/scala/models/ModelWSeqInt.scala new file mode 100644 index 00000000..44396cb3 --- /dev/null +++ b/src/test/scala/models/ModelWSeqInt.scala @@ -0,0 +1,3 @@ +package models + +case class ModelWSeqInt(ints: Seq[Int]) \ No newline at end of file From c5137b516b260eafc9c1fc633ca85e7f67f4307f Mon Sep 17 00:00:00 2001 From: SamTheisens <1911436+SamTheisens@users.noreply.github.com> Date: Sat, 6 Aug 2022 15:08:36 +0800 Subject: [PATCH 5/5] Add support for collections --- .../SwaggerScalaModelConverter.scala | 50 +++++++++++++------ .../converter/ModelPropertyParserTest.scala | 22 +++++++- .../scala/converter/ScalaModelTest.scala | 4 +- src/test/scala/models/ModelWSeqInt.scala | 6 ++- 4 files changed, 62 insertions(+), 20 deletions(-) diff --git a/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala b/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala index e3b454c0..987a0621 100644 --- a/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala +++ b/src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala @@ -32,6 +32,7 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte private val EnumClass = classOf[scala.Enumeration] private val OptionClass = classOf[scala.Option[_]] private val IterableClass = classOf[scala.collection.Iterable[_]] + private val MapClass = classOf[Map[_, _]] private val SetClass = classOf[scala.collection.Set[_]] private val BigDecimalClass = classOf[BigDecimal] private val BigIntClass = classOf[BigInt] @@ -79,24 +80,30 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte Option(chain.next().resolve(`type`, context, chain)).map { schema => val introspector = BeanIntrospector(cls) introspector.properties.foreach { property => + val propertyClass = getPropertyClass(property) + val isOptional = isOption(propertyClass) + erasedProperties.get(property.name).foreach { erasedType => - if (isOption(propertyClass)) { - overrideImplementationOfAnnotationSchema(schema, erasedType, property.name) + val primitiveType = PrimitiveType.fromType(erasedType) + if (primitiveType != null && isOptional) { + updateTypeOnSchema(schema, primitiveType, property.name) + } + if (primitiveType != null && isIterable(propertyClass) && !isMap(propertyClass)) { + updateTypeOnItemsSchema(schema, primitiveType, property.name) } } getPropertyAnnotations(property) match { case Seq() => { - val optionalFlag = isOption(propertyClass) - if (optionalFlag && schema.getRequired != null && schema.getRequired.contains(property.name)) { + if (isOptional && schema.getRequired != null && schema.getRequired.contains(property.name)) { schema.getRequired.remove(property.name) - } else if (!optionalFlag) { + } else if (!isOptional) { addRequiredItem(schema, property.name) } } case annotations => { val required = getRequiredSettings(annotations).headOption - .getOrElse(!isOption(propertyClass)) + .getOrElse(!isOptional) if (required) addRequiredItem(schema, property.name) } } @@ -108,16 +115,26 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte } } - private def overrideImplementationOfAnnotationSchema(schema: Schema[_], erasedType: Class[_], propertyName: String) = { - val primitiveType = PrimitiveType.fromType(erasedType) - if (primitiveType == null) { - schema - } else { - val prop = schema.getProperties.get(propertyName) - val propAsString = objectMapper.writeValueAsString(prop) - val parsedSchema = objectMapper.readValue(propAsString, primitiveType.createProperty().getClass) - schema.addProperty(propertyName, parsedSchema) - } + private def updateTypeOnSchema(schema: Schema[_], primitiveType: PrimitiveType, propertyName: String) = { + val property = schema.getProperties.get(propertyName) + val updatedSchema = correctSchema(property, primitiveType) + schema.addProperty(propertyName, updatedSchema) + } + + private def updateTypeOnItemsSchema(schema: Schema[_], primitiveType: PrimitiveType, propertyName: String) = { + val property = schema.getProperties.get(propertyName) + val updatedSchema = correctSchema(property.getItems, primitiveType) + property.setItems(updatedSchema) + schema.addProperty(propertyName, property) + } + + private def correctSchema(itemSchema: Schema[_], primitiveType: PrimitiveType) = { + val primitiveProperty = primitiveType.createProperty() + val propAsString = objectMapper.writeValueAsString(itemSchema) + val correctedSchema = objectMapper.readValue(propAsString, primitiveProperty.getClass) + correctedSchema.setType(primitiveProperty.getType) + correctedSchema.setFormat(primitiveProperty.getFormat) + correctedSchema } private def getRequiredSettings(annotatedType: AnnotatedType): Seq[Boolean] = annotatedType match { @@ -296,6 +313,7 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte private def isOption(cls: Class[_]): Boolean = cls == OptionClass private def isIterable(cls: Class[_]): Boolean = IterableClass.isAssignableFrom(cls) + private def isMap(cls: Class[_]): Boolean = MapClass.isAssignableFrom(cls) private def isCaseClass(cls: Class[_]): Boolean = ProductClass.isAssignableFrom(cls) private def nullSafeList[T](array: Array[T]): List[T] = Option(array) match { diff --git a/src/test/scala/com/github/swagger/scala/converter/ModelPropertyParserTest.scala b/src/test/scala/com/github/swagger/scala/converter/ModelPropertyParserTest.scala index acc1ed67..1fdf5d6e 100644 --- a/src/test/scala/com/github/swagger/scala/converter/ModelPropertyParserTest.scala +++ b/src/test/scala/com/github/swagger/scala/converter/ModelPropertyParserTest.scala @@ -350,7 +350,27 @@ class ModelPropertyParserTest extends AnyFlatSpec with Matchers with OptionValue stringsField shouldBe a[ArraySchema] val arraySchema = stringsField.asInstanceOf[ArraySchema] arraySchema.getUniqueItems() shouldBe (null) - arraySchema.getItems shouldBe a[ObjectSchema] + arraySchema.getItems shouldBe a[IntegerSchema] + nullSafeMap(arraySchema.getProperties()) shouldBe empty + nullSafeList(arraySchema.getRequired()) shouldBe empty + } + + it should "process Model with Scala Seq Int (annotated)" in { + val converter = ModelConverters.getInstance() + val schemas = converter.readAll(classOf[ModelWSeqIntAnnotated]).asScala.toMap + val model = findModel(schemas, "ModelWSeqIntAnnotated") + model should be(defined) + model.value.getProperties should not be (null) + + val stringsField = model.value.getProperties.get("ints") + + stringsField shouldBe a[ArraySchema] + val arraySchema = stringsField.asInstanceOf[ArraySchema] + arraySchema.getUniqueItems() shouldBe (null) + + + arraySchema.getItems shouldBe a[IntegerSchema] + arraySchema.getItems.getDescription shouldBe "These are ints" nullSafeMap(arraySchema.getProperties()) shouldBe empty nullSafeList(arraySchema.getRequired()) shouldBe empty } diff --git a/src/test/scala/com/github/swagger/scala/converter/ScalaModelTest.scala b/src/test/scala/com/github/swagger/scala/converter/ScalaModelTest.scala index b00fc919..08ab791d 100644 --- a/src/test/scala/com/github/swagger/scala/converter/ScalaModelTest.scala +++ b/src/test/scala/com/github/swagger/scala/converter/ScalaModelTest.scala @@ -85,7 +85,7 @@ class ScalaModelTest extends AnyFlatSpec with Matchers { val model = schemas("ModelWithIntVector") val prop = model.getProperties().get("ints") prop shouldBe a [ArraySchema] - prop.asInstanceOf[ArraySchema].getItems.getType should be ("object") + prop.asInstanceOf[ArraySchema].getItems.getType should be ("integer") } it should "read a model with vector of booleans" in { @@ -93,7 +93,7 @@ class ScalaModelTest extends AnyFlatSpec with Matchers { val model = schemas("ModelWithBooleanVector") val prop = model.getProperties().get("bools") prop shouldBe a [ArraySchema] - prop.asInstanceOf[ArraySchema].getItems.getType should be ("object") + prop.asInstanceOf[ArraySchema].getItems.getType should be ("boolean") } } diff --git a/src/test/scala/models/ModelWSeqInt.scala b/src/test/scala/models/ModelWSeqInt.scala index 44396cb3..68ed1a70 100644 --- a/src/test/scala/models/ModelWSeqInt.scala +++ b/src/test/scala/models/ModelWSeqInt.scala @@ -1,3 +1,7 @@ package models -case class ModelWSeqInt(ints: Seq[Int]) \ No newline at end of file +import io.swagger.v3.oas.annotations.media.{ArraySchema, Schema} + +case class ModelWSeqInt(ints: Seq[Int]) + +case class ModelWSeqIntAnnotated(@ArraySchema(arraySchema = new Schema(required = false), schema = new Schema(description = "These are ints")) ints: Seq[Int])