Skip to content

Commit

Permalink
Fix http4s routes in a context
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw committed Mar 29, 2019
1 parent acc3480 commit 95b96af
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ import cats.implicits._
import akka.actor.ActorSystem
import akka.http.scaladsl.Http
import akka.http.scaladsl.server.Route
import akka.http.scaladsl.server.Directives
import akka.http.scaladsl.server.Directives._
import akka.stream.ActorMaterializer
import cats.data.NonEmptyList
import cats.effect.{IO, Resource}
import tapir.Endpoint
import com.softwaremill.sttp._
import tapir.{Endpoint, endpoint, stringBody}
import tapir.server.tests.ServerTests
import tapir._

import scala.concurrent.duration._
import scala.concurrent.{Await, Future}
Expand Down Expand Up @@ -50,4 +53,13 @@ class AkkaHttpServerTests extends ServerTests[Future, AkkaStream, Route] {
import scala.concurrent.ExecutionContext.Implicits.global
Future { t }
}

test("endpoint nested in a path directive") {
val e = endpoint.get.in("test" and "directive").out(stringBody).serverLogic(_ => pureResult("ok".asRight[Unit]))
val port = randomPort()
val route = Directives.pathPrefix("api")(e.toRoute)
server(NonEmptyList.of(route), port).use { _ =>
sttp.get(uri"http://localhost:$port/api/test/directive").send().map(_.body shouldBe Right("ok"))
}.unsafeRunSync
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@ import tapir.model.{Method, ServerRequest}
class Http4sDecodeInputsContext[F[_]](req: Request[F]) extends DecodeInputsContext {
override def method: Method = Method(req.method.name.toUpperCase)
override def nextPathSegment: (Option[String], DecodeInputsContext) = {

val nextStart = req.uri.path.dropWhile(_ == '/')
val (segment, rest) = nextStart.split("/", 2) match {
case Array("") => (None, "")
case Array(s) => (Some(s), "")
case Array(s, t) => (Some(s), t)
val nextStart = req.pathInfo.dropWhile(_ == '/')
val segment = nextStart.split("/", 2) match {
case Array("") => None
case Array(s) => Some(s)
case Array(s, _) => Some(s)
}

(segment, new Http4sDecodeInputsContext(req.withUri(req.uri.withPath(rest))))
// if the routes are mounted within a context (e.g. using a router), we have to match against what comes
// after the context. This information is stored in the the PathInfoCaret attribute
val oldCaret = req.attributes.lookup(Request.Keys.PathInfoCaret).getOrElse(0)
val segmentSlashLength = segment.map(_.length).getOrElse(0) + 1
val reqWithNewCaret = req.withAttribute(Request.Keys.PathInfoCaret, oldCaret + segmentSlashLength)

(segment, new Http4sDecodeInputsContext(reqWithNewCaret))
}
override def header(name: String): List[String] = req.headers.get(CaseInsensitiveString(name)).map(_.value).toList
override def headers: Seq[(String, String)] = req.headers.map(h => (h.name.value, h.value)).toSeq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,23 @@ package tapir.server.http4s
import cats.data.{Kleisli, NonEmptyList}
import cats.effect._
import cats.implicits._
import org.http4s.server.Router
import org.http4s.server.blaze.BlazeServerBuilder
import org.http4s.syntax.kleisli._
import org.http4s.{EntityBody, HttpRoutes, Request, Response}
import tapir.server.tests.ServerTests
import tapir.Endpoint
import tapir._
import com.softwaremill.sttp._

import scala.concurrent.ExecutionContext
import scala.reflect.ClassTag

class Http4sServerTests extends ServerTests[IO, EntityBody[IO], HttpRoutes[IO]] {

implicit private val ec: ExecutionContext = scala.concurrent.ExecutionContext.Implicits.global
implicit private val contextShift: ContextShift[IO] = IO.contextShift(ec)
implicit private val timer: Timer[IO] = IO.timer(ec)
implicit val ec: ExecutionContext = scala.concurrent.ExecutionContext.Implicits.global
implicit val contextShift: ContextShift[IO] = IO.contextShift(ec)
implicit val timer: Timer[IO] = IO.timer(ec)

override def pureResult[T](t: T): IO[T] = IO.pure(t)
override def suspendResult[T](t: => T): IO[T] = IO.apply(t)
Expand All @@ -40,4 +43,19 @@ class Http4sServerTests extends ServerTests[IO, EntityBody[IO], HttpRoutes[IO]]
.resource
.map(_ => ())
}

test("should work with a router and routes in a context") {
val e = endpoint.get.in("test" / "router").out(stringBody).serverLogic(_ => IO.pure("ok".asRight[Unit]))
val routes = e.toRoutes
val port = randomPort()

BlazeServerBuilder[IO]
.bindHttp(port, "localhost")
.withHttpApp(Router("/api" -> routes).orNotFound)
.resource
.use { _ =>
sttp.get(uri"http://localhost:$port/api/test/router").send().map(_.body shouldBe Right("ok"))
}
.unsafeRunSync()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -413,5 +413,5 @@ trait ServerTests[R[_], S, ROUTE] extends FunSuite with Matchers with BeforeAndA
//

private val random = new Random()
private def randomPort(): Port = random.nextInt(29232) + 32768
def randomPort(): Port = random.nextInt(29232) + 32768
}

0 comments on commit 95b96af

Please sign in to comment.