Skip to content

Commit

Permalink
scrooge: Support union constants
Browse files Browse the repository at this point in the history
Problem

Scrooge fails to compile thrift with a const value of union type, e.g.

```
union U {
  1: i32 a,
  2: string b
}

const U u = { "a": 3 }
```

Solution

Add separate handling for unions.

RB_ID=718749
  • Loading branch information
nshkrob committed Jul 27, 2015
1 parent 2f98399 commit 8554aa1
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 30 deletions.
Expand Up @@ -12,5 +12,6 @@ case class ListRHS(elems: Seq[RHS]) extends RHS
case class SetRHS(elems: Set[RHS]) extends RHS
case class MapRHS(elems: Seq[(RHS, RHS)]) extends RHS
case class StructRHS(sid: SimpleID, elems: Map[Field, RHS]) extends RHS
case class UnionRHS(sid: SimpleID, field: Field, initializer: RHS) extends RHS
case class EnumRHS(enum: Enum, value: EnumField) extends RHS
case class IdRHS(id: Identifier) extends RHS
Expand Up @@ -126,6 +126,8 @@ class CocoaGenerator(
throw new Exception("not implemented")
def genStruct(struct: StructRHS): CodeFragment =
throw new Exception("not implemented")
def genUnion(struct: UnionRHS): CodeFragment =
throw new Exception("not implemented")

// For mutability/immutability support, not implemented
def genToImmutable(t: FieldType): CodeFragment = codify("")
Expand Down
Expand Up @@ -170,6 +170,7 @@ trait Generator extends ThriftGenerator {
case c: EnumRHS => genEnum(c, fieldType)
case iv@IdRHS(id) => genID(id)
case s: StructRHS => genStruct(s)
case u: UnionRHS => genUnion(u)
}
}

Expand All @@ -183,6 +184,8 @@ trait Generator extends ThriftGenerator {

def genStruct(struct: StructRHS): CodeFragment

def genUnion(union: UnionRHS): CodeFragment

/**
* The default value for the specified type and mutability.
*/
Expand Down
Expand Up @@ -151,8 +151,9 @@ class JavaGenerator(
}

// TODO
def genStruct(struct: StructRHS): CodeFragment =
throw new Exception("not implemented")
def genStruct(struct: StructRHS): CodeFragment = ???

def genUnion(union: UnionRHS): CodeFragment = ???

override def genDefaultValue(fieldType: FieldType): CodeFragment = {
val code = fieldType match {
Expand Down
Expand Up @@ -172,6 +172,12 @@ class ScalaGenerator(
codify(genID(struct.sid) + "(" + fields.mkString(", ") + ")")
}

def genUnion(union: UnionRHS): CodeFragment = {
val fieldId = genID(union.field.sid.toTitleCase)
val rhs = genConstant(union.initializer)
codify(s"${genID(union.sid)}.$fieldId($rhs)")
}

override def genDefaultValue(fieldType: FieldType): CodeFragment = {
val code = fieldType match {
case TI64 => "0L"
Expand Down
Expand Up @@ -201,23 +201,47 @@ case class TypeResolver(
fieldType match {
case MapType(keyType, valType, _) =>
m.copy(elems = elems.map { case (k, v) => (apply(k, keyType), apply(v, valType)) })
case st @ StructType(s, _) =>
val structMap = Map.newBuilder[Field, RHS]
s.fields.foreach { f =>
val filtered = elems.filter {
case (StringLiteral(fieldName), _) => fieldName == f.sid.name
case _ => false
}
if (filtered.size == 1) {
val (k, v) = filtered.head
structMap += f -> apply(v, f.fieldType)
} else if (filtered.size > 1) {
throw new TypeMismatchException(s"Duplicate default values for ${f.sid.name} found for $fieldType", m)
} else if (!f.requiredness.isOptional && f.default.isEmpty) {
throw new TypeMismatchException(s"Value required for ${f.sid.name} in $fieldType", m)
}
case st @ StructType(structLike: StructLike, _) =>
val fieldMultiMap: Map[String, Seq[(String, RHS)]] = elems.collect {
case (StringLiteral(fieldName), value) => (fieldName, value)
}.groupBy { case (fieldName, _) => fieldName }

val fieldMap: Map[String, RHS] = fieldMultiMap.collect {
case (fieldName: String, values: Seq[(String, RHS)]) if values.length == 1 =>
values.head
case (fieldName: String, _: Seq[(String, RHS)]) =>
throw new TypeMismatchException(s"Duplicate default values for ${fieldName} found for $fieldType", m)
// Can't have 0 elements here because fieldMultiMap is built by groupBy.
}

structLike match {
case u: Union =>
val definedFields = u.fields.collect {
case field if fieldMap.contains(field.sid.name) =>
(field, fieldMap(field.sid.name))
}
if (definedFields.length == 0)
throw new UndefinedConstantException(s"Constant value missing for union ${u.originalName}", m)
if (definedFields.length > 1)
throw new UndefinedConstantException(s"Multiple constant values for union ${u.originalName}", m)

val (field, rhs) = definedFields.head
val resolvedRhs = apply(rhs, field.fieldType)
UnionRHS(sid = st.sid, field = field, initializer = resolvedRhs)

case struct: StructLike =>
val structMap = Map.newBuilder[Field, RHS]
struct.fields.foreach { field =>
val fieldName = field.sid.name
if (fieldMap.contains(fieldName)) {
val resolvedRhs = apply(fieldMap(fieldName), field.fieldType)
structMap += field -> resolvedRhs
} else if (!field.requiredness.isOptional && field.default.isEmpty) {
throw new TypeMismatchException(s"Value required for ${fieldName} in $fieldType", m)
}
}
StructRHS(sid = st.sid, elems = structMap.result())
}
StructRHS(sid = st.sid, elems = structMap.result())
case _ => throw new TypeMismatchException("Expecting " + fieldType + ", found " + m, m)
}
case i @ IdRHS(id) => {
Expand Down
Expand Up @@ -49,15 +49,27 @@ class PrintConstController(
}
}

def struct_values = {
val values = value.asInstanceOf[StructRHS].elems
val structType = fieldType.asInstanceOf[StructType]
for {
f <- structType.struct.fields
v <- values.get(f)
} yield {
val renderedValue = renderConstValue(v, f.fieldType)
Map("key" -> f.sid.name, "value" -> renderedValue.value, "rendered_value" -> renderedValue.rendered)
def struct_values: Seq[Map[String, String]] = {
value match {
case struct: StructRHS =>
val values = value.asInstanceOf[StructRHS].elems
val structType = fieldType.asInstanceOf[StructType]
for {
f <- structType.struct.fields
v <- values.get(f)
} yield {
val renderedValue = renderConstValue(v, f.fieldType)
Map(
"key" -> f.sid.name,
"value" -> renderedValue.value,
"rendered_value" -> renderedValue.rendered)
}
case union: UnionRHS =>
val renderedValue = renderConstValue(union.initializer, union.field.fieldType)
Seq(Map(
"key" -> union.field.sid.name,
"value" -> renderedValue.value,
"rendered_value" -> renderedValue.rendered))
}
}

Expand Down
8 changes: 6 additions & 2 deletions scrooge-generator/src/test/scala/BUILD
@@ -1,4 +1,4 @@
scala_library(name='scala',
junit_tests(name='scala',
dependencies=[
'3rdparty/jvm/org/jmock',
'3rdparty/jvm/junit',
Expand All @@ -11,7 +11,11 @@ scala_library(name='scala',
'scrooge/scrooge-runtime',
'util/util-core',
],
sources=rglobs('*.scala'),
sources=globs('com/twitter/scrooge/ASTSpec.scala',
'com/twitter/scrooge/TReusableMemoryTransportSpec.scala',
'com/twitter/scrooge/frontend/*.scala',
'com/twitter/scrooge/mustache/*.scala',
'com/twitter/scrooge/testutil/*.scala'),
resources=[
'scrooge/scrooge-generator/src/test/resources:resources'
],
Expand Down
Expand Up @@ -4,7 +4,7 @@ import com.twitter.scrooge.ast._
import com.twitter.scrooge.testutil.Spec

class TypeResolverSpec extends Spec {
"TypeResolve" should {
"TypeResolver" should {
val foo = EnumField(SimpleID("FOO"), 1, None)
val bar = EnumField(SimpleID("BAR"), 2, Some("/** I am a doc. */"))
val enum = Enum(SimpleID("SomeEnum"), Seq(foo, bar), None)
Expand Down Expand Up @@ -252,6 +252,47 @@ class TypeResolverSpec extends Spec {
}
}

"initialize union constants" in {
val input =
"""union U {
| 1: i32 a,
| 2: string b
|}
|
|const U c = { "a": 3 }
""".stripMargin

resolve(input)
}

"require union initializers" in {
val input =
"""union U {
| 1: i32 a,
| 2: string b
|}
|
|const U c = { }
""".stripMargin
val ex = intercept[UndefinedConstantException] {
resolve(input)
}
}

"fail for union initializers with multiple fields." in {
val input =
"""union U {
| 1: i32 a,
| 2: string b
|}
|
|const U c = { "a": 3, "b": "b" }
""".stripMargin
val ex = intercept[UndefinedConstantException] {
resolve(input)
}
}

def resolve(input: String): ResolvedDocument = {
val parser = new ThriftParser(Importer("."), strict = true)

Expand Down
21 changes: 21 additions & 0 deletions scrooge-generator/src/test/thrift/defaults/rhs_structs.thrift
Expand Up @@ -47,3 +47,24 @@ const list<StructB> ListOfComplexStructs = [
{"snake_case_field": 1, "camelCaseField": 2, "required_field": 3, "struct_field": {"id": 1}, "default_field": 1},
{"snake_case_field": 1, "camelCaseField": 2, "required_field": 3, "struct_field": {"id": 1}}
]

union SimpleUnion {
1: i32 a,
2: string b
}

const SimpleUnion sss = { "a": 3 }

union ComplexUnion {
1: StructA a,
2: StructB b
}

const ComplexUnion ccc = {
"b": {
"snake_case_field": 1,
"camelCaseField": 2,
"required_field": 3,
"struct_field": {"id": 1}
}
}

0 comments on commit 8554aa1

Please sign in to comment.