Skip to content

Commit

Permalink
Alternative literal path segments for route definitions (#2815)
Browse files Browse the repository at this point in the history
  • Loading branch information
987Nabil committed Jun 19, 2024
1 parent 63c0616 commit f7c392e
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 13 deletions.
19 changes: 19 additions & 0 deletions zio-http/jvm/src/test/scala/zio/http/RoutesSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,24 @@ object RoutesSpec extends ZIOHttpSpec {
)
.map(response => assertTrue(response.status == Status.Ok))
},
test("alternative path segments") {
val app = Routes(
Method.GET / anyOf("foo", "bar", "baz") -> Handler.ok,
)

for {
foo <- app.runZIO(Request.get("/foo"))
bar <- app.runZIO(Request.get("/bar"))
baz <- app.runZIO(Request.get("/baz"))
box <- app.runZIO(Request.get("/box"))
} yield {
assertTrue(
extractStatus(foo) == Status.Ok,
extractStatus(bar) == Status.Ok,
extractStatus(baz) == Status.Ok,
extractStatus(box) == Status.NotFound,
)
}
},
)
}
6 changes: 4 additions & 2 deletions zio-http/shared/src/main/scala/zio/http/HttpApp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package zio.http
import zio._
import zio.stacktracer.TracingImplicits.disableAutoTrace

import zio.http.Routes.Tree

/**
* An HTTP application is a collection of routes, all of whose errors have been
* handled through conversion into HTTP responses.
Expand Down Expand Up @@ -137,10 +139,10 @@ object HttpApp {
Tree(self.tree ++ that.tree)

final def add[Env1 <: Env](route: Route[Env1, Response])(implicit trace: Trace): Tree[Env1] =
Tree(self.tree.add(route.routePattern, route.toHandler))
Tree(self.tree.addAll(route.routePattern.alternatives.map(alt => (alt, route.toHandler))))

final def addAll[Env1 <: Env](routes: Iterable[Route[Env1, Response]])(implicit trace: Trace): Tree[Env1] =
Tree(self.tree.addAll(routes.map(r => (r.routePattern, r.toHandler))))
Tree[Env1](self.tree.addAll(routes.map(r => r.routePattern.alternatives.map(alt => (alt, r.toHandler))).flatten))

final def get(method: Method, path: Path): Chunk[RequestHandler[Env, Response]] =
tree.get(method, path)
Expand Down
4 changes: 2 additions & 2 deletions zio-http/shared/src/main/scala/zio/http/Route.scala
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,10 @@ object Route {
handler: Handler[Env1, Response, In, Response],
)(implicit zippable: Zippable.Out[Params, Request, In], trace: Trace): Route[Env1, Nothing] = {
val handler2: Handler[Any, Nothing, RoutePattern[_], Handler[Env1, Response, Request, Response]] = {
Handler.fromFunction[RoutePattern[_]] { _ =>
Handler.fromFunction[RoutePattern[_]] { pattern =>
val paramHandler =
Handler.fromFunctionZIO[(rpm.Context, Request)] { case (ctx, request) =>
rpm.routePattern.decode(request.method, request.path) match {
pattern.asInstanceOf[RoutePattern[rpm.PathInput]].decode(request.method, request.path) match {
case Left(error) => ZIO.dieMessage(error)
case Right(value) =>
val params = rpm.zippable.zip(value, ctx)
Expand Down
2 changes: 2 additions & 0 deletions zio-http/shared/src/main/scala/zio/http/RoutePattern.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ final case class RoutePattern[A](method: Method, pathCodec: PathCodec[A]) { self
): Route.Builder[Env, zippable.Out] =
Route.Builder(self, middleware)(zippable)

def alternatives: List[RoutePattern[A]] = pathCodec.alternatives.map(RoutePattern(method, _))

/**
* Reinteprets the type parameter, given evidence it is equal to some other
* type.
Expand Down
6 changes: 4 additions & 2 deletions zio-http/shared/src/main/scala/zio/http/Routes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.io.File

import zio._

import zio.http.HttpApp.Tree
import zio.http.Routes.ApplyContextAspect
import zio.http.codec.PathCodec

Expand Down Expand Up @@ -331,10 +332,11 @@ object Routes extends RoutesCompanionVersionSpecific {
Tree(self.tree ++ that.tree)

final def add[Env1 <: Env](route: Route[Env1, Response])(implicit trace: Trace): Tree[Env1] =
Tree(self.tree.add(route.routePattern, route.toHandler))
Tree(self.tree.addAll(route.routePattern.alternatives.map(alt => (alt, route.toHandler))))

final def addAll[Env1 <: Env](routes: Iterable[Route[Env1, Response]])(implicit trace: Trace): Tree[Env1] =
Tree(self.tree.addAll(routes.map(r => (r.routePattern, r.toHandler))))
// only change to flatMap when Scala 2.12 is dropped
Tree(self.tree.addAll(routes.map(r => r.routePattern.alternatives.map(alt => (alt, r.toHandler))).flatten))

final def get(method: Method, path: Path): Chunk[RequestHandler[Env, Response]] =
tree.get(method, path)
Expand Down
99 changes: 98 additions & 1 deletion zio-http/shared/src/main/scala/zio/http/codec/PathCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package zio.http.codec

import scala.annotation.tailrec
import scala.collection.immutable.ListMap
import scala.language.implicitConversions

Expand Down Expand Up @@ -61,6 +62,57 @@ sealed trait PathCodec[A] { self =>
}
}

private[http] def orElse(value: PathCodec[Unit])(implicit ev: A =:= Unit): PathCodec[Unit] =
Fallback(self.asInstanceOf[PathCodec[Unit]], value)

private def fallbackAlternatives(f: Fallback[_]): List[PathCodec[Any]] = {
@tailrec
def loop(codecs: List[PathCodec[_]], result: List[PathCodec[_]]): List[PathCodec[_]] =
if (codecs.isEmpty) result
else
codecs.head match {
case PathCodec.Annotated(codec, _) =>
loop(codec :: codecs.tail, result)
case PathCodec.Segment(SegmentCodec.Literal(_)) =>
loop(codecs.tail, result :+ codecs.head)
case PathCodec.Segment(SegmentCodec.Empty) =>
loop(codecs.tail, result)
case Fallback(left, right) =>
loop(left :: right :: codecs.tail, result)
case other =>
throw new IllegalStateException(s"Alternative path segments should only contain literals, found: $other")
}
loop(List(f.left, f.right), List.empty).asInstanceOf[List[PathCodec[Any]]]
}

final def alternatives: List[PathCodec[A]] = {
var alts = List.empty[PathCodec[Any]]
def loop(codec: PathCodec[_], combiner: Combiner[_, _]): Unit = codec match {
case Concat(left, right, combiner) =>
loop(left, combiner)
loop(right, combiner)
case f: Fallback[_] =>
if (alts.isEmpty) alts = fallbackAlternatives(f)
else
alts ++= alts.flatMap { alt =>
fallbackAlternatives(f).map(fa =>
Concat(alt, fa.asInstanceOf[PathCodec[Any]], combiner.asInstanceOf[Combiner.WithOut[Any, Any, Any]]),
)
}
case Segment(SegmentCodec.Empty) =>
case pc =>
if (alts.isEmpty) alts :+= pc.asInstanceOf[PathCodec[Any]]
else
alts = alts
.map(l =>
Concat(l, pc.asInstanceOf[PathCodec[Any]], combiner.asInstanceOf[Combiner.WithOut[Any, Any, Any]])
.asInstanceOf[PathCodec[Any]],
)
}
loop(self, Combiner.leftUnit[Unit])
alts.asInstanceOf[List[PathCodec[A]]]
}

final def asType[B](implicit ev: A =:= B): PathCodec[B] = self.asInstanceOf[PathCodec[B]]

/**
Expand All @@ -84,14 +136,22 @@ sealed trait PathCodec[A] { self =>
val opt = instructions(i)

opt match {
case Match(value) =>
case Match(value) =>
if (j >= segments.length || segments(j) != value) {
fail = "Expected path segment \"" + value + "\" but found end of path"
i = instructions.length
} else {
stack.push(())
j = j + 1
}
case MatchAny(values) =>
if (j >= segments.length || !values.contains(segments(j))) {
fail = "Expected one of the following path segments: " + values.mkString(", ") + " but found end of path"
i = instructions.length
} else {
stack.push(())
j = j + 1
}

case Combine(combiner0) =>
val combiner = combiner0.asInstanceOf[Combiner[Any, Any]]
Expand Down Expand Up @@ -227,6 +287,7 @@ sealed trait PathCodec[A] { self =>
case Concat(left, right, _) => left.doc + right.doc
case Annotated(codec, annotations) =>
codec.doc + annotations.collectFirst { case MetaData.Documented(doc) => doc }.getOrElse(Doc.empty)
case Fallback(left, right) => left.doc + right.doc
}

/**
Expand Down Expand Up @@ -264,6 +325,8 @@ sealed trait PathCodec[A] { self =>

case PathCodec.TransformOrFail(api, _, g) =>
g.asInstanceOf[Any => Either[String, Any]](value).flatMap(loop(api, _))
case Fallback(left, _) =>
loop(left, value)
}

loop(self, value).map { path =>
Expand Down Expand Up @@ -298,6 +361,9 @@ sealed trait PathCodec[A] { self =>
case SegmentCodec.Trailing => Opt.TrailingOpt
})

case f: Fallback[_] =>
Chunk(Opt.MatchAny(fallbacks(f)))

case Concat(left, right, combiner) =>
loop(left) ++ loop(right) ++ Chunk(Opt.Combine(combiner))

Expand All @@ -310,6 +376,26 @@ sealed trait PathCodec[A] { self =>
_optimize
}

private def fallbacks(f: Fallback[_]): Set[String] = {
@tailrec
def loop(codecs: List[PathCodec[_]], result: Set[String]): Set[String] =
if (codecs.isEmpty) result
else
codecs.head match {
case PathCodec.Annotated(codec, _) =>
loop(codec :: codecs.tail, result)
case PathCodec.Segment(SegmentCodec.Literal(value)) =>
loop(codecs.tail, result + value)
case PathCodec.Segment(SegmentCodec.Empty) =>
loop(codecs.tail, result)
case Fallback(left, right) =>
loop(left :: right :: codecs.tail, result)
case other =>
throw new IllegalStateException(s"Alternative path segments should only contain literals, found: $other")
}
loop(List(f.left, f.right), Set.empty)
}

/**
* Renders the path codec as a string.
*/
Expand All @@ -324,6 +410,9 @@ sealed trait PathCodec[A] { self =>

case PathCodec.TransformOrFail(api, _, _) =>
loop(api)

case PathCodec.Fallback(left, _) =>
loop(left)
}

loop(self)
Expand All @@ -341,6 +430,8 @@ sealed trait PathCodec[A] { self =>
case PathCodec.Segment(segment) => segment.render

case PathCodec.TransformOrFail(api, _, _) => loop(api)

case PathCodec.Fallback(left, _) => loop(left)
}

loop(self)
Expand All @@ -360,6 +451,9 @@ sealed trait PathCodec[A] { self =>

case PathCodec.TransformOrFail(api, _, _) =>
loop(api)

case PathCodec.Fallback(left, _) =>
loop(left)
}

loop(self)
Expand Down Expand Up @@ -418,6 +512,8 @@ object PathCodec {

def uuid(name: String): PathCodec[java.util.UUID] = Segment(SegmentCodec.uuid(name))

private[http] final case class Fallback[A](left: PathCodec[Unit], right: PathCodec[Unit]) extends PathCodec[A]

private[http] final case class Segment[A](segment: SegmentCodec[A]) extends PathCodec[A]

private[http] final case class Concat[A, B, C](
Expand Down Expand Up @@ -458,6 +554,7 @@ object PathCodec {
private[http] sealed trait Opt
private[http] object Opt {
final case class Match(value: String) extends Opt
final case class MatchAny(values: Set[String]) extends Opt
final case class Combine(combiner: Combiner[_, _]) extends Opt
case object IntOpt extends Opt
case object LongOpt extends Opt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ object OpenAPIGen {
}
}),
)
case PathCodec.Fallback(left, _) =>
loop(left, annotations)
}

loop(codec, annotations).map { case (sc, annotations) =>
Expand Down
15 changes: 9 additions & 6 deletions zio-http/shared/src/main/scala/zio/http/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,15 @@ package object http extends UrlInterpolator with MdInterpolator {
def withContext[C](fn: => C)(implicit c: WithContext[C]): ZIO[c.Env, c.Err, c.Out] =
c.toZIO(fn)

def boolean(name: String): PathCodec[Boolean] = PathCodec.bool(name)
def int(name: String): PathCodec[Int] = PathCodec.int(name)
def long(name: String): PathCodec[Long] = PathCodec.long(name)
def string(name: String): PathCodec[String] = PathCodec.string(name)
val trailing: PathCodec[Path] = PathCodec.trailing
def uuid(name: String): PathCodec[UUID] = PathCodec.uuid(name)
def boolean(name: String): PathCodec[Boolean] = PathCodec.bool(name)
def int(name: String): PathCodec[Int] = PathCodec.int(name)
def long(name: String): PathCodec[Long] = PathCodec.long(name)
def string(name: String): PathCodec[String] = PathCodec.string(name)
val trailing: PathCodec[Path] = PathCodec.trailing
def uuid(name: String): PathCodec[UUID] = PathCodec.uuid(name)
def anyOf(name: String, names: String*): PathCodec[Unit] =
if (names.isEmpty) PathCodec.literal(name)
else names.foldLeft(PathCodec.literal(name))((acc, n) => acc.orElse(PathCodec.literal(n)))

val Root: PathCodec[Unit] = PathCodec.empty

Expand Down

0 comments on commit f7c392e

Please sign in to comment.