Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,13 @@ jobs:
run: sbt '++${{ matrix.scala }}' coverage test coverageReport

- name: Scala build
if: '!startsWith(matrix.scala, ''2.13'')'
if: '!startsWith(matrix.scala, ''2.13'') && !startsWith(matrix.scala, ''3.0'')'
run: sbt '++${{ matrix.scala }}' test

- name: Scala compile
if: startsWith(matrix.scala, '3.0')
run: sbt '++${{ matrix.scala }}' compile

- name: Publish to Codecov.io
if: startsWith(matrix.scala, '2.13')
uses: codecov/codecov-action@v2
Expand Down
17 changes: 12 additions & 5 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ libraryDependencies ++= Seq(
"org.scalatest" %% "scalatest" % "3.2.11" % Test,
"org.slf4j" % "slf4j-simple" % "1.7.36" % Test
)
libraryDependencies ++= {
CrossVersion.partialVersion(Keys.scalaVersion.value) match {
case Some((3, _)) => Seq()
case _ => Seq("org.scala-lang" % "scala-reflect" % scalaVersion.value)
}
}

homepage := Some(new URL("https://github.com/swagger-akka-http/swagger-scala-module"))

Expand All @@ -63,10 +69,10 @@ licenses := Seq(("Apache License 2.0", new URL("http://www.apache.org/licenses/L

pomExtra := {
pomExtra.value ++ Group(
<issueManagement>
<system>github</system>
<url>https://github.com/swagger-api/swagger-scala-module/issues</url>
</issueManagement>
<issueManagement>
<system>github</system>
<url>https://github.com/swagger-api/swagger-scala-module/issues</url>
</issueManagement>
<developers>
<developer>
<id>fehguy</id>
Expand All @@ -84,7 +90,8 @@ pomExtra := {

ThisBuild / githubWorkflowBuild := Seq(
WorkflowStep.Sbt(List("coverage", "test", "coverageReport"), name = Some("Scala 2.13 build"), cond = Some("startsWith(matrix.scala, '2.13')")),
WorkflowStep.Sbt(List("test"), name = Some("Scala build"), cond = Some("!startsWith(matrix.scala, '2.13')")),
WorkflowStep.Sbt(List("test"), name = Some("Scala build"), cond = Some("!startsWith(matrix.scala, '2.13') && !startsWith(matrix.scala, '3.0')")),
WorkflowStep.Sbt(List("compile"), name = Some("Scala compile"), cond = Some("startsWith(matrix.scala, '3.0')")),
)

ThisBuild / githubWorkflowJavaVersions := Seq(JavaSpec(Zulu, "8"))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.github.swagger.scala.converter

object ErasureHelper {

def erasedOptionalPrimitives(cls: Class[_]): Map[String, Class[_]] = {
import scala.reflect.runtime.universe
val mirror = universe.runtimeMirror(cls.getClassLoader)
val sym = mirror.staticClass(cls.getName)
val properties = sym.selfType.members
.filterNot(_.isMethod)
.filterNot(_.isClass)

properties.flatMap { prop =>
val maybeClass: Option[Class[_]] = prop.typeSignature.typeArgs.headOption.flatMap { signature =>
if (signature.typeSymbol.isClass) {
Option(mirror.runtimeClass(signature.typeSymbol.asClass))
} else None
}
maybeClass.map(prop.name.toString.trim -> _)
}.toMap
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.github.swagger.scala.converter

import io.swagger.v3.oas.models.media.Schema


object ErasureHelper {

def erasedOptionalPrimitives(cls: Class[_]): Map[String, Class[_]] = Map.empty[String, Class[_]]

}

Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import io.swagger.v3.oas.annotations.media.{ArraySchema, Schema => SchemaAnnotat
import io.swagger.v3.oas.models.media.Schema
import org.slf4j.LoggerFactory

import java.util
import scala.util.Try
import scala.util.control.NonFatal

Expand All @@ -31,6 +32,7 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte
private val EnumClass = classOf[scala.Enumeration]
private val OptionClass = classOf[scala.Option[_]]
private val IterableClass = classOf[scala.collection.Iterable[_]]
private val MapClass = classOf[Map[_, _]]
private val SetClass = classOf[scala.collection.Set[_]]
private val BigDecimalClass = classOf[BigDecimal]
private val BigIntClass = classOf[BigInt]
Expand Down Expand Up @@ -71,24 +73,37 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte
}

private def caseClassSchema(cls: Class[_], `type`: AnnotatedType, context: ModelConverterContext,
chain: Iterator[ModelConverter]): Option[Schema[_]] = {
chain: util.Iterator[ModelConverter]): Option[Schema[_]] = {
val erasedProperties = ErasureHelper.erasedOptionalPrimitives(cls)

if (chain.hasNext) {
Option(chain.next().resolve(`type`, context, chain)).map { schema =>
val introspector = BeanIntrospector(cls)
introspector.properties.foreach { property =>

val propertyClass = getPropertyClass(property)
val isOptional = isOption(propertyClass)

erasedProperties.get(property.name).foreach { erasedType =>
val primitiveType = PrimitiveType.fromType(erasedType)
if (primitiveType != null && isOptional) {
updateTypeOnSchema(schema, primitiveType, property.name)
}
if (primitiveType != null && isIterable(propertyClass) && !isMap(propertyClass)) {
updateTypeOnItemsSchema(schema, primitiveType, property.name)
}
}
getPropertyAnnotations(property) match {
case Seq() => {
val propertyClass = getPropertyClass(property)
val optionalFlag = isOption(propertyClass)
if (optionalFlag && schema.getRequired != null && schema.getRequired.contains(property.name)) {
if (isOptional && schema.getRequired != null && schema.getRequired.contains(property.name)) {
schema.getRequired.remove(property.name)
} else if (!optionalFlag) {
} else if (!isOptional) {
addRequiredItem(schema, property.name)
}
}
case annotations => {
val required = getRequiredSettings(annotations).headOption
.getOrElse(!isOption(getPropertyClass(property)))
.getOrElse(!isOptional)
if (required) addRequiredItem(schema, property.name)
}
}
Expand All @@ -100,6 +115,28 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte
}
}

private def updateTypeOnSchema(schema: Schema[_], primitiveType: PrimitiveType, propertyName: String) = {
val property = schema.getProperties.get(propertyName)
val updatedSchema = correctSchema(property, primitiveType)
schema.addProperty(propertyName, updatedSchema)
}

private def updateTypeOnItemsSchema(schema: Schema[_], primitiveType: PrimitiveType, propertyName: String) = {
val property = schema.getProperties.get(propertyName)
val updatedSchema = correctSchema(property.getItems, primitiveType)
property.setItems(updatedSchema)
schema.addProperty(propertyName, property)
}

private def correctSchema(itemSchema: Schema[_], primitiveType: PrimitiveType) = {
val primitiveProperty = primitiveType.createProperty()
val propAsString = objectMapper.writeValueAsString(itemSchema)
val correctedSchema = objectMapper.readValue(propAsString, primitiveProperty.getClass)
correctedSchema.setType(primitiveProperty.getType)
correctedSchema.setFormat(primitiveProperty.getFormat)
correctedSchema
}

private def getRequiredSettings(annotatedType: AnnotatedType): Seq[Boolean] = annotatedType match {
case _: AnnotatedTypeForOption => Seq.empty
case _ => getRequiredSettings(nullSafeList(annotatedType.getCtxAnnotations))
Expand Down Expand Up @@ -276,6 +313,7 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte

private def isOption(cls: Class[_]): Boolean = cls == OptionClass
private def isIterable(cls: Class[_]): Boolean = IterableClass.isAssignableFrom(cls)
private def isMap(cls: Class[_]): Boolean = MapClass.isAssignableFrom(cls)
private def isCaseClass(cls: Class[_]): Boolean = ProductClass.isAssignableFrom(cls)

private def nullSafeList[T](array: Array[T]): List[T] = Option(array) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,10 @@ class ModelPropertyParserTest extends AnyFlatSpec with Matchers with OptionValue
val model = schemas.get("ModelWOptionInt")
model should be (defined)
model.value.getProperties should not be (null)
val optInt = model.value.getProperties().get("optInt")
val optInt = model.value.getProperties.get("optInt")
optInt should not be (null)
optInt shouldBe a [Schema[_]]
optInt shouldBe a [IntegerSchema]
optInt.asInstanceOf[IntegerSchema].getFormat shouldEqual "int32"
nullSafeList(model.value.getRequired) shouldBe empty
}

Expand All @@ -123,9 +124,22 @@ class ModelPropertyParserTest extends AnyFlatSpec with Matchers with OptionValue
optInt should not be (null)
optInt shouldBe a [IntegerSchema]
optInt.asInstanceOf[IntegerSchema].getFormat shouldEqual "int32"
optInt.getDescription shouldBe "This is an optional int"
nullSafeList(model.value.getRequired) shouldBe empty
}

it should "allow annotation to override required with Scala Option Int" in {
val converter = ModelConverters.getInstance()
val schemas = converter.readAll(classOf[ModelWOptionIntSchemaOverrideForRequired]).asScala.toMap
val model = schemas.get("ModelWOptionIntSchemaOverrideForRequired")
model should be(defined)
model.value.getProperties should not be (null)
val optInt = model.value.getProperties().get("optInt")
optInt should not be (null)
optInt shouldBe an [IntegerSchema]
nullSafeList(model.value.getRequired) shouldEqual Seq("optInt")
}

it should "process Model with Scala Option Long" in {
val converter = ModelConverters.getInstance()
val schemas = converter.readAll(classOf[ModelWOptionLong]).asScala.toMap
Expand All @@ -134,7 +148,7 @@ class ModelPropertyParserTest extends AnyFlatSpec with Matchers with OptionValue
model.value.getProperties should not be (null)
val optLong = model.value.getProperties().get("optLong")
optLong should not be (null)
optLong shouldBe a [Schema[_]]
optLong shouldBe a [IntegerSchema]
nullSafeList(model.value.getRequired) shouldBe empty
}

Expand Down Expand Up @@ -324,6 +338,43 @@ class ModelPropertyParserTest extends AnyFlatSpec with Matchers with OptionValue
nullSafeList(arraySchema.getRequired()) shouldBe empty
}

it should "process Model with Scala Seq Int" in {
val converter = ModelConverters.getInstance()
val schemas = converter.readAll(classOf[ModelWSeqInt]).asScala.toMap
val model = findModel(schemas, "ModelWSeqInt")
model should be(defined)
model.value.getProperties should not be (null)

val stringsField = model.value.getProperties.get("ints")

stringsField shouldBe a[ArraySchema]
val arraySchema = stringsField.asInstanceOf[ArraySchema]
arraySchema.getUniqueItems() shouldBe (null)
arraySchema.getItems shouldBe a[IntegerSchema]
nullSafeMap(arraySchema.getProperties()) shouldBe empty
nullSafeList(arraySchema.getRequired()) shouldBe empty
}

it should "process Model with Scala Seq Int (annotated)" in {
val converter = ModelConverters.getInstance()
val schemas = converter.readAll(classOf[ModelWSeqIntAnnotated]).asScala.toMap
val model = findModel(schemas, "ModelWSeqIntAnnotated")
model should be(defined)
model.value.getProperties should not be (null)

val stringsField = model.value.getProperties.get("ints")

stringsField shouldBe a[ArraySchema]
val arraySchema = stringsField.asInstanceOf[ArraySchema]
arraySchema.getUniqueItems() shouldBe (null)


arraySchema.getItems shouldBe a[IntegerSchema]
arraySchema.getItems.getDescription shouldBe "These are ints"
nullSafeMap(arraySchema.getProperties()) shouldBe empty
nullSafeList(arraySchema.getRequired()) shouldBe empty
}

it should "process Model with Scala Set" in {
val converter = ModelConverters.getInstance()
val schemas = converter.readAll(classOf[ModelWSetString]).asScala.toMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class ScalaModelTest extends AnyFlatSpec with Matchers {

val date = userSchema.getProperties().get("date")
date shouldBe a [DateTimeSchema]
//date.getDescription should be ("the birthdate")
// date.getDescription should be ("the birthdate")
}

it should "read a model with vector property" in {
Expand All @@ -85,15 +85,15 @@ class ScalaModelTest extends AnyFlatSpec with Matchers {
val model = schemas("ModelWithIntVector")
val prop = model.getProperties().get("ints")
prop shouldBe a [ArraySchema]
prop.asInstanceOf[ArraySchema].getItems.getType should be ("object")
prop.asInstanceOf[ArraySchema].getItems.getType should be ("integer")
}

it should "read a model with vector of booleans" in {
val schemas = ModelConverters.getInstance().readAll(classOf[ModelWithBooleanVector]).asScala
val model = schemas("ModelWithBooleanVector")
val prop = model.getProperties().get("bools")
prop shouldBe a [ArraySchema]
prop.asInstanceOf[ArraySchema].getItems.getType should be ("object")
prop.asInstanceOf[ArraySchema].getItems.getType should be ("boolean")
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/test/scala/models/ModelWOptionInt.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ import io.swagger.v3.oas.annotations.media.Schema

case class ModelWOptionInt(optInt: Option[Int])

case class ModelWOptionIntSchemaOverride(@Schema(implementation = classOf[Int]) optInt: Option[Int])
case class ModelWOptionIntSchemaOverride(@Schema(description = "This is an optional int") optInt: Option[Int])

case class ModelWOptionIntSchemaOverrideForRequired(@Schema(required = true) optInt: Option[Int])
7 changes: 7 additions & 0 deletions src/test/scala/models/ModelWSeqInt.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package models

import io.swagger.v3.oas.annotations.media.{ArraySchema, Schema}

case class ModelWSeqInt(ints: Seq[Int])

case class ModelWSeqIntAnnotated(@ArraySchema(arraySchema = new Schema(required = false), schema = new Schema(description = "These are ints")) ints: Seq[Int])