Skip to content

Commit

Permalink
Codegen: Generate attributes on endpoints for any specification exten…
Browse files Browse the repository at this point in the history
…sions defined on path or operation objects (#3599)
  • Loading branch information
hughsimpson committed Mar 27, 2024
1 parent a026beb commit a2eca04
Show file tree
Hide file tree
Showing 9 changed files with 286 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package sttp.tapir.codegen
import sttp.tapir.codegen.openapi.models.OpenapiModels.OpenapiDocument
import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
OpenapiSchemaAny,
OpenapiSchemaBoolean,
OpenapiSchemaBinary,
OpenapiSchemaBoolean,
OpenapiSchemaDateTime,
OpenapiSchemaDouble,
OpenapiSchemaFloat,
Expand All @@ -15,6 +15,7 @@ import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
OpenapiSchemaString,
OpenapiSchemaUUID
}
import sttp.tapir.codegen.openapi.models.SpecificationExtensionRenderer

object JsonSerdeLib extends Enumeration {
val Circe, Jsoniter = Value
Expand Down Expand Up @@ -61,6 +62,24 @@ object BasicGenerator {
|}""".stripMargin
headTag -> taggedObj
}

val maybeSpecificationExtensionKeys = doc.paths
.flatMap { p =>
p.specificationExtensions.toSeq ++ p.methods.flatMap(_.specificationExtensions.toSeq)
}
.groupBy(_._1)
.map { case (keyName, pairs) =>
val values = pairs.map(_._2)
val `type` = SpecificationExtensionRenderer.renderCombinedType(values)
val name = strippedToCamelCase(keyName)
val uncapitalisedName = name.head.toLower + name.tail
val capitalisedName = name.head.toUpper + name.tail
s"""type ${capitalisedName}Extension = ${`type`}
|val ${uncapitalisedName}ExtensionKey = new sttp.tapir.AttributeKey[${capitalisedName}Extension]("$packagePath.$objName.${capitalisedName}Extension")
|""".stripMargin
}
.mkString("\n")

val mainObj = s"""|
|package $packagePath
|
Expand All @@ -70,8 +89,9 @@ object BasicGenerator {
|
|${indent(2)(classGenerator.classDefs(doc, targetScala3, queryParamRefs, normalisedJsonLib, jsonParamRefs).getOrElse(""))}
|
|${indent(2)(endpointsByTag.getOrElse(None, ""))}
|${indent(2)(maybeSpecificationExtensionKeys)}
|
|${indent(2)(endpointsByTag.getOrElse(None, ""))}
|}
|""".stripMargin
taggedObjs + (objName -> mainObj)
Expand Down Expand Up @@ -127,4 +147,11 @@ object BasicGenerator {
case x => throw new NotImplementedError(s"Not all simple types supported! Found $x")
}
}

def strippedToCamelCase(string: String): String = string
.split("[^0-9a-zA-Z$_]")
.filter(_.nonEmpty)
.zipWithIndex
.map { case (part, 0) => part; case (part, _) => part.capitalize }
.mkString
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package sttp.tapir.codegen

import io.circe.Json
import sttp.tapir.codegen.BasicGenerator.{indent, mapSchemaSimpleTypeToType}
import sttp.tapir.codegen.openapi.models.OpenapiModels.OpenapiDocument
import sttp.tapir.codegen.openapi.models.{OpenapiSchemaType, Renderer}
import sttp.tapir.codegen.openapi.models.{OpenapiSchemaType, DefaultValueRenderer}
import sttp.tapir.codegen.openapi.models.OpenapiSchemaType._

import scala.annotation.tailrec
Expand Down Expand Up @@ -267,7 +266,8 @@ class ClassDefinitionGenerator {
val tpe = mapSchemaTypeToType(name, key, obj.required.contains(key), schemaType, isJson)
val fixedKey = fixKey(key)
val optional = schemaType.nullable || !obj.required.contains(key)
val maybeExplicitDefault = maybeDefault.map(" = " + Renderer.render(allModels = allSchemas, thisType = schemaType, optional)(_))
val maybeExplicitDefault =
maybeDefault.map(" = " + DefaultValueRenderer.render(allModels = allSchemas, thisType = schemaType, optional)(_))
val default = maybeExplicitDefault getOrElse (if (optional) " = None" else "")
s"$fixedKey: $tpe$default"
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package sttp.tapir.codegen
import sttp.tapir.codegen.BasicGenerator.{indent, mapSchemaSimpleTypeToType}
import io.circe.Json
import sttp.tapir.codegen.BasicGenerator.{indent, mapSchemaSimpleTypeToType, strippedToCamelCase}
import sttp.tapir.codegen.openapi.models.OpenapiModels.{OpenapiDocument, OpenapiParameter, OpenapiPath, OpenapiRequestBody, OpenapiResponse}
import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
OpenapiSchemaAny,
OpenapiSchemaArray,
OpenapiSchemaBinary,
OpenapiSchemaRef,
OpenapiSchemaAny,
OpenapiSchemaSimpleType
}
import sttp.tapir.codegen.openapi.models.{OpenapiComponent, OpenapiSchemaType, OpenapiSecuritySchemeType}
import sttp.tapir.codegen.openapi.models.{OpenapiComponent, OpenapiSchemaType, OpenapiSecuritySchemeType, SpecificationExtensionRenderer}
import sttp.tapir.codegen.util.JavaEscape

case class Location(path: String, method: String) {
Expand Down Expand Up @@ -68,6 +69,18 @@ class EndpointGenerator {
.map(_.withResolvedParentParameters(parameters, p.parameters))
.map { m =>
implicit val location: Location = Location(p.url, m.methodType)

val attributeString = {
val pathAttributes = attributes(p.specificationExtensions)
val operationAttributes = attributes(m.specificationExtensions)
(pathAttributes, operationAttributes) match {
case (None, None) => ""
case (Some(atts), None) => indent(2)(atts)
case (None, Some(atts)) => indent(2)(atts)
case (Some(pathAtts), Some(operationAtts)) => indent(2)(pathAtts + "\n" + operationAtts)
}
}

val definition =
s"""|endpoint
| .${m.methodType}
Expand All @@ -76,15 +89,10 @@ class EndpointGenerator {
|${indent(2)(ins(m.resolvedParameters, m.requestBody))}
|${indent(2)(outs(m.responses))}
|${indent(2)(tags(m.tags))}
|$attributeString
|""".stripMargin.linesIterator.filterNot(_.trim.isEmpty).mkString("\n")

val name = m.operationId
.getOrElse(m.methodType + p.url.capitalize)
.split("[^0-9a-zA-Z$_]")
.filter(_.nonEmpty)
.zipWithIndex
.map { case (part, 0) => part; case (part, _) => part.capitalize }
.mkString
val name = strippedToCamelCase(m.operationId.getOrElse(m.methodType + p.url.capitalize))
val maybeTargetFileName = if (useHeadTagForObjectNames) m.tags.flatMap(_.headOption) else None
val queryParamRefs = m.resolvedParameters
.collect { case queryParam: OpenapiParameter if queryParam.in == "query" => queryParam.schema }
Expand Down Expand Up @@ -196,6 +204,17 @@ class EndpointGenerator {
openapiTags.map(_.distinct.mkString(".tags(List(\"", "\", \"", "\"))")).mkString
}

private def attributes(atts: Map[String, Json]): Option[String] = if (atts.nonEmpty) Some {
atts
.map { case (k, v) =>
val camelCaseK = strippedToCamelCase(k)
val uncapitalisedName = camelCaseK.head.toLower + camelCaseK.tail
s""".attribute[${camelCaseK.capitalize}Extension](${uncapitalisedName}ExtensionKey, ${SpecificationExtensionRenderer.renderValue(v)})"""
}
.mkString("\n")
}
else None

// treats redirects as ok
private val okStatus = """([23]\d\d)""".r
private val errorStatus = """([45]\d\d)""".r
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
OpenapiSchemaUUID
}

object Renderer {
object DefaultValueRenderer {
private def lookup(allModels: Map[String, OpenapiSchemaType], ref: OpenapiSchemaRef): OpenapiSchemaType = allModels(
ref.name.stripPrefix("#/components/schemas/")
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package sttp.tapir.codegen.openapi.models

import cats.implicits.toTraverseOps
import cats.syntax.either._

import OpenapiSchemaType.OpenapiSchemaRef
import io.circe.Json
// https://swagger.io/specification/
object OpenapiModels {

Expand Down Expand Up @@ -35,7 +35,8 @@ object OpenapiModels {
case class OpenapiPath(
url: String,
methods: Seq[OpenapiPathMethod],
parameters: Seq[Resolvable[OpenapiParameter]] = Nil
parameters: Seq[Resolvable[OpenapiParameter]] = Nil,
specificationExtensions: Map[String, Json] = Map.empty
)

case class OpenapiPathMethod(
Expand All @@ -46,7 +47,8 @@ object OpenapiModels {
security: Seq[Seq[String]] = Nil,
summary: Option[String] = None,
tags: Option[Seq[String]] = None,
operationId: Option[String] = None
operationId: Option[String] = None,
specificationExtensions: Map[String, Json] = Map.empty
) {
def resolvedParameters: Seq[OpenapiParameter] = parameters.collect { case Resolved(t) => t }
def withResolvedParentParameters(
Expand Down Expand Up @@ -176,6 +178,10 @@ object OpenapiModels {
summary <- c.get[Option[String]]("summary")
tags <- c.get[Option[Seq[String]]]("tags")
operationId <- c.get[Option[String]]("operationId")
specificationExtensionKeys = c.keys.toSeq.flatMap(_.filter(_.startsWith("x-")))
specificationExtensions = specificationExtensionKeys
.flatMap(key => c.downField(key).as[Option[Json]].toOption.flatten.map(key.stripPrefix("x-") -> _))
.toMap
} yield {
OpenapiPathMethod(
"--partial--",
Expand All @@ -185,7 +191,8 @@ object OpenapiModels {
security.map(_.keys.toSeq),
summary,
tags,
operationId
operationId,
specificationExtensions
)
}
}
Expand All @@ -198,7 +205,11 @@ object OpenapiModels {
.map(_.getOrElse(Nil))
methods <- List("get", "put", "post", "delete", "options", "head", "patch", "connect", "trace")
.traverse(method => c.downField(method).as[Option[OpenapiPathMethod]].map(_.map(_.copy(methodType = method))))
} yield OpenapiPath("--partial--", methods.flatten, parameters)
specificationExtensionKeys = c.keys.toSeq.flatMap(_.filter(_.startsWith("x-")))
specificationExtensions = specificationExtensionKeys
.flatMap(key => c.downField(key).as[Option[Json]].toOption.flatten.map(key.stripPrefix("x-") -> _))
.toMap
} yield OpenapiPath("--partial--", methods.flatten, parameters, specificationExtensions)
}

implicit val OpenapiPathsDecoder: Decoder[Seq[OpenapiPath]] = { (c: HCursor) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package sttp.tapir.codegen.openapi.models

import io.circe.Json

object SpecificationExtensionRenderer {

def renderCombinedType(jsons: Seq[Json]): String = {
// permit nulls for any type, but specify type as null if every value is null
val nonNull = jsons.filterNot(_.isNull)
if (jsons.isEmpty) "Nothing"
else if (nonNull.isEmpty) "Null"
else {
val groupedByBaseType = nonNull.groupBy(j =>
if (j.isBoolean) "Boolean"
else if (j.isNumber) "Number"
else if (j.isString) "String"
else if (j.isArray) "Array"
else if (j.isObject) "Object"
else throw new IllegalStateException("json must be one of boolean, number, string, array or object")
)
// Cannot resolve types if totally different...
if (groupedByBaseType.size > 1) "Any"
else
groupedByBaseType.head match {
case (t @ ("Boolean" | "String"), _) => t
case ("Number", vs) => if (vs.forall(_.asNumber.flatMap(_.toLong).isDefined)) "Long" else "Double"
case ("Array", vs) =>
val t = renderCombinedType(vs.flatMap(_.asArray).flatten)
s"Seq[$t]"
case ("Object", kvs) =>
val t = renderCombinedType(kvs.flatMap(_.asObject).flatMap(_.toMap.values))
s"Map[String, $t]"
case (x, _) => throw new IllegalStateException(s"No such group $x")
}
}
}

def renderValue(json: Json): String = json.fold(
"null",
bool => bool.toString,
n => n.toLong.map(l => s"${l}L") getOrElse s"${n.toDouble}d", // the long repr is fine even if type expanded to Double
s => '"' +: s :+ '"',
arr => if (arr.isEmpty) "Vector.empty" else s"Vector(${arr.map(renderValue).mkString(", ")})",
obj =>
if (obj.isEmpty) "Map.empty[String, Nothing]"
else s"Map(${obj.toMap.map { case (k, v) => s""""$k" -> ${renderValue(v)}""" }.mkString(", ")})"
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -250,4 +250,39 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
generatedCode shouldCompile ()
}

it should "generate attributes for specification extensions on path and operation objects" in {
val doc = TestHelpers.specificationExtensionDocs
val generatedCode = BasicGenerator.generateObjects(
doc,
"sttp.tapir.generated",
"TapirGeneratedEndpoints",
targetScala3 = false,
useHeadTagForObjectNames = false,
jsonSerdeLib = "circe"
)("TapirGeneratedEndpoints")
generatedCode shouldCompile ()
val expectedAttrDecls = Seq(
""".attribute[CustomStringExtensionOnPathExtension](customStringExtensionOnPathExtensionKey, "another string")""",
""".attribute[CustomStringExtensionOnOperationExtension](customStringExtensionOnOperationExtensionKey, "bazquux")""",
""".attribute[CustomListExtensionOnOperationExtension](customListExtensionOnOperationExtensionKey, Vector("baz", "quux"))""",
""".attribute[CustomMapExtensionOnPathExtension](customMapExtensionOnPathExtensionKey, Map("bazkey" -> "bazval", "quuxkey" -> Vector("quux1", "quux2"))""",
""".attribute[CustomStringExtensionOnPathDoubleTypeExtension](customStringExtensionOnPathDoubleTypeExtensionKey, 123L)"""
)
expectedAttrDecls foreach (decl => generatedCode should include(decl))
generatedCode should include(
"""val customMapExtensionOnOperationExtensionKey = new sttp.tapir.AttributeKey[CustomMapExtensionOnOperationExtension]("sttp.tapir.generated.TapirGeneratedEndpoints.CustomMapExtensionOnOperationExtension")""".stripMargin
)
val expectedKeyDeclarations = Seq(
"""type CustomMapExtensionOnOperationExtension = Map[String, Any]""",
"""type CustomListExtensionOnPathAnyTypeExtension = Seq[Any]""",
"""type CustomMapExtensionOnPathSingleValueTypeExtension = Map[String, String]""",
"""type CustomListExtensionOnOperationExtension = Seq[String]""",
"""type CustomStringExtensionOnPathAnyTypeExtension = Any""",
"""type CustomStringExtensionOnPathDoubleTypeExtension = Double""",
"""type CustomListExtensionOnPathExtension = Seq[String]""",
"""type CustomStringExtensionOnPathExtension = String"""
)
expectedKeyDeclarations foreach (decl => generatedCode should include(decl))
}

}

0 comments on commit a2eca04

Please sign in to comment.