Skip to content

Commit

Permalink
Separate EndpointOuput trait
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw committed Mar 24, 2019
1 parent 06e66a6 commit 563d5a1
Show file tree
Hide file tree
Showing 14 changed files with 210 additions and 125 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ class EndpointToSttpClient(clientOptions: SttpClientOptions) {
.response(ignore)
.mapResponse(Right(_): Either[Any, Any])

val (uri, req) = setInputParams(e.input.asVectorOfSingle, params, paramsAsArgs, 0, baseUri, baseReq)
val (uri, req) = setInputParams(e.input.asVectorOfSingleInputs, params, paramsAsArgs, 0, baseUri, baseReq)

var req2 = req.copy[Id, Either[Any, Any], Any](method = SttpMethod(e.input.method.getOrElse(Method.GET).m), uri = uri)

if (e.output.asVectorOfSingle.nonEmpty || e.errorOutput.asVectorOfSingle.nonEmpty) {
if (e.output.asVectorOfSingleOutputs.nonEmpty || e.errorOutput.asVectorOfSingleOutputs.nonEmpty) {
// by default, reading the body as specified by the output, and optionally adjusting to the error output
// if there's no body in the output, reading the body as specified by the error output
// otherwise, ignoring
Expand All @@ -47,7 +47,7 @@ class EndpointToSttpClient(clientOptions: SttpClientOptions) {

val responseAs = baseResponseAs2.mapWithMetadata {
(body, meta) =>
val outputs = if (meta.isSuccess) e.output.asVectorOfSingle else e.errorOutput.asVectorOfSingle
val outputs = if (meta.isSuccess) e.output.asVectorOfSingleOutputs else e.errorOutput.asVectorOfSingleOutputs

// the body type of the success output takes priority; that's why it might not match
val adjustedBody =
Expand All @@ -65,7 +65,7 @@ class EndpointToSttpClient(clientOptions: SttpClientOptions) {
})
}

private def getOutputParams(outputs: Vector[EndpointIO.Single[_]], body: Any, meta: ResponseMetadata): Any = {
private def getOutputParams(outputs: Vector[EndpointOutput.Single[_]], body: Any, meta: ResponseMetadata): Any = {
val values = outputs
.map {
case EndpointIO.Body(codec, _) =>
Expand All @@ -81,14 +81,17 @@ class EndpointToSttpClient(clientOptions: SttpClientOptions) {
case EndpointIO.Headers(_) =>
meta.headers

case EndpointIO.StatusCode() =>
case EndpointIO.Mapped(wrapped, f, _, _) =>
f.asInstanceOf[Any => Any].apply(getOutputParams(wrapped.asVectorOfSingleOutputs, body, meta))

case EndpointOutput.StatusCode() =>
meta.code

case EndpointIO.Mapped(wrapped, f, _, _) =>
f.asInstanceOf[Any => Any].apply(getOutputParams(wrapped.asVectorOfSingle, body, meta))
case EndpointOutput.StatusFrom(wrapped, _, _, _) =>
getOutputParams(wrapped.asVectorOfSingleOutputs, body, meta)

case EndpointIO.StatusFrom(wrapped, _, _, _) =>
getOutputParams(wrapped.asVectorOfSingle, body, meta)
case EndpointOutput.Mapped(wrapped, f, _, _) =>
f.asInstanceOf[Any => Any].apply(getOutputParams(wrapped.asVectorOfSingleOutputs, body, meta))
}

SeqToParams(values)
Expand All @@ -108,7 +111,7 @@ class EndpointToSttpClient(clientOptions: SttpClientOptions) {
wrappedParamsAsArgs: ParamsAsArgs[II],
tail: Vector[EndpointInput.Single[_]]): (Uri, PartialAnyRequest) = {
val (uri2, req2) = setInputParams(
wrapped.asVectorOfSingle,
wrapped.asVectorOfSingleInputs,
g(paramsAsArgs.paramAt(params, paramIndex).asInstanceOf[T]),
wrappedParamsAsArgs,
0,
Expand Down Expand Up @@ -206,14 +209,14 @@ class EndpointToSttpClient(clientOptions: SttpClientOptions) {
case MultipartValueType(_, _) => throw new IllegalArgumentException("Nested multipart bodies aren't supported")
}

private def bodyIsStream[I](in: EndpointInput[I]): Boolean = {
in match {
case _: EndpointIO.StreamBodyWrapper[_, _] => true
case EndpointIO.Multiple(inputs) => inputs.exists(i => bodyIsStream(i))
case EndpointInput.Multiple(inputs) => inputs.exists(i => bodyIsStream(i))
case EndpointIO.Mapped(wrapped, _, _, _) => bodyIsStream(wrapped)
case EndpointInput.Mapped(wrapped, _, _, _) => bodyIsStream(wrapped)
case _ => false
private def bodyIsStream[I](out: EndpointOutput[I]): Boolean = {
out match {
case _: EndpointIO.StreamBodyWrapper[_, _] => true
case EndpointIO.Multiple(inputs) => inputs.exists(i => bodyIsStream(i))
case EndpointOutput.Multiple(inputs) => inputs.exists(i => bodyIsStream(i))
case EndpointIO.Mapped(wrapped, _, _, _) => bodyIsStream(wrapped)
case EndpointOutput.Mapped(wrapped, _, _, _) => bodyIsStream(wrapped)
case _ => false
}
}

Expand Down
6 changes: 3 additions & 3 deletions core/src/main/scala/tapir/Endpoint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import tapir.typelevel.{FnComponents, ParamConcat, ParamsAsArgs}
* @tparam O Output parameter types.
* @tparam S The type of streams that are used by this endpoint's inputs/outputs. `Nothing`, if no streams are used.
*/
case class Endpoint[I, E, O, +S](input: EndpointInput[I], errorOutput: EndpointIO[E], output: EndpointIO[O], info: EndpointInfo) {
case class Endpoint[I, E, O, +S](input: EndpointInput[I], errorOutput: EndpointOutput[E], output: EndpointOutput[O], info: EndpointInfo) {

def get: Endpoint[I, E, O, S] = in(RequestMethod(Method.GET))
def post: Endpoint[I, E, O, S] = in(RequestMethod(Method.POST))
Expand All @@ -29,13 +29,13 @@ case class Endpoint[I, E, O, +S](input: EndpointInput[I], errorOutput: EndpointI
def in[J, IJ, S2 >: S](i: StreamingEndpointIO[J, S2])(implicit ts: ParamConcat.Aux[I, J, IJ]): Endpoint[IJ, E, O, S2] =
this.copy[IJ, E, O, S2](input = input.and(i.toEndpointIO))

def out[P, OP](i: EndpointIO[P])(implicit ts: ParamConcat.Aux[O, P, OP]): Endpoint[I, E, OP, S] =
def out[P, OP](i: EndpointOutput[P])(implicit ts: ParamConcat.Aux[O, P, OP]): Endpoint[I, E, OP, S] =
this.copy[I, E, OP, S](output = output.and(i))

def out[P, OP, S2 >: S](i: StreamingEndpointIO[P, S2])(implicit ts: ParamConcat.Aux[O, P, OP]): Endpoint[I, E, OP, S2] =
this.copy[I, E, OP, S2](output = output.and(i.toEndpointIO))

def errorOut[F, EF](i: EndpointIO[F])(implicit ts: ParamConcat.Aux[E, F, EF]): Endpoint[I, EF, O, S] =
def errorOut[F, EF](i: EndpointOutput[F])(implicit ts: ParamConcat.Aux[E, F, EF]): Endpoint[I, EF, O, S] =
this.copy[I, EF, O, S](errorOutput = errorOutput.and(i))

def mapIn[II](f: I => II)(g: II => I)(implicit paramsAsArgs: ParamsAsArgs[I]): Endpoint[II, E, O, S] =
Expand Down
149 changes: 106 additions & 43 deletions core/src/main/scala/tapir/EndpointIO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,42 +21,35 @@ sealed trait EndpointInput[I] {
map[CASE_CLASS](fc.tupled(c).apply)(ProductToParams(_, fc.arity).asInstanceOf[I])(paramsAsArgs)
}

private[tapir] def asVectorOfSingle: Vector[EndpointInput.Single[_]] = this match {
private[tapir] def asVectorOfSingleInputs: Vector[EndpointInput.Single[_]] = this match {
case s: EndpointInput.Single[_] => Vector(s)
case m: EndpointInput.Multiple[_] => m.inputs
case m: EndpointIO.Multiple[_] => m.ios
}

private[tapir] def traverse[T](handle: PartialFunction[EndpointInput[_], Vector[T]]): Vector[T] = this match {
private[tapir] def traverseInputs[T](handle: PartialFunction[EndpointInput[_], Vector[T]]): Vector[T] = this match {
case i: EndpointInput[_] if handle.isDefinedAt(i) => handle(i)
case EndpointInput.Multiple(inputs) => inputs.flatMap(_.traverse(handle))
case EndpointIO.Multiple(inputs) => inputs.flatMap(_.traverse(handle))
case EndpointInput.Mapped(wrapped, _, _, _) => wrapped.traverse(handle)
case EndpointIO.Mapped(wrapped, _, _, _) => wrapped.traverse(handle)
case a: EndpointInput.Auth[_] => a.input.traverse(handle)
case s: EndpointIO.StatusFrom[_] => s.io.traverse(handle)
case EndpointInput.Multiple(inputs) => inputs.flatMap(_.traverseInputs(handle))
case EndpointIO.Multiple(inputs) => inputs.flatMap(_.traverseInputs(handle))
case EndpointInput.Mapped(wrapped, _, _, _) => wrapped.traverseInputs(handle)
case EndpointIO.Mapped(wrapped, _, _, _) => wrapped.traverseInputs(handle)
case a: EndpointInput.Auth[_] => a.input.traverseInputs(handle)
case _ => Vector.empty
}

private[tapir] def asVectorOfBasic(includeAuth: Boolean = true): Vector[EndpointInput.Basic[_]] = traverse {
case b: EndpointInput.Basic[_] => Vector(b)
case a: EndpointInput.Auth[_] => if (includeAuth) a.input.asVectorOfBasic(includeAuth) else Vector.empty
case s: EndpointIO.StatusFrom[_] => s.io.asVectorOfBasic(includeAuth)
private[tapir] def asVectorOfBasicInputs(includeAuth: Boolean = true): Vector[EndpointInput.Basic[_]] = traverseInputs {
case b: EndpointInput.Basic[_] => Vector(b)
case a: EndpointInput.Auth[_] => if (includeAuth) a.input.asVectorOfBasicInputs(includeAuth) else Vector.empty
}

private[tapir] def auths: Vector[EndpointInput.Auth[_]] = traverse {
private[tapir] def auths: Vector[EndpointInput.Auth[_]] = traverseInputs {
case a: EndpointInput.Auth[_] => Vector(a)
}

private[tapir] def method: Option[Method] =
traverse {
traverseInputs {
case i: EndpointInput.RequestMethod => Vector(i.m)
}.headOption

private[tapir] def bodyType: Option[RawValueType[_]] =
traverse[RawValueType[_]] {
case b: EndpointIO.Body[_, _, _] => Vector(b.codec.meta.rawValueType)
}.headOption
}

object EndpointInput {
Expand Down Expand Up @@ -148,7 +141,92 @@ object EndpointInput {
}
}

sealed trait EndpointIO[I] extends EndpointInput[I] {
sealed trait EndpointOutput[I] {
def and[J, IJ](other: EndpointOutput[J])(implicit ts: ParamConcat.Aux[I, J, IJ]): EndpointOutput[IJ]

def show: String

def map[II](f: I => II)(g: II => I)(implicit paramsAsArgs: ParamsAsArgs[I]): EndpointOutput[II] =
EndpointOutput.Mapped(this, f, g, paramsAsArgs)

def mapTo[COMPANION, CASE_CLASS <: Product](c: COMPANION)(implicit fc: FnComponents[COMPANION, I, CASE_CLASS],
paramsAsArgs: ParamsAsArgs[I]): EndpointOutput[CASE_CLASS] = {
map[CASE_CLASS](fc.tupled(c).apply)(ProductToParams(_, fc.arity).asInstanceOf[I])(paramsAsArgs)
}

private[tapir] def asVectorOfSingleOutputs: Vector[EndpointOutput.Single[_]] = this match {
case s: EndpointOutput.Single[_] => Vector(s)
case m: EndpointOutput.Multiple[_] => m.outputs
case m: EndpointIO.Multiple[_] => m.ios
}

private[tapir] def traverseOutputs[T](handle: PartialFunction[EndpointOutput[_], Vector[T]]): Vector[T] = this match {
case o: EndpointOutput[_] if handle.isDefinedAt(o) => handle(o)
case EndpointOutput.Multiple(outputs) => outputs.flatMap(_.traverseOutputs(handle))
case EndpointIO.Multiple(outputs) => outputs.flatMap(_.traverseOutputs(handle))
case EndpointOutput.Mapped(wrapped, _, _, _) => wrapped.traverseOutputs(handle)
case EndpointIO.Mapped(wrapped, _, _, _) => wrapped.traverseOutputs(handle)
case s: EndpointOutput.StatusFrom[_] => s.output.traverseOutputs(handle)
case _ => Vector.empty
}

private[tapir] def asVectorOfBasicOutputs: Vector[EndpointOutput.Basic[_]] = traverseOutputs {
case b: EndpointOutput.Basic[_] => Vector(b)
}

private[tapir] def bodyType: Option[RawValueType[_]] =
traverseOutputs[RawValueType[_]] {
case b: EndpointIO.Body[_, _, _] => Vector(b.codec.meta.rawValueType)
}.headOption
}

object EndpointOutput {
sealed trait Single[I] extends EndpointOutput[I] {
def and[J, IJ](other: EndpointOutput[J])(implicit ts: ParamConcat.Aux[I, J, IJ]): EndpointOutput[IJ] =
other match {
case s: Single[_] => Multiple(Vector(this, s))
case Multiple(outputs) => Multiple(this +: outputs)
case EndpointIO.Multiple(ios) => Multiple(this +: ios)
}
}

sealed trait Basic[I] extends Single[I]

//

case class StatusCode() extends Basic[tapir.StatusCode] {
override def show: String = "{status code}"
}

//

case class StatusFrom[I](output: EndpointOutput[I],
default: tapir.StatusCode,
defaultSchema: Option[Schema],
when: Vector[(When[I], tapir.StatusCode)])
extends Single[I] {
def defaultSchema(s: Schema): StatusFrom[I] = this.copy(defaultSchema = Some(s))
override def show: String = s"status from(${output.show}, $default or ${when.map(_._2).mkString("/")})"
}

case class Mapped[I, T](wrapped: EndpointOutput[I], f: I => T, g: T => I, paramsAsArgs: ParamsAsArgs[I]) extends Single[T] {
override def show: String = s"map(${wrapped.show})"
}

//

case class Multiple[I](outputs: Vector[Single[_]]) extends EndpointOutput[I] {
override def and[J, IJ](other: EndpointOutput[J])(implicit ts: ParamConcat.Aux[I, J, IJ]): EndpointOutput.Multiple[IJ] =
other match {
case s: Single[_] => Multiple(outputs :+ s)
case Multiple(m) => Multiple(outputs ++ m)
case EndpointIO.Multiple(m) => Multiple(outputs ++ m)
}
def show: String = if (outputs.isEmpty) "-" else outputs.map(_.show).mkString(" ")
}
}

sealed trait EndpointIO[I] extends EndpointInput[I] with EndpointOutput[I] {
def and[J, IJ](other: EndpointIO[J])(implicit ts: ParamConcat.Aux[I, J, IJ]): EndpointIO[IJ]

def show: String
Expand All @@ -159,23 +237,18 @@ sealed trait EndpointIO[I] extends EndpointInput[I] {
paramsAsArgs: ParamsAsArgs[I]): EndpointIO[CASE_CLASS] = {
map[CASE_CLASS](fc.tupled(c).apply)(ProductToParams(_, fc.arity).asInstanceOf[I])(paramsAsArgs)
}

private[tapir] override def asVectorOfSingle: Vector[EndpointIO.Single[_]] = this match {
case s: EndpointIO.Single[_] => Vector(s)
case m: EndpointIO.Multiple[_] => m.ios
}
}

object EndpointIO {
sealed trait Single[I] extends EndpointIO[I] with EndpointInput.Single[I] {
sealed trait Single[I] extends EndpointIO[I] with EndpointInput.Single[I] with EndpointOutput.Single[I] {
def and[J, IJ](other: EndpointIO[J])(implicit ts: ParamConcat.Aux[I, J, IJ]): EndpointIO[IJ] =
other match {
case s: Single[_] => Multiple(Vector(this, s))
case Multiple(outputs) => Multiple(this +: outputs)
}
}

sealed trait Basic[I] extends Single[I] with EndpointInput.Basic[I]
sealed trait Basic[I] extends Single[I] with EndpointInput.Basic[I] with EndpointOutput.Basic[I]

case class Body[T, M <: MediaType, R](codec: CodecForOptional[T, M, R], info: Info[T]) extends Basic[T] {
def description(d: String): Body[T, M, R] = copy(info = info.description(d))
Expand All @@ -199,22 +272,6 @@ object EndpointIO {
def show = s"{multiple headers}"
}

case class StatusCode() extends Basic[tapir.StatusCode] {
override def show: String = "{status code}"
}

//

// TODO: should be output-only
case class StatusFrom[I](io: EndpointIO[I],
default: tapir.StatusCode,
defaultSchema: Option[Schema],
when: Vector[(When[I], tapir.StatusCode)])
extends Single[I] {
def defaultSchema(s: Schema): StatusFrom[I] = this.copy(defaultSchema = Some(s))
override def show: String = s"status from(${io.show}, $default or ${when.map(_._2).mkString("/")})"
}

case class Mapped[I, T](wrapped: EndpointIO[I], f: I => T, g: T => I, paramsAsArgs: ParamsAsArgs[I]) extends Single[T] {
override def show: String = s"map(${wrapped.show})"
}
Expand All @@ -228,6 +285,12 @@ object EndpointIO {
case EndpointInput.Multiple(m) => EndpointInput.Multiple((ios: Vector[EndpointInput.Single[_]]) ++ m)
case EndpointIO.Multiple(m) => EndpointInput.Multiple((ios: Vector[EndpointInput.Single[_]]) ++ m)
}
override def and[J, IJ](other: EndpointOutput[J])(implicit ts: ParamConcat.Aux[I, J, IJ]): EndpointOutput.Multiple[IJ] =
other match {
case s: EndpointOutput.Single[_] => EndpointOutput.Multiple((ios: Vector[EndpointOutput.Single[_]]) :+ s)
case EndpointOutput.Multiple(m) => EndpointOutput.Multiple((ios: Vector[EndpointOutput.Single[_]]) ++ m)
case EndpointIO.Multiple(m) => EndpointOutput.Multiple((ios: Vector[EndpointOutput.Single[_]]) ++ m)
}
override def and[J, IJ](other: EndpointIO[J])(implicit ts: ParamConcat.Aux[I, J, IJ]): Multiple[IJ] =
other match {
case s: Single[_] => Multiple(ios :+ s)
Expand Down
10 changes: 5 additions & 5 deletions core/src/main/scala/tapir/Tapir.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ trait Tapir extends TapirDerivedInputs {

// TODO
@deprecated
def statusFrom[I](io: EndpointIO[I], default: StatusCode, when: (When[I], StatusCode)*): EndpointIO.StatusFrom[I] =
EndpointIO.StatusFrom(io, default, None, when.toVector)
def statusFrom[I](io: EndpointIO[I], default: StatusCode, when: (When[I], StatusCode)*): EndpointOutput.StatusFrom[I] =
EndpointOutput.StatusFrom(io, default, None, when.toVector)

def statusCode: EndpointIO.StatusCode = EndpointIO.StatusCode()
def statusCode: EndpointOutput.StatusCode = EndpointOutput.StatusCode()

def whenClass[U: ClassTag: SchemaFor]: When[Any] = WhenClass(implicitly[ClassTag[U]], implicitly[SchemaFor[U]].schema)
def whenValue[U](p: U => Boolean): When[U] = WhenValue(p)
Expand All @@ -82,8 +82,8 @@ trait Tapir extends TapirDerivedInputs {
val endpoint: Endpoint[Unit, Unit, Unit, Nothing] =
Endpoint[Unit, Unit, Unit, Nothing](
EndpointInput.Multiple(Vector.empty),
EndpointIO.Multiple(Vector.empty),
EndpointIO.Multiple(Vector.empty),
EndpointOutput.Multiple(Vector.empty),
EndpointOutput.Multiple(Vector.empty),
EndpointInfo(None, None, None, Vector.empty)
)
}
Expand Down
3 changes: 1 addition & 2 deletions core/src/main/scala/tapir/internal/server/DecodeInputs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ object DecodeInputs {
*/
def apply(input: EndpointInput[_], ctx: DecodeInputsContext): DecodeInputsResult = {
// the first decoding failure is returned. We decode in the following order: method, path, query, headers (incl. cookies), request, status, body
val inputs = input.asVectorOfBasic().sortBy {
val inputs = input.asVectorOfBasicInputs().sortBy {
case _: EndpointInput.RequestMethod => 0
case _: EndpointInput.PathSegment => 1
case _: EndpointInput.PathCapture[_] => 1
Expand All @@ -53,7 +53,6 @@ object DecodeInputs {
case _: EndpointIO.Header[_] => 3
case _: EndpointIO.Headers => 3
case _: EndpointInput.ExtractFromRequest[_] => 4
case _: EndpointIO.StatusCode => 5
case _: EndpointIO.Body[_, _, _] => 6
case _: EndpointIO.StreamBodyWrapper[_, _] => 6
}
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/tapir/internal/server/InputValues.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ object InputValues {
* Returns the values of the inputs in the order specified by `input`, and mapped if necessary using defined mapping
* functions.
*/
def apply(input: EndpointInput[_], values: Map[EndpointInput.Basic[_], Any]): List[Any] = apply(input.asVectorOfSingle, values)
def apply(input: EndpointInput[_], values: Map[EndpointInput.Basic[_], Any]): List[Any] = apply(input.asVectorOfSingleInputs, values)

private def apply(inputs: Vector[EndpointInput.Single[_]], values: Map[EndpointInput.Basic[_], Any]): List[Any] = {
inputs match {
Expand All @@ -32,7 +32,7 @@ object InputValues {
f: II => T,
inputsTail: Vector[EndpointInput.Single[_]],
values: Map[EndpointInput.Basic[_], Any]): List[Any] = {
val wrappedValue = apply(wrapped.asVectorOfSingle, values)
val wrappedValue = apply(wrapped.asVectorOfSingleInputs, values)
f.asInstanceOf[Any => Any].apply(SeqToParams(wrappedValue)) :: apply(inputsTail, values)
}
}

0 comments on commit 563d5a1

Please sign in to comment.