Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add zio-http multipart body support #3690

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -168,19 +168,15 @@ trait ZioHttpInterpreter[R] {
}
val statusCode = resp.code.code

ZIO.succeed(
Response(
status = Status.fromInt(statusCode),
headers = ZioHttpHeaders(allHeaders),
body = body
.map {
case ZioStreamHttpResponseBody(stream, Some(contentLength)) => Body.fromStream(stream, contentLength)
case ZioStreamHttpResponseBody(stream, None) => Body.fromStreamChunked(stream)
case ZioRawHttpResponseBody(chunk, _) => Body.fromChunk(chunk)
}
.getOrElse(Body.empty)
)
)
body
.map {
case ZioStreamHttpResponseBody(stream, Some(contentLength)) => ZIO.succeed(Body.fromStream(stream, contentLength))
case ZioStreamHttpResponseBody(stream, None) => ZIO.succeed(Body.fromStreamChunked(stream))
case ZioMultipartHttpResponseBody(formFields) => Body.fromMultipartFormUUID(Form(Chunk.fromIterable(formFields)))
case ZioRawHttpResponseBody(chunk, _) => ZIO.succeed(Body.fromChunk(chunk))
}
.getOrElse(ZIO.succeed(Body.empty))
.map(zioBody => Response(status = Status.fromInt(statusCode), headers = ZioHttpHeaders(allHeaders), body = zioBody))
}

private def sttpToZioHttpHeader(hl: (String, Seq[SttpHeader])): Seq[ZioHttpHeader] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,96 @@ package sttp.tapir.server.ziohttp

import sttp.capabilities
import sttp.capabilities.zio.ZioStreams
import sttp.model.Part
import sttp.model.Part.FileNameDispositionParam
import sttp.tapir.FileRange
import sttp.tapir.InputStreamRange
import sttp.tapir.RawBodyType
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.interpreter.{RawValue, RequestBody}
import sttp.tapir.{FileRange, InputStreamRange, RawBodyType}
import zio.http.Request
import zio.stream.{Stream, ZSink, ZStream}
import sttp.tapir.server.interpreter.RawValue
import sttp.tapir.server.interpreter.RequestBody
import zio.{RIO, Task, ZIO}
import zio.http.{FormField, Request, StreamingForm}
import zio.http.FormField.StreamingBinary
import zio.stream.ZSink
import zio.stream.ZStream

import java.io.ByteArrayInputStream
import java.nio.ByteBuffer

class ZioHttpRequestBody[R](serverOptions: ZioHttpServerOptions[R]) extends RequestBody[RIO[R, *], ZioStreams] {
override val streams: capabilities.Streams[ZioStreams] = ZioStreams

override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): Task[RawValue[RAW]] = {
override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): Task[RawValue[RAW]] =
toRaw(serverRequest, zStream(serverRequest), bodyType, maxBytes)

def asByteArray: Task[Array[Byte]] =
(toStream(serverRequest, maxBytes).asInstanceOf[ZStream[Any, Throwable, Byte]]).runCollect.map(_.toArray)
private def toRaw[RAW](
serverRequest: ServerRequest,
stream: ZStream[Any, Throwable, Byte],
bodyType: RawBodyType[RAW],
maxBytes: Option[Long]
): Task[RawValue[RAW]] = {
val limitedStream = limitedZStream(stream, maxBytes)
val asByteArray = limitedStream.runCollect.map(_.toArray)

bodyType match {
case RawBodyType.StringBody(defaultCharset) => asByteArray.map(new String(_, defaultCharset)).map(RawValue(_))
case RawBodyType.ByteArrayBody => asByteArray.map(RawValue(_))
case RawBodyType.ByteBufferBody => asByteArray.map(bytes => ByteBuffer.wrap(bytes)).map(RawValue(_))
case RawBodyType.InputStreamBody => asByteArray.map(new ByteArrayInputStream(_)).map(RawValue(_))
case RawBodyType.InputStreamRangeBody =>
asByteArray.map(bytes => new InputStreamRange(() => new ByteArrayInputStream(bytes))).map(RawValue(_))
asByteArray.map(bytes => InputStreamRange(() => new ByteArrayInputStream(bytes))).map(RawValue(_))
case RawBodyType.FileBody =>
for {
file <- serverOptions.createFile(serverRequest)
_ <- (toStream(serverRequest, maxBytes).asInstanceOf[ZStream[Any, Throwable, Byte]]).run(ZSink.fromFile(file)).map(_ => ())
_ <- limitedStream.run(ZSink.fromFile(file)).unit
} yield RawValue(FileRange(file), Seq(FileRange(file)))
case RawBodyType.MultipartBody(_, _) => ZIO.fail(new UnsupportedOperationException("Multipart is not supported"))
case m: RawBodyType.MultipartBody => handleMultipartBody(serverRequest, m, limitedStream)
}
}

override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = {
val inputStream = stream(serverRequest)
maxBytes.map(ZioStreams.limitBytes(inputStream, _)).getOrElse(inputStream).asInstanceOf[streams.BinaryStream]
private def handleMultipartBody[RAW](
serverRequest: ServerRequest,
bodyType: RawBodyType.MultipartBody,
limitedStream: ZStream[Any, Throwable, Byte]
): Task[RawValue[RAW]] = {
zRequest(serverRequest).body.contentType.flatMap(_.boundary) match {
case Some(boundary) =>
StreamingForm(limitedStream, boundary).fields
.flatMap(field => ZStream.fromIterable(bodyType.partType(field.name).map((field, _))))
.mapZIO { case (field, bodyType) => toRawPart(serverRequest, field, bodyType) }
.runCollect
.map(RawValue.fromParts(_).asInstanceOf[RawValue[RAW]])
case None =>
ZIO.fail(
new IllegalStateException("Cannot decode body as streaming multipart/form-data without a known boundary")
)
}
}

private def toRawPart[A](serverRequest: ServerRequest, field: FormField, bodyType: RawBodyType[A]): Task[Part[A]] = {
val fieldsStream = field match {
case StreamingBinary(_, _, _, _, s) => s
case _ => ZStream.fromIterableZIO(field.asChunk)
}
toRaw(serverRequest, fieldsStream, bodyType, None)
.map(raw =>
Part(
field.name,
raw.value,
otherDispositionParams = field.filename.map(name => Map(FileNameDispositionParam -> name)).getOrElse(Map.empty)
).contentType(field.contentType.fullType)
)
}

private def stream(serverRequest: ServerRequest): Stream[Throwable, Byte] =
zioHttpRequest(serverRequest).body.asStream
override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream =
limitedZStream(zStream(serverRequest), maxBytes).asInstanceOf[streams.BinaryStream]

private def zRequest(serverRequest: ServerRequest) = serverRequest.underlying.asInstanceOf[Request]

private def limitedZStream(stream: ZStream[Any, Throwable, Byte], maxBytes: Option[Long]) = {
maxBytes.map(ZioStreams.limitBytes(stream, _)).getOrElse(stream)
}

private def zioHttpRequest(serverRequest: ServerRequest) = serverRequest.underlying.asInstanceOf[Request]
private def zStream(serverRequest: ServerRequest) = zRequest(serverRequest).body.asStream
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sttp.tapir.server.ziohttp

import zio.stream.ZStream
import zio.Chunk
import zio.http.FormField

sealed trait ZioHttpResponseBody {
def contentLength: Option[Long]
Expand All @@ -10,3 +11,7 @@ sealed trait ZioHttpResponseBody {
case class ZioStreamHttpResponseBody(stream: ZStream[Any, Throwable, Byte], contentLength: Option[Long]) extends ZioHttpResponseBody

case class ZioRawHttpResponseBody(bytes: Chunk[Byte], contentLength: Option[Long]) extends ZioHttpResponseBody

case class ZioMultipartHttpResponseBody(formFields: List[FormField]) extends ZioHttpResponseBody {
override def contentLength: Option[Long] = None
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ package sttp.tapir.server.ziohttp

import sttp.capabilities.zio.ZioStreams
import sttp.model.HasHeaders
import sttp.model.Part
import sttp.tapir.server.interpreter.ToResponseBody
import sttp.tapir.{CodecFormat, RawBodyType, WebSocketBodyOutput}
import sttp.tapir.{CodecFormat, RawBodyType, RawPart, WebSocketBodyOutput}
import zio.Chunk
import zio.http.FormField
import zio.http.MediaType
import zio.stream.ZStream

import java.io.InputStream
import java.nio.ByteBuffer
import java.nio.charset.Charset

Expand Down Expand Up @@ -74,6 +78,59 @@ class ZioHttpToResponseBody extends ToResponseBody[ZioResponseBody, ZioStreams]
}
}
.getOrElse(ZioStreamHttpResponseBody(ZStream.fromPath(tapirFile.file.toPath), Some(tapirFile.file.length)))
case RawBodyType.MultipartBody(_, _) => throw new UnsupportedOperationException("Multipart is not supported")
case m @ RawBodyType.MultipartBody(_, _) =>
val formFields = (r: Seq[RawPart]).flatMap { part =>
m.partType(part.name).map { partType =>
toFormField(partType.asInstanceOf[RawBodyType[Any]], part)
}
}.toList
ZioMultipartHttpResponseBody(formFields)
}

private def toFormField[R](bodyType: RawBodyType[R], part: Part[R]): FormField = {
val mediaType: Option[MediaType] = part.contentType.flatMap(MediaType.forContentType)
bodyType match {
case RawBodyType.StringBody(_) =>
FormField.Text(part.name, part.body, mediaType.getOrElse(MediaType.text.plain), part.fileName)
case RawBodyType.ByteArrayBody =>
FormField.Binary(
part.name,
Chunk.fromArray(part.body),
mediaType.getOrElse(MediaType.application.`octet-stream`),
filename = part.fileName
)
case RawBodyType.ByteBufferBody =>
val array: Array[Byte] = new Array[Byte](part.body.remaining)
part.body.get(array)
FormField.Binary(
part.name,
Chunk.fromArray(array),
mediaType.getOrElse(MediaType.application.`octet-stream`),
filename = part.fileName
)
case RawBodyType.FileBody =>
FormField.streamingBinaryField(
part.name,
ZStream.fromFile(part.body.file).orDie,
mediaType.getOrElse(MediaType.application.`octet-stream`),
filename = part.fileName
)
case RawBodyType.InputStreamBody =>
FormField.streamingBinaryField(
part.name,
ZStream.fromInputStream(part.body).orDie,
mediaType.getOrElse(MediaType.application.`octet-stream`),
filename = part.fileName
)
case RawBodyType.InputStreamRangeBody =>
FormField.streamingBinaryField(
part.name,
ZStream.fromInputStream(part.body.inputStream()).orDie,
mediaType.getOrElse(MediaType.application.`octet-stream`),
filename = part.fileName
)
case _: RawBodyType.MultipartBody =>
throw new UnsupportedOperationException("Nested multipart messages are not supported.")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,10 @@ class ZioHttpServerTest extends TestSuite {
interpreter,
backend,
basic = false,
staticContent = true,
multipart = false,
file = true,
options = false
).tests() ++
new ServerMultipartTests(createServerTest, partOtherHeaderSupport = false).tests() ++
new ServerStreamingTests(createServerTest).tests(ZioStreams)(drainZStream) ++
new ZioHttpCompositionTest(createServerTest).tests() ++
new ServerWebSocketTests(
Expand Down
Loading