Skip to content

Commit

Permalink
Merge branch 'master' into anonymous-objects
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Strnad committed Aug 5, 2019
2 parents 90936b6 + 6b746c1 commit 606ff73
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 42 deletions.
Expand Up @@ -7,6 +7,7 @@ import cats.free.Free
import cats.implicits._
import com.twilio.guardrail.SwaggerUtil.Resolved
import com.twilio.guardrail.extract.VendorExtension.VendorExtensible._
import com.twilio.guardrail.generators.RawParameterType
import com.twilio.guardrail.generators.syntax._
import com.twilio.guardrail.languages.LA
import com.twilio.guardrail.protocol.terms.protocol._
Expand All @@ -31,6 +32,7 @@ case object DataRedacted extends RedactionBehaviour
case class ProtocolParameter[L <: LA](term: L#MethodParameter,
name: String,
dep: Option[L#TermName],
rawType: RawParameterType,
readOnlyKey: Option[String],
emptyToNull: EmptyToNullBehaviour,
dataRedaction: RedactionBehaviour,
Expand Down
Expand Up @@ -113,6 +113,7 @@ object CirceProtocolGenerator {
_ <- Target.log.debug(s"Args: (${clsName}, ${name}, ...)")

argName = if (needCamelSnakeConversion) name.toCamelCase else name
rawType = RawParameterType(Option(property.getType), Option(property.getFormat))

defaultValue = property match {
case _: MapSchema =>
Expand Down Expand Up @@ -170,7 +171,7 @@ object CirceProtocolGenerator {
)(Function.const((tpe, defaultValue)) _)
term = param"${Term.Name(argName)}: ${finalDeclType}".copy(default = finalDefaultValue)
dep = classDep.filterNot(_.value == clsName) // Filter out our own class name
} yield ProtocolParameter[ScalaLanguage](term, name, dep, readOnlyKey, emptyToNull, dataRedaction, finalDefaultValue))
} yield ProtocolParameter[ScalaLanguage](term, name, dep, rawType, readOnlyKey, emptyToNull, dataRedaction, finalDefaultValue))

case RenderDTOClass(clsName, selfParams, parents) =>
val discriminators = parents.flatMap(_.discriminators)
Expand Down
Expand Up @@ -5,8 +5,9 @@ package Java
import _root_.io.swagger.v3.oas.models.media._
import cats.data.NonEmptyList
import cats.implicits._
import cats.instances.map
import cats.~>
import com.github.javaparser.ast.`type`.{ PrimitiveType, Type, UnknownType }
import com.github.javaparser.ast.`type`.{ ClassOrInterfaceType, PrimitiveType, Type, UnknownType }
import com.twilio.guardrail.Discriminator
import com.twilio.guardrail.extract.{ DataRedaction, Default, EmptyValueIsNull }
import com.twilio.guardrail.generators.syntax.Java._
Expand All @@ -20,17 +21,22 @@ import com.github.javaparser.ast.stmt._
import com.github.javaparser.ast.Modifier.{ ABSTRACT, FINAL, PRIVATE, PROTECTED, PUBLIC, STATIC }
import com.github.javaparser.ast.body._
import com.github.javaparser.ast.expr._
import java.math.BigInteger
import java.util
import java.util.Locale
import scala.language.existentials
import scala.util.Try

object JacksonGenerator {
private val BUILDER_TYPE = JavaParser.parseClassOrInterfaceType("Builder")
private val BUILDER_TYPE = JavaParser.parseClassOrInterfaceType("Builder")
private val BIG_INTEGER_FQ_TYPE = JavaParser.parseClassOrInterfaceType("java.math.BigInteger")
private val BIG_DECIMAL_FQ_TYPE = JavaParser.parseClassOrInterfaceType("java.math.BigDecimal")

private case class ParameterTerm(propertyName: String,
parameterName: String,
fieldType: Type,
parameterType: Type,
rawType: RawParameterType,
defaultValue: Option[Expression],
dataRedacted: RedactionBehaviour)

Expand All @@ -43,16 +49,15 @@ object JacksonGenerator {
}

params
.map({
case ProtocolParameter(term, name, _, _, _, dataRedaction, selfDefaultValue) =>
val parameterType = if (term.getType.isOptional) {
term.getType.containedType.unbox
} else {
term.getType.unbox
}
val defaultValue = defaultValueToExpression(selfDefaultValue)
.map({ param =>
val parameterType = if (param.term.getType.isOptional) {
param.term.getType.containedType.unbox
} else {
param.term.getType.unbox
}
val defaultValue = defaultValueToExpression(param.defaultValue)

ParameterTerm(name, term.getNameAsString, term.getType.unbox, parameterType, defaultValue, dataRedaction)
ParameterTerm(param.name, param.term.getNameAsString, param.term.getType.unbox, parameterType, param.rawType, defaultValue, param.dataRedaction)
})
.partition(
pt => !pt.fieldType.isOptional && pt.defaultValue.isEmpty
Expand Down Expand Up @@ -296,13 +301,77 @@ object JacksonGenerator {
case _ if parents.length == 1 => Target.pure(parents.headOption)
case _ => Target.pure(None)
}

discriminators = parents.flatMap(_.discriminators)
discriminatorNames = discriminators.map(_.propertyName).toSet
parentParams = parentOpt.toList.flatMap(_.params)
parentParamNames = parentParams.map(_.name)
(parentRequiredTerms, parentOptionalTerms) = sortParams(parentParams)
parentTerms = parentRequiredTerms ++ parentOptionalTerms

discriminatorValues <- parentTerms
.flatMap({ term =>
discriminators.find(_.propertyName == term.propertyName).map((term, _))
})
.traverse({
case (term, discriminator) =>
val discriminatorValue = discriminator.mapping
.collectFirst({ case (value, elem) if elem.name == clsName => value })
.getOrElse(clsName)

def parseLiteral(parser: String => Expression, friendlyName: String): Target[Expression] =
Try(parser(discriminatorValue)).fold(
t => Target.raiseError[Expression](s"Unable to parse '$discriminatorValue' as '$friendlyName': ${t.getMessage}"),
Target.pure[Expression]
)

val discriminatorValueExpr = term.rawType.tpe match {
case Some(tpe @ "string") =>
term.rawType.format match {
case Some("date") | Some("date-time") | Some("byte") | Some("binary") =>
Target.raiseError[Expression](
s"Unsupported discriminator type '$tpe' with format '${term.rawType.format.getOrElse("unknown")}' for property '${term.propertyName}'"
)
case _ => Target.pure[Expression](new StringLiteralExpr(discriminatorValue))
}
case Some(tpe @ "boolean") => parseLiteral(x => new BooleanLiteralExpr(x.toBoolean), tpe)
case Some(tpe @ "integer") =>
term.rawType.format match {
case Some(fmt @ "int32") => parseLiteral(x => new IntegerLiteralExpr(x.toInt), fmt)
case Some(fmt @ "int64") => parseLiteral(x => new LongLiteralExpr(x.toLong), fmt)
case Some(fmt) =>
Target.raiseError[Expression](s"Unsupported discriminator type '$tpe' with format '$fmt' for property '${term.propertyName}'")
case None =>
parseLiteral(x => new ObjectCreationExpr(null, BIG_INTEGER_FQ_TYPE, new NodeList(new StringLiteralExpr(new BigInteger(x).toString))),
"BigInteger")
}
case Some(tpe @ "number") =>
term.rawType.format match {
case Some(fmt @ "float") => parseLiteral(x => new DoubleLiteralExpr(x.toFloat), fmt)
case Some(fmt @ "double") => parseLiteral(x => new DoubleLiteralExpr(x.toDouble), fmt)
case Some(fmt) =>
Target.raiseError[Expression](s"Unsupported discriminator type '$tpe' with format '$fmt' for property '${term.propertyName}'")
case None =>
parseLiteral(x =>
new ObjectCreationExpr(null, BIG_DECIMAL_FQ_TYPE, new NodeList(new StringLiteralExpr(new java.math.BigDecimal(x).toString))),
"BigDecimal")
}
case Some(tpe) =>
Target.raiseError[Expression](s"Unsupported discriminator type '$tpe' for property '${term.propertyName}'")
case None =>
term.fieldType match {
case cls: ClassOrInterfaceType =>
// hopefully it's an enum type; nothing else really makes sense here
Target.pure[Expression](new FieldAccessExpr(cls.getNameAsExpression, discriminatorValue.toSnakeCase.toUpperCase(Locale.US)))
case tpe =>
Target.raiseError[Expression](s"Unsupported discriminator type '${tpe.asString}' for property '${term.propertyName}'")
}
}

discriminatorValueExpr.map((term.propertyName, _))
})
.map(_.toMap)
} yield {
val discriminators = parents.flatMap(_.discriminators)
val discriminatorNames = discriminators.map(_.propertyName).toSet
val parentParams = parentOpt.toList.flatMap(_.params)
val parentParamNames = parentParams.map(_.name)
val (parentRequiredTerms, parentOptionalTerms) = sortParams(parentParams)
val parentTerms = parentRequiredTerms ++ parentOptionalTerms
val params = parents.filterNot(parent => parentOpt.contains(parent)).flatMap(_.params) ++ selfParams.filterNot(
param => discriminatorNames.contains(param.term.getName.getIdentifier) || parentParamNames.contains(param.term.getName.getIdentifier)
)
Expand All @@ -328,7 +397,7 @@ object JacksonGenerator {
terms.filterNot(term => discriminatorNames.contains(term.propertyName))

terms.foreach({
case ParameterTerm(propertyName, parameterName, fieldType, _, _, _) =>
case ParameterTerm(propertyName, parameterName, fieldType, _, _, _, _) =>
val field: FieldDeclaration = dtoClass.addField(fieldType, parameterName, PRIVATE, FINAL)
field.addSingleMemberAnnotation("JsonProperty", new StringLiteralExpr(propertyName))
})
Expand All @@ -338,26 +407,15 @@ object JacksonGenerator {
primaryConstructor.setParameters(
new NodeList(
withoutDiscriminators(parentTerms ++ terms).map({
case ParameterTerm(propertyName, parameterName, fieldType, _, _, _) =>
case ParameterTerm(propertyName, parameterName, fieldType, _, _, _, _) =>
new Parameter(util.EnumSet.of(FINAL), fieldType, new SimpleName(parameterName))
.addAnnotation(new SingleMemberAnnotationExpr(new Name("JsonProperty"), new StringLiteralExpr(propertyName)))
}): _*
)
)
val superCall = new MethodCallExpr(
"super",
parentTerms.map({ term =>
discriminators
.find(_.propertyName == term.propertyName)
.fold[Expression](new NameExpr(term.parameterName))(
discriminator =>
new StringLiteralExpr(
discriminator.mapping
.collectFirst({ case (value, elem) if elem.name == clsName => value })
.getOrElse(clsName)
)
)
}): _*
parentTerms.map(term => discriminatorValues.getOrElse(term.propertyName, new NameExpr(term.parameterName))): _*
)
primaryConstructor.setBody(dtoConstructorBody(superCall, terms))

Expand Down Expand Up @@ -495,11 +553,11 @@ object JacksonGenerator {
val builderClass = new ClassOrInterfaceDeclaration(util.EnumSet.of(PUBLIC, STATIC), false, "Builder")

withoutDiscriminators(parentRequiredTerms ++ requiredTerms).foreach({
case ParameterTerm(_, parameterName, fieldType, _, _, _) =>
case ParameterTerm(_, parameterName, fieldType, _, _, _, _) =>
builderClass.addField(fieldType, parameterName, PRIVATE)
})
withoutDiscriminators(parentOptionalTerms ++ optionalTerms).foreach({
case ParameterTerm(_, parameterName, fieldType, _, defaultValue, _) =>
case ParameterTerm(_, parameterName, fieldType, _, _, defaultValue, _) =>
val initializer = defaultValue.fold[Expression](
new MethodCallExpr(new NameExpr("Optional"), "empty")
)(
Expand All @@ -517,7 +575,7 @@ object JacksonGenerator {
builderConstructor.setParameters(
new NodeList(
withoutDiscriminators(parentRequiredTerms ++ requiredTerms).map({
case ParameterTerm(_, parameterName, _, parameterType, _, _) =>
case ParameterTerm(_, parameterName, _, parameterType, _, _, _) =>
new Parameter(util.EnumSet.of(FINAL), parameterType, new SimpleName(parameterName))
}): _*
)
Expand All @@ -526,7 +584,7 @@ object JacksonGenerator {
new BlockStmt(
new NodeList(
withoutDiscriminators(parentRequiredTerms ++ requiredTerms).map({
case ParameterTerm(_, parameterName, fieldType, _, _, _) =>
case ParameterTerm(_, parameterName, fieldType, _, _, _, _) =>
new ExpressionStmt(
new AssignExpr(
new FieldAccessExpr(new ThisExpr, parameterName),
Expand All @@ -549,7 +607,7 @@ object JacksonGenerator {
new BlockStmt(
withoutDiscriminators(parentTerms ++ terms)
.map({
case term @ ParameterTerm(_, parameterName, _, _, _, _) =>
case term @ ParameterTerm(_, parameterName, _, _, _, _, _) =>
new ExpressionStmt(
new AssignExpr(
new FieldAccessExpr(new ThisExpr, parameterName),
Expand All @@ -564,7 +622,7 @@ object JacksonGenerator {

// TODO: leave out with${name}() if readOnlyKey?
withoutDiscriminators(parentTerms ++ terms).foreach({
case ParameterTerm(_, parameterName, fieldType, parameterType, _, _) =>
case ParameterTerm(_, parameterName, fieldType, parameterType, _, _, _) =>
val methodName = s"with${parameterName.unescapeIdentifier.capitalize}"

builderClass
Expand Down Expand Up @@ -740,6 +798,8 @@ object JacksonGenerator {
(tpe, classDep) = tpeClassDep

argName = if (needCamelSnakeConversion) name.toCamelCase else name
rawType = RawParameterType(Option(property.getType), Option(property.getFormat))

_declDefaultPair <- Option(isRequired)
.filterNot(_ == false)
.fold[Target[(Type, Option[Expression])]](
Expand All @@ -758,7 +818,7 @@ object JacksonGenerator {
(finalDeclType, finalDefaultValue) = _declDefaultPair
term <- safeParseParameter(s"final ${finalDeclType} ${argName.escapeIdentifier}")
dep = classDep.filterNot(_.asString == clsName) // Filter out our own class name
} yield ProtocolParameter[JavaLanguage](term, name, dep, readOnlyKey, emptyToNull, dataRedaction, defaultValue)
} yield ProtocolParameter[JavaLanguage](term, name, dep, rawType, readOnlyKey, emptyToNull, dataRedaction, defaultValue)
}

case RenderDTOClass(clsName, selfParams, parents) =>
Expand Down
@@ -1,7 +1,7 @@
package com.twilio.guardrail.generators.syntax

import com.github.javaparser.JavaParser
import com.github.javaparser.ast.`type`.{ ClassOrInterfaceType, Type }
import com.github.javaparser.ast.`type`.{ ClassOrInterfaceType, PrimitiveType, Type }
import com.github.javaparser.ast.body._
import com.github.javaparser.ast.comments.{ BlockComment, Comment }
import com.github.javaparser.ast.expr.{
Expand Down Expand Up @@ -52,14 +52,16 @@ object Java {
case cls: ClassOrInterfaceType if name.contains(".") =>
(cls.getScope.asScala.fold("")(_.getName.asString + ".") + cls.getNameAsString) == name
case cls: ClassOrInterfaceType => cls.getNameAsString == name
case pt: PrimitiveType => pt.asString == name
case _ => false
}

def name: Option[String] =
tpe match {
case cls: ClassOrInterfaceType =>
Some(cls.getScope.asScala.fold("")(_.getName.asString + ".") + cls.getNameAsString)
case _ => None
case pt: PrimitiveType => Option(pt.asString)
case _ => None
}
}

Expand Down
Expand Up @@ -2,7 +2,7 @@ package core.Jackson

import com.fasterxml.jackson.databind.ObjectMapper
import org.scalatest.{FreeSpec, Matchers}
import polymorphismMapped.client.dropwizard.definitions.{A, B, Base, C}
import polymorphismMapped.client.dropwizard.definitions.{A, B, Base, C, DiscrimEnum, EnumA, EnumB, EnumBase, EnumC}
import scala.reflect.ClassTag

class JacksonPolyMappingTest extends FreeSpec with Matchers {
Expand Down Expand Up @@ -33,4 +33,30 @@ class JacksonPolyMappingTest extends FreeSpec with Matchers {
verify[C]("""{"polytype": "C", "some_c": 42.42}""", "C")
}
}

"Polymorphic definitions with enum discriminator mappings" - {
"should have their discriminator initialized properly" in {
val a = new EnumA.Builder(42).build()
a.getPolytype shouldBe DiscrimEnum.SOME_VALUE_ONE

val b = new EnumB.Builder("foo").build()
b.getPolytype shouldBe DiscrimEnum.ANOTHER_VALUE

val c = new EnumC.Builder(42.42).build()
c.getPolytype shouldBe DiscrimEnum.YET_ANOTHER_VALUE
}

"should deserialize properly" in {
def verify[T](json: String, discriminatorValue: DiscrimEnum)(implicit cls: ClassTag[T]): Unit = {
val pojo = mapper.readValue(json, classOf[EnumBase])
pojo shouldNot be(null)
pojo.getClass shouldBe cls.runtimeClass
pojo.getPolytype shouldBe discriminatorValue
}

verify[EnumA]("""{"polytype": "some-value-one", "some_a": 42}""", DiscrimEnum.SOME_VALUE_ONE)
verify[EnumB]("""{"polytype": "another-value", "some_b": "foo"}""", DiscrimEnum.ANOTHER_VALUE)
verify[EnumC]("""{"polytype": "yet-another-value", "some_c": 42.42}""", DiscrimEnum.YET_ANOTHER_VALUE)
}
}
}

0 comments on commit 606ff73

Please sign in to comment.