Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -2066,7 +2066,8 @@ lazy val examples: ProjectMatrix = (projectMatrix in file("examples"))
scalaTest.value,
logback
),
publishArtifact := false
publishArtifact := false,
Compile / run / fork := true
Comment thread
kciesielski marked this conversation as resolved.
)
.jvmPlatform(scalaVersions = examplesScalaVersions)
.dependsOn(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty
val channelFuture =
NettyBootstrap(
config,
new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader, config.isSsl),
new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config),
eventLoopGroup,
socketOverride
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@ import scala.concurrent.duration._
* contains tapir's server processing logic.
*
* @param requestTimeout
* The maximum duration, to wait for a response before considering the request timed out.
* @throws ReadTimeoutException
* when no data is read from Netty within the specified period of time for a request.
* The maximum duration to wait for a response to be produced. If exceeded, the server will return a HTTP 503 response and close the
Comment thread
kciesielski marked this conversation as resolved.
* connection. This timeout is ignored in Web Sockets (after a handshake is established). Make sure it's lower than `idleTimeout`.
*
* @param connectionTimeout
* Specifies the maximum duration within which a connection between a client and a server must be established.
Expand All @@ -45,6 +44,10 @@ import scala.concurrent.duration._
* If set, attempts to wait for a given time for all in-flight requests to complete, before proceeding with shutting down the server. If
* `None`, closes the channels and terminates the server without waiting.
*
* @param idleTimeout
* Maximum inactivity time of a given connection. If nothing is sent or received within this timeout, the connection closes. Make sure
* it's greater than `requestTimeout`, which should be the first one to be triggered if it's taking too long to produce a response.
*
* @param serverHeader
* If set, send this value in the 'Server' response header. If None, don't set the header.
*/
Expand All @@ -64,6 +67,7 @@ case class NettyConfig(
socketConfig: NettySocketConfig,
initPipeline: NettyConfig => (ChannelPipeline, ChannelHandler) => Unit,
gracefulShutdownTimeout: Option[FiniteDuration],
idleTimeout: Option[FiniteDuration],
serverHeader: Option[String]
) {
def host(h: String): NettyConfig = copy(host = h)
Expand All @@ -80,7 +84,8 @@ case class NettyConfig(

def requestTimeout(r: FiniteDuration): NettyConfig = copy(requestTimeout = Some(r))
def connectionTimeout(c: FiniteDuration): NettyConfig = copy(connectionTimeout = Some(c))
def lingerTimeout(l: FiniteDuration): NettyConfig = copy(requestTimeout = Some(l))
def lingerTimeout(l: FiniteDuration): NettyConfig = copy(lingerTimeout = Some(l))
def idleTimeout(r: FiniteDuration): NettyConfig = copy(idleTimeout = Some(r))

def withSocketKeepAlive: NettyConfig = copy(socketKeepAlive = true)
def withNoSocketKeepAlive: NettyConfig = copy(socketKeepAlive = false)
Expand Down Expand Up @@ -119,6 +124,7 @@ object NettyConfig {
connectionTimeout = Some(10.seconds),
lingerTimeout = None, // see #3576
gracefulShutdownTimeout = Some(10.seconds),
idleTimeout = Some(60.seconds),
maxConnections = None,
addLoggingHandler = false,
sslContext = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe
val channelFuture =
NettyBootstrap(
config,
new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader, config.isSsl),
new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config),
eventLoopGroup,
socketOverride
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@ package sttp.tapir.server.netty.internal

import io.netty.bootstrap.ServerBootstrap
import io.netty.channel.{Channel, ChannelFuture, ChannelHandler, ChannelInitializer, ChannelOption, EventLoopGroup}
import io.netty.handler.timeout.ReadTimeoutHandler
import sttp.tapir.server.netty.NettyConfig

import java.net.{InetSocketAddress, SocketAddress}

object NettyBootstrap {

private val ReadTimeoutHandlerName = "readTimeoutHandler"

def apply[F[_]](
nettyConfig: NettyConfig,
handler: => NettyServerHandler[F],
Expand All @@ -26,15 +23,8 @@ object NettyBootstrap {
.childHandler(new ChannelInitializer[Channel] {
override def initChannel(ch: Channel): Unit = {
val nettyConfigBuilder = nettyConfig.initPipeline(nettyConfig)

nettyConfig.requestTimeout match {
case Some(requestTimeout) =>
nettyConfigBuilder(
ch.pipeline().addLast(ReadTimeoutHandlerName, new ReadTimeoutHandler(requestTimeout.toSeconds.toInt)),
handler,
)
case None => nettyConfigBuilder(ch.pipeline(), handler)
}
nettyConfigBuilder(ch.pipeline(), handler)
ch.pipeline().addLast(new UnhandledExceptionHandler)

connectionCounterOpt.map(counter => {
ch.pipeline().addFirst(counter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import io.netty.channel.group.ChannelGroup
import io.netty.handler.codec.http._
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory
import io.netty.handler.stream.{ChunkedFile, ChunkedStream}
import io.netty.handler.timeout.ReadTimeoutHandler
import io.netty.handler.timeout.{IdleState, IdleStateEvent, IdleStateHandler}
import org.playframework.netty.http.{DefaultStreamedHttpResponse, DefaultWebSocketHttpResponse, StreamedHttpRequest}
import org.reactivestreams.Publisher
import org.slf4j.LoggerFactory
Expand All @@ -21,8 +21,9 @@ import sttp.tapir.server.netty.NettyResponseContent.{
ReactiveWebSocketProcessorNettyResponseContent
}
import sttp.tapir.server.netty.internal.ws.{NettyControlFrameHandler, WebSocketAutoPingHandler}
import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route}
import sttp.tapir.server.netty.{NettyConfig, NettyResponse, NettyServerRequest, Route}

import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.JavaConverters._
import scala.collection.mutable.{Queue => MutableQueue}
Expand All @@ -40,8 +41,7 @@ class NettyServerHandler[F[_]](
unsafeRunAsync: (() => F[ServerResponse[NettyResponse]]) => (Future[ServerResponse[NettyResponse]], () => Future[Unit]),
channelGroup: ChannelGroup,
isShuttingDown: AtomicBoolean,
serverHeader: Option[String],
isSsl: Boolean = false
config: NettyConfig
)(implicit
me: MonadError[F]
) extends SimpleChannelInboundHandler[HttpRequest] {
Expand Down Expand Up @@ -82,7 +82,9 @@ class NettyServerHandler[F[_]](
if (eventLoopContext == null) {
// Initialize our ExecutionContext
eventLoopContext = ExecutionContext.fromExecutor(ctx.channel.eventLoop)

config.idleTimeout.foreach { idleTimeout =>
ctx.pipeline().addFirst(new IdleStateHandler(0, 0, idleTimeout.toMillis.toInt, TimeUnit.MILLISECONDS))
Comment thread
kciesielski marked this conversation as resolved.
}
// When the channel closes we want to cancel any pending dispatches.
// Since the listener will be executed from the channels EventLoop everything is thread safe.
val _ = ctx.channel.closeFuture.addListener { (_: ChannelFuture) =>
Expand All @@ -96,6 +98,29 @@ class NettyServerHandler[F[_]](
}
}

def writeError503ThenClose(ctx: ChannelHandlerContext): Unit = {
val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.SERVICE_UNAVAILABLE)
res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0)
res.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE)
val _ = ctx.writeAndFlush(res).addListener(ChannelFutureListener.CLOSE)
}

override def userEventTriggered(ctx: ChannelHandlerContext, evt: Any): Unit = {
evt match {
case e: IdleStateEvent =>
if (e.state() == IdleState.WRITER_IDLE) {
logger.error(s"Closing connection due to exceeded response timeout of ${config.requestTimeout}")
writeError503ThenClose(ctx)
}
if (e.state() == IdleState.ALL_IDLE) {
logger.debug(s"Closing connection due to exceeded idle timeout of ${config.idleTimeout}")
val _ = ctx.close()
}
case other =>
super.userEventTriggered(ctx, evt)
}
}

override def channelRead0(ctx: ChannelHandlerContext, request: HttpRequest): Unit = {

def writeError500(req: HttpRequest, reason: Throwable): Unit = {
Expand All @@ -107,14 +132,11 @@ class NettyServerHandler[F[_]](
ctx.writeAndFlush(res).closeIfNeeded(req)
}

def writeError503(req: HttpRequest): Unit = {
val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.SERVICE_UNAVAILABLE)
res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0)
res.handleCloseAndKeepAliveHeaders(req)
ctx.writeAndFlush(res).closeIfNeeded(req)
}

def runRoute(req: HttpRequest, releaseReq: () => Any = () => ()): Unit = {
val requestTimeoutHandler = config.requestTimeout.map { requestTimeout =>
new IdleStateHandler(0, requestTimeout.toMillis.toInt, 0, TimeUnit.MILLISECONDS)
}
requestTimeoutHandler.foreach(h => ctx.pipeline().addFirst(h))
val (runningFuture, cancellationSwitch) = unsafeRunAsync { () =>
route(NettyServerRequest(req))
.map {
Expand All @@ -124,34 +146,38 @@ class NettyServerHandler[F[_]](
}
pendingResponses.enqueue(cancellationSwitch)
lastResponseSent = lastResponseSent.flatMap { _ =>
runningFuture.transform {
case Success(serverResponse) =>
pendingResponses.dequeue()
try {
handleResponse(ctx, req, serverResponse)
Success(())
} catch {
case NonFatal(ex) =>
runningFuture
.andThen { case _ =>
requestTimeoutHandler.foreach(ctx.pipeline().remove)
}(eventLoopContext)
.transform {
case Success(serverResponse) =>
pendingResponses.dequeue()
try {
handleResponse(ctx, req, serverResponse)
Success(())
} catch {
case NonFatal(ex) =>
writeError500(req, ex)
Failure(ex)
} finally {
val _ = releaseReq()
}
case Failure(NonFatal(ex)) =>
try {
writeError500(req, ex)
Failure(ex)
} finally {
val _ = releaseReq()
}
case Failure(NonFatal(ex)) =>
try {
writeError500(req, ex)
Failure(ex)
} finally {
val _ = releaseReq()
}
case Failure(fatalException) => Failure(fatalException)
}(eventLoopContext)
} finally {
val _ = releaseReq()
}
case Failure(fatalException) => Failure(fatalException)
}(eventLoopContext)
}(eventLoopContext)
}

if (isShuttingDown.get()) {
logger.info("Rejecting request, server is shutting down")
writeError503(request)
writeError503ThenClose(ctx)
} else if (HttpUtil.is100ContinueExpected(request)) {
ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE))
()
Expand Down Expand Up @@ -250,7 +276,6 @@ class NettyServerHandler[F[_]](
handshakeReq: HttpRequest
) = {
ctx.pipeline().remove(this)
ctx.pipeline().remove(classOf[ReadTimeoutHandler])
ctx
.pipeline()
.addAfter(
Expand Down Expand Up @@ -291,7 +316,7 @@ class NettyServerHandler[F[_]](

// Only ancient WS protocol versions will use this in the response header.
private def wsUrl(req: HttpRequest): String = {
val scheme = if (isSsl) "wss" else "ws"
val scheme = if (config.isSsl) "wss" else "ws"
s"$scheme://${req.headers().get(HttpHeaderNames.HOST)}${req.uri()}"
}
private implicit class RichServerNettyResponse(r: ServerResponse[NettyResponse]) {
Expand Down Expand Up @@ -323,7 +348,7 @@ class NettyServerHandler[F[_]](

private implicit class RichHttpMessage(m: HttpMessage) {
def setHeadersFrom(response: ServerResponse[_]): Unit = {
serverHeader.foreach(m.headers().set(HttpHeaderNames.SERVER, _))
config.serverHeader.foreach(m.headers().set(HttpHeaderNames.SERVER, _))
response.headers
.groupBy(_.name)
.foreach { case (k, v) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package sttp.tapir.server.netty.internal

import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter}
import io.netty.util.internal.logging.InternalLoggerFactory

private[internal] class UnhandledExceptionHandler extends ChannelInboundHandlerAdapter {
private lazy val logger = InternalLoggerFactory.getInstance(getClass)

override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
cause match {
case ex =>
logger.warn("Unhandled exception", ex)
}
val _ = ctx.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ import scala.util.control.NonFatal
* options.
*/
private[sync] case class NettySyncServerEndpointListOverridenOptions(
ses: List[ServerEndpoint[OxStreams & WebSockets, Identity]],
overridenOptions: NettySyncServerOptions
ses: List[ServerEndpoint[OxStreams & WebSockets, Identity]],
overridenOptions: NettySyncServerOptions
)

case class NettySyncServer(
endpoints: List[ServerEndpoint[OxStreams & WebSockets, Identity]],
endpointsWithOptions: List[NettySyncServerEndpointListOverridenOptions],
options: NettySyncServerOptions,
config: NettyConfig
endpoints: List[ServerEndpoint[OxStreams & WebSockets, Identity]],
endpointsWithOptions: List[NettySyncServerEndpointListOverridenOptions],
options: NettySyncServerOptions,
config: NettyConfig
):
private val executor = Executors.newVirtualThreadPerTaskExecutor()

Expand Down Expand Up @@ -134,8 +134,7 @@ case class NettySyncServer(
unsafeRunF,
channelGroup,
isShuttingDown,
config.serverHeader,
config.isSsl
config
),
eventLoopGroup,
socketOverride
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options:
unsafeRunAsync(runtime),
channelGroup,
isShuttingDown,
config.serverHeader,
config.isSsl
config
),
eventLoopGroup,
socketOverride
Expand Down