Skip to content

Commit

Permalink
Implement ASCII format parser and generator in Scala
Browse files Browse the repository at this point in the history
  • Loading branch information
thesamet committed Jan 1, 2016
1 parent 91eb7c3 commit 44af26e
Show file tree
Hide file tree
Showing 19 changed files with 1,141 additions and 61 deletions.
5 changes: 4 additions & 1 deletion build.sbt
Expand Up @@ -56,9 +56,12 @@ lazy val runtime = crossProject.crossType(CrossType.Full).in(file("scalapb-runti
name := "scalapb-runtime",
libraryDependencies ++= Seq(
"com.trueaccord.lenses" %%% "lenses" % "0.4.4",
"com.lihaoyi" %%% "fastparse" % "0.3.4",
"com.lihaoyi" %%% "utest" % "0.3.1" % "test",
"org.scalacheck" %% "scalacheck" % "1.12.5" % "test",
"org.scalatest" %% "scalatest" % (if (scalaVersion.value.startsWith("2.12")) "2.2.5-M2" else "2.2.5") % "test"
),
testFrameworks += new TestFramework("utest.runner.Framework"),
unmanagedResourceDirectories in Compile += baseDirectory.value / "../../protobuf"
)
.jvmSettings(
Expand All @@ -70,7 +73,7 @@ lazy val runtime = crossProject.crossType(CrossType.Full).in(file("scalapb-runti
.jsSettings(
// Add JS-specific settings here
libraryDependencies ++= Seq(
"com.trueaccord.scalapb" %%% "protobuf-runtime-scala" % "0.1.4"
"com.trueaccord.scalapb" %%% "protobuf-runtime-scala" % "0.1.5-SNAPSHOT"
),
unmanagedResourceDirectories in Compile += baseDirectory.value / "../../third_party"
)
Expand Down
Expand Up @@ -276,6 +276,9 @@ trait DescriptorPimps {
def isTopLevel = enum.getContainingType == null

def javaTypeName = enum.getFile.fullJavaName(enum.getFullName)

def valuesWithNoDuplicates = enum.getValues.groupBy(_.getNumber)
.mapValues(_.head).values.toVector.sortBy(_.getNumber)
}

implicit class EnumValueDescriptorPimp(val enumValue: EnumValueDescriptor) {
Expand Down
Expand Up @@ -4,6 +4,8 @@ import com.google.protobuf.Descriptors._
import com.google.protobuf.CodedOutputStream
import com.google.protobuf.{ByteString => GoogleByteString}
import com.google.protobuf.compiler.PluginProtos.{CodeGeneratorRequest, CodeGeneratorResponse}
import com.trueaccord.scalapb.TextFormat
import com.trueaccord.scalapb.textformat.TextFormatUtils
import scala.collection.JavaConversions._

case class GeneratorParams(javaConversions: Boolean = false, flatPackage: Boolean = false, grpc: Boolean = false)
Expand Down Expand Up @@ -47,7 +49,7 @@ class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps {
|""")
.add(s"lazy val values = Seq(${e.getValues.map(_.getName.asSymbol).mkString(", ")})")
.add(s"def fromValue(value: Int): $name = value match {")
.print(e.getValues) {
.print(e.valuesWithNoDuplicates) {
case (v, p) => p.add(s" case ${v.getNumber} => ${v.getName.asSymbol}")
}
.add(s" case __other => Unrecognized(__other)")
Expand Down Expand Up @@ -104,14 +106,18 @@ class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps {
.add("}")
}

def escapeString(raw: String): String = {
import scala.reflect.runtime.universe._
Literal(Constant(raw)).toString
}

def byteArrayAsBase64Literal(buffer: Array[Byte]): String = {
"\"\"\"" + new sun.misc.BASE64Encoder().encode(buffer) + "\"\"\""
}
def escapeString(raw: String): String = raw.map {
case u if u >= ' ' && u <= '~' => u.toString
case '\b' => "\\b"
case '\f' => "\\f"
case '\n' => "\\n"
case '\r' => "\\r"
case '\t' => "\\t"
case '\\' => "\\\\"
case '\"' => "\\\""
case '\'' => "\\\'"
case c: Char => "\\u%4s".format(c.toInt.toHexString).replace(' ','0')
}.mkString("\"", "", "\"")

def defaultValueForGet(field: FieldDescriptor, uncustomized: Boolean = false) = {
// Needs to be 'def' and not val since for some of the cases it's invalid to call it.
Expand Down Expand Up @@ -253,8 +259,26 @@ class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps {
.add("__field.getNumber match {")
.indent
.print(message.fields) {
case (f, fp) => val e = toBaseFieldType(f).apply(fieldAccessorSymbol(f), isCollection = !f.isSingular)
fp.add(s"case ${f.getNumber} => $e")
case (f, fp) =>
val e = toBaseFieldType(f)
.apply(fieldAccessorSymbol(f), isCollection = !f.isSingular)
if (f.supportsPresence || f.isInOneof)
fp.add(s"case ${f.getNumber} => $e.getOrElse(null)")
else if (f.isOptional) {
// In proto3, drop default value
fp.add(s"case ${f.getNumber} => {")
.indent
.add(s"val __t = $e")
.add({
val cond = if (!f.isEnum)
s"__t != ${defaultValueForGet(f, uncustomized = true)}"
else
s"__t.getNumber() != 0"
s"if ($cond) __t else null"
})
.outdent
.add("}")
} else fp.add(s"case ${f.getNumber} => $e")
}
.outdent
.add("}")
Expand Down Expand Up @@ -347,15 +371,15 @@ class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps {
.add("lazy val serializedSize: Int = {")
.indent
.add("var __size = 0")
.print(message.getFields)(generateSerializedSizeForField)
.print(message.fields)(generateSerializedSizeForField)
.add("__size")
.outdent
.add("}")
}

def generateSerializedSizeForPackedFields(message: Descriptor)(fp: FunctionalPrinter) =
fp
.print(message.getFields.filter(_.isPacked).zipWithIndex) {
.print(message.fields.filter(_.isPacked).zipWithIndex) {
case ((field, index), printer) =>
printer
.add(s"lazy val ${field.scalaName}SerializedSize =")
Expand Down Expand Up @@ -383,7 +407,7 @@ class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps {
def generateWriteTo(message: Descriptor)(fp: FunctionalPrinter) =
fp.add(s"def writeTo(output: com.google.protobuf.CodedOutputStream): Unit = {")
.indent
.print(message.getFields.sortBy(_.getNumber).zipWithIndex) {
.print(message.fields.sortBy(_.getNumber).zipWithIndex) {
case ((field, index), printer) =>
val fieldNameSymbol = fieldAccessorSymbol(field)
val capTypeName = Types.capitalizedType(field.getType)
Expand Down Expand Up @@ -432,7 +456,7 @@ class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps {
.add("}")

def printConstructorFieldList(message: Descriptor)(printer: FunctionalPrinter): FunctionalPrinter = {
val regularFields = message.getFields.collect {
val regularFields = message.fields.collect {
case field if !field.isInOneof =>
val typeName = field.scalaTypeName
val ctorDefaultValue =
Expand Down Expand Up @@ -472,7 +496,7 @@ class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps {
| val _tag__ = __input.readTag()
| _tag__ match {
| case 0 => _done__ = true""")
.print(message.getFields) {
.print(message.fields) {
(field, printer) =>
if (!field.isPacked) {
val newValBase = if (field.isMessage) {
Expand Down Expand Up @@ -547,7 +571,7 @@ class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps {
printer.add(s"def toJavaProto(scalaPbSource: $myFullScalaName): ${message.javaTypeName} = {")
.indent
.add(s"val javaPbOut = ${message.javaTypeName}.newBuilder")
.print(message.getFields) {
.print(message.fields) {
case (field, printer) =>
printer.add(assignScalaFieldToJava("scalaPbSource", "javaPbOut", field))
}
Expand All @@ -562,7 +586,7 @@ class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps {
.indent
.call {
printer =>
val normal = message.getFields.collect {
val normal = message.fields.collect {
case field if !field.isInOneof =>
val conversion = if (field.isMap) javaMapFieldToScala("javaPbSource", field)
else javaFieldToScala("javaPbSource", field)
Expand All @@ -587,8 +611,9 @@ class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps {

def generateFromFieldsMap(message: Descriptor)(printer: FunctionalPrinter): FunctionalPrinter = {
def transform(field: FieldDescriptor) =
(if (!field.isEnum) Identity else (MethodApplication("getNumber") andThen
FunctionApplication(field.getEnumType.scalaTypeName + ".fromValue"))) andThen
(if (!field.isEnum) Identity else (
MethodApplication("getNumber") andThen
FunctionApplication(field.getEnumType.scalaTypeName + ".fromValue"))) andThen
toCustomTypeExpr(field)

val myFullScalaName = message.scalaTypeName
Expand All @@ -600,15 +625,21 @@ class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps {
.indent
.call {
printer =>
val fields = message.getFields.collect {
val fields = message.fields.collect {
case field if !field.isInOneof =>
val baseTypeName = field.typeCategory(if (field.isEnum) "com.google.protobuf.Descriptors.EnumValueDescriptor" else field.baseSingleScalaTypeName)
val e = if (field.isOptional)
s"__fieldsMap.getOrElse(__fields.get(${field.getIndex}), None).asInstanceOf[$baseTypeName]"
val e = if (field.supportsPresence)
s"__fieldsMap.get(__fields.get(${field.getIndex})).asInstanceOf[$baseTypeName]"
else if (field.isRepeated)
s"__fieldsMap.getOrElse(__fields.get(${field.getIndex}), Nil).asInstanceOf[$baseTypeName]"
else
else if (field.isRequired)
s"__fieldsMap(__fields.get(${field.getIndex})).asInstanceOf[$baseTypeName]"
else {
// This is for proto3, no default value.
val t = defaultValueForGet(field, uncustomized = true) + (if (field.isEnum)
".valueDescriptor" else "")
s"__fieldsMap.getOrElse(__fields.get(${field.getIndex}), $t).asInstanceOf[$baseTypeName]"
}

val s = transform(field).apply(e, isCollection = !field.isSingular)
if (field.isMap) s + "(scala.collection.breakOut)"
Expand All @@ -619,7 +650,7 @@ class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps {
val elems = oneOf.fields.map {
field =>
val typeName = if (field.isEnum) "com.google.protobuf.Descriptors.EnumValueDescriptor" else field.baseSingleScalaTypeName
val e = s"__fieldsMap.getOrElse(__fields.get(${field.getIndex}), None).asInstanceOf[Option[$typeName]]"
val e = s"__fieldsMap.get(__fields.get(${field.getIndex})).asInstanceOf[Option[$typeName]]"
(transform(field) andThen FunctionApplication(field.oneOfTypeName)).apply(e, isCollection = true)
} mkString (" orElse\n")
s"${oneOf.scalaName.asSymbol} = $elems getOrElse ${oneOf.empty}"
Expand All @@ -632,15 +663,6 @@ class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps {
.add("}")
}

def generateFromAscii(message: Descriptor)(printer: FunctionalPrinter): FunctionalPrinter = {
printer.addM(
s"""override def fromAscii(ascii: String): ${message.scalaTypeName} = {
| val javaProtoBuilder = ${message.javaTypeName}.newBuilder
| com.google.protobuf.TextFormat.merge(ascii, javaProtoBuilder)
| fromJavaProto(javaProtoBuilder.build)
|}""")
}

def generateDescriptor(message: Descriptor)(printer: FunctionalPrinter): FunctionalPrinter = {
printer
.when(message.isTopLevel)(_.add(s"def descriptor: com.google.protobuf.Descriptors.Descriptor = ${message.getFile.fileDescriptorObjectName}.descriptor.getMessageTypes.get(${message.getIndex})"))
Expand All @@ -652,7 +674,7 @@ class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps {
printer
.add(s"lazy val defaultInstance = $myFullScalaName(")
.indent
.addWithDelimiter(",")(message.getFields.collect {
.addWithDelimiter(",")(message.fields.collect {
case field if field.isRequired =>
val default = defaultValueForDefaultInstance(field)
s"${field.scalaName.asSymbol} = $default"
Expand All @@ -669,7 +691,7 @@ class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps {
printer.add(
s"implicit class ${className}Lens[UpperPB](_l: com.trueaccord.lenses.Lens[UpperPB, $classNameSymbol]) extends com.trueaccord.lenses.ObjectLens[UpperPB, $classNameSymbol](_l) {")
.indent
.print(message.getFields) {
.print(message.fields) {
case (field, printer) =>
val fieldName = field.scalaName.asSymbol
if (!field.isInOneof) {
Expand Down Expand Up @@ -699,15 +721,15 @@ class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps {

def generateFieldNumbers(message: Descriptor)(printer: FunctionalPrinter): FunctionalPrinter = {
printer
.print(message.getFields) {
.print(message.fields) {
case (field, printer) =>
printer.add(s"final val ${field.fieldNumberConstantName} = ${field.getNumber}")
}
}

def generateTypeMappers(message: Descriptor)(printer: FunctionalPrinter): FunctionalPrinter = {
val customizedFields: Seq[(FieldDescriptor, String)] = for {
field <- message.getFields
field <- message.fields
custom <- field.customSingleScalaTypeName
} yield (field, custom)

Expand Down Expand Up @@ -789,7 +811,6 @@ class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps {
.indent
.when(message.javaConversions)(generateToJavaProto(message))
.when(message.javaConversions)(generateFromJavaProto(message))
.when(message.javaConversions)(generateFromAscii(message))
.call(generateFromFieldsMap(message))
.call(generateDescriptor(message))
.call(generateMessageCompanionForField(message))
Expand Down Expand Up @@ -820,7 +841,7 @@ class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps {
.call(generateSerializedSize(message))
.call(generateWriteTo(message))
.call(generateMergeFrom(message))
.print(message.getFields) {
.print(message.fields) {
case (field, printer) =>
val withMethod = "with" + field.upperScalaName
val clearMethod = "clear" + field.upperScalaName
Expand Down Expand Up @@ -857,8 +878,7 @@ class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps {
|def with${oneof.upperScalaName}(__v: ${oneof.scalaTypeName}): ${message.nameSymbol} = copy(${oneof.scalaName.asSymbol} = __v)""")
}
.call(generateGetField(message))
.when(message.javaConversions)(
_.add(s"override def toString: String = com.google.protobuf.TextFormat.printToUnicodeString(${message.scalaTypeName}.toJavaProto(this))"))
.add(s"override def toString: String = com.trueaccord.scalapb.TextFormat.printToUnicodeString(this)")
.add(s"def companion = ${message.scalaTypeName}")
.outdent
.outdent
Expand Down Expand Up @@ -890,15 +910,6 @@ class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps {
.print(message.nestedTypes)(generateInternalFields)
}

def generateInternalFieldsFor(file: FileDescriptor)(fp: FunctionalPrinter): FunctionalPrinter =
if (file.getMessageTypes.nonEmpty) {
fp.add("def internalFieldsFor(scalaName: String): Seq[Descriptors.FieldDescriptor] = scalaName match {")
.indent
.print(file.getMessageTypes)(generateInternalFields)
.outdent
.add("}")
} else fp

def scalaFileHeader(file: FileDescriptor): FunctionalPrinter = {
new FunctionalPrinter().addM(
s"""// Generated by the Scala Plugin for the Protocol Buffer Compiler.
Expand Down Expand Up @@ -941,6 +952,33 @@ class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps {
.add("}")
}

private def encodeByteArray(a: GoogleByteString): Seq[String] = {
val CH_SLASH: java.lang.Byte = '\\'.toByte
val CH_SQ: java.lang.Byte = '\''.toByte
val CH_DQ: java.lang.Byte = '\"'.toByte
for {
groups <- a.grouped(60).toSeq
} yield {
val sb = scala.collection.mutable.StringBuilder.newBuilder
sb.append('\"')
groups.foreach {
b =>
b match {
case CH_SLASH => sb.append("\\\\")
case CH_SQ => sb.append("\\\'")
case CH_DQ => sb.append("\\\"")
case b if b >= 0x20 => sb.append(b)
case b =>
sb.append("\\u00")
sb.append(Integer.toHexString((b >>> 4) & 0xf))
sb.append(Integer.toHexString(b & 0xf))
}
}
sb.append('\"')
sb.result()
}
}

def generateScalaFilesForFileDescriptor(file: FileDescriptor): Seq[CodeGeneratorResponse.File] = {
val serviceFiles = if(params.grpc) {
file.getServices.map { service =>
Expand Down

0 comments on commit 44af26e

Please sign in to comment.