Skip to content

Commit

Permalink
#55: Tapir.statusFrom endpoint output
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw committed Mar 18, 2019
1 parent 2a4adaf commit 96e630f
Show file tree
Hide file tree
Showing 18 changed files with 297 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ class EndpointToSttpClient(clientOptions: SttpClientOptions) {

case EndpointIO.Mapped(wrapped, f, _, _) =>
f.asInstanceOf[Any => Any].apply(getOutputParams(wrapped.asVectorOfSingle, body, meta))

case EndpointIO.StatusFrom(wrapped, _, _, _) =>
getOutputParams(wrapped.asVectorOfSingle, body, meta)
}

SeqToParams(values)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ trait ClientTests[S] extends FunSuite with Matchers with BeforeAndAfterAll {
UsernamePassword("teddy", Some("bear")),
Right("Authorization=Some(Basic dGVkZHk6YmVhcg==); X-Api-Key=None; Query=None"))
testClient(in_auth_bearer_out_string, "1234", Right("Authorization=Some(Bearer 1234); X-Api-Key=None; Query=None"))
testClient(in_string_out_status_from_string, "apple", Right("fruit: apple")) // status from should be a no-op from the client interpreter's point of view

//

Expand Down
13 changes: 11 additions & 2 deletions core/src/main/scala/tapir/EndpointIO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@ sealed trait EndpointInput[I] {
case EndpointInput.Mapped(wrapped, _, _, _) => wrapped.traverse(handle)
case EndpointIO.Mapped(wrapped, _, _, _) => wrapped.traverse(handle)
case a: EndpointInput.Auth[_] => a.input.traverse(handle)
case s: EndpointIO.StatusFrom[_] => s.io.traverse(handle)
case _ => Vector.empty
}

private[tapir] def asVectorOfBasic(includeAuth: Boolean = true): Vector[EndpointInput.Basic[_]] = traverse {
case b: EndpointInput.Basic[_] => Vector(b)
case a: EndpointInput.Auth[_] => if (includeAuth) a.input.asVectorOfBasic(includeAuth) else Vector.empty
case b: EndpointInput.Basic[_] => Vector(b)
case a: EndpointInput.Auth[_] => if (includeAuth) a.input.asVectorOfBasic(includeAuth) else Vector.empty
case s: EndpointIO.StatusFrom[_] => s.io.asVectorOfBasic(includeAuth)
}

private[tapir] def auths: Vector[EndpointInput.Auth[_]] = traverse {
Expand Down Expand Up @@ -195,6 +197,13 @@ object EndpointIO {

//

// TODO: should be output-only
case class StatusFrom[I](io: EndpointIO[I], default: StatusCode, defaultSchema: Option[Schema], when: Vector[(When[I], StatusCode)])
extends Single[I] {
def defaultSchema(s: Schema): StatusFrom[I] = this.copy(defaultSchema = Some(s))
override def show: String = s"status from(${io.show}, $default or ${when.map(_._2).mkString("/")})"
}

case class Mapped[I, T](wrapped: EndpointIO[I], f: I => T, g: T => I, paramsAsArgs: ParamsAsArgs[I]) extends Single[T] {
override def show: String = s"map(${wrapped.show})"
}
Expand Down
8 changes: 8 additions & 0 deletions core/src/main/scala/tapir/Tapir.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import tapir.CodecForMany.PlainCodecForMany
import tapir.CodecForOptional.PlainCodecForOptional
import tapir.model.{Cookie, SetCookie, SetCookieValue}

import scala.reflect.ClassTag

trait Tapir {
implicit def stringToPath(s: String): EndpointInput[Unit] = EndpointInput.PathSegment(s)

Expand Down Expand Up @@ -59,6 +61,12 @@ trait Tapir {

def auth: TapirAuth.type = TapirAuth

def statusFrom[I](io: EndpointIO[I], default: StatusCode, when: (When[I], StatusCode)*): EndpointIO.StatusFrom[I] =
EndpointIO.StatusFrom(io, default, None, when.toVector)

def whenClass[U: ClassTag: SchemaFor]: When[Any] = WhenClass(implicitly[ClassTag[U]], implicitly[SchemaFor[U]].schema)
def whenValue[U](p: U => Boolean): When[U] = WhenValue(p)

def schemaFor[T: SchemaFor]: Schema = implicitly[SchemaFor[T]].schema

val endpoint: Endpoint[Unit, Unit, Unit, Nothing] =
Expand Down
16 changes: 16 additions & 0 deletions core/src/main/scala/tapir/When.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package tapir
import scala.reflect.ClassTag

/**
* Describe conditions for status code mapping using `Tapir.statusFrom`.
*/
trait When[-I] {
def matches(i: I): Boolean
}

case class WhenClass[T](ct: ClassTag[T], s: Schema) extends When[Any] {
override def matches(i: Any): Boolean = ct.runtimeClass.isInstance(i)
}
case class WhenValue[T](p: T => Boolean) extends When[T] {
override def matches(i: T): Boolean = p(i)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package tapir.docs.openapi

import tapir.docs.openapi.schema.ObjectSchemas
import tapir.openapi.{MediaType => OMediaType, _}
import tapir.{MediaType => SMediaType, Schema => SSchema, _}

private[openapi] class CodecToMediaType(objectSchemas: ObjectSchemas) {
def apply[T, M <: SMediaType](o: CodecForOptional[T, M, _],
example: Option[T],
overrideSchema: Option[SSchema]): Map[String, OMediaType] = {
val schema = overrideSchema.getOrElse(o.meta.schema)
Map(
o.meta.mediaType.mediaTypeNoParams -> OMediaType(Some(objectSchemas(schema)),
example.flatMap(exampleValue(o, _)),
Map.empty,
Map.empty))
}

def apply[M <: SMediaType](schema: SSchema, mediaType: M, example: Option[String]): Map[String, OMediaType] = {
Map(mediaType.mediaTypeNoParams -> OMediaType(Some(objectSchemas(schema)), example.map(ExampleValue), Map.empty, Map.empty))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ package tapir.docs.openapi
import tapir.docs.openapi.schema.ObjectSchemas
import tapir.model.Method
import tapir.openapi.OpenAPI.ReferenceOr
import tapir.openapi.{MediaType => OMediaType, _}
import tapir.{EndpointInput, MediaType => SMediaType, Schema => SSchema, _}
import tapir.openapi._
import tapir._

private[openapi] class EndpointToOpenApiPaths(objectSchemas: ObjectSchemas, securitySchemes: SecuritySchemes, options: OpenAPIDocsOptions) {

private val codecToMediaType = new CodecToMediaType(objectSchemas)
private val endpointToOperationResponse = new EndpointToOperationResponse(objectSchemas, codecToMediaType)

def pathItem(e: Endpoint[_, _, _, _]): (String, PathItem) = {
import model.Method._

Expand Down Expand Up @@ -45,7 +48,7 @@ private[openapi] class EndpointToOpenApiPaths(objectSchemas: ObjectSchemas, secu
private def endpointToOperation(defaultId: String, e: Endpoint[_, _, _, _], inputs: Vector[EndpointInput.Basic[_]]): Operation = {
val parameters = operationParameters(inputs)
val body: Vector[ReferenceOr[RequestBody]] = operationInputBody(inputs)
val responses: Map[ResponsesKey, ReferenceOr[Response]] = operationResponse(e)
val responses: Map[ResponsesKey, ReferenceOr[Response]] = endpointToOperationResponse(e)

val authNames = e.input.auths.flatMap(auth => securitySchemes.get(auth).map(_._1))
// for now, all auths have empty scope
Expand All @@ -68,7 +71,7 @@ private[openapi] class EndpointToOpenApiPaths(objectSchemas: ObjectSchemas, secu
private def operationInputBody(inputs: Vector[EndpointInput.Basic[_]]) = {
inputs.collect {
case EndpointIO.Body(codec, info) =>
Right(RequestBody(info.description, codecToMediaType(codec, info.example), Some(!codec.meta.isOptional)))
Right(RequestBody(info.description, codecToMediaType(codec, info.example, None), Some(!codec.meta.isOptional)))
case EndpointIO.StreamBodyWrapper(StreamingEndpointIO.Body(s, mt, i)) =>
Right(RequestBody(i.description, codecToMediaType(s, mt, i.example), Some(true)))
}
Expand Down Expand Up @@ -103,72 +106,6 @@ private[openapi] class EndpointToOpenApiPaths(objectSchemas: ObjectSchemas, secu
query.info.example.flatMap(exampleValue(query.codec, _)))
}

private def operationResponse(e: Endpoint[_, _, _, _]): Map[ResponsesKey, Right[Nothing, Response]] = {
// There always needs to be at least a 200 empty response
val okResponse = outputToResponse(e.output).getOrElse(Response("", Map.empty, Map.empty))

List(
Some(ResponsesCodeKey(200) -> Right(okResponse)),
outputToResponse(e.errorOutput).map { r =>
ResponsesDefaultKey -> Right(r)
}
).flatten.toMap
}

private def outputToResponse(io: EndpointIO[_]): Option[Response] = {
val ios = io.asVectorOfBasic()

val headers = ios.collect {
case EndpointIO.Header(name, codec, info) =>
name -> Right(
Header(
info.description,
Some(!codec.meta.isOptional),
None,
None,
None,
None,
None,
Some(objectSchemas(codec.meta.schema)),
info.example.flatMap(exampleValue(codec, _)),
Map.empty,
Map.empty
))
}

val bodies = ios.collect {
case EndpointIO.Body(m, i) => (i.description, codecToMediaType(m, i.example))
case EndpointIO.StreamBodyWrapper(StreamingEndpointIO.Body(s, mt, i)) => (i.description, codecToMediaType(s, mt, i.example))
}
val body = bodies.headOption

val description = body.flatMap(_._1).getOrElse("")
val content = body.map(_._2).getOrElse(Map.empty)

if (body.isDefined || headers.nonEmpty) {
Some(Response(description, headers.toMap, content))
} else {
None
}
}

private def codecToMediaType[T, M <: SMediaType](o: CodecForOptional[T, M, _], example: Option[T]): Map[String, OMediaType] = {
Map(
o.meta.mediaType.mediaTypeNoParams -> OMediaType(Some(objectSchemas(o.meta.schema)),
example.flatMap(exampleValue(o, _)),
Map.empty,
Map.empty))
}

private def codecToMediaType[M <: SMediaType](schema: SSchema, mediaType: M, example: Option[String]): Map[String, OMediaType] = {
Map(mediaType.mediaTypeNoParams -> OMediaType(Some(objectSchemas(schema)), example.map(ExampleValue), Map.empty, Map.empty))
}

private def exampleValue[T](v: Any): ExampleValue = ExampleValue(v.toString)
private def exampleValue[T](codec: Codec[T, _, _], e: T): Option[ExampleValue] = Some(exampleValue(codec.encode(e)))
private def exampleValue[T](codec: CodecForOptional[T, _, _], e: T): Option[ExampleValue] = codec.encode(e).map(exampleValue)
private def exampleValue[T](codec: CodecForMany[T, _, _], e: T): Option[ExampleValue] = codec.encode(e).headOption.map(exampleValue)

/**
* @return `Left` if the component is a capture, `Right` if it is a segment
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package tapir.docs.openapi

import tapir.docs.openapi.schema.ObjectSchemas
import tapir.openapi.OpenAPI.ReferenceOr
import tapir.openapi._
import tapir.{Schema => SSchema, _}

private[openapi] class EndpointToOperationResponse(objectSchemas: ObjectSchemas, codecToMediaType: CodecToMediaType) {
def apply(e: Endpoint[_, _, _, _]): Map[ResponsesKey, ReferenceOr[Response]] = {
// There always needs to be at least a 200 empty response
outputToResponses(e.output, ResponsesCodeKey(200), Some(Response("", Map.empty, Map.empty))) ++
outputToResponses(e.errorOutput, ResponsesDefaultKey, None)
}

private def statusCodesToBodySchemas(io: EndpointIO[_]): Map[StatusCode, Option[SSchema]] = {
io.traverse {
case EndpointIO.StatusFrom(_: EndpointIO.Body[_, _, _], default, defaultSchema, whens) =>
val fromWhens = whens.map {
case (WhenClass(_, schema), statusCode) => statusCode -> Some(schema)
case (_, statusCode) => statusCode -> None
}
(default -> defaultSchema) +: fromWhens
case EndpointIO.StatusFrom(_, default, _, whens) =>
val statusCodes = default +: whens.map(_._2)
statusCodes.map(_ -> None)
}.toMap
}

private def outputToResponses(io: EndpointIO[_],
defaultResponseKey: ResponsesKey,
defaultResponse: Option[Response]): Map[ResponsesKey, ReferenceOr[Response]] = {
val statusCodes = statusCodesToBodySchemas(io)
val responses = if (statusCodes.isEmpty) {
// no status code mapping defined in the output - using the default response key, if there's any response defined at all
outputToResponse(io, None).map(defaultResponseKey -> Right(_)).toMap
} else {
statusCodes.flatMap {
case (statusCode, bodySchema) =>
outputToResponse(io, bodySchema).map((ResponsesCodeKey(statusCode): ResponsesKey) -> Right(_))
}
}

if (responses.isEmpty) {
// no output at all - using default if defined
defaultResponse.map(defaultResponseKey -> Right(_)).toMap
} else responses
}

private def outputToResponse(io: EndpointIO[_], overrideBodySchema: Option[SSchema]): Option[Response] = {
val ios = io.asVectorOfBasic()

val headers = ios.collect {
case EndpointIO.Header(name, codec, info) =>
name -> Right(
Header(
info.description,
Some(!codec.meta.isOptional),
None,
None,
None,
None,
None,
Some(objectSchemas(codec.meta.schema)),
info.example.flatMap(exampleValue(codec, _)),
Map.empty,
Map.empty
))
}

val bodies = ios.collect {
case EndpointIO.Body(m, i) => (i.description, codecToMediaType(m, i.example, overrideBodySchema))
case EndpointIO.StreamBodyWrapper(StreamingEndpointIO.Body(s, mt, i)) =>
(i.description, codecToMediaType(overrideBodySchema.getOrElse(s), mt, i.example))
}
val body = bodies.headOption

val description = body.flatMap(_._1).getOrElse("")
val content = body.map(_._2).getOrElse(Map.empty)

if (body.isDefined || headers.nonEmpty) {
Some(Response(description, headers.toMap, content))
} else {
None
}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package tapir.docs

import tapir.EndpointInput
import tapir.openapi.SecurityScheme
import tapir.{Codec, CodecForMany, CodecForOptional, EndpointInput}
import tapir.openapi.{ExampleValue, SecurityScheme}

package object openapi extends OpenAPIDocs {
private[openapi] type SchemeName = String
Expand All @@ -16,4 +16,10 @@ package object openapi extends OpenAPIDocs {
}
result
}

private[openapi] def exampleValue[T](v: Any): ExampleValue = ExampleValue(v.toString)
private[openapi] def exampleValue[T](codec: Codec[T, _, _], e: T): Option[ExampleValue] = Some(exampleValue(codec.encode(e)))
private[openapi] def exampleValue[T](codec: CodecForOptional[T, _, _], e: T): Option[ExampleValue] = codec.encode(e).map(exampleValue)
private[openapi] def exampleValue[T](codec: CodecForMany[T, _, _], e: T): Option[ExampleValue] =
codec.encode(e).headOption.map(exampleValue)
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,17 @@ object ObjectSchemasForEndpoints {
filterIsObjectSchema(schema)
case EndpointIO.Mapped(wrapped, _, _, _) =>
forInput(wrapped)
case EndpointIO.StatusFrom(wrapped, _, defaultSchema, whens) =>
val fromDefaultSchema = defaultSchema.toList.flatMap(filterIsObjectSchema)
val fromWhens = whens.collect {
case (WhenClass(_, s), _) => filterIsObjectSchema(s)
}.flatten
val fromInput = forInput(wrapped)

// if there's a default schema, we exclude the one from the input
val fromInputOrDefault = if (fromDefaultSchema.nonEmpty) fromDefaultSchema else fromInput

fromInputOrDefault ++ fromWhens
}
}
}

0 comments on commit 96e630f

Please sign in to comment.