Skip to content

Commit

Permalink
Method is optional, if no method is specified any matches
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw committed Feb 27, 2019
1 parent 936b53b commit c1d2b11
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package tapir.client.sttp
import java.io.{BufferedOutputStream, ByteArrayInputStream, FileOutputStream}
import java.nio.ByteBuffer

import com.softwaremill.sttp._
import com.softwaremill.sttp.{Method => SttpMethod, _}
import tapir.Codec.PlainCodec
import tapir.internal.SeqToParams
import tapir.typelevel.ParamsAsArgs
import tapir._
import tapir.model.{MultiQueryParams, Part}
import tapir.model.{MultiQueryParams, Part, Method}

class EndpointToSttpClient(clientOptions: SttpClientOptions) {
// don't look. The code is really, really ugly.
Expand All @@ -22,7 +22,7 @@ class EndpointToSttpClient(clientOptions: SttpClientOptions) {

val (uri, req) = setInputParams(e.input.asVectorOfSingle, params, paramsAsArgs, 0, baseUri, baseReq)

var req2 = req.copy[Id, Either[Any, Any], Any](method = com.softwaremill.sttp.Method(e.method.m), uri = uri)
var req2 = req.copy[Id, Either[Any, Any], Any](method = SttpMethod(e.method.getOrElse(Method.GET).m), uri = uri)

if (e.output.asVectorOfSingle.nonEmpty || e.errorOutput.asVectorOfSingle.nonEmpty) {
// by default, reading the body as specified by the output, and optionally adjusting to the error output
Expand Down
28 changes: 15 additions & 13 deletions core/src/main/scala/tapir/Endpoint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,22 @@ import tapir.typelevel.{FnComponents, ParamConcat, ParamsAsArgs}
* @tparam O Output parameter types.
* @tparam S The type of streams that are used by this endpoint's inputs/outputs. `Nothing`, if no streams are used.
*/
case class Endpoint[I, E, O, +S](method: Method,
case class Endpoint[I, E, O, +S](method: Option[Method],
input: EndpointInput[I],
errorOutput: EndpointIO[E],
output: EndpointIO[O],
info: EndpointInfo) {

def get: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Method.GET)
def head: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Method.HEAD)
def post: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Method.POST)
def put: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Method.PUT)
def delete: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Method.DELETE)
def options: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Method.OPTIONS)
def patch: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Method.PATCH)
def connect: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Method.CONNECT)
def trace: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Method.TRACE)
def method(m: String): Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Method(m))
def get: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Some(Method.GET))
def head: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Some(Method.HEAD))
def post: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Some(Method.POST))
def put: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Some(Method.PUT))
def delete: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Some(Method.DELETE))
def options: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Some(Method.OPTIONS))
def patch: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Some(Method.PATCH))
def connect: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Some(Method.CONNECT))
def trace: Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Some(Method.TRACE))
def method(m: String): Endpoint[I, E, O, S] = this.copy[I, E, O, S](method = Some(Method(m)))

def in[J, IJ](i: EndpointInput[J])(implicit ts: ParamConcat.Aux[I, J, IJ]): Endpoint[IJ, E, O, S] =
this.copy[IJ, E, O, S](input = input.and(i))
Expand Down Expand Up @@ -70,8 +70,10 @@ case class Endpoint[I, E, O, +S](method: Method,

def info(i: EndpointInfo): Endpoint[I, E, O, S] = copy(info = i)

def show: String =
s"Endpoint${info.name.map("[" + _ + "]").getOrElse("")}(${method.m}, in: ${input.show}, errout: ${errorOutput.show}, out: ${output.show})"
def show: String = {
val m = method.fold("")(m => m.m + ", ")
s"Endpoint${info.name.map("[" + _ + "]").getOrElse("")}(${m}in: ${input.show}, errout: ${errorOutput.show}, out: ${output.show})"
}
}

case class EndpointInfo(name: Option[String], summary: Option[String], description: Option[String], tags: Vector[String]) {
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/tapir/Tapir.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ trait Tapir {

val endpoint: Endpoint[Unit, Unit, Unit, Nothing] =
Endpoint[Unit, Unit, Unit, Nothing](
Method.GET,
None,
EndpointInput.Multiple(Vector.empty),
EndpointIO.Multiple(Vector.empty),
EndpointIO.Multiple(Vector.empty),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tapir.docs.openapi

import tapir.docs.openapi.schema.ObjectSchemas
import tapir.model.Method
import tapir.openapi.OpenAPI.ReferenceOr
import tapir.openapi.{MediaType => OMediaType, _}
import tapir.{EndpointInput, MediaType => SMediaType, Schema => SSchema, _}
Expand All @@ -12,9 +13,10 @@ private[openapi] class EndpointToOpenApiPaths(objectSchemas: ObjectSchemas, opti

val inputs = e.input.asVectorOfBasic
val pathComponents = namedPathComponents(inputs)
val method = e.method.getOrElse(Method.GET)

val pathComponentsForId = pathComponents.map(_.fold(identity, identity))
val defaultId = options.operationIdGenerator(pathComponentsForId, e.method)
val defaultId = options.operationIdGenerator(pathComponentsForId, method)

val pathComponentForPath = pathComponents.map {
case Left(p) => s"{$p}"
Expand All @@ -25,14 +27,14 @@ private[openapi] class EndpointToOpenApiPaths(objectSchemas: ObjectSchemas, opti
val pathItem = PathItem(
None,
None,
get = if (e.method == GET) operation else None,
put = if (e.method == PUT) operation else None,
post = if (e.method == POST) operation else None,
delete = if (e.method == DELETE) operation else None,
options = if (e.method == OPTIONS) operation else None,
head = if (e.method == HEAD) operation else None,
patch = if (e.method == PATCH) operation else None,
trace = if (e.method == TRACE) operation else None,
get = if (method == GET) operation else None,
put = if (method == PUT) operation else None,
post = if (method == POST) operation else None,
delete = if (method == DELETE) operation else None,
options = if (method == OPTIONS) operation else None,
head = if (method == HEAD) operation else None,
patch = if (method == PATCH) operation else None,
trace = if (method == TRACE) operation else None,
servers = List.empty,
parameters = List.empty
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ import akka.http.scaladsl.server.Directives.{
patch,
post,
put,
reject
reject,
pass
}
import akka.http.scaladsl.server.{Directive0, Directive1, RequestContext}
import akka.http.scaladsl.unmarshalling.FromEntityUnmarshaller
Expand Down Expand Up @@ -89,14 +90,19 @@ private[akkahttp] class EndpointToAkkaDirective(serverOptions: AkkaHttpServerOpt

private def methodToAkkaDirective[O, E, I](e: Endpoint[I, E, O, AkkaStream]): Directive0 = {
e.method match {
case Method.GET => get
case Method.HEAD => head
case Method.POST => post
case Method.PUT => put
case Method.DELETE => delete
case Method.OPTIONS => options
case Method.PATCH => patch
case m => method(HttpMethod.custom(m.m))
case Some(m) =>
m match {
case Method.GET => get
case Method.HEAD => head
case Method.POST => post
case Method.PUT => put
case Method.DELETE => delete
case Method.OPTIONS => options
case Method.PATCH => patch
case _ => method(HttpMethod.custom(m.m))
}

case None => pass
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ class EndpointToHttp4sServer[F[_]: Sync: ContextShift](serverOptions: Http4sServ
}
}

val methodMatches = http4sMethodToTapirMethodMap.get(req.method).contains(e.method)
val methodMatches = e.method match {
case Some(m) => http4sMethodToTapirMethodMap.get(req.method).contains(m)
case None => true
}

if (methodMatches) {
OptionT(decodeBody(DecodeInputs(e.input, new Http4sDecodeInputsContext[F](req))).flatMap {
Expand Down
16 changes: 14 additions & 2 deletions server/tests/src/main/scala/tapir/server/tests/ServerTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,26 @@ import scala.util.Random

trait ServerTests[R[_], S, ROUTE] extends FunSuite with Matchers with BeforeAndAfterAll {

testServer(endpoint, () => pureResult(().asRight[Unit])) { baseUri =>
// method matching

testServer(endpoint, () => pureResult(().asRight[Unit]), "GET empty endpoint") { baseUri =>
sttp.get(baseUri).send().map(_.body shouldBe Right(""))
}

testServer(endpoint, () => pureResult(().asRight[Unit]), "POST empty endpoint") { baseUri =>
sttp.post(baseUri).send().map(_.body shouldBe Right(""))
}

testServer(endpoint.get, () => pureResult(().asRight[Unit]), "GET a GET endpoint") { baseUri =>
sttp.get(baseUri).send().map(_.body shouldBe Right(""))
}

testServer(endpoint, () => pureResult(().asRight[Unit]), "with post method") { baseUri =>
testServer(endpoint.get, () => pureResult(().asRight[Unit]), "POST a GET endpoint") { baseUri =>
sttp.post(baseUri).send().map(_.body shouldBe 'left)
}

//

testServer(in_query_out_string, (fruit: String) => pureResult(s"fruit: $fruit".asRight[Unit])) { baseUri =>
sttp.get(uri"$baseUri?fruit=orange").send().map(_.body shouldBe Right("fruit: orange"))
}
Expand Down

0 comments on commit c1d2b11

Please sign in to comment.