Skip to content

Commit

Permalink
#23: require exact path matches
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw committed Feb 27, 2019
1 parent a5f98e9 commit 936b53b
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 11 deletions.
47 changes: 37 additions & 10 deletions core/src/main/scala/tapir/internal/server/DecodeInputs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,30 +50,36 @@ object DecodeInputs {
case _: EndpointIO.StreamBodyWrapper[_, _] => 3
}

apply(inputs, DecodeInputsResult.Values(Map(), None), ctx)
val (result, consumedCtx) = apply(inputs, DecodeInputsResult.Values(Map(), None), ctx)

result match {
case v: DecodeInputsResult.Values => verifyPathExactMatch(inputs, consumedCtx).getOrElse(v)
case r => r
}
}

private def apply(inputs: Vector[EndpointInput.Basic[_]],
values: DecodeInputsResult.Values,
ctx: DecodeInputsContext): DecodeInputsResult = {
ctx: DecodeInputsContext): (DecodeInputsResult, DecodeInputsContext) = {
inputs match {
case Vector() => values
case Vector() => (values, ctx)

case (input @ EndpointInput.PathSegment(ss)) +: inputsTail =>
ctx.nextPathSegment match {
case (Some(`ss`), ctx2) => apply(inputsTail, values, ctx2)
case (Some(s), _) => DecodeInputsResult.Failure(input, DecodeResult.Mismatch(ss, s))
case (None, _) => DecodeInputsResult.Failure(input, DecodeResult.Missing)
case (Some(`ss`), ctx2) => apply(inputsTail, values, ctx2)
case (None, ctx2) if ss == "" => apply(inputsTail, values, ctx2) // root path
case (Some(s), _) => (DecodeInputsResult.Failure(input, DecodeResult.Mismatch(ss, s)), ctx)
case (None, _) => (DecodeInputsResult.Failure(input, DecodeResult.Missing), ctx)
}

case (input @ EndpointInput.PathCapture(codec, _, _)) +: inputsTail =>
ctx.nextPathSegment match {
case (Some(s), ctx2) =>
codec.decode(s) match {
case DecodeResult.Value(v) => apply(inputsTail, values.value(input, v), ctx2)
case failure: DecodeFailure => DecodeInputsResult.Failure(input, failure)
case failure: DecodeFailure => (DecodeInputsResult.Failure(input, failure), ctx)
}
case (None, _) => DecodeInputsResult.Failure(input, DecodeResult.Missing)
case (None, _) => (DecodeInputsResult.Failure(input, DecodeResult.Missing), ctx)
}

case (input @ EndpointInput.PathsCapture(_)) +: inputsTail =>
Expand All @@ -90,7 +96,7 @@ object DecodeInputs {
case (input @ EndpointInput.Query(name, codec, _)) +: inputsTail =>
codec.decode(ctx.queryParameter(name).toList) match {
case DecodeResult.Value(v) => apply(inputsTail, values.value(input, v), ctx)
case failure: DecodeFailure => DecodeInputsResult.Failure(input, failure)
case failure: DecodeFailure => (DecodeInputsResult.Failure(input, failure), ctx)
}

case (input @ EndpointInput.QueryParams(_)) +: inputsTail =>
Expand All @@ -99,7 +105,7 @@ object DecodeInputs {
case (input @ EndpointIO.Header(name, codec, _)) +: inputsTail =>
codec.decode(ctx.header(name)) match {
case DecodeResult.Value(v) => apply(inputsTail, values.value(input, v), ctx)
case failure: DecodeFailure => DecodeInputsResult.Failure(input, failure)
case failure: DecodeFailure => (DecodeInputsResult.Failure(input, failure), ctx)
}

case (input @ EndpointIO.Headers(_)) +: inputsTail =>
Expand All @@ -112,4 +118,25 @@ object DecodeInputs {
apply(inputsTail, values.value(input, ctx.bodyStream), ctx)
}
}

/**
* If there's any path input, the path must match exactly.
*/
private def verifyPathExactMatch(inputs: Vector[EndpointInput.Basic[_]], ctx: DecodeInputsContext): Option[DecodeInputsResult.Failure] = {
inputs.filter {
case _: EndpointInput.PathSegment => true
case _: EndpointInput.PathCapture[_] => true
case _: EndpointInput.PathsCapture => true
case _ => false
}.lastOption match {
case Some(lastPathInput) =>
ctx.nextPathSegment._1 match {
case Some(nextPathSegment) =>
Some(DecodeInputsResult.Failure(lastPathInput, DecodeResult.Mismatch("", nextPathSegment)))
case None => None
}

case None => None
}
}
}
11 changes: 11 additions & 0 deletions doc/endpoint/ios.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,17 @@ val paging: EndpointInput[Paging] =
Mapping methods can also be called on an endpoint (which is useful if inputs/outputs are accumulated, for example).
The `Endpoint.mapIn`, `Endpoint.mapInTo` etc. have the same signatures are the ones above.

## Path matching

By default (as with all other types of inputs), if no path input/path segments are defined, any path will match.

If any path input/path segment is defined, the path must match *exactly* - any remaining path segments will cause the
endpoint not to match the request. For example, `endpoint.path("api")` will match `/api`, `/api/`, but won't match
`/`, `/api/users`.

To match only the root path, use an empty string: `endpoint.path("")` will match `http://server.com/` and
`http://server.com`.

## Next

Read on about [codecs](codecs.html).
46 changes: 45 additions & 1 deletion server/tests/src/main/scala/tapir/server/tests/ServerTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ trait ServerTests[R[_], S, ROUTE] extends FunSuite with Matchers with BeforeAndA
sttp.get(uri"$baseUri/hello/it/is/me/hal").send().map(_.body shouldBe Right("hello it is me hal"))
}

testServer(in_paths_out_string, (ps: Seq[String]) => pureResult(ps.mkString(" ").asRight[Unit]), "paths should match empty path") {
baseUri =>
sttp.get(uri"$baseUri").send().map(_.body shouldBe Right(""))
}

testServer(in_stream_out_stream[S], (s: S) => pureResult(s.asRight[Unit])) { baseUri =>
sttp.post(uri"$baseUri/api/echo").body("pen pineapple apple pen").send().map(_.body shouldBe Right("pen pineapple apple pen"))
}
Expand Down Expand Up @@ -213,7 +218,7 @@ trait ServerTests[R[_], S, ROUTE] extends FunSuite with Matchers with BeforeAndA
}

testServer(in_query_out_string, (fruit: String) => pureResult(s"fruit: $fruit".asRight[Unit]), "invalid query parameter") { baseUri =>
sttp.get(uri"$baseUri?fruit2=orange").send().map(_.code shouldBe 400)
sttp.get(uri"$baseUri?fruit2=orange").send().map(_.code shouldBe StatusCodes.BadRequest)
}

testServer(
Expand All @@ -225,6 +230,45 @@ trait ServerTests[R[_], S, ROUTE] extends FunSuite with Matchers with BeforeAndA
}
}

// path matching

testServer(endpoint, () => pureResult(Either.right[Unit, Unit](())), "no path should match anything") { baseUri =>
sttp.get(uri"$baseUri").send().map(_.code shouldBe StatusCodes.Ok) >>
sttp.get(uri"$baseUri/").send().map(_.code shouldBe StatusCodes.Ok) >>
sttp.get(uri"$baseUri/nonemptypath").send().map(_.code shouldBe StatusCodes.Ok) >>
sttp.get(uri"$baseUri/nonemptypath/nonemptypath2").send().map(_.code shouldBe StatusCodes.Ok)
}

testServer(in_root_path, () => pureResult(Either.right[Unit, Unit](())), "root path should not match non-root path") { baseUri =>
sttp.get(uri"$baseUri/nonemptypath").send().map(_.code shouldBe StatusCodes.NotFound)
}

testServer(in_root_path, () => pureResult(Either.right[Unit, Unit](())), "root path should match empty path") { baseUri =>
sttp.get(uri"$baseUri").send().map(_.code shouldBe StatusCodes.Ok)
}

testServer(in_root_path, () => pureResult(Either.right[Unit, Unit](())), "root path should match root path") { baseUri =>
sttp.get(uri"$baseUri/").send().map(_.code shouldBe StatusCodes.Ok)
}

testServer(in_single_path, () => pureResult(Either.right[Unit, Unit](())), "single path should match single path") { baseUri =>
sttp.get(uri"$baseUri/api").send().map(_.code shouldBe StatusCodes.Ok)
}

testServer(in_single_path, () => pureResult(Either.right[Unit, Unit](())), "single path should match single/ path") { baseUri =>
sttp.get(uri"$baseUri/api/").send().map(_.code shouldBe StatusCodes.Ok)
}

testServer(in_single_path, () => pureResult(Either.right[Unit, Unit](())), "single path should not match root path") { baseUri =>
sttp.get(uri"$baseUri").send().map(_.code shouldBe StatusCodes.NotFound) >>
sttp.get(uri"$baseUri/").send().map(_.code shouldBe StatusCodes.NotFound)
}

testServer(in_single_path, () => pureResult(Either.right[Unit, Unit](())), "single path should not match larger path") { baseUri =>
sttp.get(uri"$baseUri/api/echo").send().map(_.code shouldBe StatusCodes.NotFound) >>
sttp.get(uri"$baseUri/api/echo/").send().map(_.code shouldBe StatusCodes.NotFound)
}

//

testServer(
Expand Down
4 changes: 4 additions & 0 deletions tests/src/main/scala/tapir/tests/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ package object tests {
val in_cookies_out_cookies: Endpoint[List[CookiePair], Unit, List[Cookie], Nothing] =
endpoint.get.in("api" / "echo" / "headers").in(cookies).out(setCookies)

val in_root_path: Endpoint[Unit, Unit, Unit, Nothing] = endpoint.get.in("")

val in_single_path: Endpoint[Unit, Unit, Unit, Nothing] = endpoint.get.in("api")

val allTestEndpoints: Set[Endpoint[_, _, _, _]] = wireSet[Endpoint[_, _, _, _]]

def writeToFile(s: String): File = {
Expand Down

0 comments on commit 936b53b

Please sign in to comment.