Skip to content

Commit

Permalink
scrooge: Add camel case aliases for union field names
Browse files Browse the repository at this point in the history
Problem

Union field names in generated Scala code are not camel cased when
they have underscores in them.

Solution

Generate camel cased aliases as necessary. The previous names are left
intact in order to preserve binary compatibility.

Result

Given a union such as:

  union TheUnion {
    1: Field1 a_field
    ...

The generated code for Field1 will look something like:

  case class Field1(a_field: Field1Alias) extends TheUnion {
    /** An alias for `a_field` */
    def aField: Field1Alias = a_field

RB_ID=638688
  • Loading branch information
kevinoliver authored and jenkins committed Apr 20, 2015
1 parent f1f85be commit 12366f4
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 61 deletions.
6 changes: 5 additions & 1 deletion scrooge-generator/src/main/resources/scalagen/union.scala
Expand Up @@ -124,7 +124,6 @@ object {{StructName}} extends ThriftStructCodec3[{{StructName}}] {
{{/fields|,}}
)


override def encode(_item: {{StructName}}, _oprot: TProtocol): Unit = { _item.write(_oprot) }
override def decode(_iprot: TProtocol): {{StructName}} = {{StructName}}Decoder(_iprot, UnknownUnionField(_))

Expand Down Expand Up @@ -177,6 +176,11 @@ object {{StructName}} extends ThriftStructCodec3[{{StructName}}] {
}

case class {{FieldName}}({{fieldName}}: {{FieldName}}Alias{{#hasDefaultValue}} = {{FieldName}}DefaultValue{{/hasDefaultValue}}) extends {{StructName}} {
{{#fieldNameCamelCase}}
/** An alias for `{{fieldName}}` */
def {{fieldNameCamelCase}}: {{FieldName}}Alias = {{fieldName}}
{{/fieldNameCamelCase}}

override def write(_oprot: TProtocol): Unit = {
{{^isPrimitive}}
if ({{fieldName}} == null)
Expand Down
Expand Up @@ -6,22 +6,21 @@ import com.twitter.scrooge.ast.{Enum, Identifier}

trait EnumTemplate { self: TemplateGenerator =>
def enumDict(
namespace: Identifier,
enum: Enum
): Dictionary =
namespace: Identifier,
enum: Enum
): Dictionary =
Dictionary(
"package" -> genID(namespace),
"EnumName" -> genID(enum.sid.toTitleCase),
"docstring" -> codify(enum.docstring.getOrElse("")),
"values" -> v(enum.values map {
value =>
Dictionary(
"valuedocstring" -> codify(value.docstring.getOrElse("")),
"name" -> genID(value.sid),
"originalName" -> codify(value.sid.originalName),
"unquotedNameLowerCase" -> codify(value.sid.fullName.toLowerCase),
"value" -> codify(value.value.toString)
)
"values" -> v(enum.values.map { value =>
Dictionary(
"valuedocstring" -> codify(value.docstring.getOrElse("")),
"name" -> genID(value.sid),
"originalName" -> codify(value.sid.originalName),
"unquotedNameLowerCase" -> codify(value.sid.fullName.toLowerCase),
"value" -> codify(value.value.toString)
)
})
)
}
Expand Up @@ -16,15 +16,13 @@ package com.twitter.scrooge.backend
* limitations under the License.
*/

import java.io.{OutputStreamWriter, FileOutputStream, File}
import scala.collection.mutable
import com.twitter.scrooge.mustache.HandlebarLoader
import com.twitter.finagle.util.LoadService
import com.twitter.scrooge.ast._
import com.twitter.scrooge.mustache.Dictionary
import com.twitter.scrooge.frontend.{ResolvedDocument, ScroogeInternalException}
import com.twitter.scrooge.java_generator.ApacheJavaGeneratorFactory
import scala.collection.JavaConverters._
import com.twitter.scrooge.frontend.{ScroogeInternalException, ResolvedDocument}
import com.twitter.finagle.util.LoadService
import com.twitter.scrooge.mustache.{Dictionary, HandlebarLoader}
import java.io.{File, FileOutputStream, OutputStreamWriter}
import scala.collection.mutable

abstract sealed class ServiceOption

Expand Down Expand Up @@ -158,7 +156,7 @@ trait Generator extends ThriftGenerator {
// methods that convert AST nodes to CodeFragment
def genID(data: Identifier): CodeFragment = data match {
case SimpleID(name, _) => codify(quoteKeyword(name))
case QualifiedID(names) => codify(names.map { quoteKeyword(_) }.mkString("."))
case QualifiedID(names) => codify(names.map(quoteKeyword).mkString("."))
}

def genConstant(constant: RHS, mutable: Boolean = false, fieldType: Option[FieldType] = None): CodeFragment = {
Expand Down Expand Up @@ -338,7 +336,7 @@ trait TemplateGenerator extends Generator
val namespace = getNamespace(_doc)
val packageDir = namespacedFolder(outputPath, namespace.fullName, dryRun)
val includes = doc.headers.collect {
case x@ Include(_, doc) => x
case x@Include(_, _) => x
}

if (doc.consts.nonEmpty) {
Expand Down
Expand Up @@ -65,37 +65,37 @@ class ScalaGenerator(
else
str

def normalizeCase[N <: Node](node: N) = {
def normalizeCase[N <: Node](node: N): N = {
(node match {
case d: Document =>
d.copy(defs = d.defs.map(normalizeCase(_)))
d.copy(defs = d.defs.map(normalizeCase))
case id: Identifier => id.toTitleCase
case e: EnumRHS =>
e.copy(normalizeCase(e.enum), normalizeCase(e.value))
case f: Field =>
f.copy(
sid = f.sid.toCamelCase,
default = f.default.map(normalizeCase(_)))
default = f.default.map(normalizeCase))
case f: Function =>
f.copy(
args = f.args.map(normalizeCase(_)),
throws = f.throws.map(normalizeCase(_)))
args = f.args.map(normalizeCase),
throws = f.throws.map(normalizeCase))
case c: ConstDefinition =>
c.copy(value = normalizeCase(c.value))
case e: Enum =>
e.copy(values = e.values.map(normalizeCase(_)))
e.copy(values = e.values.map(normalizeCase))
case e: EnumField =>
e.copy(sid = e.sid.toTitleCase)
case s: Struct =>
s.copy(fields = s.fields.map(normalizeCase(_)))
s.copy(fields = s.fields.map(normalizeCase))
case f: FunctionArgs =>
f.copy(fields = f.fields.map(normalizeCase(_)))
f.copy(fields = f.fields.map(normalizeCase))
case f: FunctionResult =>
f.copy(fields = f.fields.map(normalizeCase(_)))
f.copy(fields = f.fields.map(normalizeCase))
case e: Exception_ =>
e.copy(fields = e.fields.map(normalizeCase(_)))
e.copy(fields = e.fields.map(normalizeCase))
case s: Service =>
s.copy(functions = s.functions.map(normalizeCase(_)))
s.copy(functions = s.functions.map(normalizeCase))
case n => n
}).asInstanceOf[N]
}
Expand Down Expand Up @@ -148,7 +148,7 @@ class ScalaGenerator(

def genStruct(struct: StructRHS): CodeFragment = {
val values = struct.elems
val fields = values map { case (f, value) =>
val fields = values.map { case (f, value) =>
val v = genConstant(value)
genID(f.sid.toCamelCase) + "=" + (if (f.requiredness.isOptional) "Some(" + v + ")" else v)
}
Expand Down
Expand Up @@ -110,6 +110,12 @@ trait StructTemplate { self: TemplateGenerator =>
fields.zipWithIndex map {
case (field, index) =>
val valueVariableID = field.sid.append("_item")
val fieldName = genID(field.sid)
val camelCaseFieldName = if (fieldName.toString.indexOf('_') >= 0)
genID(field.sid.toCamelCase)
else
NoValue

Dictionary(
"index" -> codify(index.toString),
"indexP1" -> codify((index + 1).toString),
Expand All @@ -120,8 +126,9 @@ trait StructTemplate { self: TemplateGenerator =>
"readBlobName" -> genID(field.sid.toTitleCase.prepend("read").append("Blob")),
"getName" -> genID(field.sid.toTitleCase.prepend("get")), // for Java only
"isSetName" -> genID(field.sid.toTitleCase.prepend("isSet")), // for Java only
"fieldName" -> genID(field.sid),
"fieldName" -> fieldName,
"fieldNameForWire" -> codify(field.originalName),
"fieldNameCamelCase" -> camelCaseFieldName,
"newFieldName" -> genID(field.sid.toTitleCase.prepend("new")),
"FieldName" -> genID(field.sid.toTitleCase),
"FIELD_NAME" -> genID(field.sid.toUpperCase),
Expand Down Expand Up @@ -249,28 +256,25 @@ trait StructTemplate { self: TemplateGenerator =>
}

private def exceptionMsgFieldName(struct: StructLike): Option[SimpleID] = {
val msgField: Option[Field] = struct.fields find {
field =>
val msgField: Option[Field] = struct.fields.find { field =>
// 1st choice: find a field called message
field.sid.name == "message"
} orElse {
field.sid.name == "message"
}.orElse {
// 2nd choice: the first string field
struct.fields find {
struct.fields.find {
field => field.fieldType == TString
}
}

msgField map {
_.sid
}
msgField.map { _.sid }
}

def structDict(
struct: StructLike,
namespace: Option[Identifier],
includes: Seq[Include],
serviceOptions: Set[ServiceOption]
) = {
): Dictionary = {
val isStruct = struct.isInstanceOf[Struct]
val isException = struct.isInstanceOf[Exception_]
val isUnion = struct.isInstanceOf[Union]
Expand All @@ -288,32 +292,34 @@ trait StructTemplate { self: TemplateGenerator =>
}
val arity = struct.fields.size
val product = if (arity >= 1 && arity <= 22) {
val fieldTypes = struct.fields.map {
f => genFieldType(f).toData
val fieldTypes = struct.fields.map { f =>
genFieldType(f).toData
}.mkString(", ")
"scala.Product" + arity + "[" + fieldTypes + "]"
s"scala.Product$arity[$fieldTypes]"
} else {
"scala.Product"
}

val exceptionMsgField: Option[SimpleID] = if (isException) exceptionMsgFieldName(struct) else None
val exceptionMsgField: Option[SimpleID] =
if (isException) exceptionMsgFieldName(struct) else None

val fieldDictionaries = fieldsToDict(
struct.fields,
if (isException) Seq("message") else Seq())
if (isException) Seq("message") else Nil)

val isPublic = namespace.isDefined
val structName = if (isPublic) genID(struct.sid.toTitleCase) else genID(struct.sid)

Dictionary(
"public" -> v(isPublic),
"package" -> namespace.map{ genID(_) }.getOrElse(codify("")),
"package" -> namespace.map(genID).getOrElse(codify("")),
"docstring" -> codify(struct.docstring.getOrElse("")),
"parentType" -> codify(parentType),
"fields" -> v(fieldDictionaries),
"defaultFields" -> v(fieldsToDict(struct.fields.filter(!_.requiredness.isOptional), Seq())),
"defaultFields" -> v(fieldsToDict(struct.fields.filter(!_.requiredness.isOptional), Nil)),
"alternativeConstructor" -> v(
struct.fields.exists(_.requiredness.isOptional) && struct.fields.exists(_.requiredness.isDefault)),
struct.fields.exists(_.requiredness.isOptional)
&& struct.fields.exists(_.requiredness.isDefault)),
"StructNameForWire" -> codify(struct.originalName),
"StructName" ->
// if isPublic, the struct comes from a Thrift definition. Otherwise
Expand All @@ -324,7 +330,7 @@ trait StructTemplate { self: TemplateGenerator =>
"arity" -> codify(arity.toString),
"isException" -> v(isException),
"hasExceptionMessage" -> v(exceptionMsgField.isDefined),
"exceptionMessageField" -> exceptionMsgField.map { genID(_) }.getOrElse { codify("")},
"exceptionMessageField" -> exceptionMsgField.map(genID).getOrElse { codify("")},
"product" -> codify(product),
"arity0" -> v(arity == 0),
"arity1" -> v((if (arity == 1) fieldDictionaries.take(1) else Nil)),
Expand All @@ -347,4 +353,3 @@ object StructTemplate {
Dictionary("pairs" -> v(pairDicts))
}
}

Expand Up @@ -62,6 +62,7 @@ object Dictionary {
* Wrap generated code fragments in the form of Strings in a dictionary value.
*/
def codify(code: String): CodeFragment = CodeFragment(code)

/**
* Wrap a boolean flag in a dictionary value.
*/
Expand All @@ -87,7 +88,7 @@ object Dictionary {
*/
def v(data: Handlebar): Value = PartialValue(data)

def apply(values: (String, Value)*) = new Dictionary ++= (values: _*)
def apply(values: (String, Value)*): Dictionary = new Dictionary ++= (values: _*)
}

case class Dictionary private(
Expand All @@ -109,28 +110,28 @@ case class Dictionary private(
}.getOrElse(NoValue)
}

def update(key: String, data: String) {
def update(key: String, data: String): Unit = {
map(key) = CodeFragment(data)
}

def update(key: String, data: Boolean) {
def update(key: String, data: Boolean): Unit = {
map(key) = BooleanValue(data)
}

def update(key: String, data: Seq[Dictionary]) {
def update(key: String, data: Seq[Dictionary]): Unit = {
map(key) = ListValue(data)
}

def update(key: String, data: Handlebar) {
def update(key: String, data: Handlebar): Unit = {
map(key) = PartialValue(data)
}

def ++=(values: (String, Value)*) = {
def ++=(values: (String, Value)*): Dictionary = {
map ++= values.toMap
this
}

def +(dict: Dictionary) = {
def +(dict: Dictionary): Dictionary = {
new Dictionary() ++= (this.map.toSeq: _*) ++= (dict.map.toSeq: _*)
}
}
Expand Up @@ -185,6 +185,15 @@ class ScalaGeneratorSpec extends JMockSpec with EvalHelper {
decodedNew must be(UnionPostEvolution.NewField(unionField))
}

"produce aliases for union fields with underscores" in { _ =>
val unionField = NewUnionField(
14653230,
SomeInnerUnionStruct(26, "a_a")
)
val union = UnionWithUnderscores.NewField(unionField)
assert(union.newField == union.new_field)
}

"be identified as an ENUM" in { _ =>
EnumStruct.NumberField.`type` must be(TType.ENUM)
}
Expand Down
5 changes: 5 additions & 0 deletions scrooge-generator/src/test/thrift/standalone/union.thrift
Expand Up @@ -23,6 +23,11 @@ union UnionPostEvolution {
2: NewUnionField newField
}

union UnionWithUnderscores {
1: OldUnionField old_field
2: NewUnionField new_field
}

struct MatchingStructField { 1: i64 id }
struct MatchingStructList { 1: i64 id }
struct MatchingStructSet { 1: i64 id }
Expand Down

0 comments on commit 12366f4

Please sign in to comment.