diff --git a/code-gen/src/main/scala/scalapb/grpc_web/GrpcServiceMetadataPrinter.scala b/code-gen/src/main/scala/scalapb/grpc_web/GrpcServiceMetadataPrinter.scala index 6885be6..d1ffdcc 100644 --- a/code-gen/src/main/scala/scalapb/grpc_web/GrpcServiceMetadataPrinter.scala +++ b/code-gen/src/main/scala/scalapb/grpc_web/GrpcServiceMetadataPrinter.scala @@ -23,9 +23,26 @@ final class GrpcServiceMetadataPrinter( val overrideStr = if (overrideSig) "override " else "" method.streamType match { case StreamType.Unary => - s"${method.deprecatedAnnotation}${overrideStr}def ${method.name}" + s"(request: ${method.inputType.scalaType}, metadata: $metadata): scala.concurrent.Future[${method.outputType.scalaType}]" + s"${method.deprecatedAnnotation}${overrideStr}def ${method.name}" + s"(request: ${method.inputType.scalaType}, context: $context = defaultContext): scala.concurrent.Future[${method.outputType.scalaType}]" case StreamType.ServerStreaming => - s"${method.deprecatedAnnotation}${overrideStr}def ${method.name}" + s"(request: ${method.inputType.scalaType}, metadata: $metadata, responseObserver: ${observer(method.outputType.scalaType)}): Unit" + s"${method.deprecatedAnnotation}${overrideStr}def ${method.name}" + s"(request: ${method.inputType.scalaType}, context: $context = defaultContext, responseObserver: ${observer( + method.outputType.scalaType)}): Unit" + case _ => + "" + } + } + + private[this] def serviceMethodTraitSignature( + method: MethodDescriptor, + overrideSig: Boolean + ) = { + val overrideStr = if (overrideSig) "override " else "" + method.streamType match { + case StreamType.Unary => + s"${method.deprecatedAnnotation}${overrideStr}def ${method.name}" + s"(request: ${method.inputType.scalaType}, context: $context): scala.concurrent.Future[${method.outputType.scalaType}]" + case StreamType.ServerStreaming => + s"${method.deprecatedAnnotation}${overrideStr}def ${method.name}" + s"(request: ${method.inputType.scalaType}, context: $context, responseObserver: ${observer( + method.outputType.scalaType)}): Unit" case _ => "" } @@ -38,21 +55,21 @@ final class GrpcServiceMetadataPrinter( val overrideStr = if (overrideSig) "override " else "" s"${method.deprecatedAnnotation}${overrideStr}def ${method.name}" + (method.streamType match { case StreamType.Unary => - s"(request: ${method.inputType.scalaType}, metadata: $metadata): ${method.outputType.scalaType}" + s"(request: ${method.inputType.scalaType}, context: $context = defaultContext): ${method.outputType.scalaType}" case StreamType.ServerStreaming => - s"(request: ${method.inputType.scalaType}, metadata: $metadata): scala.collection.Iterator[${method.outputType.scalaType}]" + s"(request: ${method.inputType.scalaType}, context: $context = defaultContext): scala.collection.Iterator[${method.outputType.scalaType}]" case _ => throw new IllegalArgumentException("Invalid method type.") }) } private[this] def serviceTrait: PrinterEndo = { p => p.call(generateScalaDoc(service)) - .add(s"trait ${service.name} {") + .add(s"trait ${service.name}[-Context] {") .indent .print(service.methods) { case (p, method) => p.call(generateScalaDoc(method)) - .add(serviceMethodSignature(method, overrideSig = false)) + .add(serviceMethodTraitSignature(method, overrideSig = false)) } .outdent .add("}") @@ -68,6 +85,10 @@ final class GrpcServiceMetadataPrinter( private[this] val metadata = "_root_.scalapb.grpc.grpcweb.Metadata.Metadata" + private[this] val metadataPackage = "_root_.scalapb.grpc.grpcweb.Metadata" + + private[this] val context = "Context" + private[this] def methodDescriptor(method: MethodDescriptor) = PrinterEndo { p => def marshaller(t: MethodDescriptorPimp#MethodTypeWrapper) = @@ -118,7 +139,7 @@ final class GrpcServiceMetadataPrinter( "channel", m.grpcDescriptor.nameSymbol, "options", - "metadata" + "f(context)" ) ++ (if (m.isClientStreaming) Seq() else Seq("request")) ++ (if ((m.isClientStreaming || m.isServerStreaming) && !blocking) @@ -144,14 +165,11 @@ final class GrpcServiceMetadataPrinter( baseClass: String, methods: Seq[PrinterEndo] ): PrinterEndo = { p => - val build = - s"override def build(channel: $channel, options: $callOptions): ${className} = new $className(channel, options)" p.add( - s"class $className(channel: $channel, options: $callOptions = $callOptions.DEFAULT) extends $abstractStub[$className](channel, options) with $baseClass {" + s"class $className[$context](channel: $channel, f: $context => $metadata, defaultContext: => $context, options: $callOptions = $callOptions.DEFAULT) extends $baseClass[$context] {" ) .indent .call(methods: _*) - .add(build) .outdent .add("}") } @@ -162,6 +180,9 @@ final class GrpcServiceMetadataPrinter( stubImplementation(service.stub, service.name, methods) } + private[this] val identityMetadata = + s"implicit def identity(metadata: $metadata): $metadata => $metadata = metadata => metadata" + def printService(printer: FunctionalPrinter): FunctionalPrinter = { printer .add( @@ -176,9 +197,16 @@ final class GrpcServiceMetadataPrinter( .newline .call(stub) .newline + .add( + identityMetadata + ) + .newline + .add( + s"def stub(channel: $channel): ${service.stub}[$metadata] = new ${service.stub}[$metadata](channel, $metadataPackage.empty(), $metadataPackage.empty())" + ) .newline .add( - s"def stub(channel: $channel): ${service.stub} = new ${service.stub}(channel)" + s"def stub(channel: $channel, metadata: $metadata): ${service.stub}[$metadata] = new ${service.stub}[$metadata](channel, metadata, $metadataPackage.empty())" ) .newline .add( diff --git a/code-gen/src/main/scala/scalapb/grpc_web/GrpcWebCodeGenerator.scala b/code-gen/src/main/scala/scalapb/grpc_web/GrpcWebCodeGenerator.scala index 719adcb..3a6e391 100644 --- a/code-gen/src/main/scala/scalapb/grpc_web/GrpcWebCodeGenerator.scala +++ b/code-gen/src/main/scala/scalapb/grpc_web/GrpcWebCodeGenerator.scala @@ -4,19 +4,11 @@ import com.google.protobuf.Descriptors.FileDescriptor import com.google.protobuf.ExtensionRegistry import com.google.protobuf.compiler.PluginProtos.CodeGeneratorResponse import protocbridge.codegen.{CodeGenApp, CodeGenRequest, CodeGenResponse} -import scalapb.compiler.{ - DescriptorImplicits, - FunctionalPrinter, - GeneratorException, - GeneratorParams, - NameUtils, - ProtoValidation, - ProtobufGenerator -} +import scalapb.compiler._ import scalapb.grpc_web.compat.JavaConverters._ import scalapb.options.compiler.Scalapb -case class GrpcWebCodeGenerator(metadata: Boolean = false) extends CodeGenApp { +object GrpcWebCodeGenerator extends CodeGenApp { override def registerExtensions(registry: ExtensionRegistry): Unit = Scalapb.registerAllExtensions(registry) @@ -28,11 +20,7 @@ case class GrpcWebCodeGenerator(metadata: Boolean = false) extends CodeGenApp { new DescriptorImplicits(params, request.allProtos) validate(request, implicits) val generatedFiles = request.filesToGenerate.flatMap { file => - if (metadata) { - generateWithMetadata(params, file, implicits) - } else { - generate(params, file, implicits) - } + generateWithMetadata(params, file, implicits) } CodeGenResponse.succeed( generatedFiles @@ -55,15 +43,6 @@ case class GrpcWebCodeGenerator(metadata: Boolean = false) extends CodeGenApp { validator.validateFiles(request.allProtos) } - private def generate( - params: GeneratorParams, - file: FileDescriptor, - implicits: DescriptorImplicits - ): Seq[CodeGeneratorResponse.File] = { - val generator = new ProtobufGenerator(params.copy(grpc = true), implicits) - generator.generateMultipleScalaFilesForFileDescriptor(file) - } - private def generateWithMetadata( params: GeneratorParams, file: FileDescriptor, diff --git a/example2/build.sbt b/example2/build.sbt index ea39bc0..f1e4b75 100644 --- a/example2/build.sbt +++ b/example2/build.sbt @@ -14,8 +14,7 @@ resolvers in ThisBuild ++= Seq( // js settings with attribute metadata true in GrpcWebCodeGenerator lazy val jsSettings: Seq[Setting[_]] = Seq( PB.targets in Compile := Seq( - scalapb.grpc_web - .GrpcWebCodeGenerator(true) -> (sourceManaged in Compile).value + scalapb.grpc_web.GrpcWebCodeGenerator -> (sourceManaged in Compile).value ) ) @@ -37,7 +36,7 @@ lazy val protos = ) .jsSettings( // publish locally and update the version for test - libraryDependencies += "com.thesamet.scalapb" %%% "scalapb-grpcweb" % "0.2.0+15-1ed77e30+20200427-2143-SNAPSHOT" + libraryDependencies += "com.thesamet.scalapb" %%% "scalapb-grpcweb" % "0.2.0+18-d8ba44da+20200430-2220-SNAPSHOT" ) lazy val protosJS = protos.js.settings(jsSettings: _*) diff --git a/example2/client/src/main/scala/scalapb/grpc/example/Client.scala b/example2/client/src/main/scala/scalapb/grpc/example/Client.scala index 06474e3..915361f 100644 --- a/example2/client/src/main/scala/scalapb/grpc/example/Client.scala +++ b/example2/client/src/main/scala/scalapb/grpc/example/Client.scala @@ -26,6 +26,11 @@ object Client { val metadata: Metadata = Metadata(header1) // Make an async unary call + stub.unary(req).onComplete { f => + println("Unary", f) + } + + // Make an async unary call with metadata stub.unary(req, metadata).onComplete { f => println("Unary", f) } diff --git a/example2/project/plugins.sbt b/example2/project/plugins.sbt index 22781fc..e0d23e5 100644 --- a/example2/project/plugins.sbt +++ b/example2/project/plugins.sbt @@ -8,4 +8,4 @@ addSbtPlugin("ch.epfl.scala" % "sbt-scalajs-bundler" % "0.17.0") libraryDependencies += "com.thesamet.scalapb" %% "compilerplugin" % "0.10.2" // publish locally and update the version for test -libraryDependencies += "com.thesamet.scalapb" %% "scalapb-grpcweb-code-gen" % "0.2.0+15-1ed77e30+20200427-2143-SNAPSHOT" +libraryDependencies += "com.thesamet.scalapb" %% "scalapb-grpcweb-code-gen" % "0.2.0+18-d8ba44da+20200430-2220-SNAPSHOT"