-
Notifications
You must be signed in to change notification settings - Fork 399
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Lazy InputStream over Netty HttpContent Publisher (#3637)
- Loading branch information
Showing
10 changed files
with
297 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
147 changes: 147 additions & 0 deletions
147
...c/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStream.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
package sttp.tapir.server.netty.internal.reactivestreams | ||
|
||
import io.netty.buffer.ByteBuf | ||
import io.netty.handler.codec.http.HttpContent | ||
import org.reactivestreams.{Publisher, Subscriber, Subscription} | ||
import sttp.capabilities.StreamMaxLengthExceededException | ||
|
||
import java.io.{IOException, InputStream} | ||
import java.util.concurrent.LinkedBlockingQueue | ||
import java.util.concurrent.locks.ReentrantLock | ||
import scala.annotation.tailrec | ||
import scala.concurrent.Promise | ||
|
||
/** A blocking input stream that reads from a reactive streams publisher of [[HttpContent]]. | ||
* @param maxBufferedChunks | ||
* maximum number of unread chunks that can be buffered before blocking the publisher | ||
*/ | ||
private[netty] class SubscriberInputStream(maxBufferedChunks: Int = 1) extends InputStream with Subscriber[HttpContent] { | ||
|
||
require(maxBufferedChunks > 0) | ||
|
||
import SubscriberInputStream._ | ||
|
||
// volatile because used in both InputStream & Subscriber methods | ||
@volatile private[this] var closed = false | ||
|
||
// Calls on the subscription must be synchronized in order to satisfy the Reactive Streams spec | ||
// (https://github.com/reactive-streams/reactive-streams-jvm?tab=readme-ov-file#2-subscriber-code - rule 7) | ||
// because they are called both from InputStream & Subscriber methods. | ||
private[this] var subscription: Subscription = _ | ||
private[this] val lock = new ReentrantLock | ||
|
||
private def locked[T](code: => T): T = | ||
try { | ||
lock.lock() | ||
code | ||
} finally { | ||
lock.unlock() | ||
} | ||
|
||
private[this] var currentItem: Item = _ | ||
// the queue serves as a buffer to allow for possible parallelism between the subscriber and the publisher | ||
private val queue = new LinkedBlockingQueue[Item](maxBufferedChunks + 1) // +1 to have a spot for End/Error | ||
|
||
private def readItem(blocking: Boolean): Item = { | ||
if (currentItem eq null) { | ||
currentItem = if (blocking) queue.take() else queue.poll() | ||
currentItem match { | ||
case _: Chunk => locked(subscription.request(1)) | ||
case _ => | ||
} | ||
} | ||
currentItem | ||
} | ||
|
||
override def available(): Int = | ||
readItem(blocking = false) match { | ||
case Chunk(data) => data.readableBytes() | ||
case _ => 0 | ||
} | ||
|
||
override def read(): Int = { | ||
val buffer = new Array[Byte](1) | ||
if (read(buffer) == -1) -1 else buffer(0) | ||
} | ||
|
||
override def read(b: Array[Byte], off: Int, len: Int): Int = | ||
if (closed) throw new IOException("Stream closed") | ||
else if (len == 0) 0 | ||
else | ||
readItem(blocking = true) match { | ||
case Chunk(data) => | ||
val toRead = Math.min(len, data.readableBytes()) | ||
data.readBytes(b, off, toRead) | ||
if (data.readableBytes() == 0) { | ||
data.release() | ||
currentItem = null | ||
} | ||
toRead | ||
case Error(cause) => throw cause | ||
case End => -1 | ||
} | ||
|
||
override def close(): Unit = if (!closed) { | ||
locked(subscription.cancel()) | ||
closed = true | ||
clearQueue() | ||
} | ||
|
||
@tailrec private def clearQueue(): Unit = | ||
queue.poll() match { | ||
case Chunk(data) => | ||
data.release() | ||
clearQueue() | ||
case _ => | ||
} | ||
|
||
override def onSubscribe(s: Subscription): Unit = locked { | ||
if (s eq null) { | ||
throw new NullPointerException("Subscription must not be null") | ||
} | ||
subscription = s | ||
subscription.request(maxBufferedChunks) | ||
} | ||
|
||
override def onNext(chunk: HttpContent): Unit = { | ||
if (!queue.offer(Chunk(chunk.content()))) { | ||
// This should be impossible according to the Reactive Streams spec, | ||
// if it happens then it's a bug in the implementation of the subscriber of publisher | ||
chunk.release() | ||
locked(subscription.cancel()) | ||
} else if (closed) { | ||
clearQueue() | ||
} | ||
} | ||
|
||
override def onError(t: Throwable): Unit = | ||
if (!closed) { | ||
queue.offer(Error(t)) | ||
} | ||
|
||
override def onComplete(): Unit = | ||
if (!closed) { | ||
queue.offer(End) | ||
} | ||
} | ||
private[netty] object SubscriberInputStream { | ||
private sealed abstract class Item | ||
private case class Chunk(data: ByteBuf) extends Item | ||
private case class Error(cause: Throwable) extends Item | ||
private object End extends Item | ||
|
||
def processAsStream( | ||
publisher: Publisher[HttpContent], | ||
contentLength: Option[Long], | ||
maxBytes: Option[Long], | ||
maxBufferedChunks: Int = 1 | ||
): InputStream = maxBytes match { | ||
case Some(max) if contentLength.exists(_ > max) => | ||
throw StreamMaxLengthExceededException(max) | ||
case _ => | ||
val subscriber = new SubscriberInputStream(maxBufferedChunks) | ||
val maybeLimitedSubscriber = maxBytes.map(new LimitedLengthSubscriber(_, subscriber)).getOrElse(subscriber) | ||
publisher.subscribe(maybeLimitedSubscriber) | ||
subscriber | ||
} | ||
} |
114 changes: 114 additions & 0 deletions
114
...st/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStreamTest.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
package sttp.tapir.server.netty.internal.reactivestreams | ||
|
||
import cats.effect.IO | ||
import cats.effect.kernel.Resource | ||
import cats.effect.unsafe.IORuntime | ||
import fs2.Stream | ||
import fs2.interop.reactivestreams._ | ||
import io.netty.buffer.Unpooled | ||
import io.netty.handler.codec.http.DefaultHttpContent | ||
import org.scalactic.source.Position | ||
import org.scalatest.freespec.AnyFreeSpec | ||
import org.scalatest.matchers.should.Matchers | ||
|
||
import java.io.InputStream | ||
import scala.annotation.tailrec | ||
import scala.util.Random | ||
|
||
class SubscriberInputStreamTest extends AnyFreeSpec with Matchers { | ||
private implicit def runtime: IORuntime = IORuntime.global | ||
|
||
private def readAll(is: InputStream, batchSize: Int): Array[Byte] = { | ||
val buf = Unpooled.buffer(batchSize) | ||
@tailrec def writeLoop(): Array[Byte] = buf.writeBytes(is, batchSize) match { | ||
case -1 => buf.array().take(buf.readableBytes()) | ||
case _ => writeLoop() | ||
} | ||
writeLoop() | ||
} | ||
|
||
private def testReading( | ||
totalSize: Int, | ||
publishedChunkLimit: Int, | ||
readBatchSize: Int, | ||
maxBufferedChunks: Int = 1 | ||
)(implicit pos: Position): Unit = { | ||
val bytes = new Array[Byte](totalSize) | ||
Random.nextBytes(bytes) | ||
|
||
val publisherResource = Stream | ||
.emits(bytes) | ||
.chunkLimit(publishedChunkLimit) | ||
.map(ch => new DefaultHttpContent(Unpooled.wrappedBuffer(ch.toByteBuffer))) | ||
.covary[IO] | ||
.toUnicastPublisher | ||
|
||
val io = publisherResource.use { publisher => | ||
IO { | ||
val subscriberInputStream = new SubscriberInputStream(maxBufferedChunks) | ||
publisher.subscribe(subscriberInputStream) | ||
readAll(subscriberInputStream, readBatchSize) shouldBe bytes | ||
} | ||
} | ||
|
||
io.unsafeRunSync() | ||
} | ||
|
||
"empty stream" in { | ||
testReading(totalSize = 0, publishedChunkLimit = 1024, readBatchSize = 1024) | ||
} | ||
|
||
"single chunk stream, one read batch" in { | ||
testReading(totalSize = 10, publishedChunkLimit = 1024, readBatchSize = 1024) | ||
} | ||
|
||
"single chunk stream, multiple read batches" in { | ||
testReading(totalSize = 100, publishedChunkLimit = 1024, readBatchSize = 10) | ||
testReading(totalSize = 100, publishedChunkLimit = 1024, readBatchSize = 11) | ||
} | ||
|
||
"multiple chunks, read batch larger than chunk" in { | ||
testReading(totalSize = 100, publishedChunkLimit = 10, readBatchSize = 1024) | ||
testReading(totalSize = 105, publishedChunkLimit = 10, readBatchSize = 1024) | ||
} | ||
|
||
"multiple chunks, read batch smaller than chunk" in { | ||
testReading(totalSize = 105, publishedChunkLimit = 20, readBatchSize = 17) | ||
testReading(totalSize = 105, publishedChunkLimit = 20, readBatchSize = 7) | ||
testReading(totalSize = 105, publishedChunkLimit = 20, readBatchSize = 5) | ||
} | ||
|
||
"multiple chunks, large publishing buffer" in { | ||
testReading(totalSize = 105, publishedChunkLimit = 10, readBatchSize = 1024, maxBufferedChunks = 5) | ||
} | ||
|
||
"closing the stream should cancel the subscription" in { | ||
var canceled = false | ||
|
||
val publisherResource = | ||
Stream | ||
.emits(Array.fill(1024)(0.toByte)) | ||
.chunkLimit(100) | ||
.map(ch => new DefaultHttpContent(Unpooled.wrappedBuffer(ch.toByteBuffer))) | ||
.covary[IO] | ||
.onFinalizeCase { | ||
case Resource.ExitCase.Canceled => IO { canceled = true } | ||
case _ => IO.unit | ||
} | ||
.toUnicastPublisher | ||
|
||
publisherResource | ||
.use(publisher => | ||
IO { | ||
val stream = new SubscriberInputStream() | ||
publisher.subscribe(stream) | ||
|
||
stream.readNBytes(120).length shouldBe 120 | ||
stream.close() | ||
} | ||
) | ||
.unsafeRunSync() | ||
|
||
canceled shouldBe true | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.