From 59e91e6404cb802f0297c1fc7cb89e6692b3f4e9 Mon Sep 17 00:00:00 2001 From: George Leontiev Date: Thu, 30 Jun 2022 20:47:10 +0000 Subject: [PATCH] scrooge: Introduce AnnotatedFieldType to support annotating inner container types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Problem Currently, thrift annotations inside of collection types pass the parser, but are subsequently dropped. # Solution This change adds a new `FieldType` – `AnnotatedFieldType`. This is currently used to support thrift annotations on types inside of containers (maps/lists/sets). But this can be further extended to apply annotations on any `FieldType`. # Result By doing this, annotations inside of collection types are now propagated and are available to downstream compilers. JIRA Issues: CSL-11944 Differential Revision: https://phabricator.twitter.biz/D911997 --- CHANGELOG.rst | 4 + .../scrooge/frontend/ThriftParserSpec.scala | 105 +++++++++++++++++- .../scrooge/frontend/TypeResolverSpec.scala | 24 ++++ .../com/twitter/scrooge/AST/Definition.scala | 10 ++ .../scala/com/twitter/scrooge/AST/Type.scala | 46 +++++++- .../android_generator/AndroidGenerator.scala | 24 ++-- .../scrooge/backend/CocoaGenerator.scala | 19 +++- .../twitter/scrooge/backend/Generator.scala | 81 +++++++++----- .../scrooge/backend/ScalaGenerator.scala | 12 +- .../scrooge/backend/StructTemplate.scala | 10 ++ .../scrooge/backend/lua/LuaGenerator.scala | 10 +- .../scrooge/frontend/ThriftParser.scala | 21 ++-- .../scrooge/frontend/TypeResolver.scala | 81 ++++++++------ .../java_generator/ApacheJavaGenerator.scala | 10 +- .../java_generator/FieldTypeController.scala | 83 +++++++++++--- .../FieldValueMetadataController.scala | 6 +- .../java_generator/PrintConstController.scala | 1 + .../swift_generator/SwiftGenerator.scala | 9 +- .../SwiftPrintConstController.scala | 22 +--- 19 files changed, 429 insertions(+), 149 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index c57be01bf..a19b867a4 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -10,6 +10,10 @@ Unreleased New Features ~~~~~~~~~~~~ +* scrooge-generator: Introduce a `AnnotatedFieldType` to abstract type annotations from + `FieldType` definitions. Currently used to propagate thrift annotations inside of + collection types. ``PHAB_ID=D911997`` + * scrooge-core: `c.t.scrooge.ThriftUnion.fieldInfoForUnionClass` API for retrieving `ThriftStructFieldInfo` for a `ThriftUnion` member class without having to instantiate it. ``PHAB_ID=D871986`` diff --git a/scrooge-generator-tests/src/test/scala/com/twitter/scrooge/frontend/ThriftParserSpec.scala b/scrooge-generator-tests/src/test/scala/com/twitter/scrooge/frontend/ThriftParserSpec.scala index 31d5d2cef..4601acdb9 100644 --- a/scrooge-generator-tests/src/test/scala/com/twitter/scrooge/frontend/ThriftParserSpec.scala +++ b/scrooge-generator-tests/src/test/scala/com/twitter/scrooge/frontend/ThriftParserSpec.scala @@ -101,11 +101,29 @@ struct MyStruct {} parser.parse("list>", parser.fieldType) must be( ListType(ListType(TString, None), None) ) - // Annotations within container type definitions are parsed properly, but thrown away parser.parse("list < i64 (something = \"else\") >", parser.fieldType) must be( - ListType(TI64, None)) + ListType( + AnnotatedFieldType.wrap( + TI64, + Map("something" -> "else") + ), + None + ) + ) parser.parse("list(c = \"d\")>", parser.fieldType) must be( - ListType(ListType(TString, None), None) + ListType( + AnnotatedFieldType.wrap( + ListType( + AnnotatedFieldType.wrap( + TString, + Map("a" -> "b") + ), + None + ), + Map("c" -> "d") + ), + None + ) ) } @@ -114,7 +132,13 @@ struct MyStruct {} SetType(ReferenceType(Identifier("Monster")), None) ) parser.parse("set", parser.fieldType) must be( - SetType(TBool, None) + SetType( + AnnotatedFieldType.wrap( + TBool, + Map("hello" -> "goodbye") + ), + None + ) ) } @@ -125,7 +149,70 @@ struct MyStruct {} parser.parse( "map (e=\"f\")>", parser.fieldType) must be( - MapType(TString, ListType(TBool, None), None) + MapType( + AnnotatedFieldType.wrap( + TString, + Map("a" -> "b") + ), + AnnotatedFieldType.wrap( + ListType( + AnnotatedFieldType.wrap( + TBool, + Map("c" -> "d") + ), + None + ), + Map("e" -> "f") + ), + None + ) + ) + } + + "inner collection types annotations" in { + parser.parse( + """list (python.immutable = "1"), map(python.immutable = "2")> (python.immutable = "3")>>>>""", + parser.fieldType + ) must be( + ListType( + MapType( + AnnotatedFieldType.wrap( + SetType( + AnnotatedFieldType.wrap( + TI32, + Map("python.immutable" -> "0") + ), + None + ), + Map("python.immutable" -> "1") + ), + MapType( + TI32, + SetType( + AnnotatedFieldType.wrap( + ListType( + AnnotatedFieldType.wrap( + MapType( + ReferenceType( + Identifier("Insanity") + ), + TString, + None + ), + Map("python.immutable" -> "2") + ), + None + ), + Map("python.immutable" -> "3") + ), + None + ), + None + ), + None + ), + None + ) ) } @@ -862,7 +949,13 @@ enum Foo ) must be( Typedef( SimpleID("tiny_float_list", None), - ListType(TDouble, None), + ListType( + AnnotatedFieldType.wrap( + TDouble, + Map("cpp.fixed_point" -> "16") + ), + None + ), Map(), Map() ) diff --git a/scrooge-generator-tests/src/test/scala/com/twitter/scrooge/frontend/TypeResolverSpec.scala b/scrooge-generator-tests/src/test/scala/com/twitter/scrooge/frontend/TypeResolverSpec.scala index 67516e015..082b58e41 100644 --- a/scrooge-generator-tests/src/test/scala/com/twitter/scrooge/frontend/TypeResolverSpec.scala +++ b/scrooge-generator-tests/src/test/scala/com/twitter/scrooge/frontend/TypeResolverSpec.scala @@ -67,6 +67,30 @@ class TypeResolverSpec extends Spec { } } + "resolve annotations inside of container types" in { + val input = + """ struct Test { + | 1: list testList, + | 2: set testSet, + | 3: map testMap + |} + """.stripMargin + val result = resolve(input) + result.document.defs.head match { + case struct: Struct => + struct.fields(0).fieldType must be( + ListType(AnnotatedFieldType.wrap(TI32, Map("my.annotation" -> "list")), None)) + struct.fields(1).fieldType must be( + SetType(AnnotatedFieldType.wrap(TString, Map("my.annotation" -> "set")), None)) + struct.fields(2).fieldType must be( + MapType( + AnnotatedFieldType.wrap(TI32, Map("my.annotation" -> "mapKey")), + AnnotatedFieldType.wrap(TI32, Map("my.annotation" -> "mapValue")), + None)) + case _ => fail() + } + } + "resolve self-referencing types" in { val input = """struct S { diff --git a/scrooge-generator/src/main/scala/com/twitter/scrooge/AST/Definition.scala b/scrooge-generator/src/main/scala/com/twitter/scrooge/AST/Definition.scala index 569044dfc..25e27c68e 100644 --- a/scrooge-generator/src/main/scala/com/twitter/scrooge/AST/Definition.scala +++ b/scrooge-generator/src/main/scala/com/twitter/scrooge/AST/Definition.scala @@ -42,6 +42,16 @@ sealed abstract class StructLike extends Definition { val fields: Seq[Field] val docstring: Option[String] val annotations: Map[String, String] + + def withAnnotations(newAnnotations: Map[String, String]): StructLike = + this match { + case s: Struct => s.copy(annotations = annotations ++ newAnnotations) + case u: Union => u.copy(annotations = annotations ++ newAnnotations) + case e: Exception_ => e.copy(annotations = annotations ++ newAnnotations) + // FunctionResult and FunctionArgs don't keep track of annotations + case r: FunctionResult => r + case a: FunctionArgs => a + } } case class Struct( diff --git a/scrooge-generator/src/main/scala/com/twitter/scrooge/AST/Type.scala b/scrooge-generator/src/main/scala/com/twitter/scrooge/AST/Type.scala index 4df829b92..ba1a25dbd 100644 --- a/scrooge-generator/src/main/scala/com/twitter/scrooge/AST/Type.scala +++ b/scrooge-generator/src/main/scala/com/twitter/scrooge/AST/Type.scala @@ -14,9 +14,53 @@ case object TDouble extends BaseType case object TString extends BaseType case object TBinary extends BaseType +/** + * AnnotatedFieldType is used to be able to annotate arbitrary FieldTypes. + * The current use case is to allow annotating types inside of collection hierarchies. + * In the future this can/should be expanded for type annotations in general. + * + * Note: please use AnnotatedFieldType.build to construct instances of this. + * + * @param underlying the type being annotated + * @param annos annotations applied to the underlying type + */ +case class AnnotatedFieldType private (underlying: FieldType, annos: Map[String, String]) + extends FieldType { + require(annos.nonEmpty, "Cannot construct an AnnotatedFieldType with empty annotations.") + + /** + * Once type annotations from underlying types have fully migrated to + * AnnotatedFieldType, this will not be necessary anymore. + */ + def unwrap: FieldType = underlying match { + case s: StructType => s.copy(struct = s.struct.withAnnotations(annos)) + case e: EnumType => e.copy(enum = e.enum.copy(annotations = annos)) + // Other types don't keep track of annotations, so there is nothing to propagate + case otherwise => otherwise + } + def annotations: Map[String, String] = annos +} + +/** + * Type builder utils where we make the best effort to unwrap + * nested annotated types and avoid producing extra wrappers + * where possible. + */ +object AnnotatedFieldType { + def wrap(t: FieldType, annotations: Map[String, String]): FieldType = { + if (annotations.isEmpty) t + else + t match { + case t: AnnotatedFieldType => + new AnnotatedFieldType(t.underlying, t.annotations ++ annotations) + case otherwise => new AnnotatedFieldType(otherwise, annotations) + } + } +} + /** * ReferenceType is generated by ThriftParser in the frontend and - * resolved by TypeResolver. There will only ReferenceTypes after + * resolved by TypeResolver. There will only be ReferenceTypes after * resolution seen by the backend when self-reference structs, * mutually recursive structs, or references to further definitions * (structs/enums) are present in the Document. diff --git a/scrooge-generator/src/main/scala/com/twitter/scrooge/android_generator/AndroidGenerator.scala b/scrooge-generator/src/main/scala/com/twitter/scrooge/android_generator/AndroidGenerator.scala index e841517a0..7bd96a358 100644 --- a/scrooge-generator/src/main/scala/com/twitter/scrooge/android_generator/AndroidGenerator.scala +++ b/scrooge-generator/src/main/scala/com/twitter/scrooge/android_generator/AndroidGenerator.scala @@ -1,18 +1,21 @@ package com.twitter.scrooge.android_generator -import com.github.mustachejava.{Mustache, DefaultMustacheFactory} -import com.twitter.scrooge.ast.ListType -import com.twitter.scrooge.ast.MapType -import com.twitter.scrooge.ast.ReferenceType -import com.twitter.scrooge.ast.SetType +import com.github.mustachejava.Mustache +import com.github.mustachejava.DefaultMustacheFactory import com.twitter.scrooge.ast._ -import com.twitter.scrooge.backend.{GeneratorFactory, Generator, ServiceOption} +import com.twitter.scrooge.backend.GeneratorFactory +import com.twitter.scrooge.backend.Generator +import com.twitter.scrooge.backend.ServiceOption import com.twitter.scrooge.CompilerDefaults -import com.twitter.scrooge.frontend.{ScroogeInternalException, ResolvedDocument} +import com.twitter.scrooge.frontend.ScroogeInternalException +import com.twitter.scrooge.frontend.ResolvedDocument import com.twitter.scrooge.frontend.ParseException -import com.twitter.scrooge.java_generator.{ApacheJavaGenerator, TypeController} +import com.twitter.scrooge.java_generator.ApacheJavaGenerator +import com.twitter.scrooge.java_generator.TypeController import com.twitter.scrooge.mustache.ScalaObjectHandler -import java.io.{FileWriter, File, StringWriter} +import java.io.FileWriter +import java.io.File +import java.io.StringWriter import scala.collection.concurrent.TrieMap import scala.collection.mutable @@ -70,6 +73,8 @@ class AndroidGenerator( t match { case Void => if (inContainer) "Void" else "void" case OnewayVoid => if (inContainer) "Void" else "void" + case at: AnnotatedFieldType => + typeName(at.unwrap, inContainer, inInit, skipGeneric, fileNamespace) case TBool => if (inContainer) "Boolean" else "boolean" case TByte => if (inContainer) "Byte" else "byte" case TI16 => if (inContainer) "Short" else "short" @@ -120,6 +125,7 @@ class AndroidGenerator( def isListOrSetType(t: FunctionType): Boolean = { t match { + case at: AnnotatedFieldType => isListOrSetType(at.unwrap) case ListType(_, _) => true case SetType(_, _) => true case _ => false diff --git a/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/CocoaGenerator.scala b/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/CocoaGenerator.scala index 399b626bf..032454eb3 100644 --- a/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/CocoaGenerator.scala +++ b/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/CocoaGenerator.scala @@ -2,10 +2,14 @@ package com.twitter.scrooge.backend import com.twitter.scrooge.ast._ import com.twitter.scrooge.mustache.Dictionary._ -import com.twitter.scrooge.mustache.{Dictionary, HandlebarLoader} -import com.twitter.scrooge.frontend.{ScroogeInternalException, ResolvedDocument} - -import java.io.{OutputStreamWriter, FileOutputStream, File} +import com.twitter.scrooge.mustache.Dictionary +import com.twitter.scrooge.mustache.HandlebarLoader +import com.twitter.scrooge.frontend.ScroogeInternalException +import com.twitter.scrooge.frontend.ResolvedDocument + +import java.io.OutputStreamWriter +import java.io.FileOutputStream +import java.io.File import scala.collection.mutable object CocoaGeneratorFactory extends GeneratorFactory { @@ -96,6 +100,7 @@ class CocoaGenerator( def getDependentTypes(struct: StructLike): Set[FieldType] = { def getDependentTypes(fieldType: FieldType): Set[FieldType] = { fieldType match { + case at: AnnotatedFieldType => getDependentTypes(at.unwrap) case t: ListType => getDependentTypes(t.eltType) case t: MapType => getDependentTypes(t.keyType) ++ getDependentTypes(t.valueType) case t: SetType => getDependentTypes(t.eltType) @@ -288,7 +293,9 @@ class CocoaGenerator( def toMutable(f: Field): (String, String) = ("", "") def genType(t: FunctionType, immutable: Boolean = false): CodeFragment = { - val code = t match { + @scala.annotation.tailrec + def getCode(t: FunctionType): String = t match { + case at: AnnotatedFieldType => getCode(at.unwrap) case Void => "void" case OnewayVoid => "void" case TBool => "BOOL" @@ -306,7 +313,7 @@ class CocoaGenerator( case r: ReferenceType => throw new ScroogeInternalException("ReferenceType should not appear in backend") } - v(code) + v(getCode(t)) } def genFieldType(f: Field): CodeFragment = { diff --git a/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/Generator.scala b/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/Generator.scala index d3f0ce799..a95f4610d 100644 --- a/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/Generator.scala +++ b/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/Generator.scala @@ -104,20 +104,25 @@ abstract class Generator(doc: ResolvedDocument) { fieldType: FieldType, fieldAnnotations: scala.collection.immutable.Map[String, String] ): Unit = { - val fieldTypes: Set[Class[_]] = fieldType match { - case TBool => Set(classOf[java.lang.Boolean], classOf[Boolean]) - case TByte => Set(classOf[java.lang.Byte], classOf[Byte]) - case TI16 => Set(classOf[java.lang.Short], classOf[Short]) - case TI32 => Set(classOf[java.lang.Integer], classOf[Int]) - case TI64 => Set(classOf[java.lang.Long], classOf[Long]) - case TDouble => Set(classOf[java.lang.Double], classOf[Double]) - case TString => Set(classOf[String]) - case TBinary => Set(classOf[ByteBuffer]) - case MapType(_, _, _) => Set(classOf[Map[_, _]]) - case SetType(_, _) => Set(classOf[Set[_]]) - case ListType(_, _) => Set(classOf[Seq[_]]) - case _ => Set.empty + @scala.annotation.tailrec + def extractFieldTypes(fieldType: FieldType): Set[Class[_]] = { + fieldType match { + case at: AnnotatedFieldType => extractFieldTypes(at.unwrap) + case TBool => Set(classOf[java.lang.Boolean], classOf[Boolean]) + case TByte => Set(classOf[java.lang.Byte], classOf[Byte]) + case TI16 => Set(classOf[java.lang.Short], classOf[Short]) + case TI32 => Set(classOf[java.lang.Integer], classOf[Int]) + case TI64 => Set(classOf[java.lang.Long], classOf[Long]) + case TDouble => Set(classOf[java.lang.Double], classOf[Double]) + case TString => Set(classOf[String]) + case TBinary => Set(classOf[ByteBuffer]) + case MapType(_, _, _) => Set(classOf[Map[_, _]]) + case SetType(_, _) => Set(classOf[Set[_]]) + case ListType(_, _) => Set(classOf[Seq[_]]) + case _ => Set.empty + } } + val fieldTypes: Set[Class[_]] = extractFieldTypes(fieldType) val violations = AnnotationValidator.validateAnnotations(fieldTypes, fieldAnnotations) if (violations.nonEmpty) throw new IllegalArgumentException(s"The annotation is invalid: ${violations.mkString(", ")}") @@ -287,6 +292,7 @@ abstract class TemplateGenerator(val resolvedDoc: ResolvedDocument) def isPrimitive(t: FunctionType): Boolean = { t match { + case at: AnnotatedFieldType => isPrimitive(at.unwrap) case Void | TBool | TByte | TI16 | TI32 | TI64 | TDouble => true case _ => false } @@ -294,6 +300,7 @@ abstract class TemplateGenerator(val resolvedDoc: ResolvedDocument) def isLazyReadEnabled(t: FunctionType, optional: Boolean): Boolean = { t match { + case at: AnnotatedFieldType => isLazyReadEnabled(at.unwrap, optional) case TString => true case Void | TBool | TByte | TI16 | TI32 | TI64 | TDouble => optional case _ => false @@ -358,13 +365,15 @@ abstract class TemplateGenerator(val resolvedDoc: ResolvedDocument) * The default value for the specified type and mutability. */ def genDefaultValue(fieldType: FieldType): CodeFragment = { - val code = fieldType match { + @scala.annotation.tailrec + def getCode(fieldType: FieldType): String = fieldType match { + case at: AnnotatedFieldType => getCode(at.unwrap) case TBool => "false" case TByte | TI16 | TI32 => "0" case TDouble => "0.0" case _ => "null" } - v(code) + v(getCode(fieldType)) } /** @@ -373,7 +382,9 @@ abstract class TemplateGenerator(val resolvedDoc: ResolvedDocument) * For String, null is not valid value for required field. */ def genUnsafeEmptyValue(fieldType: FieldType): CodeFragment = { - val code = fieldType match { + @scala.annotation.tailrec + def getCode(fieldType: FieldType): String = fieldType match { + case at: AnnotatedFieldType => getCode(at.unwrap) case TBool => "false" case TByte | TI16 | TI32 => "0" case TDouble => "0.0" @@ -381,7 +392,7 @@ abstract class TemplateGenerator(val resolvedDoc: ResolvedDocument) case TString => "\"empty\"" case _ => "null" } - v(code) + v(getCode(fieldType)) } def genDefaultFieldValueForFieldInfo(f: Field): Option[CodeFragment] = { @@ -428,7 +439,9 @@ abstract class TemplateGenerator(val resolvedDoc: ResolvedDocument) genDefaultFieldValue(f).getOrElse(genDefaultValue(f.fieldType)) def genConstType(t: FunctionType): CodeFragment = { - val code = t match { + @scala.annotation.tailrec + def getCode(t: FunctionType): String = t match { + case at: AnnotatedFieldType => getCode(at.unwrap) case Void => "VOID" case TBool => "BOOL" case TByte => "BYTE" @@ -445,7 +458,7 @@ abstract class TemplateGenerator(val resolvedDoc: ResolvedDocument) case ListType(_, _) => "LIST" case x => throw new InternalError("constType#" + t) } - v(code) + v(getCode(t)) } /** @@ -460,7 +473,9 @@ abstract class TemplateGenerator(val resolvedDoc: ResolvedDocument) } def genProtocolReadMethod(t: FunctionType): CodeFragment = { - val code = t match { + @scala.annotation.tailrec + def getCode(t: FunctionType): String = t match { + case at: AnnotatedFieldType => getCode(at.unwrap) case TBool => "readBool" case TByte => "readByte" case TI16 => "readI16" @@ -471,11 +486,13 @@ abstract class TemplateGenerator(val resolvedDoc: ResolvedDocument) case TBinary => "readBinary" case x => throw new ScroogeInternalException("genProtocolReadMethod#" + t) } - v(code) + v(getCode(t)) } def genProtocolSkipMethod(t: FunctionType): CodeFragment = { - val code = t match { + @scala.annotation.tailrec + def getCode(t: FunctionType): String = t match { + case at: AnnotatedFieldType => getCode(at.unwrap) case TBool => "offsetSkipBool" case TByte => "offsetSkipBool" case TI16 => "offsetSkipI16" @@ -486,11 +503,13 @@ abstract class TemplateGenerator(val resolvedDoc: ResolvedDocument) case TBinary => "offsetSkipBinary" case x => throw new ScroogeInternalException("genProtocolSkipMethod#" + t) } - v(code) + v(getCode(t)) } def genOffsetSkipProtocolMethod(t: FunctionType): CodeFragment = { - val code = t match { + @scala.annotation.tailrec + def getCode(t: FunctionType): String = t match { + case at: AnnotatedFieldType => getCode(at.unwrap) case TBool => "offsetSkipBool" case TByte => "offsetSkipByte" case TI16 => "offsetSkipI16" @@ -502,11 +521,13 @@ abstract class TemplateGenerator(val resolvedDoc: ResolvedDocument) case x => s"""Invalid type passed($x) for genOffsetSkipProtocolMethod method. Compile will fail here.""" } - v(code) + v(getCode(t)) } def genDecodeProtocolMethod(t: FunctionType): CodeFragment = { - val code = t match { + @scala.annotation.tailrec + def getCode(t: FunctionType): String = t match { + case at: AnnotatedFieldType => getCode(at.unwrap) case TBool => "decodeBool" case TByte => "decodeByte" case TI16 => "decodeI16" @@ -518,11 +539,13 @@ abstract class TemplateGenerator(val resolvedDoc: ResolvedDocument) case x => s"""Invalid type passed ($x) for genDecodeProtocolMethod method. Compile will fail here.""" } - v(code) + v(getCode(t)) } def genProtocolWriteMethod(t: FunctionType): CodeFragment = { - val code = t match { + @scala.annotation.tailrec + def getCode(t: FunctionType): String = t match { + case at: AnnotatedFieldType => getCode(at.unwrap) case TBool => "writeBool" case TByte => "writeByte" case TI16 => "writeI16" @@ -533,7 +556,7 @@ abstract class TemplateGenerator(val resolvedDoc: ResolvedDocument) case TBinary => "writeBinary" case x => throw new ScroogeInternalException("protocolWriteMethod#" + t) } - v(code) + v(getCode(t)) } def genType(t: FunctionType, immutable: Boolean = false): CodeFragment diff --git a/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/ScalaGenerator.scala b/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/ScalaGenerator.scala index 34d137304..582dd7a2a 100644 --- a/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/ScalaGenerator.scala +++ b/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/ScalaGenerator.scala @@ -193,7 +193,9 @@ class ScalaGenerator( } override def genDefaultValue(fieldType: FieldType): CodeFragment = { - val code = fieldType match { + @scala.annotation.tailrec + def getCode(fieldType: FieldType): String = fieldType match { + case at: AnnotatedFieldType => getCode(at.unwrap) case TI64 => "0L" case ListType(_, _) => "_root_.scala.collection.immutable.Nil" @@ -203,7 +205,7 @@ class ScalaGenerator( "_root_.scala.collection.immutable.Set.empty" case _ => super.genDefaultValue(fieldType).toData } - v(code) + v(getCode(fieldType)) } override def genConstant(constant: RHS, fieldType: Option[FieldType] = None): CodeFragment = { @@ -215,7 +217,9 @@ class ScalaGenerator( def genType(t: FunctionType, immutable: Boolean = false): CodeFragment = { val prefix = if (immutable) "_root_.scala.collection.immutable." else "_root_.scala.collection." - val code = t match { + @scala.annotation.tailrec + def getCode(t: FunctionType): String = t match { + case at: AnnotatedFieldType => getCode(at.unwrap) case Void => "Unit" case OnewayVoid => "Unit" case TBool => "Boolean" @@ -243,7 +247,7 @@ class ScalaGenerator( case r: ReferenceType => throw new ScroogeInternalException("ReferenceType should not appear in backend") } - v(code) + v(getCode(t)) } def genPrimitiveType(t: FunctionType): CodeFragment = genType(t) diff --git a/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/StructTemplate.scala b/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/StructTemplate.scala index 19cdd739f..c64406f20 100644 --- a/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/StructTemplate.scala +++ b/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/StructTemplate.scala @@ -45,6 +45,7 @@ trait StructTemplate { self: TemplateGenerator => def readWriteInfo[T <: FieldType](sid: SimpleID, t: FieldType): Dictionary = { t match { + case at: AnnotatedFieldType => readWriteInfo(sid, at.unwrap) case t: ListType => val elt = sid.append("_element") TypeTemplate + Dictionary( @@ -154,8 +155,10 @@ trait StructTemplate { self: TemplateGenerator => * For a set>: * "readSet($protoName, proto => { readList(proto, TProtocols.readI32Fn) })" */ + @scala.annotation.tailrec private[this] def genReadValue(fieldType: FieldType, protoName: String): CodeFragment = { fieldType match { + case at: AnnotatedFieldType => genReadValue(at.unwrap, protoName) case TBool => v(s"$protoName.readBool()") case TByte => v(s"$protoName.readByte()") case TI16 => v(s"$protoName.readI16()") @@ -214,8 +217,10 @@ trait StructTemplate { self: TemplateGenerator => * For a set>: * "proto => { readSet(proto => { readList(protocol, TProtocols.readStringFn) }}}" */ + @scala.annotation.tailrec private[this] def genReadValueFn1(fieldType: FieldType): CodeFragment = { fieldType match { + case at: AnnotatedFieldType => genReadValueFn1(at.unwrap) case TBool => v("_root_.com.twitter.scrooge.internal.TProtocols.readBoolFn") case TByte => v("_root_.com.twitter.scrooge.internal.TProtocols.readByteFn") case TI16 => v("_root_.com.twitter.scrooge.internal.TProtocols.readI16Fn") @@ -231,8 +236,10 @@ trait StructTemplate { self: TemplateGenerator => } } + @scala.annotation.tailrec private[this] def genWriteValueFn2(fieldType: FieldType): CodeFragment = { fieldType match { + case at: AnnotatedFieldType => genWriteValueFn2(at.unwrap) case TBool => v("_root_.com.twitter.scrooge.internal.TProtocols.writeBoolFn") case TByte => @@ -266,12 +273,14 @@ trait StructTemplate { self: TemplateGenerator => } } + @scala.annotation.tailrec private[this] def genWriteValue( fieldName: CodeFragment, fieldType: FieldType, protoName: String ): CodeFragment = { fieldType match { + case at: AnnotatedFieldType => genWriteValue(fieldName, at.unwrap, protoName) case TBool => v(s"$protoName.writeBool($fieldName)") case TByte => @@ -548,6 +557,7 @@ trait StructTemplate { self: TemplateGenerator => */ private def canCallWithoutPassthroughFields(fieldType: FieldType): Boolean = { fieldType match { + case at: AnnotatedFieldType => canCallWithoutPassthroughFields(at.unwrap) case t if isPrimitive(t) => true case TBinary | TString => diff --git a/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/lua/LuaGenerator.scala b/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/lua/LuaGenerator.scala index edaf194b8..27d1ba335 100644 --- a/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/lua/LuaGenerator.scala +++ b/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/lua/LuaGenerator.scala @@ -1,9 +1,13 @@ package com.twitter.scrooge.backend.lua import com.twitter.scrooge.ast._ -import com.twitter.scrooge.backend.{Generator, GeneratorFactory, ServiceOption, TemplateGenerator} +import com.twitter.scrooge.backend.Generator +import com.twitter.scrooge.backend.GeneratorFactory +import com.twitter.scrooge.backend.ServiceOption +import com.twitter.scrooge.backend.TemplateGenerator import com.twitter.scrooge.frontend.ResolvedDocument -import com.twitter.scrooge.mustache.Dictionary.{CodeFragment, v} +import com.twitter.scrooge.mustache.Dictionary.CodeFragment +import com.twitter.scrooge.mustache.Dictionary.v import com.twitter.scrooge.mustache.HandlebarLoader import java.io.File import com.twitter.scrooge.mustache.Dictionary @@ -105,6 +109,7 @@ class LuaGenerator( } def genType(t: FunctionType, immutable: Boolean = false): CodeFragment = t match { + case at: AnnotatedFieldType => genType(at.unwrap, immutable) case bt: BaseType => v(s"ttype = '${genPrimitiveType(bt)}'") case StructType(st, _) => v(s"ttype = 'struct', fields = ${genID(st.sid.toTitleCase)}.fields") case EnumType(et, _) => v(s"ttype = 'enum', value = ${genID(et.sid.toTitleCase)}") @@ -173,6 +178,7 @@ class LuaGenerator( excludeSelfType: SimpleID ): Seq[NamedType] = { ft match { + case at: AnnotatedFieldType => findRequireableStructTypes(at.unwrap, excludeSelfType) case t: NamedType if (excludeSelfType == t.sid) => Nil case t: StructType => Seq(t) case t: EnumType => Seq(t) diff --git a/scrooge-generator/src/main/scala/com/twitter/scrooge/frontend/ThriftParser.scala b/scrooge-generator/src/main/scala/com/twitter/scrooge/frontend/ThriftParser.scala index 57860e9b2..e5f6a10d9 100644 --- a/scrooge-generator/src/main/scala/com/twitter/scrooge/frontend/ThriftParser.scala +++ b/scrooge-generator/src/main/scala/com/twitter/scrooge/frontend/ThriftParser.scala @@ -177,24 +177,25 @@ class ThriftParser( lazy val containerType: Parser[ContainerType] = mapType | setType | listType - // Annotations within container types are parsed correctly, but currently thrown away lazy val mapType: Parser[MapType] = ("map" ~> opt(cppType) <~ "<") ~ - (fieldType <~ defaultedAnnotations <~ ",") ~ - (fieldType <~ defaultedAnnotations <~ ">") ^^ { - case cpp ~ key ~ value => MapType(key, value, cpp) + (fieldType ~ defaultedAnnotations <~ ",") ~ + (fieldType ~ defaultedAnnotations <~ ">") ^^ { + case cpp ~ (key ~ keyAnnotations) ~ (value ~ valueAnnotations) => + MapType( + AnnotatedFieldType.wrap(key, keyAnnotations), + AnnotatedFieldType.wrap(value, valueAnnotations), + cpp) } - // Annotations within container types are parsed correctly, but currently thrown away lazy val setType: Parser[SetType] = - ("set" ~> opt(cppType)) ~ ("<" ~> fieldType <~ defaultedAnnotations <~ ">") ^^ { - case cpp ~ t => SetType(t, cpp) + ("set" ~> opt(cppType)) ~ ("<" ~> fieldType ~ defaultedAnnotations <~ ">") ^^ { + case cpp ~ (t ~ annotations) => SetType(AnnotatedFieldType.wrap(t, annotations), cpp) } - // Annotations within container types are parsed correctly, but currently thrown away lazy val listType: Parser[ListType] = - ("list" ~ "<") ~> (fieldType <~ defaultedAnnotations <~ ">") ~ opt(cppType) ^^ { - case t ~ cpp => ListType(t, cpp) + ("list" ~ "<") ~> (fieldType ~ defaultedAnnotations <~ ">") ~ opt(cppType) ^^ { + case (t ~ annotations) ~ cpp => ListType(AnnotatedFieldType.wrap(t, annotations), cpp) } // FFS. i'm very close to removing this and forcably breaking old thrift files. diff --git a/scrooge-generator/src/main/scala/com/twitter/scrooge/frontend/TypeResolver.scala b/scrooge-generator/src/main/scala/com/twitter/scrooge/frontend/TypeResolver.scala index d8cd917f7..4aa4f74a0 100644 --- a/scrooge-generator/src/main/scala/com/twitter/scrooge/frontend/TypeResolver.scala +++ b/scrooge-generator/src/main/scala/com/twitter/scrooge/frontend/TypeResolver.scala @@ -329,26 +329,21 @@ case class TypeResolver( case m @ MapType(k, v, _) => m.copy(keyType = apply(k), valueType = apply(v)) case s @ SetType(e, _) => s.copy(eltType = apply(e)) case l @ ListType(e, _) => l.copy(eltType = apply(e)) + case at: AnnotatedFieldType => at.copy(underlying = apply(at.unwrap)) case b: BaseType => b case e: EnumType => e case s: StructType => s } - def apply(c: RHS, fieldType: FieldType): RHS = c match { - // list values and map values look the same in Thrift, but different in Java and Scala - // So we need type information in order to generated correct code. - case l @ ListRHS(elems) => - fieldType match { - case ListType(eltType, _) => l.copy(elems = elems.map(e => apply(e, eltType))) - case SetType(eltType, _) => SetRHS(elems.map(e => apply(e, eltType)).toSet) - case _ => throw TypeMismatchException("Expecting " + fieldType + ", found " + l, c) - } - case m @ MapRHS(elems) => + def apply(c: RHS, fieldType: FieldType): RHS = { + @scala.annotation.tailrec + def loop(fieldType: FieldType, m: MapRHS): RHS = { fieldType match { + case at: AnnotatedFieldType => loop(at.unwrap, m) case MapType(keyType, valType, _) => - m.copy(elems = elems.map { case (k, v) => (apply(k, keyType), apply(v, valType)) }) + m.copy(elems = m.elems.map { case (k, v) => (apply(k, keyType), apply(v, valType)) }) case st @ StructType(structLike: StructLike, _) => - val fieldMultiMap: Map[String, Seq[(String, RHS)]] = elems + val fieldMultiMap: Map[String, Seq[(String, RHS)]] = m.elems .collect { case (StringLiteral(fieldName), value) => (fieldName, value) } @@ -401,29 +396,43 @@ case class TypeResolver( } case _ => throw TypeMismatchException("Expecting " + fieldType + ", found " + m, m) } - case i @ IdRHS(id) => - val (constFieldType, constRHS) = id match { - case sid: SimpleID => - // When the rhs value is a simpleID, it can only be a constant - // defined in the same file - resolveConst(sid) - case qid @ QualifiedID(names) => - fieldType match { - case EnumType(enum, _) => - val resolvedFieldType = resolveFieldType(qid.qualifier) - val value = enum.values - .find(_.sid.name == names.last) - .getOrElse(throw UndefinedSymbolException(qid.fullName, qid)) - (resolvedFieldType, EnumRHS(enum, value)) - case t => resolveConst(qid) - } - } - if (constFieldType != fieldType) - throw TypeMismatchException( - s"Type mismatch: Expecting $fieldType, found ${id.fullName}: $constFieldType", - id - ) - constRHS - case _ => c + } + + c match { + // list values and map values look the same in Thrift, but different in Java and Scala + // So we need type information in order to generated correct code. + case l @ ListRHS(elems) => + fieldType match { + case ListType(eltType, _) => l.copy(elems = elems.map(e => apply(e, eltType))) + case SetType(eltType, _) => SetRHS(elems.map(e => apply(e, eltType)).toSet) + case _ => throw TypeMismatchException("Expecting " + fieldType + ", found " + l, c) + } + case m @ MapRHS(elems) => + loop(fieldType, m) + case i @ IdRHS(id) => + val (constFieldType, constRHS) = id match { + case sid: SimpleID => + // When the rhs value is a simpleID, it can only be a constant + // defined in the same file + resolveConst(sid) + case qid @ QualifiedID(names) => + fieldType match { + case EnumType(enum, _) => + val resolvedFieldType = resolveFieldType(qid.qualifier) + val value = enum.values + .find(_.sid.name == names.last) + .getOrElse(throw UndefinedSymbolException(qid.fullName, qid)) + (resolvedFieldType, EnumRHS(enum, value)) + case t => resolveConst(qid) + } + } + if (constFieldType != fieldType) + throw TypeMismatchException( + s"Type mismatch: Expecting $fieldType, found ${id.fullName}: $constFieldType", + id + ) + constRHS + case _ => c + } } } diff --git a/scrooge-generator/src/main/scala/com/twitter/scrooge/java_generator/ApacheJavaGenerator.scala b/scrooge-generator/src/main/scala/com/twitter/scrooge/java_generator/ApacheJavaGenerator.scala index 20a3456ab..e64f463b5 100644 --- a/scrooge-generator/src/main/scala/com/twitter/scrooge/java_generator/ApacheJavaGenerator.scala +++ b/scrooge-generator/src/main/scala/com/twitter/scrooge/java_generator/ApacheJavaGenerator.scala @@ -3,12 +3,6 @@ package com.twitter.scrooge.java_generator import com.github.mustachejava.DefaultMustacheFactory import com.github.mustachejava.Mustache import com.twitter.scrooge.mustache.ScalaObjectHandler -import com.twitter.scrooge.ast.EnumType -import com.twitter.scrooge.ast.ListType -import com.twitter.scrooge.ast.MapType -import com.twitter.scrooge.ast.ReferenceType -import com.twitter.scrooge.ast.SetType -import com.twitter.scrooge.ast.StructType import com.twitter.scrooge.ast._ import com.twitter.scrooge.backend.GeneratorFactory import com.twitter.scrooge.backend.ServiceOption @@ -187,6 +181,8 @@ class ApacheJavaGenerator( fileNamespace: Option[Identifier] = None ): String = { t match { + case at: AnnotatedFieldType => + typeName(at.unwrap, inContainer, inInit, skipGeneric, fileNamespace) case Void => if (inContainer) "Void" else "void" case OnewayVoid => if (inContainer) "Void" else "void" case TBool => if (inContainer) "Boolean" else "boolean" @@ -226,6 +222,7 @@ class ApacheJavaGenerator( def initField(fieldType: FunctionType, inContainer: Boolean = false): String = { fieldType match { + case at: AnnotatedFieldType => initField(at.unwrap, inContainer) case SetType(eltType: EnumType, _) => s"EnumSet.noneOf(${typeName(eltType)}.class)" case _ => @@ -241,6 +238,7 @@ class ApacheJavaGenerator( def getTypeString(fieldType: FunctionType): String = { fieldType match { + case at: AnnotatedFieldType => getTypeString(at.unwrap) case TString => "TType.STRING" case TBool => "TType.BOOL" case TByte => "TType.BYTE" diff --git a/scrooge-generator/src/main/scala/com/twitter/scrooge/java_generator/FieldTypeController.scala b/scrooge-generator/src/main/scala/com/twitter/scrooge/java_generator/FieldTypeController.scala index 0b221eebe..a1186907e 100644 --- a/scrooge-generator/src/main/scala/com/twitter/scrooge/java_generator/FieldTypeController.scala +++ b/scrooge-generator/src/main/scala/com/twitter/scrooge/java_generator/FieldTypeController.scala @@ -16,57 +16,109 @@ class FieldTypeController(fieldType: FunctionType, generator: ApacheJavaGenerato val init_type_name: String = generator.typeName(fieldType, inInit = true) def is_enum_set: Boolean = fieldType match { case SetType(_: EnumType, _) => true + case at: AnnotatedFieldType => new FieldTypeController(at.unwrap, generator).is_enum_set case _ => false } def init_field: String = generator.initField(fieldType) def init_container_field_prelude: String = generator.initContainerFieldPrelude(fieldType) val nullable: Boolean = generator.isNullableType(fieldType) - val double: Boolean = fieldType == TDouble - val boolean: Boolean = fieldType == TBool + val double: Boolean = fieldType match { + case TDouble => true + case at: AnnotatedFieldType => new FieldTypeController(at.unwrap, generator).double + case _ => false + } + val boolean: Boolean = fieldType match { + case TBool => true + case at: AnnotatedFieldType => new FieldTypeController(at.unwrap, generator).boolean + case _ => false + } val is_container: Boolean = fieldType match { case _: MapType | _: SetType | _: ListType => true + case at: AnnotatedFieldType => new FieldTypeController(at.unwrap, generator).is_container case _ => false } val is_map_or_set: Boolean = fieldType match { case _: MapType | _: SetType => true + case at: AnnotatedFieldType => new FieldTypeController(at.unwrap, generator).is_map_or_set case _ => false } def is_preallocatable: Boolean = is_container && !is_enum_set - val is_enum: Boolean = fieldType.isInstanceOf[EnumType] + val is_enum: Boolean = fieldType match { + case _: EnumType => true + case at: AnnotatedFieldType => new FieldTypeController(at.unwrap, generator).is_enum + case _ => false + } // is the field value effectively an i32 - val base_int_type: Boolean = fieldType != TDouble && fieldType != TBool + val base_int_type: Boolean = !double && !boolean val is_list_or_set: Any = fieldType match { case SetType(x, _) => Map("elem_type" -> new FieldTypeController(x, generator)) case ListType(x, _) => Map("elem_type" -> new FieldTypeController(x, generator)) + case at: AnnotatedFieldType => new FieldTypeController(at.unwrap, generator).is_list_or_set + case _ => false + } + val is_list: Boolean = fieldType match { + case _: ListType => true + case at: AnnotatedFieldType => new FieldTypeController(at.unwrap, generator).is_list + case _ => false + } + val is_map: Boolean = fieldType match { + case _: MapType => true + case at: AnnotatedFieldType => new FieldTypeController(at.unwrap, generator).is_map case _ => false } - val is_list: Boolean = fieldType.isInstanceOf[ListType] - val is_map: Boolean = fieldType.isInstanceOf[MapType] def map_types: Any = fieldType match { case MapType(k, v, _) => Map( "key_type" -> new FieldTypeController(k, generator), "value_type" -> new FieldTypeController(v, generator) ) + case at: AnnotatedFieldType => new FieldTypeController(at.unwrap, generator).map_types + case _ => false + } + val is_binary: Boolean = fieldType match { + case TBinary => true + case at: AnnotatedFieldType => new FieldTypeController(at.unwrap, generator).is_binary + case _ => false + } + val is_typedef: Boolean = fieldType match { + case _: Typedef => true + case at: AnnotatedFieldType => new FieldTypeController(at.unwrap, generator).is_typedef case _ => false } - val is_binary: Boolean = fieldType == TBinary - val is_typedef: Boolean = fieldType.isInstanceOf[Typedef] val is_base_type: Boolean = fieldType match { case Void | TString | TBool | TByte | TI16 | TI32 | TI64 | TDouble => true + case at: AnnotatedFieldType => new FieldTypeController(at.unwrap, generator).is_base_type + case _ => false + } + val is_base_type_or_enum: Boolean = is_base_type || (fieldType match { + case _: EnumType => true + case at: AnnotatedFieldType => + new FieldTypeController(at.unwrap, generator).is_base_type_or_enum + case _ => false + }) + val is_base_type_or_binary: Boolean = is_base_type || is_binary + val is_base_type_not_string: Boolean = is_base_type && (fieldType match { + case TString => false + case at: AnnotatedFieldType => + new FieldTypeController(at.unwrap, generator).is_base_type_not_string + case _ => true + }) + val is_struct: Boolean = fieldType match { + case _: StructType => true // this can be a struct or an exception + case at: AnnotatedFieldType => new FieldTypeController(at.unwrap, generator).is_struct + case _ => false + } + val is_struct_or_enum: Boolean = is_struct || is_enum + val is_void: Boolean = fieldType match { + case Void | OnewayVoid => true + case at: AnnotatedFieldType => new FieldTypeController(at.unwrap, generator).is_void case _ => false } - val is_base_type_or_enum: Boolean = is_base_type || fieldType.isInstanceOf[EnumType] - val is_base_type_or_binary: Boolean = is_base_type || fieldType == TBinary - val is_base_type_not_string: Boolean = is_base_type && fieldType != TString - val is_struct: Boolean = - fieldType.isInstanceOf[StructType] // this can be a struct or an exception - val is_struct_or_enum: Boolean = is_struct || fieldType.isInstanceOf[EnumType] - val is_void: Boolean = fieldType == Void || fieldType == OnewayVoid val has_struct_at_leaf: Boolean = hasStructAtLeaf(fieldType) def get_type: String = { fieldType match { + case at: AnnotatedFieldType => new FieldTypeController(at.unwrap, generator).get_type case TString => "String" case TBool => "Bool" case TByte => "Byte" @@ -93,6 +145,7 @@ class FieldTypeController(fieldType: FunctionType, generator: ApacheJavaGenerato functionType match { case fieldType: FieldType => fieldType match { + case at: AnnotatedFieldType => hasStructAtLeaf(at.unwrap) case StructType(_, _) => true case MapType(keyType, valueType, _) => hasStructAtLeaf(keyType) || hasStructAtLeaf(valueType) diff --git a/scrooge-generator/src/main/scala/com/twitter/scrooge/java_generator/FieldValueMetadataController.scala b/scrooge-generator/src/main/scala/com/twitter/scrooge/java_generator/FieldValueMetadataController.scala index 2e978d8cc..8fbc449e0 100644 --- a/scrooge-generator/src/main/scala/com/twitter/scrooge/java_generator/FieldValueMetadataController.scala +++ b/scrooge-generator/src/main/scala/com/twitter/scrooge/java_generator/FieldValueMetadataController.scala @@ -36,6 +36,10 @@ class FieldValueMetadataController( } def generateMetadata(k: FieldType): String = { - indent(generator.fieldValueMetaData(k, ns), 4, skipFirst = true, addLast = false) + k match { + case at: AnnotatedFieldType => generateMetadata(at.unwrap) + case otherwise => + indent(generator.fieldValueMetaData(otherwise, ns), 4, skipFirst = true, addLast = false) + } } } diff --git a/scrooge-generator/src/main/scala/com/twitter/scrooge/java_generator/PrintConstController.scala b/scrooge-generator/src/main/scala/com/twitter/scrooge/java_generator/PrintConstController.scala index c39fb06d3..173e16e1b 100644 --- a/scrooge-generator/src/main/scala/com/twitter/scrooge/java_generator/PrintConstController.scala +++ b/scrooge-generator/src/main/scala/com/twitter/scrooge/java_generator/PrintConstController.scala @@ -85,6 +85,7 @@ class PrintConstController( private[twitter] def renderConstValue(constant: RHS, fieldType: FieldType): ConstValue = { fieldType match { + case at: AnnotatedFieldType => renderConstValue(constant, at.unwrap) case TString => { val constValue = constant.asInstanceOf[StringLiteral].value new ConstValue(null, "\"" + constValue + "\"") diff --git a/scrooge-generator/src/main/scala/com/twitter/scrooge/swift_generator/SwiftGenerator.scala b/scrooge-generator/src/main/scala/com/twitter/scrooge/swift_generator/SwiftGenerator.scala index 9e86089d8..79752c5e0 100644 --- a/scrooge-generator/src/main/scala/com/twitter/scrooge/swift_generator/SwiftGenerator.scala +++ b/scrooge-generator/src/main/scala/com/twitter/scrooge/swift_generator/SwiftGenerator.scala @@ -2,10 +2,6 @@ package com.twitter.scrooge.swift_generator import com.github.mustachejava.DefaultMustacheFactory import com.github.mustachejava.Mustache -import com.twitter.scrooge.ast.ListType -import com.twitter.scrooge.ast.MapType -import com.twitter.scrooge.ast.ReferenceType -import com.twitter.scrooge.ast.SetType import com.twitter.scrooge.ast._ import com.twitter.scrooge.backend.Generator import com.twitter.scrooge.backend.GeneratorFactory @@ -117,6 +113,8 @@ class SwiftGenerator( fileNamespace: Option[Identifier] = None ): String = { t match { + case at: AnnotatedFieldType => + typeName(at.unwrap, inContainer, inInit, skipGeneric, fileNamespace) case Void => "Void" case OnewayVoid => "Void" case TBool => "Bool" @@ -143,6 +141,7 @@ class SwiftGenerator( def leftElementTypeName(t: FunctionType, skipGeneric: Boolean = false): String = { t match { + case at: AnnotatedFieldType => leftElementTypeName(at.unwrap, skipGeneric) case MapType(k, v, _) => typeName(k, inContainer = true, skipGeneric = skipGeneric) case SetType(x, _) => typeName(x, inContainer = true, skipGeneric = skipGeneric) case ListType(x, _) => typeName(x, inContainer = true, skipGeneric = skipGeneric) @@ -152,6 +151,7 @@ class SwiftGenerator( def rightElementTypeName(t: FunctionType, skipGeneric: Boolean = false): String = { t match { + case at: AnnotatedFieldType => rightElementTypeName(at.unwrap, skipGeneric) case MapType(k, v, _) => typeName(v, inContainer = true, skipGeneric = skipGeneric) case _ => "" } @@ -159,6 +159,7 @@ class SwiftGenerator( def isListOrSetType(t: FunctionType): Boolean = { t match { + case at: AnnotatedFieldType => isListOrSetType(at.unwrap) case ListType(_, _) => true case SetType(_, _) => true case _ => false diff --git a/scrooge-generator/src/main/scala/com/twitter/scrooge/swift_generator/SwiftPrintConstController.scala b/scrooge-generator/src/main/scala/com/twitter/scrooge/swift_generator/SwiftPrintConstController.scala index 8e80fda12..5a8941abf 100644 --- a/scrooge-generator/src/main/scala/com/twitter/scrooge/swift_generator/SwiftPrintConstController.scala +++ b/scrooge-generator/src/main/scala/com/twitter/scrooge/swift_generator/SwiftPrintConstController.scala @@ -1,25 +1,6 @@ package com.twitter.scrooge.swift_generator -import com.twitter.scrooge.ast.EnumType -import com.twitter.scrooge.ast.FieldType -import com.twitter.scrooge.ast.Identifier -import com.twitter.scrooge.ast.IntLiteral -import com.twitter.scrooge.ast.ListRHS -import com.twitter.scrooge.ast.ListType -import com.twitter.scrooge.ast.MapRHS -import com.twitter.scrooge.ast.MapType -import com.twitter.scrooge.ast.RHS -import com.twitter.scrooge.ast.SetRHS -import com.twitter.scrooge.ast.SetType -import com.twitter.scrooge.ast.StructRHS -import com.twitter.scrooge.ast.StructType -import com.twitter.scrooge.ast.TBool -import com.twitter.scrooge.ast.TByte -import com.twitter.scrooge.ast.TDouble -import com.twitter.scrooge.ast.TI16 -import com.twitter.scrooge.ast.TI32 -import com.twitter.scrooge.ast.TI64 -import com.twitter.scrooge.ast.TString +import com.twitter.scrooge.ast._ import com.twitter.scrooge.frontend.ScroogeInternalException import com.twitter.scrooge.java_generator.ConstValue import com.twitter.scrooge.java_generator.PrintConstController @@ -101,6 +82,7 @@ class SwiftPrintConstController( override def renderConstValue(constant: RHS, fieldType: FieldType): ConstValue = { fieldType match { + case at: AnnotatedFieldType => renderConstValue(constant, at.unwrap) case TByte | TI16 | TI32 | TI64 => new ConstValue( null,