Skip to content

Commit

Permalink
Added the context to grpc-web
Browse files Browse the repository at this point in the history
  • Loading branch information
mineme0110 committed Apr 30, 2020
1 parent bfe46f3 commit bb5c9b9
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,26 @@ final class GrpcServiceMetadataPrinter(
val overrideStr = if (overrideSig) "override " else ""
method.streamType match {
case StreamType.Unary =>
s"${method.deprecatedAnnotation}${overrideStr}def ${method.name}" + s"(request: ${method.inputType.scalaType}, metadata: $metadata): scala.concurrent.Future[${method.outputType.scalaType}]"
s"${method.deprecatedAnnotation}${overrideStr}def ${method.name}" + s"(request: ${method.inputType.scalaType}, context: $context = defaultContext): scala.concurrent.Future[${method.outputType.scalaType}]"
case StreamType.ServerStreaming =>
s"${method.deprecatedAnnotation}${overrideStr}def ${method.name}" + s"(request: ${method.inputType.scalaType}, metadata: $metadata, responseObserver: ${observer(method.outputType.scalaType)}): Unit"
s"${method.deprecatedAnnotation}${overrideStr}def ${method.name}" + s"(request: ${method.inputType.scalaType}, context: $context = defaultContext, responseObserver: ${observer(
method.outputType.scalaType)}): Unit"
case _ =>
""
}
}

private[this] def serviceMethodTraitSignature(
method: MethodDescriptor,
overrideSig: Boolean
) = {
val overrideStr = if (overrideSig) "override " else ""
method.streamType match {
case StreamType.Unary =>
s"${method.deprecatedAnnotation}${overrideStr}def ${method.name}" + s"(request: ${method.inputType.scalaType}, context: $context): scala.concurrent.Future[${method.outputType.scalaType}]"
case StreamType.ServerStreaming =>
s"${method.deprecatedAnnotation}${overrideStr}def ${method.name}" + s"(request: ${method.inputType.scalaType}, context: $context, responseObserver: ${observer(
method.outputType.scalaType)}): Unit"
case _ =>
""
}
Expand All @@ -38,21 +55,21 @@ final class GrpcServiceMetadataPrinter(
val overrideStr = if (overrideSig) "override " else ""
s"${method.deprecatedAnnotation}${overrideStr}def ${method.name}" + (method.streamType match {
case StreamType.Unary =>
s"(request: ${method.inputType.scalaType}, metadata: $metadata): ${method.outputType.scalaType}"
s"(request: ${method.inputType.scalaType}, context: $context = defaultContext): ${method.outputType.scalaType}"
case StreamType.ServerStreaming =>
s"(request: ${method.inputType.scalaType}, metadata: $metadata): scala.collection.Iterator[${method.outputType.scalaType}]"
s"(request: ${method.inputType.scalaType}, context: $context = defaultContext): scala.collection.Iterator[${method.outputType.scalaType}]"
case _ => throw new IllegalArgumentException("Invalid method type.")
})
}

private[this] def serviceTrait: PrinterEndo = { p =>
p.call(generateScalaDoc(service))
.add(s"trait ${service.name} {")
.add(s"trait ${service.name}[-Context] {")
.indent
.print(service.methods) {
case (p, method) =>
p.call(generateScalaDoc(method))
.add(serviceMethodSignature(method, overrideSig = false))
.add(serviceMethodTraitSignature(method, overrideSig = false))
}
.outdent
.add("}")
Expand All @@ -68,6 +85,10 @@ final class GrpcServiceMetadataPrinter(

private[this] val metadata = "_root_.scalapb.grpc.grpcweb.Metadata.Metadata"

private[this] val metadataPackage = "_root_.scalapb.grpc.grpcweb.Metadata"

private[this] val context = "Context"

private[this] def methodDescriptor(method: MethodDescriptor) = PrinterEndo {
p =>
def marshaller(t: MethodDescriptorPimp#MethodTypeWrapper) =
Expand Down Expand Up @@ -118,7 +139,7 @@ final class GrpcServiceMetadataPrinter(
"channel",
m.grpcDescriptor.nameSymbol,
"options",
"metadata"
"f(context)"
) ++
(if (m.isClientStreaming) Seq() else Seq("request")) ++
(if ((m.isClientStreaming || m.isServerStreaming) && !blocking)
Expand All @@ -144,14 +165,11 @@ final class GrpcServiceMetadataPrinter(
baseClass: String,
methods: Seq[PrinterEndo]
): PrinterEndo = { p =>
val build =
s"override def build(channel: $channel, options: $callOptions): ${className} = new $className(channel, options)"
p.add(
s"class $className(channel: $channel, options: $callOptions = $callOptions.DEFAULT) extends $abstractStub[$className](channel, options) with $baseClass {"
s"class $className[$context](channel: $channel, f: $context => $metadata, defaultContext: => $context, options: $callOptions = $callOptions.DEFAULT) extends $baseClass[$context] {"
)
.indent
.call(methods: _*)
.add(build)
.outdent
.add("}")
}
Expand All @@ -162,6 +180,9 @@ final class GrpcServiceMetadataPrinter(
stubImplementation(service.stub, service.name, methods)
}

private[this] val identityMetadata =
s"implicit def identity(metadata: $metadata): $metadata => $metadata = metadata => metadata"

def printService(printer: FunctionalPrinter): FunctionalPrinter = {
printer
.add(
Expand All @@ -176,9 +197,16 @@ final class GrpcServiceMetadataPrinter(
.newline
.call(stub)
.newline
.add(
identityMetadata
)
.newline
.add(
s"def stub(channel: $channel): ${service.stub}[$metadata] = new ${service.stub}[$metadata](channel, $metadataPackage.empty(), $metadataPackage.empty())"
)
.newline
.add(
s"def stub(channel: $channel): ${service.stub} = new ${service.stub}(channel)"
s"def stub(channel: $channel, metadata: $metadata): ${service.stub}[$metadata] = new ${service.stub}[$metadata](channel, metadata, $metadataPackage.empty())"
)
.newline
.add(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,11 @@ import com.google.protobuf.Descriptors.FileDescriptor
import com.google.protobuf.ExtensionRegistry
import com.google.protobuf.compiler.PluginProtos.CodeGeneratorResponse
import protocbridge.codegen.{CodeGenApp, CodeGenRequest, CodeGenResponse}
import scalapb.compiler.{
DescriptorImplicits,
FunctionalPrinter,
GeneratorException,
GeneratorParams,
NameUtils,
ProtoValidation,
ProtobufGenerator
}
import scalapb.compiler._
import scalapb.grpc_web.compat.JavaConverters._
import scalapb.options.compiler.Scalapb

case class GrpcWebCodeGenerator(metadata: Boolean = false) extends CodeGenApp {
object GrpcWebCodeGenerator extends CodeGenApp {
override def registerExtensions(registry: ExtensionRegistry): Unit =
Scalapb.registerAllExtensions(registry)

Expand All @@ -28,11 +20,7 @@ case class GrpcWebCodeGenerator(metadata: Boolean = false) extends CodeGenApp {
new DescriptorImplicits(params, request.allProtos)
validate(request, implicits)
val generatedFiles = request.filesToGenerate.flatMap { file =>
if (metadata) {
generateWithMetadata(params, file, implicits)
} else {
generate(params, file, implicits)
}
generateWithMetadata(params, file, implicits)
}
CodeGenResponse.succeed(
generatedFiles
Expand All @@ -55,15 +43,6 @@ case class GrpcWebCodeGenerator(metadata: Boolean = false) extends CodeGenApp {
validator.validateFiles(request.allProtos)
}

private def generate(
params: GeneratorParams,
file: FileDescriptor,
implicits: DescriptorImplicits
): Seq[CodeGeneratorResponse.File] = {
val generator = new ProtobufGenerator(params.copy(grpc = true), implicits)
generator.generateMultipleScalaFilesForFileDescriptor(file)
}

private def generateWithMetadata(
params: GeneratorParams,
file: FileDescriptor,
Expand Down
5 changes: 2 additions & 3 deletions example2/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ resolvers in ThisBuild ++= Seq(
// js settings with attribute metadata true in GrpcWebCodeGenerator
lazy val jsSettings: Seq[Setting[_]] = Seq(
PB.targets in Compile := Seq(
scalapb.grpc_web
.GrpcWebCodeGenerator(true) -> (sourceManaged in Compile).value
scalapb.grpc_web.GrpcWebCodeGenerator -> (sourceManaged in Compile).value
)
)

Expand All @@ -37,7 +36,7 @@ lazy val protos =
)
.jsSettings(
// publish locally and update the version for test
libraryDependencies += "com.thesamet.scalapb" %%% "scalapb-grpcweb" % "0.2.0+15-1ed77e30+20200427-2143-SNAPSHOT"
libraryDependencies += "com.thesamet.scalapb" %%% "scalapb-grpcweb" % "0.2.0+18-d8ba44da+20200430-2220-SNAPSHOT"
)

lazy val protosJS = protos.js.settings(jsSettings: _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ object Client {

val metadata: Metadata = Metadata(header1)
// Make an async unary call
stub.unary(req).onComplete { f =>
println("Unary", f)
}

// Make an async unary call with metadata
stub.unary(req, metadata).onComplete { f =>
println("Unary", f)
}
Expand Down
2 changes: 1 addition & 1 deletion example2/project/plugins.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ addSbtPlugin("ch.epfl.scala" % "sbt-scalajs-bundler" % "0.17.0")

libraryDependencies += "com.thesamet.scalapb" %% "compilerplugin" % "0.10.2"
// publish locally and update the version for test
libraryDependencies += "com.thesamet.scalapb" %% "scalapb-grpcweb-code-gen" % "0.2.0+15-1ed77e30+20200427-2143-SNAPSHOT"
libraryDependencies += "com.thesamet.scalapb" %% "scalapb-grpcweb-code-gen" % "0.2.0+18-d8ba44da+20200430-2220-SNAPSHOT"

0 comments on commit bb5c9b9

Please sign in to comment.