Skip to content

Commit

Permalink
Lazy InputStream over Netty HttpContent Publisher (#3637)
Browse files Browse the repository at this point in the history
  • Loading branch information
ghik committed Mar 27, 2024
1 parent a2eca04 commit da9a76f
Show file tree
Hide file tree
Showing 10 changed files with 297 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ trait MonadErrorSyntax {
})

override def ensure[T](f: G[T], e: => G[Unit]): G[T] = fk(mef.ensure(gK(f), gK(e)))

override def blocking[T](t: => T): G[T] = fk(mef.blocking(t))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ class RIOMonadError[R] extends MonadError[RIO[R, *]] {
override def suspend[T](t: => RIO[R, T]): RIO[R, T] = ZIO.suspend(t)
override def flatten[T](ffa: RIO[R, RIO[R, T]]): RIO[R, T] = ffa.flatten
override def ensure[T](f: RIO[R, T], e: => RIO[R, Unit]): RIO[R, T] = f.ensuring(e.ignore)
override def blocking[T](t: => T): RIO[R, T] = ZIO.attemptBlocking(t)
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,7 @@ object TestMonadError {
rt.catchSome(h)

override def ensure[T](f: TestEffect[T], e: => TestEffect[Unit]): TestEffect[T] = f.ensuring(e.ignore)

override def blocking[T](t: => T): TestEffect[T] = ZIO.attemptBlocking(t)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,17 @@ import io.netty.buffer.Unpooled
import io.netty.handler.codec.http.{FullHttpRequest, HttpContent}
import org.playframework.netty.http.StreamedHttpRequest
import org.reactivestreams.Publisher
import sttp.capabilities.Streams
import sttp.model.HeaderNames
import sttp.monad.MonadError
import sttp.monad.syntax._
import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile}
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.interpreter.RequestBody
import sttp.tapir.RawBodyType
import sttp.tapir.TapirFile
import sttp.tapir.server.interpreter.RawValue
import sttp.tapir.FileRange
import sttp.tapir.InputStreamRange
import java.io.ByteArrayInputStream
import sttp.tapir.server.interpreter.{RawValue, RequestBody}
import sttp.tapir.server.netty.internal.reactivestreams.SubscriberInputStream

import java.io.{ByteArrayInputStream, InputStream}
import java.nio.ByteBuffer
import sttp.capabilities.Streams
import sttp.model.HeaderNames

/** Common logic for processing request body in all Netty backends. It requires particular backends to implement a few operations. */
private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody[F, S] {
Expand Down Expand Up @@ -64,17 +62,16 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody
case RawBodyType.ByteBufferBody =>
readAllBytes(serverRequest, maxBytes).map(bs => RawValue(ByteBuffer.wrap(bs)))
case RawBodyType.InputStreamBody =>
// Possibly can be optimized to avoid loading all data eagerly into memory
readAllBytes(serverRequest, maxBytes).map(bs => RawValue(new ByteArrayInputStream(bs)))
monad.eval(RawValue(readAsStream(serverRequest, maxBytes)))
case RawBodyType.InputStreamRangeBody =>
// Possibly can be optimized to avoid loading all data eagerly into memory
readAllBytes(serverRequest, maxBytes).map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs))))
monad.unit(RawValue(InputStreamRange(() => readAsStream(serverRequest, maxBytes))))
case RawBodyType.FileBody =>
for {
file <- createFile(serverRequest)
_ <- writeToFile(serverRequest, file, maxBytes)
} yield RawValue(FileRange(file), Seq(FileRange(file)))
case _: RawBodyType.MultipartBody => monad.error(new UnsupportedOperationException)
case _: RawBodyType.MultipartBody =>
monad.error(new UnsupportedOperationException)
}

private def readAllBytes(serverRequest: ServerRequest, maxBytes: Option[Long]): F[Array[Byte]] =
Expand All @@ -84,7 +81,19 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody
case req: StreamedHttpRequest =>
val contentLength = Option(req.headers().get(HeaderNames.ContentLength)).map(_.toLong)
publisherToBytes(req, contentLength, maxBytes)
case other =>
case other =>
monad.error(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass.getName}"))
}

private def readAsStream(serverRequest: ServerRequest, maxBytes: Option[Long]): InputStream = {
serverRequest.underlying match {
case r: FullHttpRequest if r.content() == Unpooled.EMPTY_BUFFER => // Empty request
InputStream.nullInputStream()
case req: StreamedHttpRequest =>
val contentLength = Option(req.headers().get(HeaderNames.ContentLength)).map(_.toLong)
SubscriberInputStream.processAsStream(req, contentLength, maxBytes)
case other =>
throw new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass.getName}")
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package sttp.tapir.server.netty.internal.reactivestreams

import io.netty.buffer.ByteBuf
import io.netty.handler.codec.http.HttpContent
import org.reactivestreams.{Publisher, Subscriber, Subscription}
import sttp.capabilities.StreamMaxLengthExceededException

import java.io.{IOException, InputStream}
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.locks.ReentrantLock
import scala.annotation.tailrec
import scala.concurrent.Promise

/** A blocking input stream that reads from a reactive streams publisher of [[HttpContent]].
* @param maxBufferedChunks
* maximum number of unread chunks that can be buffered before blocking the publisher
*/
private[netty] class SubscriberInputStream(maxBufferedChunks: Int = 1) extends InputStream with Subscriber[HttpContent] {

require(maxBufferedChunks > 0)

import SubscriberInputStream._

// volatile because used in both InputStream & Subscriber methods
@volatile private[this] var closed = false

// Calls on the subscription must be synchronized in order to satisfy the Reactive Streams spec
// (https://github.com/reactive-streams/reactive-streams-jvm?tab=readme-ov-file#2-subscriber-code - rule 7)
// because they are called both from InputStream & Subscriber methods.
private[this] var subscription: Subscription = _
private[this] val lock = new ReentrantLock

private def locked[T](code: => T): T =
try {
lock.lock()
code
} finally {
lock.unlock()
}

private[this] var currentItem: Item = _
// the queue serves as a buffer to allow for possible parallelism between the subscriber and the publisher
private val queue = new LinkedBlockingQueue[Item](maxBufferedChunks + 1) // +1 to have a spot for End/Error

private def readItem(blocking: Boolean): Item = {
if (currentItem eq null) {
currentItem = if (blocking) queue.take() else queue.poll()
currentItem match {
case _: Chunk => locked(subscription.request(1))
case _ =>
}
}
currentItem
}

override def available(): Int =
readItem(blocking = false) match {
case Chunk(data) => data.readableBytes()
case _ => 0
}

override def read(): Int = {
val buffer = new Array[Byte](1)
if (read(buffer) == -1) -1 else buffer(0)
}

override def read(b: Array[Byte], off: Int, len: Int): Int =
if (closed) throw new IOException("Stream closed")
else if (len == 0) 0
else
readItem(blocking = true) match {
case Chunk(data) =>
val toRead = Math.min(len, data.readableBytes())
data.readBytes(b, off, toRead)
if (data.readableBytes() == 0) {
data.release()
currentItem = null
}
toRead
case Error(cause) => throw cause
case End => -1
}

override def close(): Unit = if (!closed) {
locked(subscription.cancel())
closed = true
clearQueue()
}

@tailrec private def clearQueue(): Unit =
queue.poll() match {
case Chunk(data) =>
data.release()
clearQueue()
case _ =>
}

override def onSubscribe(s: Subscription): Unit = locked {
if (s eq null) {
throw new NullPointerException("Subscription must not be null")
}
subscription = s
subscription.request(maxBufferedChunks)
}

override def onNext(chunk: HttpContent): Unit = {
if (!queue.offer(Chunk(chunk.content()))) {
// This should be impossible according to the Reactive Streams spec,
// if it happens then it's a bug in the implementation of the subscriber of publisher
chunk.release()
locked(subscription.cancel())
} else if (closed) {
clearQueue()
}
}

override def onError(t: Throwable): Unit =
if (!closed) {
queue.offer(Error(t))
}

override def onComplete(): Unit =
if (!closed) {
queue.offer(End)
}
}
private[netty] object SubscriberInputStream {
private sealed abstract class Item
private case class Chunk(data: ByteBuf) extends Item
private case class Error(cause: Throwable) extends Item
private object End extends Item

def processAsStream(
publisher: Publisher[HttpContent],
contentLength: Option[Long],
maxBytes: Option[Long],
maxBufferedChunks: Int = 1
): InputStream = maxBytes match {
case Some(max) if contentLength.exists(_ > max) =>
throw StreamMaxLengthExceededException(max)
case _ =>
val subscriber = new SubscriberInputStream(maxBufferedChunks)
val maybeLimitedSubscriber = maxBytes.map(new LimitedLengthSubscriber(_, subscriber)).getOrElse(subscriber)
publisher.subscribe(maybeLimitedSubscriber)
subscriber
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package sttp.tapir.server.netty.internal.reactivestreams

import cats.effect.IO
import cats.effect.kernel.Resource
import cats.effect.unsafe.IORuntime
import fs2.Stream
import fs2.interop.reactivestreams._
import io.netty.buffer.Unpooled
import io.netty.handler.codec.http.DefaultHttpContent
import org.scalactic.source.Position
import org.scalatest.freespec.AnyFreeSpec
import org.scalatest.matchers.should.Matchers

import java.io.InputStream
import scala.annotation.tailrec
import scala.util.Random

class SubscriberInputStreamTest extends AnyFreeSpec with Matchers {
private implicit def runtime: IORuntime = IORuntime.global

private def readAll(is: InputStream, batchSize: Int): Array[Byte] = {
val buf = Unpooled.buffer(batchSize)
@tailrec def writeLoop(): Array[Byte] = buf.writeBytes(is, batchSize) match {
case -1 => buf.array().take(buf.readableBytes())
case _ => writeLoop()
}
writeLoop()
}

private def testReading(
totalSize: Int,
publishedChunkLimit: Int,
readBatchSize: Int,
maxBufferedChunks: Int = 1
)(implicit pos: Position): Unit = {
val bytes = new Array[Byte](totalSize)
Random.nextBytes(bytes)

val publisherResource = Stream
.emits(bytes)
.chunkLimit(publishedChunkLimit)
.map(ch => new DefaultHttpContent(Unpooled.wrappedBuffer(ch.toByteBuffer)))
.covary[IO]
.toUnicastPublisher

val io = publisherResource.use { publisher =>
IO {
val subscriberInputStream = new SubscriberInputStream(maxBufferedChunks)
publisher.subscribe(subscriberInputStream)
readAll(subscriberInputStream, readBatchSize) shouldBe bytes
}
}

io.unsafeRunSync()
}

"empty stream" in {
testReading(totalSize = 0, publishedChunkLimit = 1024, readBatchSize = 1024)
}

"single chunk stream, one read batch" in {
testReading(totalSize = 10, publishedChunkLimit = 1024, readBatchSize = 1024)
}

"single chunk stream, multiple read batches" in {
testReading(totalSize = 100, publishedChunkLimit = 1024, readBatchSize = 10)
testReading(totalSize = 100, publishedChunkLimit = 1024, readBatchSize = 11)
}

"multiple chunks, read batch larger than chunk" in {
testReading(totalSize = 100, publishedChunkLimit = 10, readBatchSize = 1024)
testReading(totalSize = 105, publishedChunkLimit = 10, readBatchSize = 1024)
}

"multiple chunks, read batch smaller than chunk" in {
testReading(totalSize = 105, publishedChunkLimit = 20, readBatchSize = 17)
testReading(totalSize = 105, publishedChunkLimit = 20, readBatchSize = 7)
testReading(totalSize = 105, publishedChunkLimit = 20, readBatchSize = 5)
}

"multiple chunks, large publishing buffer" in {
testReading(totalSize = 105, publishedChunkLimit = 10, readBatchSize = 1024, maxBufferedChunks = 5)
}

"closing the stream should cancel the subscription" in {
var canceled = false

val publisherResource =
Stream
.emits(Array.fill(1024)(0.toByte))
.chunkLimit(100)
.map(ch => new DefaultHttpContent(Unpooled.wrappedBuffer(ch.toByteBuffer)))
.covary[IO]
.onFinalizeCase {
case Resource.ExitCase.Canceled => IO { canceled = true }
case _ => IO.unit
}
.toUnicastPublisher

publisherResource
.use(publisher =>
IO {
val stream = new SubscriberInputStream()
publisher.subscribe(stream)

stream.readNBytes(120).length shouldBe 120
stream.close()
}
)
.unsafeRunSync()

canceled shouldBe true
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE](

def inputStreamTests(): List[Test] = List(
testServer(in_input_stream_out_input_stream)((is: InputStream) =>
pureResult((new ByteArrayInputStream(inputStreamToByteArray(is)): InputStream).asRight[Unit])
blockingResult((new ByteArrayInputStream(inputStreamToByteArray(is)): InputStream).asRight[Unit])
) { (backend, baseUri) => basicRequest.post(uri"$baseUri/api/echo").body("mango").send(backend).map(_.body shouldBe Right("mango")) },
testServer(in_string_out_stream_with_header)(_ => pureResult(Right((new ByteArrayInputStream(Array.fill[Byte](128)(0)), Some(128))))) {
(backend, baseUri) =>
Expand Down Expand Up @@ -795,7 +795,7 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE](
"checks payload limit and returns OK on content length below or equal max (request)"
)(i => {
// Forcing server logic to drain the InputStream
suspendResult(i.readAllBytes()).map(_ => new ByteArrayInputStream(Array.empty[Byte]).asRight[Unit])
blockingResult(i.readAllBytes()).map(_ => new ByteArrayInputStream(Array.empty[Byte]).asRight[Unit])
}) { (backend, baseUri) =>
val tooLargeBody: String = List.fill(maxLength)('x').mkString
basicRequest.post(uri"$baseUri/api/echo").body(tooLargeBody).response(asByteArray).send(backend).map { r =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class ServerMetricsTest[F[_], OPTIONS, ROUTE](createServerTest: CreateServerTest
testServer(
in_input_stream_out_input_stream.name("metrics"),
interceptors = (ci: CustomiseInterceptors[F, OPTIONS]) => ci.metricsInterceptor(metrics)
)(is => (new ByteArrayInputStream(inputStreamToByteArray(is)): InputStream).asRight[Unit].unit) { (backend, baseUri) =>
)(is => blockingResult((new ByteArrayInputStream(inputStreamToByteArray(is)): InputStream).asRight[Unit])) { (backend, baseUri) =>
basicRequest
.post(uri"$baseUri/api/echo")
.body("okoń")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import sttp.monad.MonadError
package object tests {
val backendResource: Resource[IO, SttpBackend[IO, Fs2Streams[IO] with WebSockets]] = HttpClientFs2Backend.resource()
val basicStringRequest: PartialRequest[String, Any] = basicRequest.response(asStringAlways)
def pureResult[F[_]: MonadError, T](t: T): F[T] = implicitly[MonadError[F]].unit(t)
def suspendResult[F[_]: MonadError, T](t: => T): F[T] = implicitly[MonadError[F]].eval(t)
def pureResult[F[_]: MonadError, T](t: T): F[T] = MonadError[F].unit(t)
def suspendResult[F[_]: MonadError, T](t: => T): F[T] = MonadError[F].eval(t)
def blockingResult[F[_]: MonadError, T](t: => T): F[T] = MonadError[F].blocking(t)
}

0 comments on commit da9a76f

Please sign in to comment.