From 4495bf73e5cfe877506fed93e770e38ab0e47f66 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Mon, 3 Jun 2024 10:00:44 +0200 Subject: [PATCH 1/2] Use IdleStateHandlers --- build.sbt | 3 +- .../server/netty/cats/NettyCatsServer.scala | 2 +- .../sttp/tapir/server/netty/NettyConfig.scala | 14 ++- .../server/netty/NettyFutureServer.scala | 2 +- .../netty/internal/NettyBootstrap.scala | 14 +-- .../netty/internal/NettyServerHandler.scala | 97 ++++++++++++------- .../internal/UnhandledExceptionHandler.scala | 16 +++ .../server/netty/sync/NettySyncServer.scala | 15 ++- .../server/netty/zio/NettyZioServer.scala | 3 +- 9 files changed, 101 insertions(+), 65 deletions(-) create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/UnhandledExceptionHandler.scala diff --git a/build.sbt b/build.sbt index e12b109649..9141de9878 100644 --- a/build.sbt +++ b/build.sbt @@ -2066,7 +2066,8 @@ lazy val examples: ProjectMatrix = (projectMatrix in file("examples")) scalaTest.value, logback ), - publishArtifact := false + publishArtifact := false, + Compile / run / fork := true ) .jvmPlatform(scalaVersions = examplesScalaVersions) .dependsOn( diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala index 8ebe94f74c..417e0957ad 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala @@ -78,7 +78,7 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty val channelFuture = NettyBootstrap( config, - new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader, config.isSsl), + new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala index d098a6874c..0f607af361 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala @@ -30,9 +30,8 @@ import scala.concurrent.duration._ * contains tapir's server processing logic. * * @param requestTimeout - * The maximum duration, to wait for a response before considering the request timed out. - * @throws ReadTimeoutException - * when no data is read from Netty within the specified period of time for a request. + * The maximum duration to wait for a response to be produced. If exceeded, the server will return a HTTP 503 response and close the + * channel. This timeout is ignored in Web Sockets (after a handshake is established). Make sure it's lower than `idleTimeout`. * * @param connectionTimeout * Specifies the maximum duration within which a connection between a client and a server must be established. @@ -45,6 +44,10 @@ import scala.concurrent.duration._ * If set, attempts to wait for a given time for all in-flight requests to complete, before proceeding with shutting down the server. If * `None`, closes the channels and terminates the server without waiting. * + * @param idleTimeout + * Maximum inactivity time of a given connection. If nothing is sent or received within this timeout, the connection closes. Make sure + * it's greater than `requestTimeout`, which should be the first one to be triggered if it's taking too long to produce a response. + * * @param serverHeader * If set, send this value in the 'Server' response header. If None, don't set the header. */ @@ -64,6 +67,7 @@ case class NettyConfig( socketConfig: NettySocketConfig, initPipeline: NettyConfig => (ChannelPipeline, ChannelHandler) => Unit, gracefulShutdownTimeout: Option[FiniteDuration], + idleTimeout: Option[FiniteDuration], serverHeader: Option[String] ) { def host(h: String): NettyConfig = copy(host = h) @@ -80,7 +84,8 @@ case class NettyConfig( def requestTimeout(r: FiniteDuration): NettyConfig = copy(requestTimeout = Some(r)) def connectionTimeout(c: FiniteDuration): NettyConfig = copy(connectionTimeout = Some(c)) - def lingerTimeout(l: FiniteDuration): NettyConfig = copy(requestTimeout = Some(l)) + def lingerTimeout(l: FiniteDuration): NettyConfig = copy(lingerTimeout = Some(l)) + def idleTimeout(r: FiniteDuration): NettyConfig = copy(idleTimeout = Some(r)) def withSocketKeepAlive: NettyConfig = copy(socketKeepAlive = true) def withNoSocketKeepAlive: NettyConfig = copy(socketKeepAlive = false) @@ -119,6 +124,7 @@ object NettyConfig { connectionTimeout = Some(10.seconds), lingerTimeout = None, // see #3576 gracefulShutdownTimeout = Some(10.seconds), + idleTimeout = Some(60.seconds), maxConnections = None, addLoggingHandler = false, sslContext = None, diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala index 0d4285f4e5..5162e29eb8 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala @@ -71,7 +71,7 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe val channelFuture = NettyBootstrap( config, - new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader, config.isSsl), + new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala index 07edb5ace2..143efc3bc8 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala @@ -2,15 +2,12 @@ package sttp.tapir.server.netty.internal import io.netty.bootstrap.ServerBootstrap import io.netty.channel.{Channel, ChannelFuture, ChannelHandler, ChannelInitializer, ChannelOption, EventLoopGroup} -import io.netty.handler.timeout.ReadTimeoutHandler import sttp.tapir.server.netty.NettyConfig import java.net.{InetSocketAddress, SocketAddress} object NettyBootstrap { - private val ReadTimeoutHandlerName = "readTimeoutHandler" - def apply[F[_]]( nettyConfig: NettyConfig, handler: => NettyServerHandler[F], @@ -26,15 +23,8 @@ object NettyBootstrap { .childHandler(new ChannelInitializer[Channel] { override def initChannel(ch: Channel): Unit = { val nettyConfigBuilder = nettyConfig.initPipeline(nettyConfig) - - nettyConfig.requestTimeout match { - case Some(requestTimeout) => - nettyConfigBuilder( - ch.pipeline().addLast(ReadTimeoutHandlerName, new ReadTimeoutHandler(requestTimeout.toSeconds.toInt)), - handler, - ) - case None => nettyConfigBuilder(ch.pipeline(), handler) - } + nettyConfigBuilder(ch.pipeline(), handler) + ch.pipeline().addLast(new UnhandledExceptionHandler) connectionCounterOpt.map(counter => { ch.pipeline().addFirst(counter) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index 8ce0004b9a..5d070c7995 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -6,7 +6,7 @@ import io.netty.channel.group.ChannelGroup import io.netty.handler.codec.http._ import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory import io.netty.handler.stream.{ChunkedFile, ChunkedStream} -import io.netty.handler.timeout.ReadTimeoutHandler +import io.netty.handler.timeout.{IdleState, IdleStateEvent, IdleStateHandler} import org.playframework.netty.http.{DefaultStreamedHttpResponse, DefaultWebSocketHttpResponse, StreamedHttpRequest} import org.reactivestreams.Publisher import org.slf4j.LoggerFactory @@ -21,8 +21,9 @@ import sttp.tapir.server.netty.NettyResponseContent.{ ReactiveWebSocketProcessorNettyResponseContent } import sttp.tapir.server.netty.internal.ws.{NettyControlFrameHandler, WebSocketAutoPingHandler} -import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} +import sttp.tapir.server.netty.{NettyConfig, NettyResponse, NettyServerRequest, Route} +import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ import scala.collection.mutable.{Queue => MutableQueue} @@ -40,8 +41,7 @@ class NettyServerHandler[F[_]]( unsafeRunAsync: (() => F[ServerResponse[NettyResponse]]) => (Future[ServerResponse[NettyResponse]], () => Future[Unit]), channelGroup: ChannelGroup, isShuttingDown: AtomicBoolean, - serverHeader: Option[String], - isSsl: Boolean = false + config: NettyConfig )(implicit me: MonadError[F] ) extends SimpleChannelInboundHandler[HttpRequest] { @@ -82,7 +82,9 @@ class NettyServerHandler[F[_]]( if (eventLoopContext == null) { // Initialize our ExecutionContext eventLoopContext = ExecutionContext.fromExecutor(ctx.channel.eventLoop) - + config.idleTimeout.foreach { idleTimeout => + ctx.pipeline().addFirst(new IdleStateHandler(0, 0, idleTimeout.toMillis.toInt, TimeUnit.MILLISECONDS)) + } // When the channel closes we want to cancel any pending dispatches. // Since the listener will be executed from the channels EventLoop everything is thread safe. val _ = ctx.channel.closeFuture.addListener { (_: ChannelFuture) => @@ -96,6 +98,29 @@ class NettyServerHandler[F[_]]( } } + def writeError503ThenClose(ctx: ChannelHandlerContext): Unit = { + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.SERVICE_UNAVAILABLE) + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) + res.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) + val _ = ctx.writeAndFlush(res).addListener(ChannelFutureListener.CLOSE) + } + + override def userEventTriggered(ctx: ChannelHandlerContext, evt: Any): Unit = { + evt match { + case e: IdleStateEvent => + if (e.state() == IdleState.WRITER_IDLE) { + logger.error(s"Closing connection due to exceeded response timeout of ${config.requestTimeout}") + writeError503ThenClose(ctx) + } + if (e.state() == IdleState.ALL_IDLE) { + logger.debug(s"Closing connection due to exceeded idle timeout of ${config.idleTimeout}") + val _ = ctx.close() + } + case other => + super.userEventTriggered(ctx, evt) + } + } + override def channelRead0(ctx: ChannelHandlerContext, request: HttpRequest): Unit = { def writeError500(req: HttpRequest, reason: Throwable): Unit = { @@ -107,14 +132,11 @@ class NettyServerHandler[F[_]]( ctx.writeAndFlush(res).closeIfNeeded(req) } - def writeError503(req: HttpRequest): Unit = { - val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.SERVICE_UNAVAILABLE) - res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) - res.handleCloseAndKeepAliveHeaders(req) - ctx.writeAndFlush(res).closeIfNeeded(req) - } - def runRoute(req: HttpRequest, releaseReq: () => Any = () => ()): Unit = { + val idleHandler = config.requestTimeout.map { requestTimeout => + new IdleStateHandler(0, requestTimeout.toMillis.toInt, 0, TimeUnit.MILLISECONDS) + } + idleHandler.foreach(h => ctx.pipeline().addFirst(h)) val (runningFuture, cancellationSwitch) = unsafeRunAsync { () => route(NettyServerRequest(req)) .map { @@ -124,34 +146,38 @@ class NettyServerHandler[F[_]]( } pendingResponses.enqueue(cancellationSwitch) lastResponseSent = lastResponseSent.flatMap { _ => - runningFuture.transform { - case Success(serverResponse) => - pendingResponses.dequeue() - try { - handleResponse(ctx, req, serverResponse) - Success(()) - } catch { - case NonFatal(ex) => + runningFuture + .andThen { case _ => + idleHandler.foreach(ctx.pipeline().remove) + }(eventLoopContext) + .transform { + case Success(serverResponse) => + pendingResponses.dequeue() + try { + handleResponse(ctx, req, serverResponse) + Success(()) + } catch { + case NonFatal(ex) => + writeError500(req, ex) + Failure(ex) + } finally { + val _ = releaseReq() + } + case Failure(NonFatal(ex)) => + try { writeError500(req, ex) Failure(ex) - } finally { - val _ = releaseReq() - } - case Failure(NonFatal(ex)) => - try { - writeError500(req, ex) - Failure(ex) - } finally { - val _ = releaseReq() - } - case Failure(fatalException) => Failure(fatalException) - }(eventLoopContext) + } finally { + val _ = releaseReq() + } + case Failure(fatalException) => Failure(fatalException) + }(eventLoopContext) }(eventLoopContext) } if (isShuttingDown.get()) { logger.info("Rejecting request, server is shutting down") - writeError503(request) + writeError503ThenClose(ctx) } else if (HttpUtil.is100ContinueExpected(request)) { ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE)) () @@ -250,7 +276,6 @@ class NettyServerHandler[F[_]]( handshakeReq: HttpRequest ) = { ctx.pipeline().remove(this) - ctx.pipeline().remove(classOf[ReadTimeoutHandler]) ctx .pipeline() .addAfter( @@ -291,7 +316,7 @@ class NettyServerHandler[F[_]]( // Only ancient WS protocol versions will use this in the response header. private def wsUrl(req: HttpRequest): String = { - val scheme = if (isSsl) "wss" else "ws" + val scheme = if (config.isSsl) "wss" else "ws" s"$scheme://${req.headers().get(HttpHeaderNames.HOST)}${req.uri()}" } private implicit class RichServerNettyResponse(r: ServerResponse[NettyResponse]) { @@ -323,7 +348,7 @@ class NettyServerHandler[F[_]]( private implicit class RichHttpMessage(m: HttpMessage) { def setHeadersFrom(response: ServerResponse[_]): Unit = { - serverHeader.foreach(m.headers().set(HttpHeaderNames.SERVER, _)) + config.serverHeader.foreach(m.headers().set(HttpHeaderNames.SERVER, _)) response.headers .groupBy(_.name) .foreach { case (k, v) => diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/UnhandledExceptionHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/UnhandledExceptionHandler.scala new file mode 100644 index 0000000000..2c57673c32 --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/UnhandledExceptionHandler.scala @@ -0,0 +1,16 @@ +package sttp.tapir.server.netty.internal + +import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter} +import io.netty.util.internal.logging.InternalLoggerFactory + +private[internal] class UnhandledExceptionHandler extends ChannelInboundHandlerAdapter { + private lazy val logger = InternalLoggerFactory.getInstance(getClass) + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + cause match { + case ex => + logger.warn("Unhandled exception", ex) + } + val _ = ctx.close() + } +} diff --git a/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/NettySyncServer.scala b/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/NettySyncServer.scala index 2e8e11e693..70ae9f9f41 100644 --- a/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/NettySyncServer.scala +++ b/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/NettySyncServer.scala @@ -27,15 +27,15 @@ import scala.util.control.NonFatal * options. */ private[sync] case class NettySyncServerEndpointListOverridenOptions( - ses: List[ServerEndpoint[OxStreams & WebSockets, Identity]], - overridenOptions: NettySyncServerOptions + ses: List[ServerEndpoint[OxStreams & WebSockets, Identity]], + overridenOptions: NettySyncServerOptions ) case class NettySyncServer( - endpoints: List[ServerEndpoint[OxStreams & WebSockets, Identity]], - endpointsWithOptions: List[NettySyncServerEndpointListOverridenOptions], - options: NettySyncServerOptions, - config: NettyConfig + endpoints: List[ServerEndpoint[OxStreams & WebSockets, Identity]], + endpointsWithOptions: List[NettySyncServerEndpointListOverridenOptions], + options: NettySyncServerOptions, + config: NettyConfig ): private val executor = Executors.newVirtualThreadPerTaskExecutor() @@ -134,8 +134,7 @@ case class NettySyncServer( unsafeRunF, channelGroup, isShuttingDown, - config.serverHeader, - config.isSsl + config ), eventLoopGroup, socketOverride diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala index 390899e21a..12d1973d76 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala @@ -91,8 +91,7 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: unsafeRunAsync(runtime), channelGroup, isShuttingDown, - config.serverHeader, - config.isSsl + config ), eventLoopGroup, socketOverride From 686e83fc42634ccbbdf136590d5b810545df5ce6 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Wed, 5 Jun 2024 09:06:54 +0200 Subject: [PATCH 2/2] Review fixes --- .../main/scala/sttp/tapir/server/netty/NettyConfig.scala | 2 +- .../tapir/server/netty/internal/NettyServerHandler.scala | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala index 0f607af361..dc799c362a 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala @@ -31,7 +31,7 @@ import scala.concurrent.duration._ * * @param requestTimeout * The maximum duration to wait for a response to be produced. If exceeded, the server will return a HTTP 503 response and close the - * channel. This timeout is ignored in Web Sockets (after a handshake is established). Make sure it's lower than `idleTimeout`. + * connection. This timeout is ignored in Web Sockets (after a handshake is established). Make sure it's lower than `idleTimeout`. * * @param connectionTimeout * Specifies the maximum duration within which a connection between a client and a server must be established. diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index 5d070c7995..a0d49f546b 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -133,10 +133,10 @@ class NettyServerHandler[F[_]]( } def runRoute(req: HttpRequest, releaseReq: () => Any = () => ()): Unit = { - val idleHandler = config.requestTimeout.map { requestTimeout => + val requestTimeoutHandler = config.requestTimeout.map { requestTimeout => new IdleStateHandler(0, requestTimeout.toMillis.toInt, 0, TimeUnit.MILLISECONDS) } - idleHandler.foreach(h => ctx.pipeline().addFirst(h)) + requestTimeoutHandler.foreach(h => ctx.pipeline().addFirst(h)) val (runningFuture, cancellationSwitch) = unsafeRunAsync { () => route(NettyServerRequest(req)) .map { @@ -148,7 +148,7 @@ class NettyServerHandler[F[_]]( lastResponseSent = lastResponseSent.flatMap { _ => runningFuture .andThen { case _ => - idleHandler.foreach(ctx.pipeline().remove) + requestTimeoutHandler.foreach(ctx.pipeline().remove) }(eventLoopContext) .transform { case Success(serverResponse) =>