Skip to content

Commit

Permalink
airframe-grpc: Add GrpcServerFactory (#1430)
Browse files Browse the repository at this point in the history
* airframe-grpc: Add GrpcServerFactory
* Add xxxServerFactory.awaitTermination
  • Loading branch information
xerial committed Jan 15, 2021
1 parent 07df5a4 commit b1d1e43
Show file tree
Hide file tree
Showing 7 changed files with 403 additions and 163 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ case class MessageCodecFactory(codecFinder: MessageCodecFinder = Compat.messageC
}

object MessageCodecFactory {
val defaultFactory: MessageCodecFactory = new MessageCodecFactory()
def defaultFactoryForJSON: MessageCodecFactory = defaultFactory.withMapOutput
val defaultFactory: MessageCodecFactory = new MessageCodecFactory()
def defaultFactoryForJSON: MessageCodecFactory = defaultFactory.withMapOutput
def defaultFactoryForMapOutput: MessageCodecFactory = defaultFactoryForJSON

/**
* Create a custom MessageCodecFactory from a partial mapping
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
*/
package wvlet.airframe.http.finagle
import java.lang.reflect.InvocationTargetException

import com.twitter.finagle._
import com.twitter.finagle.http.{Request, Response, Status}
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.finagle.tracing.Tracer
import com.twitter.util.{Await, Future}

import javax.annotation.PostConstruct
import wvlet.airframe._
import wvlet.airframe.codec.MessageCodec
Expand All @@ -32,6 +32,7 @@ import wvlet.log.LogSupport
import wvlet.log.io.IOUtil

import scala.annotation.tailrec
import scala.collection.parallel.immutable.ParVector
import scala.concurrent.ExecutionException
import scala.util.control.NonFatal

Expand Down Expand Up @@ -328,6 +329,15 @@ trait FinagleServerFactory extends AutoCloseable with LogSupport {
)
}

/**
* Block until all servers created by this factory terminate
*/
def awaitTermination: Unit = {
val b = ParVector.newBuilder[FinagleServer]
b ++= createdServers
b.result().foreach(_.waitServerTermination)
}

override def close(): Unit = {
debug(s"Closing FinagleServerFactory")
val ex = Seq.newBuilder[Throwable]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@
* limitations under the License.
*/
package wvlet.airframe.http.grpc
import java.util.concurrent.Executors

import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder
import io.grpc.{Channel, ManagedChannel, ManagedChannelBuilder, Server, ServerBuilder, ServerInterceptor}
import io.grpc._
import wvlet.airframe.codec.MessageCodecFactory
import wvlet.airframe.control.MultipleExceptions
import wvlet.airframe.http.Router
import wvlet.airframe.http.grpc.GrpcServiceBuilder.GrpcServiceThreadExecutor
import wvlet.airframe.{Design, Session}
import wvlet.log.LogSupport
import wvlet.log.io.IOUtil

import java.util.concurrent.{ExecutorService, Executors}
import scala.collection.parallel.immutable.ParVector
import scala.language.existentials
import scala.util.control.NonFatal

/**
*/
Expand All @@ -31,32 +33,49 @@ case class GrpcServerConfig(
private val serverPort: Option[Int] = None,
router: Router = Router.empty,
interceptors: Seq[ServerInterceptor] = Seq.empty,
serverInitializer: ServerBuilder[_] => ServerBuilder[_] = identity
serverInitializer: ServerBuilder[_] => ServerBuilder[_] = identity,
executorProvider: GrpcServerConfig => ExecutorService = { config: GrpcServerConfig =>
Executors.newCachedThreadPool()
},
codecFactory: MessageCodecFactory = MessageCodecFactory.defaultFactoryForMapOutput
) extends LogSupport {
lazy val port = serverPort.getOrElse(IOUtil.unusedPort)

def withName(name: String): GrpcServerConfig = this.copy(name = name)
def withPort(port: Int): GrpcServerConfig = this.copy(serverPort = Some(port))
def withRouter(router: Router): GrpcServerConfig = this.copy(router = router)

/**
* Use this method to customize gRPC server, e.g., setting tracer, add transport filter, etc.
* @param serverInitializer
* @return
*/
def withServerInitializer(serverInitializer: ServerBuilder[_] => ServerBuilder[_]) =
this.copy(serverInitializer = serverInitializer)

/**
* Add an gRPC interceptor
* @param interceptor
* @return
*/
def withInterceptor(interceptor: ServerInterceptor): GrpcServerConfig =
this.copy(interceptors = interceptors :+ interceptor)
def noInterceptor: GrpcServerConfig = this.copy(interceptors = Seq.empty)

/**
* Set a custom thread pool. The default is Executors.newCachedThreadPool()
*/
def withExecutorServiceProvider(provider: GrpcServerConfig => ExecutorService) =
this.copy(executorProvider = provider)

def withCodecFactory(newCodecFactory: MessageCodecFactory) = this.copy(codecFactory = newCodecFactory)

/**
* Create and start a new server based on this config.
*/
def newServer(session: Session): GrpcServer = {
val services = GrpcServiceBuilder.buildService(router, session)
trace(s"service:\n${services.map(_.getServiceDescriptor).mkString("\n")}")
// We need to use NettyServerBuilder explicitly when NettyServerBuilder cannot be found from the classpath (e.g., onejar)
val serverBuilder = NettyServerBuilder.forPort(port)
for (service <- services) {
serverBuilder.addService(service)
}
for (interceptor <- interceptors) {
serverBuilder.intercept(interceptor)
}
val customServerBuilder = serverInitializer(serverBuilder)
new GrpcServer(this, customServerBuilder.build())
val grpcService = GrpcServiceBuilder.buildService(this, session)
grpcService.newServer
}

/**
Expand All @@ -78,13 +97,10 @@ case class GrpcServerConfig(
Design.newDesign
.bind[GrpcServerConfig].toInstance(this)
.bind[GrpcServer].toProvider { (config: GrpcServerConfig, session: Session) => config.newServer(session) }
.onStart { _.start }
.bind[GrpcServiceThreadExecutor].toInstance(Executors.newCachedThreadPool())
.onShutdown(_.shutdownNow())
}

/**
* Create a design for GrpcServer and ManagedChannel. Useful for testing purpsoe
* Create a design for GrpcServer and ManagedChannel. Useful for testing purpose
* @return
*/
def designWithChannel: Design = {
Expand All @@ -99,12 +115,42 @@ case class GrpcServerConfig(
}
}

class GrpcServer(grpcServerConfig: GrpcServerConfig, server: Server) extends AutoCloseable with LogSupport {
def port: Int = grpcServerConfig.port
def localAddress: String = s"localhost:${grpcServerConfig.port}"
/**
* GrpcService is a holder of the thread executor and service definitions for running gRPC servers
*/
case class GrpcService(
config: GrpcServerConfig,
executorService: ExecutorService,
serviceDefinitions: Seq[ServerServiceDefinition]
) extends AutoCloseable
with LogSupport {
def newServer: GrpcServer = {
trace(s"service:\n${serviceDefinitions.map(_.getServiceDescriptor).mkString("\n")}")
// We need to use NettyServerBuilder explicitly when NettyServerBuilder cannot be found from the classpath (e.g., onejar)
val serverBuilder = NettyServerBuilder.forPort(config.port)
for (service <- serviceDefinitions) {
serverBuilder.addService(service)
}
for (interceptor <- config.interceptors) {
serverBuilder.intercept(interceptor)
}
val customServerBuilder = config.serverInitializer(serverBuilder)
val server = new GrpcServer(this, customServerBuilder.build())
server.start
server
}

override def close(): Unit = {
executorService.shutdownNow()
}
}

class GrpcServer(grpcService: GrpcService, server: Server) extends AutoCloseable with LogSupport {
def port: Int = grpcService.config.port
def localAddress: String = s"localhost:${port}"

def start: Unit = {
info(s"Starting gRPC server ${grpcServerConfig.name} at ${localAddress}")
info(s"Starting gRPC server: (${grpcService.config.name}) at ${localAddress}")
server.start()
}

Expand All @@ -113,7 +159,56 @@ class GrpcServer(grpcServerConfig: GrpcServerConfig, server: Server) extends Aut
}

override def close(): Unit = {
info(s"Closing gRPC server ${grpcServerConfig.name} at ${localAddress}")
info(s"Closing gRPC server (${grpcService.config.name}) at ${localAddress}")
server.shutdownNow()
grpcService.close()
}
}

/**
* GrpcServerFactory manages
* @param session
*/
class GrpcServerFactory(session: Session) extends AutoCloseable with LogSupport {
private var createdServers = List.empty[GrpcServer]

def newServer(config: GrpcServerConfig): GrpcServer = {
val server = config.newServer(session)
synchronized {
createdServers = server :: createdServers
}
server
}

def awaitTermination: Unit = {
// Workaround for `.par` in Scala 2.13, which requires import scala.collection.parallel.CollectionConverters._
// But this import doesn't work in Scala 2.12
val b = ParVector.newBuilder[GrpcServer]
b ++= createdServers
b.result().foreach(_.awaitTermination)
}

override def close(): Unit = {
debug(s"Closing GrpcServerFactory")
val ex = Seq.newBuilder[Throwable]
for (server <- createdServers) {
try {
server.close()
} catch {
case NonFatal(e) =>
ex += e
}
}
createdServers = List.empty

val exceptions = ex.result()
if (exceptions.nonEmpty) {
if (exceptions.size == 1) {
throw exceptions.head
} else {
throw MultipleExceptions(exceptions)
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,24 @@
* limitations under the License.
*/
package wvlet.airframe.http.grpc
import java.io.{ByteArrayInputStream, InputStream}
import java.util.concurrent.ExecutorService

import io.grpc.MethodDescriptor.Marshaller
import io.grpc.stub.ServerCalls
import io.grpc.{MethodDescriptor, ServerServiceDefinition}
import wvlet.airframe.Session
import wvlet.airframe.codec.{MessageCodec, MessageCodecFactory}
import wvlet.airframe.control.IO
import wvlet.airframe.http.Router
import wvlet.airframe.http.router.Route
import wvlet.airframe.msgpack.spi.MsgPack
import wvlet.airframe.rx.Rx
import wvlet.airframe.surface.{MethodParameter, MethodSurface, Surface}
import wvlet.log.LogSupport

import java.io.{ByteArrayInputStream, InputStream}
import java.util.concurrent.ExecutorService

/**
*/
object GrpcServiceBuilder {

type GrpcServiceThreadExecutor = ExecutorService

private implicit class RichMethod(val m: MethodSurface) extends AnyVal {

private def findClientStreamingArg: Option[MethodParameter] = {
Expand Down Expand Up @@ -86,24 +82,22 @@ object GrpcServiceBuilder {
}

def buildService(
router: Router,
session: Session,
codecFactory: MessageCodecFactory = MessageCodecFactory.defaultFactoryForJSON
): Seq[ServerServiceDefinition] = {
config: GrpcServerConfig,
session: Session
): GrpcService = {
val threadManager: ExecutorService = config.executorProvider(config)
val services =
for ((serviceName, routes) <- router.routes.groupBy(_.serviceName))
for ((serviceName, routes) <- config.router.routes.groupBy(_.serviceName))
yield {
val routeAndMethods = for (route <- routes) yield {
(route, buildMethodDescriptor(route, codecFactory))
(route, buildMethodDescriptor(route, config.codecFactory))
}

val serviceBuilder = ServerServiceDefinition.builder(serviceName)

for ((r, m) <- routeAndMethods) {
// TODO Support Client/Server Streams
val controller = session.getInstanceOf(r.controllerSurface)
val threadManager = session.build[GrpcServiceThreadExecutor]
val requestHandler = new RPCRequestHandler(controller, r.methodSurface, codecFactory, threadManager)
val requestHandler = new RPCRequestHandler(controller, r.methodSurface, config.codecFactory, threadManager)
val serverCall = r.methodSurface.grpcMethodType match {
case MethodDescriptor.MethodType.UNARY =>
ServerCalls.asyncUnaryCall(new RPCUnaryMethodHandler(requestHandler))
Expand All @@ -126,7 +120,7 @@ object GrpcServiceBuilder {
serviceDef
}

services.toSeq
GrpcService(config, threadManager, services.toSeq)
}

object RPCRequestMarshaller extends Marshaller[MsgPack] with LogSupport {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package wvlet.airframe.http.grpc

import wvlet.airframe.http.Router
import wvlet.airframe.http.grpc.example.Greeter
import wvlet.airspec.AirSpec

object GrpcServerFactoryTest extends AirSpec {

private val r = Router.add[Greeter]

test("Build multiple gRPC servers") { f: GrpcServerFactory =>
val s1 = f.newServer(gRPC.server.withName("grpc1").withRouter(r))
val s2 = f.newServer(gRPC.server.withName("grpc2").withRouter(r))
}
}
Loading

0 comments on commit b1d1e43

Please sign in to comment.