Skip to content

Commit

Permalink
Method is yet another type of input
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw committed Feb 27, 2019
1 parent c1d2b11 commit bb945db
Show file tree
Hide file tree
Showing 13 changed files with 68 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class EndpointToSttpClient(clientOptions: SttpClientOptions) {

val (uri, req) = setInputParams(e.input.asVectorOfSingle, params, paramsAsArgs, 0, baseUri, baseReq)

var req2 = req.copy[Id, Either[Any, Any], Any](method = SttpMethod(e.method.getOrElse(Method.GET).m), uri = uri)
var req2 = req.copy[Id, Either[Any, Any], Any](method = SttpMethod(e.input.method.getOrElse(Method.GET).m), uri = uri)

if (e.output.asVectorOfSingle.nonEmpty || e.errorOutput.asVectorOfSingle.nonEmpty) {
// by default, reading the body as specified by the output, and optionally adjusting to the error output
Expand Down Expand Up @@ -115,6 +115,8 @@ class EndpointToSttpClient(clientOptions: SttpClientOptions) {

inputs match {
case Vector() => (uri, req)
case EndpointInput.RequestMethod(_) +: tail =>
setInputParams(tail, params, paramsAsArgs, paramIndex, uri, req)
case EndpointInput.PathSegment(p) +: tail =>
setInputParams(tail, params, paramsAsArgs, paramIndex, uri.copy(path = uri.path :+ p), req)
case EndpointInput.PathCapture(codec, _, _) +: tail =>
Expand Down
32 changes: 14 additions & 18 deletions core/src/main/scala/tapir/Endpoint.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package tapir

import tapir.EndpointInput.RequestMethod
import tapir.model.Method
import tapir.typelevel.{FnComponents, ParamConcat, ParamsAsArgs}

Expand All @@ -9,22 +10,18 @@ import tapir.typelevel.{FnComponents, ParamConcat, ParamsAsArgs}
* @tparam O Output parameter types.
* @tparam S The type of streams that are used by this endpoint's inputs/outputs. `Nothing`, if no streams are used.
*/
case class Endpoint[I, E, O, +S](method: Option[Method],
input: EndpointInput[I],
errorOutput: EndpointIO[E],
output: EndpointIO[O],
info: EndpointInfo) {

def get: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Some(Method.GET))
def head: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Some(Method.HEAD))
def post: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Some(Method.POST))
def put: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Some(Method.PUT))
def delete: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Some(Method.DELETE))
def options: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Some(Method.OPTIONS))
def patch: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Some(Method.PATCH))
def connect: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Some(Method.CONNECT))
def trace: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Some(Method.TRACE))
def method(m: String): Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Some(Method(m)))
case class Endpoint[I, E, O, +S](input: EndpointInput[I], errorOutput: EndpointIO[E], output: EndpointIO[O], info: EndpointInfo) {

def get: Endpoint[I, E, O, S] = in(RequestMethod(Method.GET))
def post: Endpoint[I, E, O, S] = in(RequestMethod(Method.POST))
def head: Endpoint[I, E, O, S] = in(RequestMethod(Method.HEAD))
def put: Endpoint[I, E, O, S] = in(RequestMethod(Method.PUT))
def delete: Endpoint[I, E, O, S] = in(RequestMethod(Method.DELETE))
def options: Endpoint[I, E, O, S] = in(RequestMethod(Method.OPTIONS))
def patch: Endpoint[I, E, O, S] = in(RequestMethod(Method.PATCH))
def connect: Endpoint[I, E, O, S] = in(RequestMethod(Method.CONNECT))
def trace: Endpoint[I, E, O, S] = in(RequestMethod(Method.TRACE))
def method(m: String): Endpoint[I, E, O, S] = in(RequestMethod(Method(m)))

def in[J, IJ](i: EndpointInput[J])(implicit ts: ParamConcat.Aux[I, J, IJ]): Endpoint[IJ, E, O, S] =
this.copy[IJ, E, O, S](input = input.and(i))
Expand Down Expand Up @@ -71,8 +68,7 @@ case class Endpoint[I, E, O, +S](method: Option[Method],
def info(i: EndpointInfo): Endpoint[I, E, O, S] = copy(info = i)

def show: String = {
val m = method.fold("")(m => m.m + ", ")
s"Endpoint${info.name.map("[" + _ + "]").getOrElse("")}(${m}in: ${input.show}, errout: ${errorOutput.show}, out: ${output.show})"
s"Endpoint${info.name.map("[" + _ + "]").getOrElse("")}(in: ${input.show}, errout: ${errorOutput.show}, out: ${output.show})"
}
}

Expand Down
15 changes: 14 additions & 1 deletion core/src/main/scala/tapir/EndpointIO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package tapir
import tapir.Codec.PlainCodec
import tapir.CodecForMany.PlainCodecForMany
import tapir.internal.ProductToParams
import tapir.model.MultiQueryParams
import tapir.model.{Method, MultiQueryParams}
import tapir.typelevel.{FnComponents, ParamConcat, ParamsAsArgs}

sealed trait EndpointInput[I] {
Expand Down Expand Up @@ -42,6 +42,15 @@ sealed trait EndpointInput[I] {
case EndpointIO.Mapped(wrapped, _, _, _) => wrapped.bodyType
case _ => None
}

private[tapir] def method: Option[Method] = this match {
case i: EndpointInput.RequestMethod => Some(i.m)
case EndpointInput.Multiple(inputs) => inputs.flatMap(_.method).headOption
case EndpointIO.Multiple(inputs) => inputs.flatMap(_.method).headOption
case EndpointInput.Mapped(wrapped, _, _, _) => wrapped.method
case EndpointIO.Mapped(wrapped, _, _, _) => wrapped.method
case _ => None
}
}

object EndpointInput {
Expand All @@ -56,6 +65,10 @@ object EndpointInput {

sealed trait Basic[I] extends Single[I]

case class RequestMethod(m: Method) extends Basic[Unit] {
def show: String = m.m
}

case class PathSegment(s: String) extends Basic[Unit] {
def show = s"/$s"
}
Expand Down
3 changes: 1 addition & 2 deletions core/src/main/scala/tapir/Tapir.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import java.nio.charset.{Charset, StandardCharsets}

import tapir.Codec.PlainCodec
import tapir.CodecForMany.PlainCodecForMany
import tapir.model.{Cookie, CookiePair, Method}
import tapir.model.{Cookie, CookiePair}

trait Tapir {
implicit def stringToPath(s: String): EndpointInput[Unit] = EndpointInput.PathSegment(s)
Expand Down Expand Up @@ -54,7 +54,6 @@ trait Tapir {

val endpoint: Endpoint[Unit, Unit, Unit, Nothing] =
Endpoint[Unit, Unit, Unit, Nothing](
None,
EndpointInput.Multiple(Vector.empty),
EndpointIO.Multiple(Vector.empty),
EndpointIO.Multiple(Vector.empty),
Expand Down
29 changes: 18 additions & 11 deletions core/src/main/scala/tapir/internal/server/DecodeInputs.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package tapir.internal.server

import tapir.model.MultiQueryParams
import tapir.model.{Method, MultiQueryParams}
import tapir.{DecodeFailure, DecodeResult, EndpointIO, EndpointInput}

import scala.annotation.tailrec
Expand All @@ -14,6 +14,8 @@ object DecodeInputsResult {
}

trait DecodeInputsContext {
def method: Method

def nextPathSegment: (Option[String], DecodeInputsContext)

def header(name: String): List[String]
Expand All @@ -37,17 +39,18 @@ object DecodeInputs {
* In case any of the decoding fails, the failure is returned together with the failing input.
*/
def apply(input: EndpointInput[_], ctx: DecodeInputsContext): DecodeInputsResult = {
// the first decoding failure is returned. We decode in the following order: path, query, headers, body
// the first decoding failure is returned. We decode in the following order: method, path, query, headers, body
val inputs = input.asVectorOfBasic.sortBy {
case _: EndpointInput.PathSegment => 0
case _: EndpointInput.PathCapture[_] => 0
case _: EndpointInput.PathsCapture => 0
case _: EndpointInput.Query[_] => 1
case _: EndpointInput.QueryParams => 1
case _: EndpointIO.Header[_] => 2
case _: EndpointIO.Headers => 2
case _: EndpointIO.Body[_, _, _] => 3
case _: EndpointIO.StreamBodyWrapper[_, _] => 3
case _: EndpointInput.RequestMethod => 0
case _: EndpointInput.PathSegment => 1
case _: EndpointInput.PathCapture[_] => 1
case _: EndpointInput.PathsCapture => 1
case _: EndpointInput.Query[_] => 2
case _: EndpointInput.QueryParams => 2
case _: EndpointIO.Header[_] => 3
case _: EndpointIO.Headers => 3
case _: EndpointIO.Body[_, _, _] => 4
case _: EndpointIO.StreamBodyWrapper[_, _] => 4
}

val (result, consumedCtx) = apply(inputs, DecodeInputsResult.Values(Map(), None), ctx)
Expand All @@ -64,6 +67,10 @@ object DecodeInputs {
inputs match {
case Vector() => (values, ctx)

case (input @ EndpointInput.RequestMethod(m)) +: inputsTail =>
if (m == ctx.method) apply(inputsTail, values, ctx)
else (DecodeInputsResult.Failure(input, DecodeResult.Mismatch(m.m, ctx.method.m)), ctx)

case (input @ EndpointInput.PathSegment(ss)) +: inputsTail =>
ctx.nextPathSegment match {
case (Some(`ss`), ctx2) => apply(inputsTail, values, ctx2)
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/tapir/internal/server/InputValues.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ object InputValues {
private def apply(inputs: Vector[EndpointInput.Single[_]], values: Map[EndpointInput.Single[_], Any]): List[Any] = {
inputs match {
case Vector() => Nil
case (_: EndpointInput.RequestMethod) +: inputsTail =>
apply(inputsTail, values)
case (_: EndpointInput.PathSegment) +: inputsTail =>
apply(inputsTail, values)
case EndpointInput.Mapped(wrapped, f, _, _) +: inputsTail =>
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/tapir/server/ServerDefaults.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ object ServerDefaults {
* By default, a 400 (bad request) is returned if a query, header or body input can't be decoded (for any reason),
* or if decoding a path capture ends with an error.
*
* Otherwise (e.g. if a path segment, or path capture is missing or there's a mismatch), a "no match" is returned,
* which is a signal to try the next endpoint.
* Otherwise (e.g. if the method, a path segment, or path capture is missing or there's a mismatch), a "no match" is
* returned, which is a signal to try the next endpoint.
*/
def decodeFailureHandler[R]: DecodeFailureHandler[R] = (_, input, failure) => {
input match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ private[openapi] class EndpointToOpenApiPaths(objectSchemas: ObjectSchemas, opti

val inputs = e.input.asVectorOfBasic
val pathComponents = namedPathComponents(inputs)
val method = e.method.getOrElse(Method.GET)
val method = e.input.method.getOrElse(Method.GET)

val pathComponentsForId = pathComponents.map(_.fold(identity, identity))
val defaultId = options.operationIdGenerator(pathComponentsForId, method)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ object ObjectSchemasForEndpoints {

private def forInput(input: EndpointInput[_]): List[TSchema.SObject] = {
input match {
case EndpointInput.RequestMethod(_) =>
List.empty
case EndpointInput.PathSegment(_) =>
List.empty
case EndpointInput.PathCapture(tm, _, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ package tapir.server.akkahttp
import akka.http.scaladsl.model.Uri
import akka.http.scaladsl.server.RequestContext
import tapir.internal.server.DecodeInputsContext
import tapir.model.MultiQueryParams
import tapir.model.{Method, MultiQueryParams}

private[akkahttp] class AkkaDecodeInputsContext(req: RequestContext) extends DecodeInputsContext {
override def method: Method = Method(req.request.method.value.toUpperCase)
override def nextPathSegment: (Option[String], DecodeInputsContext) = {
req.unmatchedPath match {
case Uri.Path.Slash(pathTail) => new AkkaDecodeInputsContext(req.withUnmatchedPath(pathTail)).nextPathSegment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,20 @@ import java.io.ByteArrayInputStream
import akka.http.scaladsl.model._
import akka.http.scaladsl.server.Directives.{
complete,
delete,
extractExecutionContext,
extractMaterializer,
extractRequestContext,
get,
head,
method,
onSuccess,
options,
patch,
post,
put,
reject,
pass
}
import akka.http.scaladsl.server.{Directive0, Directive1, RequestContext}
import akka.http.scaladsl.server.{Directive1, RequestContext}
import akka.http.scaladsl.unmarshalling.FromEntityUnmarshaller
import akka.stream.Materializer
import akka.stream.scaladsl.{FileIO, Sink}
import akka.util.ByteString
import tapir.internal.SeqToParams
import tapir.internal.server.{DecodeInputs, DecodeInputsResult, InputValues}
import tapir.model.{Method, Part}
import tapir.model.Part
import tapir.server.DecodeFailureHandling
import tapir.{
ByteArrayValueType,
Expand Down Expand Up @@ -54,8 +45,6 @@ private[akkahttp] class EndpointToAkkaDirective(serverOptions: AkkaHttpServerOpt
import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server._

val methodDirective = methodToAkkaDirective(e)

val inputDirectives: Directive1[I] = {

def decodeBody(result: DecodeInputsResult): Directive1[DecodeInputsResult] = {
Expand Down Expand Up @@ -85,25 +74,7 @@ private[akkahttp] class EndpointToAkkaDirective(serverOptions: AkkaHttpServerOpt
}
}

methodDirective & inputDirectives
}

private def methodToAkkaDirective[O, E, I](e: Endpoint[I, E, O, AkkaStream]): Directive0 = {
e.method match {
case Some(m) =>
m match {
case Method.GET => get
case Method.HEAD => head
case Method.POST => post
case Method.PUT => put
case Method.DELETE => delete
case Method.OPTIONS => options
case Method.PATCH => patch
case _ => method(HttpMethod.custom(m.m))
}

case None => pass
}
inputDirectives
}

private def rawBodyDirective(bodyType: RawValueType[_]): Directive1[Any] = extractRequestContext.flatMap { ctx =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import cats.implicits._
import org.http4s.{EntityBody, Headers, HttpRoutes, Request, Response, Status}
import tapir.internal.SeqToParams
import tapir.internal.server.{DecodeInputs, DecodeInputsResult, InputValues}
import tapir.model.Method
import tapir.server.{DecodeFailureHandling, StatusMapper}
import tapir.typelevel.ParamsAsArgs
import tapir.{DecodeFailure, DecodeResult, Endpoint, EndpointIO, EndpointInput}
Expand Down Expand Up @@ -49,38 +48,15 @@ class EndpointToHttp4sServer[F[_]: Sync: ContextShift](serverOptions: Http4sServ
}
}

val methodMatches = e.method match {
case Some(m) => http4sMethodToTapirMethodMap.get(req.method).contains(m)
case None => true
}

if (methodMatches) {
OptionT(decodeBody(DecodeInputs(e.input, new Http4sDecodeInputsContext[F](req))).flatMap {
case values: DecodeInputsResult.Values => valuesToResponse(values).map(_.some)
case DecodeInputsResult.Failure(input, failure) => handleDecodeFailure(req, input, failure).pure[F]
})
} else {
OptionT.none
}
OptionT(decodeBody(DecodeInputs(e.input, new Http4sDecodeInputsContext[F](req))).flatMap {
case values: DecodeInputsResult.Values => valuesToResponse(values).map(_.some)
case DecodeInputsResult.Failure(input, failure) => handleDecodeFailure(req, input, failure).pure[F]
})
}

service
}

private val http4sMethodToTapirMethodMap: Map[org.http4s.Method, Method] = {
import org.http4s.Method._
import tapir.model.Method
Map(
GET -> Method.GET,
POST -> Method.POST,
DELETE -> Method.DELETE,
PUT -> Method.PUT,
OPTIONS -> Method.OPTIONS,
PATCH -> Method.PATCH,
CONNECT -> Method.CONNECT
)
}

private def statusCodeToHttp4sStatus(code: tapir.StatusCode): Status =
Status.fromInt(code).right.getOrElse(throw new IllegalArgumentException(s"Invalid status code: $code"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package tapir.server.http4s
import org.http4s.Request
import org.http4s.util.CaseInsensitiveString
import tapir.internal.server.DecodeInputsContext
import tapir.model.Method

class Http4sDecodeInputsContext[F[_]](req: Request[F]) extends DecodeInputsContext {
override def method: Method = Method(req.method.name.toUpperCase)
override def nextPathSegment: (Option[String], DecodeInputsContext) = {

val nextStart = req.uri.path.dropWhile(_ == '/')
Expand Down

0 comments on commit bb945db

Please sign in to comment.