Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alternative literal path segments for route definitions (#2815) #2920

Merged
merged 1 commit into from
Jun 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions zio-http-testkit/src/test/scala/zio/http/TestServerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@ object TestServerSpec extends ZIOHttpSpec {
client <- ZIO.service[Client]
testRequest <- requestToCorrectPort
_ <- TestServer.addRequestResponse(testRequest, Response(Status.Ok))
finalResponse <-
client(
testRequest,
)
finalResponse <- client(testRequest)

} yield assertTrue(status(finalResponse) == Status.Ok)
},
Expand All @@ -59,10 +56,7 @@ object TestServerSpec extends ZIOHttpSpec {
client <- ZIO.service[Client]
testRequest <- requestToCorrectPort
_ <- TestServer.addRequestResponse(testRequest, Response(Status.Ok))
finalResponse <-
client(
testRequest.addHeaders(Headers(Header.ContentLanguage.French)),
)
finalResponse <- client(testRequest.addHeaders(Headers(Header.ContentLanguage.French)))

} yield assertTrue(status(finalResponse) == Status.Ok)
},
Expand Down
19 changes: 19 additions & 0 deletions zio-http/jvm/src/test/scala/zio/http/RoutesSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,24 @@ object RoutesSpec extends ZIOHttpSpec {
)
.map(response => assertTrue(response.status == Status.Ok))
},
test("alternative path segments") {
val app = Routes(
Method.GET / anyOf("foo", "bar", "baz") -> Handler.ok,
)

for {
foo <- app.runZIO(Request.get("/foo"))
bar <- app.runZIO(Request.get("/bar"))
baz <- app.runZIO(Request.get("/baz"))
box <- app.runZIO(Request.get("/box"))
} yield {
assertTrue(
extractStatus(foo) == Status.Ok,
extractStatus(bar) == Status.Ok,
extractStatus(baz) == Status.Ok,
extractStatus(box) == Status.NotFound,
)
}
},
)
}
6 changes: 4 additions & 2 deletions zio-http/shared/src/main/scala/zio/http/HttpApp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package zio.http
import zio._
import zio.stacktracer.TracingImplicits.disableAutoTrace

import zio.http.Routes.Tree

/**
* An HTTP application is a collection of routes, all of whose errors have been
* handled through conversion into HTTP responses.
Expand Down Expand Up @@ -137,10 +139,10 @@ object HttpApp {
Tree(self.tree ++ that.tree)

final def add[Env1 <: Env](route: Route[Env1, Response])(implicit trace: Trace): Tree[Env1] =
Tree(self.tree.add(route.routePattern, route.toHandler))
Tree(self.tree.addAll(route.routePattern.alternatives.map(alt => (alt, route.toHandler))))

final def addAll[Env1 <: Env](routes: Iterable[Route[Env1, Response]])(implicit trace: Trace): Tree[Env1] =
Tree(self.tree.addAll(routes.map(r => (r.routePattern, r.toHandler))))
Tree[Env1](self.tree.addAll(routes.map(r => r.routePattern.alternatives.map(alt => (alt, r.toHandler))).flatten))

final def get(method: Method, path: Path): Chunk[RequestHandler[Env, Response]] =
tree.get(method, path)
Expand Down
4 changes: 2 additions & 2 deletions zio-http/shared/src/main/scala/zio/http/Route.scala
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,10 @@ object Route {
handler: Handler[Env1, Response, In, Response],
)(implicit zippable: Zippable.Out[Params, Request, In], trace: Trace): Route[Env1, Nothing] = {
val handler2: Handler[Any, Nothing, RoutePattern[_], Handler[Env1, Response, Request, Response]] = {
Handler.fromFunction[RoutePattern[_]] { _ =>
Handler.fromFunction[RoutePattern[_]] { pattern =>
val paramHandler =
Handler.fromFunctionZIO[(rpm.Context, Request)] { case (ctx, request) =>
rpm.routePattern.decode(request.method, request.path) match {
pattern.asInstanceOf[RoutePattern[rpm.PathInput]].decode(request.method, request.path) match {
case Left(error) => ZIO.dieMessage(error)
case Right(value) =>
val params = rpm.zippable.zip(value, ctx)
Expand Down
2 changes: 2 additions & 0 deletions zio-http/shared/src/main/scala/zio/http/RoutePattern.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ final case class RoutePattern[A](method: Method, pathCodec: PathCodec[A]) { self
): Route.Builder[Env, zippable.Out] =
Route.Builder(self, middleware)(zippable)

def alternatives: List[RoutePattern[A]] = pathCodec.alternatives.map(RoutePattern(method, _))

/**
* Reinteprets the type parameter, given evidence it is equal to some other
* type.
Expand Down
6 changes: 4 additions & 2 deletions zio-http/shared/src/main/scala/zio/http/Routes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.io.File

import zio._

import zio.http.HttpApp.Tree
import zio.http.Routes.ApplyContextAspect
import zio.http.codec.PathCodec

Expand Down Expand Up @@ -331,10 +332,11 @@ object Routes extends RoutesCompanionVersionSpecific {
Tree(self.tree ++ that.tree)

final def add[Env1 <: Env](route: Route[Env1, Response])(implicit trace: Trace): Tree[Env1] =
Tree(self.tree.add(route.routePattern, route.toHandler))
Tree(self.tree.addAll(route.routePattern.alternatives.map(alt => (alt, route.toHandler))))

final def addAll[Env1 <: Env](routes: Iterable[Route[Env1, Response]])(implicit trace: Trace): Tree[Env1] =
Tree(self.tree.addAll(routes.map(r => (r.routePattern, r.toHandler))))
// only change to flatMap when Scala 2.12 is dropped
Tree(self.tree.addAll(routes.map(r => r.routePattern.alternatives.map(alt => (alt, r.toHandler))).flatten))

final def get(method: Method, path: Path): Chunk[RequestHandler[Env, Response]] =
tree.get(method, path)
Expand Down
100 changes: 99 additions & 1 deletion zio-http/shared/src/main/scala/zio/http/codec/PathCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package zio.http.codec

import scala.annotation.tailrec
import scala.collection.immutable.ListMap
import scala.language.implicitConversions

Expand Down Expand Up @@ -61,6 +62,58 @@ sealed trait PathCodec[A] { self =>
}
}

private[http] def orElse(value: PathCodec[Unit])(implicit ev: A =:= Unit): PathCodec[Unit] =
Fallback(self.asInstanceOf[PathCodec[Unit]], value)

private def fallbackAlternatives(f: Fallback[_]): List[PathCodec[Any]] = {
@tailrec
def loop(codecs: List[PathCodec[_]], result: List[PathCodec[_]]): List[PathCodec[_]] =
if (codecs.isEmpty) result
else
codecs.head match {
case PathCodec.Annotated(codec, _) =>
loop(codec :: codecs.tail, result)
case PathCodec.Segment(SegmentCodec.Literal(_)) =>
loop(codecs.tail, result :+ codecs.head)
case PathCodec.Segment(SegmentCodec.Empty) =>
loop(codecs.tail, result)
case Fallback(left, right) =>
loop(left :: right :: codecs.tail, result)
case other =>
throw new IllegalStateException(s"Alternative path segments should only contain literals, found: $other")
}
loop(List(f.left, f.right), List.empty).asInstanceOf[List[PathCodec[Any]]]
}

final def alternatives: List[PathCodec[A]] = {
var alts = List.empty[PathCodec[Any]]
def loop(codec: PathCodec[_], combiner: Combiner[_, _]): Unit = codec match {
case Concat(left, right, combiner) =>
loop(left, combiner)
loop(right, combiner)
case f: Fallback[_] =>
if (alts.isEmpty) alts = fallbackAlternatives(f)
else
alts ++= alts.flatMap { alt =>
fallbackAlternatives(f).map(fa =>
Concat(alt, fa.asInstanceOf[PathCodec[Any]], combiner.asInstanceOf[Combiner.WithOut[Any, Any, Any]]),
)
}
case Segment(SegmentCodec.Empty) =>
alts :+= codec.asInstanceOf[PathCodec[Any]]
case pc =>
if (alts.isEmpty) alts :+= pc.asInstanceOf[PathCodec[Any]]
else
alts = alts
.map(l =>
Concat(l, pc.asInstanceOf[PathCodec[Any]], combiner.asInstanceOf[Combiner.WithOut[Any, Any, Any]])
.asInstanceOf[PathCodec[Any]],
)
}
loop(self, Combiner.leftUnit[Unit])
alts.asInstanceOf[List[PathCodec[A]]]
}

final def asType[B](implicit ev: A =:= B): PathCodec[B] = self.asInstanceOf[PathCodec[B]]

/**
Expand All @@ -84,14 +137,22 @@ sealed trait PathCodec[A] { self =>
val opt = instructions(i)

opt match {
case Match(value) =>
case Match(value) =>
if (j >= segments.length || segments(j) != value) {
fail = "Expected path segment \"" + value + "\" but found end of path"
i = instructions.length
} else {
stack.push(())
j = j + 1
}
case MatchAny(values) =>
if (j >= segments.length || !values.contains(segments(j))) {
fail = "Expected one of the following path segments: " + values.mkString(", ") + " but found end of path"
i = instructions.length
} else {
stack.push(())
j = j + 1
}

case Combine(combiner0) =>
val combiner = combiner0.asInstanceOf[Combiner[Any, Any]]
Expand Down Expand Up @@ -227,6 +288,7 @@ sealed trait PathCodec[A] { self =>
case Concat(left, right, _) => left.doc + right.doc
case Annotated(codec, annotations) =>
codec.doc + annotations.collectFirst { case MetaData.Documented(doc) => doc }.getOrElse(Doc.empty)
case Fallback(left, right) => left.doc + right.doc
}

/**
Expand Down Expand Up @@ -264,6 +326,8 @@ sealed trait PathCodec[A] { self =>

case PathCodec.TransformOrFail(api, _, g) =>
g.asInstanceOf[Any => Either[String, Any]](value).flatMap(loop(api, _))
case Fallback(left, _) =>
loop(left, value)
}

loop(self, value).map { path =>
Expand Down Expand Up @@ -298,6 +362,9 @@ sealed trait PathCodec[A] { self =>
case SegmentCodec.Trailing => Opt.TrailingOpt
})

case f: Fallback[_] =>
Chunk(Opt.MatchAny(fallbacks(f)))

case Concat(left, right, combiner) =>
loop(left) ++ loop(right) ++ Chunk(Opt.Combine(combiner))

Expand All @@ -310,6 +377,26 @@ sealed trait PathCodec[A] { self =>
_optimize
}

private def fallbacks(f: Fallback[_]): Set[String] = {
@tailrec
def loop(codecs: List[PathCodec[_]], result: Set[String]): Set[String] =
if (codecs.isEmpty) result
else
codecs.head match {
case PathCodec.Annotated(codec, _) =>
loop(codec :: codecs.tail, result)
case PathCodec.Segment(SegmentCodec.Literal(value)) =>
loop(codecs.tail, result + value)
case PathCodec.Segment(SegmentCodec.Empty) =>
loop(codecs.tail, result)
case Fallback(left, right) =>
loop(left :: right :: codecs.tail, result)
case other =>
throw new IllegalStateException(s"Alternative path segments should only contain literals, found: $other")
}
loop(List(f.left, f.right), Set.empty)
}

/**
* Renders the path codec as a string.
*/
Expand All @@ -324,6 +411,9 @@ sealed trait PathCodec[A] { self =>

case PathCodec.TransformOrFail(api, _, _) =>
loop(api)

case PathCodec.Fallback(left, _) =>
loop(left)
}

loop(self)
Expand All @@ -341,6 +431,8 @@ sealed trait PathCodec[A] { self =>
case PathCodec.Segment(segment) => segment.render

case PathCodec.TransformOrFail(api, _, _) => loop(api)

case PathCodec.Fallback(left, _) => loop(left)
}

loop(self)
Expand All @@ -360,6 +452,9 @@ sealed trait PathCodec[A] { self =>

case PathCodec.TransformOrFail(api, _, _) =>
loop(api)

case PathCodec.Fallback(left, _) =>
loop(left)
}

loop(self)
Expand Down Expand Up @@ -418,6 +513,8 @@ object PathCodec {

def uuid(name: String): PathCodec[java.util.UUID] = Segment(SegmentCodec.uuid(name))

private[http] final case class Fallback[A](left: PathCodec[Unit], right: PathCodec[Unit]) extends PathCodec[A]

private[http] final case class Segment[A](segment: SegmentCodec[A]) extends PathCodec[A]

private[http] final case class Concat[A, B, C](
Expand Down Expand Up @@ -458,6 +555,7 @@ object PathCodec {
private[http] sealed trait Opt
private[http] object Opt {
final case class Match(value: String) extends Opt
final case class MatchAny(values: Set[String]) extends Opt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to use a set here. We can do even better than that by producing a finite state machine that produces true for a match, and false otherwise (due to the char-by-char nature of testing for string equality). But, separate issue and PR. 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean some structure, where we have like n states, where n is max length of the alternatives and there is a way to progress to the next state with the right char, true if the string ends or false for any other char, right?

final case class Combine(combiner: Combiner[_, _]) extends Opt
case object IntOpt extends Opt
case object LongOpt extends Opt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ object OpenAPIGen {
}
}),
)
case PathCodec.Fallback(left, _) =>
loop(left, annotations)
}

loop(codec, annotations).map { case (sc, annotations) =>
Expand Down
15 changes: 9 additions & 6 deletions zio-http/shared/src/main/scala/zio/http/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,15 @@ package object http extends UrlInterpolator with MdInterpolator {
def withContext[C](fn: => C)(implicit c: WithContext[C]): ZIO[c.Env, c.Err, c.Out] =
c.toZIO(fn)

def boolean(name: String): PathCodec[Boolean] = PathCodec.bool(name)
def int(name: String): PathCodec[Int] = PathCodec.int(name)
def long(name: String): PathCodec[Long] = PathCodec.long(name)
def string(name: String): PathCodec[String] = PathCodec.string(name)
val trailing: PathCodec[Path] = PathCodec.trailing
def uuid(name: String): PathCodec[UUID] = PathCodec.uuid(name)
def boolean(name: String): PathCodec[Boolean] = PathCodec.bool(name)
def int(name: String): PathCodec[Int] = PathCodec.int(name)
def long(name: String): PathCodec[Long] = PathCodec.long(name)
def string(name: String): PathCodec[String] = PathCodec.string(name)
val trailing: PathCodec[Path] = PathCodec.trailing
def uuid(name: String): PathCodec[UUID] = PathCodec.uuid(name)
def anyOf(name: String, names: String*): PathCodec[Unit] =
if (names.isEmpty) PathCodec.literal(name)
else names.foldLeft(PathCodec.literal(name))((acc, n) => acc.orElse(PathCodec.literal(n)))

val Root: PathCodec[Unit] = PathCodec.empty

Expand Down
Loading