Skip to content

Commit

Permalink
Merge pull request #2548 from softwaremill/add-suppoer-for-grpc-colle…
Browse files Browse the repository at this point in the history
…ctions

Add support for grpc repeated  and oneof fields
  • Loading branch information
adamw committed Nov 13, 2022
2 parents 5c09dfd + 006b7c4 commit ce51bcb
Show file tree
Hide file tree
Showing 7 changed files with 333 additions and 41 deletions.
10 changes: 9 additions & 1 deletion doc/grpc.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ Definition of an endpoint's inputs and output format is very similar to the one
endpoint body constructor helper `sttp.tapir.grpc.protobuf.pbdirect.grpcBody[T]` that based on the type `T` create a
body definition that can be passed to as input or output of a given endpoint
e.g. `endpoint.in(grpcBody[AddSimpleBook]).out(grpcBody[SimpleBook])`.
Mapping for basic types are defined (e.g. `java.lang.String` -> `string`), but the target protobuf type can be simply customized via `.attribute` schema feature
(e.g. `implicit newSchema = implicitly[Derived[Schema[SimpleBook]]].value.modify(_.title)(_.attribute(ProtobufAttributes.ScalarValueAttribute, ProtobufScalarType.ProtobufBytes))`

Currently, the only supported protocol is protobuf. On the server side, we use
a [PBDirect](https://github.com/47degrees/pbdirect) library for encoding and decoding messages. It derives codecs from a
Expand Down Expand Up @@ -66,4 +68,10 @@ you can find a simple example.

It's worth mentioning that by adjusting slightly encoders/decoders it's possible to expose gRPC endpoints with
`AkkaHttpServerInterpreter` as simple HTTP endpoints. This approach is not recommended, because it does not support
transmitting multiple messages in a single http request.
transmitting multiple messages in a single http request.

## Supported data formats
* Basic scalar types
* Collections (repeated values)
* Top level and nested products
* Tapir schema derivation for coproducts (sealed traits) is supported, but we're missing codecs on the pbdirect side out of the box (oneof)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package sttp.tapir.grpc.protobuf

import sttp.tapir.Schema.SName
import sttp.tapir.SchemaType.{SDate, SDateTime, SInteger, SNumber, SProduct, SProductField, SString}
import sttp.tapir.SchemaType.{SArray, SCoproduct, SDate, SDateTime, SInteger, SNumber, SProduct, SProductField, SString}
import sttp.tapir.{Schema, _}
import sttp.tapir.grpc.protobuf.model._

Expand Down Expand Up @@ -75,45 +75,68 @@ class EndpointToProtobufMessage {
// TODO files support?
fromProductField(msgs)(field)
}
List(ProtobufMessage(name.fullName.split('.').last, protoFields)) // FIXME
List(ProtobufProductMessage(toMessageName(name), protoFields))
case SCoproduct(subtypes, discriminator) =>

List(
ProtobufCoproductMessage(
toMessageName(name),
subtypes.map(fromCoproductSubtype(msgs))
)
)
case _ => ???
}
}.toList
}
private def toMessageName(sName: SName): MessageName = sName.fullName.split('.').last // FIXME

private def availableMessagesFromSchema(schema: Schema[_]): Map[SName, Schema[_]] = schema.schemaType match {
case SProduct(fields) =>
schema.name.map(name => Map(name -> schema)).getOrElse(Map.empty) ++
fields.foldLeft(Map.empty[SName, Schema[_]])((m, field) => m ++ availableMessagesFromSchema(field.schema))
case SchemaType.SCoproduct(subtypes, discriminator) => ???
case _ => Map.empty
case SchemaType.SCoproduct(subtypes, discriminator) =>
schema.name.map(name => Map(name -> schema)).getOrElse(Map.empty) ++
subtypes.foldLeft(Map.empty[SName, Schema[_]])((m, subtype) => m ++ availableMessagesFromSchema(subtype))
case SchemaType.SArray(element) => availableMessagesFromSchema(element)
case _ => Map.empty
}

private def defaultScalarMappings(field: SProductField[_]): ProtobufScalarType = field.schema.schemaType match {
case SString() => ProtobufScalarType.ProtobufString
case SInteger() if field.schema.format.contains("int64") => ProtobufScalarType.ProtobufInt64
case SInteger() => ProtobufScalarType.ProtobufInt32
case SNumber() if field.schema.format.contains("float") => ProtobufScalarType.ProtobufFloat
case SNumber() => ProtobufScalarType.ProtobufDouble
case SchemaType.SBoolean() => ProtobufScalarType.ProtobufBool
case SProduct(Nil) => ProtobufScalarType.ProtobufEmpty
case SchemaType.SBinary() => ProtobufScalarType.ProtobufBytes
case SDateTime() => ProtobufScalarType.ProtobufInt64
case SDate() => ProtobufScalarType.ProtobufInt64
case in =>
println(s"Not supported input [$in]") // FIXME
???
private def fromCoproductSubtype(availableMessages: Map[SName, Schema[_]])(subtype: Schema[_]) = {
val `type` = resolveType(availableMessages)(subtype)

ProtobufMessageField(`type`, `type`.filedTypeName.toLowerCase, None)
}
private def fromProductField(availableMessages: Map[SName, Schema[_]])(field: SProductField[_]): ProtobufMessageField =
ProtobufMessageField(resolveType(availableMessages)(field.schema), field.name.name, None)

private def fromProductField(availableMessages: Map[SName, Schema[_]])(field: SProductField[_]): ProtobufMessageField = {
val maybeCustomType = field.schema.attribute(ProtobufAttributes.ScalarValueAttribute)
val maybeMessageRef = field.schema.name match {
private def resolveType(availableMessages: Map[SName, Schema[_]])(schema: Schema[_]): ProtobufType = {
lazy val maybeCustomType = schema.attribute(ProtobufAttributes.ScalarValueAttribute)
lazy val maybeMessageRef = schema.name match {
case Some(name) if availableMessages.contains(name) => Some(ProtobufMessageRef(name))
case _ => None
}
val `type` = maybeCustomType.orElse(maybeMessageRef).getOrElse(defaultScalarMappings(field))
lazy val defaultMappings: ProtobufType = schema.schemaType match {
case SString() => ProtobufScalarType.ProtobufString
case SInteger() if schema.format.contains("int64") => ProtobufScalarType.ProtobufInt64
case SInteger() => ProtobufScalarType.ProtobufInt32
case SNumber() if schema.format.contains("float") => ProtobufScalarType.ProtobufFloat
case SNumber() => ProtobufScalarType.ProtobufDouble
case SchemaType.SBoolean() => ProtobufScalarType.ProtobufBool
case SProduct(Nil) => ProtobufScalarType.ProtobufEmpty
case SchemaType.SBinary() => ProtobufScalarType.ProtobufBytes
case SDateTime() => ProtobufScalarType.ProtobufInt64
case SDate() => ProtobufScalarType.ProtobufInt64
case SArray(element) =>
resolveType(availableMessages)(element) match {
case valueType: SingularValueType => ProtobufRepeatedField(valueType)
case ProtobufRepeatedField(_) => throw new IllegalArgumentException("Nested collections are not supported.")
}
case in =>
println(s"Not supported input [$in]") // FIXME
???
}

ProtobufMessageField(`type`, field.name.name, None)
maybeCustomType.orElse(maybeMessageRef).getOrElse(defaultMappings)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,27 @@ class ProtoRenderer {
|rpc ${method.name} (${method.input}) returns (${method.output}) {}
""".stripMargin

private def renderMessage(msg: ProtobufMessage): String = {
private def renderMessage(msg: ProtobufMessage): String = msg match {
case m: ProtobufCoproductMessage => renderCoproductMessage(m)
case m: ProtobufProductMessage => renderProductMessage(m)
}

private def renderCoproductMessage(msg: ProtobufCoproductMessage): String =
s"""
|message ${msg.name} {
| oneof alternatives {
| ${renderMessageFields(msg.alternatives.toVector)}
| }
|}
|""".stripMargin

private def renderProductMessage(msg: ProtobufProductMessage): String = {
s"""
|message ${msg.name} {
|${renderMessageFields(msg.fields.toVector)}

|}
""".stripMargin
""".stripMargin
}

private def renderOptions(options: ProtobufOptions): String =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@ sealed trait ProtobufType {
def filedTypeName: String
}

case class ProtobufMessageRef(refName: SName) extends ProtobufType {
override def filedTypeName: String = refName.show.split('.').last // FIXME we need to a better way for generating messages names
sealed trait SingularValueType extends ProtobufType
case class ProtobufMessageRef(refName: SName) extends SingularValueType {
override def filedTypeName: String = refName.fullName.split('.').last // FIXME we need to a better way for generating messages names
}
case class ProtobufRepeatedField(element: SingularValueType) extends ProtobufType {
override def filedTypeName: String = s"repeated ${element.filedTypeName}"
}
sealed trait ProtobufScalarType extends ProtobufType

sealed trait ProtobufScalarType extends SingularValueType
object ProtobufScalarType {
case object ProtobufString extends ProtobufScalarType {
override val filedTypeName: String = "string"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
package sttp.tapir.grpc.protobuf.model

case class ProtobufMessage(name: MessageName, fields: Iterable[ProtobufMessageField])
import sttp.tapir.grpc.protobuf.ProtobufMessageRef

sealed trait ProtobufMessage {
def name: MessageName
}

case class ProtobufProductMessage(name: MessageName, fields: Iterable[ProtobufMessageField]) extends ProtobufMessage
case class ProtobufCoproductMessage(name: MessageName, alternatives: Iterable[ProtobufMessageField]) extends ProtobufMessage
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sttp.tapir.grpc.protobuf

import org.scalatest.flatspec.AnyFlatSpec
import sttp.tapir.Schema.SName
import sttp.tapir.grpc.protobuf.model._

class ProtoRendererTest extends AnyFlatSpec with ProtobufMatchers {
Expand All @@ -23,7 +24,8 @@ class ProtoRendererTest extends AnyFlatSpec with ProtobufMatchers {
|""".stripMargin

val proto = Protobuf(
messages = List(ProtobufMessage("SimpleBook", List(ProtobufMessageField(ProtobufScalarType.ProtobufString, "title", Some(1))))),
messages =
List(ProtobufProductMessage("SimpleBook", List(ProtobufMessageField(ProtobufScalarType.ProtobufString, "title", Some(1))))),
services = List(ProtobufService("Library", List(ProtobufServiceMethod("AddBook", "SimpleBook", "SimpleBook")))),
options = ProtobufOptions.empty
)
Expand All @@ -49,7 +51,8 @@ class ProtoRendererTest extends AnyFlatSpec with ProtobufMatchers {
|""".stripMargin

val proto = Protobuf(
messages = List(ProtobufMessage("SimpleBook", List(ProtobufMessageField(ProtobufScalarType.ProtobufString, "title", Some(1))))),
messages =
List(ProtobufProductMessage("SimpleBook", List(ProtobufMessageField(ProtobufScalarType.ProtobufString, "title", Some(1))))),
services = List(ProtobufService("Library", List(ProtobufServiceMethod("AddBook", "SimpleBook", "SimpleBook")))),
options = ProtobufOptions(Some("com.myexample"))
)
Expand Down Expand Up @@ -80,13 +83,13 @@ class ProtoRendererTest extends AnyFlatSpec with ProtobufMatchers {

val proto = Protobuf(
messages = List(
ProtobufMessage(
ProtobufProductMessage(
"Title",
List(
ProtobufMessageField(ProtobufScalarType.ProtobufString, "title", Some(1))
)
),
ProtobufMessage(
ProtobufProductMessage(
"SimpleBook",
List(
ProtobufMessageField(ProtobufScalarType.ProtobufString, "title", Some(1)),
Expand All @@ -101,4 +104,130 @@ class ProtoRendererTest extends AnyFlatSpec with ProtobufMatchers {
matchProtos(renderer.render(proto), expectedProto)
}

it should "render proto file for a message with repeated values" in {
val expectedProto =
"""
|syntax = "proto3";
|
|option java_multiple_files = true;
|
|service Library {
| rpc AddBook (Title) returns (SimpleBook) {}
|}
|
|message Title {
| string title = 1;
|}
|
|message SimpleBook {
| string title = 1;
| string content = 2;
| repeated string authors = 3;
|}
|""".stripMargin

val proto = Protobuf(
messages = List(
ProtobufProductMessage(
"Title",
List(
ProtobufMessageField(ProtobufScalarType.ProtobufString, "title", Some(1))
)
),
ProtobufProductMessage(
"SimpleBook",
List(
ProtobufMessageField(ProtobufScalarType.ProtobufString, "title", Some(1)),
ProtobufMessageField(ProtobufScalarType.ProtobufString, "content", Some(2)),
ProtobufMessageField(ProtobufRepeatedField(ProtobufScalarType.ProtobufString), "authors", Some(3))
)
)
),
services = List(ProtobufService("Library", List(ProtobufServiceMethod("AddBook", "Title", "SimpleBook")))),
options = ProtobufOptions.empty
)

matchProtos(renderer.render(proto), expectedProto)
}

it should "render proto file for a message with a coproduct field" in {
val expectedProto =
"""
|syntax = "proto3";
|
|option java_multiple_files = true;
|
|service Library {
| rpc AddBook (Title) returns (SimpleBook) {}
|}
|
|message Title {
| string title = 1;
|}
|
|message Epub {
| string source = 1;
|}
|
|message Paper {
| string size = 1;
|}
|message Format {
| oneof alternatives {
| Epub epub = 1;
| Paper paper = 2;
| }
|}
|
|message SimpleBook {
| string title = 1;
| string content = 2;
| Format format = 3;
|
|}
|""".stripMargin

val proto = Protobuf(
messages = List(
ProtobufProductMessage(
"Title",
List(
ProtobufMessageField(ProtobufScalarType.ProtobufString, "title", Some(1))
)
),
ProtobufProductMessage(
"Epub",
List(
ProtobufMessageField(ProtobufScalarType.ProtobufString, "source", Some(1))
)
),
ProtobufProductMessage(
"Paper",
List(
ProtobufMessageField(ProtobufScalarType.ProtobufString, "size", Some(1))
)
),
ProtobufCoproductMessage(
"Format",
List(
ProtobufMessageField(ProtobufMessageRef(SName("Epub")), "epub", None),
ProtobufMessageField(ProtobufMessageRef(SName("Paper")), "paper", None),
)
),
ProtobufProductMessage(
"SimpleBook",
List(
ProtobufMessageField(ProtobufScalarType.ProtobufString, "title", Some(1)),
ProtobufMessageField(ProtobufScalarType.ProtobufString, "content", Some(2)),
ProtobufMessageField(ProtobufMessageRef(SName("Format")), "format", Some(3))
)
)
),
services = List(ProtobufService("Library", List(ProtobufServiceMethod("AddBook", "Title", "SimpleBook")))),
options = ProtobufOptions.empty
)

println(renderer.render(proto))
matchProtos(renderer.render(proto), expectedProto)
}
}

0 comments on commit ce51bcb

Please sign in to comment.