diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 590e10c481..a312ffdff6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,6 +31,12 @@ jobs: ~/.ivy2/cache ~/.coursier key: sbt-cache-${{ runner.os }}-${{ matrix.target-platform }}-${{ hashFiles('project/build.properties') }} + - name: Install sam cli + run: | + wget -q https://github.com/aws/aws-sam-cli/releases/latest/download/aws-sam-cli-linux-x86_64.zip + unzip -q aws-sam-cli-linux-x86_64.zip -d sam-installation + sudo ./sam-installation/install --update + sam --version - name: Compile run: sbt -v compile compileDocumentation - name: Test diff --git a/build.sbt b/build.sbt index 4d98ba678b..1fa549484b 100644 --- a/build.sbt +++ b/build.sbt @@ -1,9 +1,12 @@ -import java.net.URL import com.softwaremill.SbtSoftwareMillBrowserTestJS._ import com.softwaremill.UpdateVersionInDocs import sbt.Reference.display import sbt.internal.ProjectMatrix +import java.net.URL +import scala.concurrent.duration.DurationInt +import scala.sys.process.Process + val scala2_12 = "2.12.13" val scala2_13 = "2.13.6" @@ -112,6 +115,11 @@ lazy val allAggregates = core.projectRefs ++ playServer.projectRefs ++ vertxServer.projectRefs ++ zioServer.projectRefs ++ + awsLambda.projectRefs ++ + awsLambdaTests.projectRefs ++ + awsSam.projectRefs ++ + awsTerraform.projectRefs ++ + awsExamples.projectRefs ++ http4sClient.projectRefs ++ sttpClient.projectRefs ++ playClient.projectRefs ++ @@ -263,7 +271,7 @@ lazy val tests: ProjectMatrix = (projectMatrix in file("tests")) "com.beachape" %%% "enumeratum-circe" % Versions.enumeratum, "com.softwaremill.common" %%% "tagging" % "2.3.0", scalaTest.value, - "com.softwaremill.macwire" %% "macros" % "2.3.7" % "provided", + "com.softwaremill.macwire" %% "macros" % "2.3.7", "org.typelevel" %%% "cats-effect" % Versions.catsEffect ), libraryDependencies ++= loggerDependencies @@ -283,6 +291,7 @@ lazy val cats: ProjectMatrix = (projectMatrix in file("integrations/cats")) name := "tapir-cats", libraryDependencies ++= Seq( "org.typelevel" %%% "cats-core" % "2.6.1", + "org.typelevel" %%% "cats-effect" % Versions.catsEffect, scalaTest.value % Test, scalaCheck.value % Test, scalaTestPlusScalaCheck.value % Test, @@ -397,8 +406,8 @@ lazy val circeJson: ProjectMatrix = (projectMatrix in file("json/circe")) libraryDependencies ++= Seq( "io.circe" %%% "circe-core" % Versions.circe, "io.circe" %%% "circe-parser" % Versions.circe, - scalaTest.value % Test, - "io.circe" %%% "circe-generic" % Versions.circe % Test + "io.circe" %%% "circe-generic" % Versions.circe, + scalaTest.value % Test ) ) .jvmPlatform(scalaVersions = allScalaVersions) @@ -775,7 +784,7 @@ lazy val http4sServer: ProjectMatrix = (projectMatrix in file("server/http4s-ser ) ) .jvmPlatform(scalaVersions = allScalaVersions) - .dependsOn(core, serverTests % Test) + .dependsOn(core, cats, serverTests % Test) lazy val sttpStubServer: ProjectMatrix = (projectMatrix in file("server/sttp-stub-server")) .settings(commonJvmSettings) @@ -874,6 +883,110 @@ lazy val zioServer: ProjectMatrix = (projectMatrix in file("server/zio-http4s-se .jvmPlatform(scalaVersions = allScalaVersions) .dependsOn(zio, http4sServer, serverTests % Test) +// serverless + +lazy val awsLambda: ProjectMatrix = (projectMatrix in file("serverless/aws/lambda")) + .settings(commonJvmSettings) + .settings( + name := "tapir-aws-lambda", + libraryDependencies ++= loggerDependencies, + libraryDependencies ++= Seq( + "com.softwaremill.sttp.client3" %% "http4s-ce2-backend" % Versions.sttp, + "org.http4s" %% "http4s-blaze-client" % Versions.http4s + ) + ) + .jvmPlatform(scalaVersions = allScalaVersions) + .dependsOn(core, cats, circeJson, awsSam, sttpStubServer % "test", tests % "test", serverTests) + +// integration tests for lambda interpreter +// it's a separate project since it needs a fat jar with lambda code which cannot be build from tests sources +// runs sam local cmd line tool to start AWS Api Gateway with lambda proxy +lazy val awsLambdaTests: ProjectMatrix = (projectMatrix in file("serverless/aws/lambda-tests")) + .settings(commonJvmSettings) + .settings( + name := "tapir-aws-lambda-tests", + libraryDependencies += "com.amazonaws" % "aws-lambda-java-runtime-interface-client" % Versions.awsLambdaInterface, + assembly / assemblyJarName := "tapir-aws-lambda-tests.jar", + assembly / test := {}, // no tests before building jar + assembly / assemblyMergeStrategy := { + case PathList("META-INF", "io.netty.versions.properties") => MergeStrategy.first + case _ @("scala/annotation/nowarn.class" | "scala/annotation/nowarn$.class") => MergeStrategy.first + case x => (assembly / assemblyMergeStrategy).value(x) + }, + Test / test := (Test / test) + .dependsOn((Compile / runMain).toTask(" sttp.tapir.serverless.aws.lambda.tests.LambdaSamTemplate")) + .dependsOn(assembly) + .value, + Test / testOptions ++= { + val log = sLog.value + // process uses template.yaml which is generated by `LambdaSamTemplate` called above + lazy val sam = Process("sam local start-api --warm-containers EAGER").run() + Seq( + Tests.Setup(() => { + val samReady = PollingUtils.poll(60.seconds, 1.second) { + sam.isAlive() && PollingUtils.urlConnectionAvailable(new URL(s"http://127.0.0.1:3000/health")) + } + if (!samReady) { + sam.destroy() + val exit = sam.exitValue() + log.error(s"failed to start sam local within 30 seconds (exit code: $exit") + } + }), + Tests.Cleanup(() => { + sam.destroy() + val exit = sam.exitValue() + log.info(s"stopped sam local (exit code: $exit") + }) + ) + }, + Test / parallelExecution := false + ) + .jvmPlatform(scalaVersions = allScalaVersions) + .dependsOn(core, cats, circeJson, awsLambda, awsSam, tests) + +lazy val awsSam: ProjectMatrix = (projectMatrix in file("serverless/aws/sam")) + .settings(commonJvmSettings) + .settings( + name := "tapir-aws-sam", + libraryDependencies ++= Seq( + "io.circe" %% "circe-yaml" % Versions.circeYaml, + "io.circe" %% "circe-generic" % Versions.circe + ) + ) + .jvmPlatform(scalaVersions = allScalaVersions) + .dependsOn(core, tests % Test) + +lazy val awsTerraform: ProjectMatrix = (projectMatrix in file("serverless/aws/terraform")) + .settings(commonJvmSettings) + .settings( + name := "tapir-aws-terraform", + libraryDependencies ++= Seq( + "io.circe" %% "circe-yaml" % Versions.circeYaml, + "io.circe" %% "circe-generic" % Versions.circe, + "io.circe" %% "circe-literal" % Versions.circe, + "org.typelevel" %% "jawn-parser" % "1.0.0" + ) + ) + .jvmPlatform(scalaVersions = allScalaVersions) + .dependsOn(core, tests % Test) + +lazy val awsExamples: ProjectMatrix = (projectMatrix in file("serverless/aws/examples")) + .settings(commonJvmSettings) + .settings( + libraryDependencies += "com.amazonaws" % "aws-lambda-java-runtime-interface-client" % Versions.awsLambdaInterface + ) + .settings( + name := "tapir-aws-examples", + assembly / assemblyJarName := "tapir-aws-examples.jar", + assembly / assemblyMergeStrategy := { + case PathList("META-INF", "io.netty.versions.properties") => MergeStrategy.first + case _ @("scala/annotation/nowarn.class" | "scala/annotation/nowarn$.class") => MergeStrategy.first + case x => (assembly / assemblyMergeStrategy).value(x) + } + ) + .jvmPlatform(scalaVersions = allScalaVersions) + .dependsOn(awsLambda, awsSam, awsTerraform) + // client lazy val clientTests: ProjectMatrix = (projectMatrix in file("client/tests")) diff --git a/doc/index.md b/doc/index.md index d7fb1f125c..a27d9992b6 100644 --- a/doc/index.md +++ b/doc/index.md @@ -136,6 +136,7 @@ Development and maintenance of sttp tapir is sponsored by [SoftwareMill](https:/ :caption: Server interpreters server/akkahttp + server/aws server/http4s server/finatra server/play diff --git a/doc/server/aws.md b/doc/server/aws.md new file mode 100644 index 0000000000..53ab09a4ad --- /dev/null +++ b/doc/server/aws.md @@ -0,0 +1,66 @@ +# Running behind AWS API Gateway + +[AWS API Gateway](https://docs.aws.amazon.com/apigateway/latest/developerguide/welcome.html) provides a proxy +integration +with [AWS Lambda](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html) +which allows you to implement API routes using Lambda functions. On the other hand tools +like [AWS SAM](https://aws.amazon.com/serverless/sam/) and [Terraform](https://www.terraform.io/) provides a +configuration mechanism for binding AWS Api Gateway routes to Lambda functions and automating cloud deployments. + +This concept of serverless API has been adapted to Tapir in form of three components. + +The first one is `AwsServerInterpreter` which routes AWS API Gateway requests to responses just as any other server +interpreter does. It should be used in your lambda function code. + +```scala +"com.softwaremill.sttp.tapir" %% "tapir-aws-lambda" % "@VERSION@" +``` + +The remaining two are `AwsSamInterpreter` which interprets Tapir `Endpoints` into AWS SAM template file +and `AwsTerraformInterpreter` which interprets `Endpoints` into terraform configuration file. One of them should be used +to configure your API Gateway. + +```scala +"com.softwaremill.sttp.tapir" %% "tapir-aws-sam" % "@VERSION@" +"com.softwaremill.sttp.tapir" %% "tapir-aws-terraform" % "@VERSION@" +``` + +## Examples + +In +our [GitHub repository](https://github.com/softwaremill/tapir/tree/master/serverless/aws/examples/src/main/scala/sttp/tapir/serverless/aws/examples) +you'll find a `LambdaApiExample` handler which uses `AwsServerInterpreter` to route a hello endpoint along +with `SamTemplateExample` and `TerraformConfigExample` which interpret endpoints to SAM/Terraform configuration. Go +ahead and clone tapir project and select `project awsExamples` from sbt shell. + +Make sure you +have [AWS command line tools installed](https://docs.aws.amazon.com/cli/latest/userguide/install-cliv2.html). + +### SAM + +To try it out using SAM template you don't need an AWS account. + +* install [AWS SAM command line tool](https://docs.aws.amazon.com/serverless-application-model/latest/developerguide/serverless-sam-cli-command-reference.html) +* run `assembly` task and `runMain sttp.tapir.serverless.aws.examples.SamTemplateExample` +* open a terminal and in tapir root directory run `sam local start-api --warm-containers EAGER` + +That will create `template.yaml` and start up AWS Api Gateway locally. Hello endpoint will be available +under `curl http://127.0.0.1:3000/api/hello`. First invocation will take a while but subsequent ones will be faster +since the created container will be reused. + +### Terraform + +To run the example using terraform you will need an AWS account, and an S3 bucket. + +* install [Terraform](https://learn.hashicorp.com/tutorials/terraform/install-cli) +* run `assembly` task +* open a terminal in `tapir/serverless/aws/examples/target/jvm-2.13` directory. That's where the fat jar is saved. You + need to upload it into your s3 bucket. Using command line + tools: `aws s3 cp tapir-aws-examples.jar s3://{your-bucket}/{your-key}`. +* Run `runMain sttp.tapir.serverless.aws.examples.TerraformConfigExample {your-aws-region} {your-bucket} {your-key}` +* open terminal in tapir root directory, run `terraform init` and `terraform apply` + +That will create `api_gateway.tf.json` configuration and deploy Api Gateway and lambda function to AWS. Terraform will +output the url of the created API Gateway which you can call followed by `/api/hello` path. + +To destroy all the created resources run `terraform destroy`. \ No newline at end of file diff --git a/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/CatsMonadError.scala b/integrations/cats/src/main/scala/sttp/tapir/integ/cats/CatsMonadError.scala similarity index 84% rename from server/http4s-server/src/main/scala/sttp/tapir/server/http4s/CatsMonadError.scala rename to integrations/cats/src/main/scala/sttp/tapir/integ/cats/CatsMonadError.scala index 9793cea610..7c642abbd5 100644 --- a/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/CatsMonadError.scala +++ b/integrations/cats/src/main/scala/sttp/tapir/integ/cats/CatsMonadError.scala @@ -1,9 +1,9 @@ -package sttp.tapir.server.http4s +package sttp.tapir.integ.cats import cats.effect.Sync import sttp.monad.MonadError -private[http4s] class CatsMonadError[F[_]](implicit F: Sync[F]) extends MonadError[F] { +class CatsMonadError[F[_]](implicit F: Sync[F]) extends MonadError[F] { override def unit[T](t: T): F[T] = F.pure(t) override def map[T, T2](fa: F[T])(f: T => T2): F[T2] = F.map(fa)(f) override def flatMap[T, T2](fa: F[T])(f: T => F[T2]): F[T2] = F.flatMap(fa)(f) @@ -13,4 +13,4 @@ private[http4s] class CatsMonadError[F[_]](implicit F: Sync[F]) extends MonadErr override def suspend[T](t: => F[T]): F[T] = F.suspend(t) override def flatten[T](ffa: F[F[T]]): F[T] = F.flatten(ffa) override def ensure[T](f: F[T], e: => F[Unit]): F[T] = F.guarantee(f)(e) -} +} \ No newline at end of file diff --git a/project/Versions.scala b/project/Versions.scala index 367802f346..e8c6cd086c 100644 --- a/project/Versions.scala +++ b/project/Versions.scala @@ -31,4 +31,5 @@ object Versions { val jwtScala = "5.0.0" val derevo = "0.12.5" val newtype = "0.4.4" + val awsLambdaInterface = "1.0.0" } diff --git a/project/plugins.sbt b/project/plugins.sbt index 83e7a34fe5..b5055193a4 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -10,3 +10,4 @@ addSbtPlugin("com.eed3si9n" % "sbt-projectmatrix" % "0.8.0") addSbtPlugin("org.jetbrains.scala" % "sbt-ide-settings" % "1.1.1") addSbtPlugin("org.scala-js" % "sbt-scalajs" % "1.5.1") addSbtPlugin("io.spray" % "sbt-revolver" % "0.9.1") +addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.15.0") diff --git a/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala b/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala index 6df022d607..a3bf04f9df 100644 --- a/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala +++ b/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala @@ -9,21 +9,15 @@ import cats.implicits._ import org.scalatest.EitherValues import org.scalatest.matchers.should.Matchers._ import sttp.capabilities.akka.AkkaStreams +import akka.http.scaladsl.server.Route +import sttp.capabilities.{WebSockets, akka} import sttp.client3._ import sttp.client3.akkahttp.AkkaHttpBackend import sttp.model.sse.ServerSentEvent import sttp.monad.FutureMonad import sttp.monad.syntax._ import sttp.tapir._ -import sttp.tapir.server.tests.{ - CreateServerTest, - ServerAuthenticationTests, - ServerBasicTests, - ServerMetricsTest, - ServerStreamingTests, - ServerWebSocketTests, - backendResource -} +import sttp.tapir.server.tests.{DefaultCreateServerTest, ServerAuthenticationTests, ServerBasicTests, ServerFileMultipartTests, ServerMetricsTest, ServerStreamingTests, ServerWebSocketTests, backendResource} import sttp.tapir.tests.{Test, TestSuite} import java.util.UUID @@ -43,7 +37,7 @@ class AkkaHttpServerTest extends TestSuite with EitherValues { implicit val m: FutureMonad = new FutureMonad()(actorSystem.dispatcher) val interpreter = new AkkaHttpTestServerInterpreter()(actorSystem) - val createServerTest = new CreateServerTest(interpreter) + val createServerTest = new DefaultCreateServerTest(backend, interpreter) def additionalTests(): List[Test] = List( Test("endpoint nested in a path directive") { @@ -86,13 +80,14 @@ class AkkaHttpServerTest extends TestSuite with EitherValues { } ) - new ServerBasicTests(backend, createServerTest, interpreter).tests() ++ - new ServerStreamingTests(backend, createServerTest, AkkaStreams).tests() ++ - new ServerWebSocketTests(backend, createServerTest, AkkaStreams) { + new ServerBasicTests(createServerTest, interpreter).tests() ++ + new ServerFileMultipartTests(createServerTest).tests() ++ + new ServerWebSocketTests(createServerTest, AkkaStreams) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) }.tests() ++ - new ServerAuthenticationTests(backend, createServerTest).tests() ++ - new ServerMetricsTest(backend, createServerTest).tests() ++ + new ServerStreamingTests(createServerTest, AkkaStreams).tests() ++ + new ServerAuthenticationTests(createServerTest).tests() ++ + new ServerMetricsTest(createServerTest).tests() ++ additionalTests() } } diff --git a/server/finatra-server/finatra-server-cats/src/test/scala/sttp/tapir/server/finatra/cats/FinatraServerCatsTests.scala b/server/finatra-server/finatra-server-cats/src/test/scala/sttp/tapir/server/finatra/cats/FinatraServerCatsTests.scala index fbe8037474..c0f2a755d6 100644 --- a/server/finatra-server/finatra-server-cats/src/test/scala/sttp/tapir/server/finatra/cats/FinatraServerCatsTests.scala +++ b/server/finatra-server/finatra-server-cats/src/test/scala/sttp/tapir/server/finatra/cats/FinatraServerCatsTests.scala @@ -2,7 +2,7 @@ package sttp.tapir.server.finatra.cats import cats.effect.{IO, Resource} import sttp.client3.impl.cats.CatsMonadAsyncError -import sttp.tapir.server.tests.{CreateServerTest, ServerAuthenticationTests, ServerBasicTests, backendResource} +import sttp.tapir.server.tests.{DefaultCreateServerTest, ServerAuthenticationTests, ServerBasicTests, ServerFileMultipartTests, backendResource} import sttp.tapir.tests.{Test, TestSuite} class FinatraServerCatsTests extends TestSuite { @@ -10,9 +10,10 @@ class FinatraServerCatsTests extends TestSuite { implicit val m: CatsMonadAsyncError[IO] = new CatsMonadAsyncError[IO]() val interpreter = new FinatraCatsTestServerInterpreter() - val createServerTest = new CreateServerTest(interpreter) + val createTestServer = new DefaultCreateServerTest(backend, interpreter) - new ServerBasicTests(backend, createServerTest, interpreter).tests() ++ - new ServerAuthenticationTests(backend, createServerTest).tests() + new ServerBasicTests(createTestServer, interpreter).tests() ++ + new ServerFileMultipartTests(createTestServer).tests() ++ + new ServerAuthenticationTests(createTestServer).tests() } } diff --git a/server/finatra-server/src/test/scala/sttp/tapir/server/finatra/FinatraServerTest.scala b/server/finatra-server/src/test/scala/sttp/tapir/server/finatra/FinatraServerTest.scala index 0812b2a5d6..3c414ad4a9 100644 --- a/server/finatra-server/src/test/scala/sttp/tapir/server/finatra/FinatraServerTest.scala +++ b/server/finatra-server/src/test/scala/sttp/tapir/server/finatra/FinatraServerTest.scala @@ -3,19 +3,27 @@ package sttp.tapir.server.finatra import cats.effect.{IO, Resource} import sttp.monad.MonadError import sttp.tapir.server.finatra.FinatraServerInterpreter.FutureMonadError -import sttp.tapir.server.tests.{CreateServerTest, ServerAuthenticationTests, ServerBasicTests, ServerMetricsTest, backendResource} +import sttp.tapir.server.tests.{ + DefaultCreateServerTest, + ServerAuthenticationTests, + ServerBasicTests, + ServerFileMultipartTests, + ServerMetricsTest, + backendResource +} import sttp.tapir.tests.{Test, TestSuite} class FinatraServerTest extends TestSuite { override def tests: Resource[IO, List[Test]] = backendResource.map { backend => val interpreter = new FinatraTestServerInterpreter() - val createServerTest = new CreateServerTest(interpreter) + val createServerTest = new DefaultCreateServerTest(backend, interpreter) implicit val m: MonadError[com.twitter.util.Future] = FutureMonadError - new ServerBasicTests(backend, createServerTest, interpreter).tests() ++ - new ServerAuthenticationTests(backend, createServerTest).tests() ++ - new ServerMetricsTest(backend, createServerTest).tests() + new ServerBasicTests(createServerTest, interpreter).tests() ++ + new ServerFileMultipartTests(createServerTest).tests() ++ + new ServerAuthenticationTests(createServerTest).tests() ++ + new ServerMetricsTest(createServerTest).tests() } } diff --git a/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sServerInterpreter.scala b/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sServerInterpreter.scala index c5d96e9ecf..33e1957ac9 100644 --- a/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sServerInterpreter.scala +++ b/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sServerInterpreter.scala @@ -16,6 +16,7 @@ import sttp.capabilities.WebSockets import sttp.capabilities.fs2.Fs2Streams import sttp.model.{Header => SttpHeader} import sttp.tapir.Endpoint +import sttp.tapir.integ.cats.CatsMonadError import sttp.tapir.model.ServerResponse import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.interpreter.{BodyListener, ServerInterpreter} diff --git a/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala b/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala index 9d954299e1..bf51df836f 100644 --- a/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala +++ b/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala @@ -2,6 +2,7 @@ package sttp.tapir.server.http4s import cats.effect._ import cats.syntax.all._ +import org.http4s.HttpRoutes import org.http4s.server.Router import org.http4s.server.blaze.BlazeServerBuilder import org.http4s.syntax.kleisli._ @@ -12,15 +13,8 @@ import sttp.capabilities.fs2.Fs2Streams import sttp.client3._ import sttp.model.sse.ServerSentEvent import sttp.tapir._ -import sttp.tapir.server.tests.{ - CreateServerTest, - ServerAuthenticationTests, - ServerBasicTests, - ServerMetricsTest, - ServerStreamingTests, - ServerWebSocketTests, - backendResource -} +import sttp.tapir.integ.cats.CatsMonadError +import sttp.tapir.server.tests.{DefaultCreateServerTest, ServerAuthenticationTests, ServerBasicTests, ServerFileMultipartTests, ServerMetricsTest, ServerStreamingTests, ServerWebSocketTests, backendResource} import sttp.tapir.tests.{Test, TestSuite} import sttp.ws.{WebSocket, WebSocketFrame} @@ -35,7 +29,7 @@ class Http4sServerTest[R >: Fs2Streams[IO] with WebSockets] extends TestSuite wi implicit val m: CatsMonadError[IO] = new CatsMonadError[IO] val interpreter = new Http4sTestServerInterpreter() - val createServerTest = new CreateServerTest(interpreter) + val createServerTest = new DefaultCreateServerTest(backend, interpreter) def randomUUID = Some(UUID.randomUUID().toString) val sse1 = ServerSentEvent(randomUUID, randomUUID, randomUUID, Some(Random.nextInt(200))) val sse2 = ServerSentEvent(randomUUID, randomUUID, randomUUID, Some(Random.nextInt(200))) @@ -64,7 +58,7 @@ class Http4sServerTest[R >: Fs2Streams[IO] with WebSockets] extends TestSuite wi .autoPing(Some((1.second, WebSocketFrame.ping))) ), "automatic pings" - )((_: Unit) => IO(Right((in: fs2.Stream[IO, String]) => in))) { baseUri => + )((_: Unit) => IO(Right((in: fs2.Stream[IO, String]) => in))) { (backend, baseUri) => basicRequest .response(asWebSocket { ws: WebSocket[IO] => List(ws.receive().timeout(60.seconds), ws.receive().timeout(60.seconds)).sequence @@ -78,7 +72,7 @@ class Http4sServerTest[R >: Fs2Streams[IO] with WebSockets] extends TestSuite wi "streaming should send data according to producer stream rate" )((_: Unit) => IO(Right(fs2.Stream.awakeEvery[IO](1.second).map(_.toString()).through(fs2.text.utf8Encode).interruptAfter(5.seconds))) - ) { baseUri => + ) { (backend, baseUri) => basicRequest .response( asStream(Fs2Streams[IO])(bs => { @@ -95,7 +89,7 @@ class Http4sServerTest[R >: Fs2Streams[IO] with WebSockets] extends TestSuite wi createServerTest.testServer( endpoint.out(serverSentEventsBody[IO]), "Send and receive SSE" - )((_: Unit) => IO(Right(fs2.Stream(sse1, sse2)))) { baseUri => + )((_: Unit) => IO(Right(fs2.Stream(sse1, sse2)))) { (backend, baseUri) => basicRequest .response(asStream[IO, List[ServerSentEvent], Fs2Streams[IO]](Fs2Streams[IO]) { stream => Http4sServerSentEvents @@ -110,12 +104,13 @@ class Http4sServerTest[R >: Fs2Streams[IO] with WebSockets] extends TestSuite wi } ) - new ServerBasicTests(backend, createServerTest, interpreter).tests() ++ - new ServerStreamingTests(backend, createServerTest, Fs2Streams[IO]).tests() ++ - new ServerWebSocketTests(backend, createServerTest, Fs2Streams[IO]) { + new ServerBasicTests(createServerTest, interpreter).tests() ++ + new ServerFileMultipartTests(createServerTest).tests() ++ + new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests() ++ + new ServerWebSocketTests(createServerTest, Fs2Streams[IO]) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) }.tests() ++ - new ServerAuthenticationTests(backend, createServerTest).tests() ++ - new ServerMetricsTest(backend, createServerTest).tests() ++ additionalTests() + new ServerAuthenticationTests(createServerTest).tests() ++ + new ServerMetricsTest(createServerTest).tests() ++ additionalTests() } } diff --git a/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala b/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala index 981aa3ece9..c031778cd5 100644 --- a/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala +++ b/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala @@ -3,7 +3,14 @@ package sttp.tapir.server.play import akka.actor.ActorSystem import cats.effect.{IO, Resource} import sttp.monad.FutureMonad -import sttp.tapir.server.tests.{CreateServerTest, ServerAuthenticationTests, ServerBasicTests, ServerMetricsTest, backendResource} +import sttp.tapir.server.tests.{ + DefaultCreateServerTest, + ServerAuthenticationTests, + ServerBasicTests, + ServerFileMultipartTests, + ServerMetricsTest, + backendResource +} import sttp.tapir.tests.{Test, TestSuite} class PlayServerTest extends TestSuite { @@ -16,17 +23,17 @@ class PlayServerTest extends TestSuite { implicit val m: FutureMonad = new FutureMonad()(actorSystem.dispatcher) val interpreter = new PlayTestServerInterpreter()(actorSystem) - val createServerTest = new CreateServerTest(interpreter) + val createServerTest = new DefaultCreateServerTest(backend, interpreter) new ServerBasicTests( - backend, createServerTest, interpreter, multipleValueHeaderSupport = false, - multipartInlineHeaderSupport = false, inputStreamSupport = false - ).tests() ++ new ServerAuthenticationTests(backend, createServerTest).tests() ++ - new ServerMetricsTest(backend, createServerTest).tests() + ).tests() ++ + new ServerFileMultipartTests(createServerTest, multipartInlineHeaderSupport = false).tests() + new ServerAuthenticationTests(createServerTest).tests() ++ + new ServerMetricsTest(createServerTest).tests() } } } diff --git a/server/sttp-stub-server/src/main/scala/sttp/tapir/server/stub/SttpStubServer.scala b/server/sttp-stub-server/src/main/scala/sttp/tapir/server/stub/SttpStubServer.scala index 48573c9c07..a1003ce77b 100644 --- a/server/sttp-stub-server/src/main/scala/sttp/tapir/server/stub/SttpStubServer.scala +++ b/server/sttp-stub-server/src/main/scala/sttp/tapir/server/stub/SttpStubServer.scala @@ -17,7 +17,6 @@ import java.nio.ByteBuffer import java.nio.charset.Charset import scala.collection.immutable.Seq import scala.util.{Success, Try} -import sttp.monad.syntax._ trait SttpStubServer { diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/CreateServerTest.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/CreateServerTest.scala deleted file mode 100644 index 5e503298de..0000000000 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/CreateServerTest.scala +++ /dev/null @@ -1,54 +0,0 @@ -package sttp.tapir.server.tests - -import cats.data.NonEmptyList -import cats.effect.{IO, Resource} -import cats.implicits._ -import com.typesafe.scalalogging.StrictLogging -import org.scalatest.Assertion -import sttp.client3._ -import sttp.model._ -import sttp.tapir._ -import sttp.tapir.server.ServerEndpoint -import sttp.tapir.server.interceptor.decodefailure.DecodeFailureHandler -import sttp.tapir.server.interceptor.metrics.{MetricsEndpointInterceptor, MetricsRequestInterceptor} -import sttp.tapir.tests._ - -class CreateServerTest[F[_], +R, ROUTE, B](interpreter: TestServerInterpreter[F, R, ROUTE, B]) extends StrictLogging { - def testServer[I, E, O]( - e: Endpoint[I, E, O, R], - testNameSuffix: String = "", - decodeFailureHandler: Option[DecodeFailureHandler] = None, - metricsInterceptor: Option[MetricsRequestInterceptor[F, B]] = None - )( - fn: I => F[Either[E, O]] - )(runTest: Uri => IO[Assertion]): Test = { - testServer( - e.showDetail + (if (testNameSuffix == "") "" else " " + testNameSuffix), - NonEmptyList.of(interpreter.route(e.serverLogic(fn), decodeFailureHandler, metricsInterceptor)) - )(runTest) - } - - def testServerLogic[I, E, O](e: ServerEndpoint[I, E, O, R, F], testNameSuffix: String = "")(runTest: Uri => IO[Assertion]): Test = { - testServer( - e.showDetail + (if (testNameSuffix == "") "" else " " + testNameSuffix), - NonEmptyList.of(interpreter.route(e)) - )(runTest) - } - - def testServer(name: String, rs: => NonEmptyList[ROUTE])(runTest: Uri => IO[Assertion]): Test = { - val resources = for { - port <- interpreter.server(rs).onError { case e: Exception => - Resource.eval(IO(logger.error(s"Starting server failed because of ${e.getMessage}"))) - } - _ <- Resource.eval(IO(logger.info(s"Bound server on port: $port"))) - } yield port - - Test(name)( - resources - .use { port => - runTest(uri"http://localhost:$port").guarantee(IO(logger.info(s"Tests completed on port $port"))) - } - .unsafeRunSync() - ) - } -} diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/DefaultCreateServerTest.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/DefaultCreateServerTest.scala new file mode 100644 index 0000000000..6639e31be6 --- /dev/null +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/DefaultCreateServerTest.scala @@ -0,0 +1,81 @@ +package sttp.tapir.server.tests + +import cats.data.NonEmptyList +import cats.effect.{IO, Resource} +import cats.implicits._ +import com.typesafe.scalalogging.StrictLogging +import org.scalatest.Assertion +import sttp.capabilities.WebSockets +import sttp.capabilities.fs2.Fs2Streams +import sttp.client3._ +import sttp.model._ +import sttp.tapir._ +import sttp.tapir.server.ServerEndpoint +import sttp.tapir.server.interceptor.decodefailure.DecodeFailureHandler +import sttp.tapir.server.interceptor.metrics.MetricsRequestInterceptor +import sttp.tapir.tests._ + +class DefaultCreateServerTest[F[_], +R, ROUTE, B]( + backend: SttpBackend[IO, Fs2Streams[IO] with WebSockets], + interpreter: TestServerInterpreter[F, R, ROUTE, B] +) extends CreateServerTest[F, R, ROUTE, B] + with StrictLogging { + override def testServer[I, E, O]( + e: Endpoint[I, E, O, R], + testNameSuffix: String = "", + decodeFailureHandler: Option[DecodeFailureHandler] = None, + metricsInterceptor: Option[MetricsRequestInterceptor[F, B]] = None + )( + fn: I => F[Either[E, O]] + )(runTest: (SttpBackend[IO, Fs2Streams[IO] with WebSockets], Uri) => IO[Assertion]): Test = { + testServer( + e.showDetail + (if (testNameSuffix == "") "" else " " + testNameSuffix), + NonEmptyList.of(interpreter.route(e.serverLogic(fn), decodeFailureHandler, metricsInterceptor)) + )(runTest) + } + + override def testServerLogic[I, E, O](e: ServerEndpoint[I, E, O, R, F], testNameSuffix: String = "")( + runTest: (SttpBackend[IO, Fs2Streams[IO] with WebSockets], Uri) => IO[Assertion] + ): Test = { + testServer( + e.showDetail + (if (testNameSuffix == "") "" else " " + testNameSuffix), + NonEmptyList.of(interpreter.route(e)) + )(runTest) + } + + override def testServer(name: String, rs: => NonEmptyList[ROUTE])( + runTest: (SttpBackend[IO, Fs2Streams[IO] with WebSockets], Uri) => IO[Assertion] + ): Test = { + val resources = for { + port <- interpreter.server(rs).onError { case e: Exception => + Resource.eval(IO(logger.error(s"Starting server failed because of ${e.getMessage}"))) + } + _ <- Resource.eval(IO(logger.info(s"Bound server on port: $port"))) + } yield port + + Test(name)( + resources + .use { port => + runTest(backend, uri"http://localhost:$port").guarantee(IO(logger.info(s"Tests completed on port $port"))) + } + .unsafeRunSync() + ) + } +} + +trait CreateServerTest[F[_], +R, ROUTE, B] { + def testServer[I, E, O]( + e: Endpoint[I, E, O, R], + testNameSuffix: String = "", + decodeFailureHandler: Option[DecodeFailureHandler] = None, + metricsInterceptor: Option[MetricsRequestInterceptor[F, B]] = None + )(fn: I => F[Either[E, O]])(runTest: (SttpBackend[IO, Fs2Streams[IO] with WebSockets], Uri) => IO[Assertion]): Test + + def testServerLogic[I, E, O](e: ServerEndpoint[I, E, O, R, F], testNameSuffix: String = "")( + runTest: (SttpBackend[IO, Fs2Streams[IO] with WebSockets], Uri) => IO[Assertion] + ): Test + + def testServer(name: String, rs: => NonEmptyList[ROUTE])( + runTest: (SttpBackend[IO, Fs2Streams[IO] with WebSockets], Uri) => IO[Assertion] + ): Test +} diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerAuthenticationTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerAuthenticationTests.scala index 101ea3746d..6b9da09f6d 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerAuthenticationTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerAuthenticationTests.scala @@ -1,22 +1,19 @@ package sttp.tapir.server.tests -import cats.effect.IO +import cats.implicits._ import org.scalatest.matchers.should.Matchers import sttp.client3._ -import sttp.model._ -import sttp.model.StatusCode -import sttp.monad.MonadError -import sttp.tapir._ -import sttp.tapir.tests.Test -import cats.implicits._ import sttp.model.Uri.QuerySegment +import sttp.model.{StatusCode, _} +import sttp.monad.MonadError import sttp.tapir.EndpointInput.WWWAuthenticate +import sttp.tapir._ import sttp.tapir.model.UsernamePassword +import sttp.tapir.tests.Test -class ServerAuthenticationTests[F[_], S, ROUTE, B](backend: SttpBackend[IO, Any], serverTests: CreateServerTest[F, S, ROUTE, B])(implicit - m: MonadError[F] -) extends Matchers { - import serverTests._ +class ServerAuthenticationTests[F[_], S, ROUTE, B](createServerTest: CreateServerTest[F, S, ROUTE, B])(implicit m: MonadError[F]) + extends Matchers { + import createServerTest._ private val Realm = "realm" private val base = endpoint.post.in("secret" / path[Long]("id")).in(query[String]("q")) @@ -50,7 +47,7 @@ class ServerAuthenticationTests[F[_], S, ROUTE, B](backend: SttpBackend[IO, Any] def tests(): List[Test] = missingAuthTests ++ correctAuthTests ++ badRequestWithCorrectAuthTests private def missingAuthTests = endpoints.map { case (authType, endpoint, _) => - testServer(endpoint, s"missing $authType")(_ => result) { baseUri => + testServer(endpoint, s"missing $authType")(_ => result) { (backend, baseUri) => validRequest(baseUri).send(backend).map { r => r.code shouldBe StatusCode.Unauthorized r.header("WWW-Authenticate") shouldBe Some(expectedChallenge(authType)) @@ -65,7 +62,7 @@ class ServerAuthenticationTests[F[_], S, ROUTE, B](backend: SttpBackend[IO, Any] } private def correctAuthTests = endpoints.map { case (authType, endpoint, auth) => - testServer(endpoint, s"correct $authType")(_ => result) { baseUri => + testServer(endpoint, s"correct $authType")(_ => result) { (backend, baseUri) => auth(validRequest(baseUri)) .send(backend) .map(_.code shouldBe StatusCode.Ok) @@ -73,7 +70,7 @@ class ServerAuthenticationTests[F[_], S, ROUTE, B](backend: SttpBackend[IO, Any] } private def badRequestWithCorrectAuthTests = endpoints.map { case (authType, endpoint, auth) => - testServer(endpoint, s"invalid request $authType")(_ => result) { baseUri => + testServer(endpoint, s"invalid request $authType")(_ => result) { (backend, baseUri) => auth(invalidRequest(baseUri)).send(backend).map(_.code shouldBe StatusCode.BadRequest) } } diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala index 7cd2eef73a..850097082f 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala @@ -31,11 +31,9 @@ import scala.concurrent.Await import scala.concurrent.duration.DurationInt class ServerBasicTests[F[_], ROUTE, B]( - backend: SttpBackend[IO, Any], createServerTest: CreateServerTest[F, Any, ROUTE, B], serverInterpreter: TestServerInterpreter[F, Any, ROUTE, B], multipleValueHeaderSupport: Boolean = true, - multipartInlineHeaderSupport: Boolean = true, inputStreamSupport: Boolean = true )(implicit m: MonadError[F] @@ -48,118 +46,120 @@ class ServerBasicTests[F[_], ROUTE, B]( private def suspendResult[T](t: => T): F[T] = m.eval(t) def tests(): List[Test] = - basicTests() ++ - (if (multipartInlineHeaderSupport) multipartInlineHeaderTests() else Nil) ++ - (if (inputStreamSupport) inputStreamTests() else Nil) + basicTests() ++ (if (inputStreamSupport) inputStreamTests() else Nil) def basicTests(): List[Test] = List( testServer(in_string_out_status_from_type_erasure_using_partial_matcher)((v: String) => pureResult((if (v == "right") Some(Right("right")) else if (v == "left") Some(Left(42)) else None).asRight[Unit]) - ) { baseUri => + ) { (backend, baseUri) => basicRequest.get(uri"$baseUri?fruit=nothing").send(backend).map(_.code shouldBe StatusCode.NoContent) >> basicRequest.get(uri"$baseUri?fruit=right").send(backend).map(_.code shouldBe StatusCode.Ok) >> basicRequest.get(uri"$baseUri?fruit=left").send(backend).map(_.code shouldBe StatusCode.Accepted) }, // method matching - testServer(endpoint, "GET empty endpoint")((_: Unit) => pureResult(().asRight[Unit])) { baseUri => + testServer(endpoint, "GET empty endpoint")((_: Unit) => pureResult(().asRight[Unit])) { (backend, baseUri) => basicRequest.get(baseUri).send(backend).map(_.body shouldBe Right("")) }, - testServer(endpoint, "POST empty endpoint")((_: Unit) => pureResult(().asRight[Unit])) { baseUri => + testServer(endpoint, "POST empty endpoint")((_: Unit) => pureResult(().asRight[Unit])) { (backend, baseUri) => basicRequest.post(baseUri).send(backend).map(_.body shouldBe Right("")) }, - testServer(endpoint.get, "GET a GET endpoint")((_: Unit) => pureResult(().asRight[Unit])) { baseUri => + testServer(endpoint.get, "GET a GET endpoint")((_: Unit) => pureResult(().asRight[Unit])) { (backend, baseUri) => basicRequest.get(baseUri).send(backend).map(_.body shouldBe Right("")) }, - testServer(endpoint.get, "POST a GET endpoint")((_: Unit) => pureResult(().asRight[Unit])) { baseUri => + testServer(endpoint.get, "POST a GET endpoint")((_: Unit) => pureResult(().asRight[Unit])) { (backend, baseUri) => basicRequest.post(baseUri).send(backend).map(_.body shouldBe Symbol("left")) }, // - testServer(in_query_out_string)((fruit: String) => pureResult(s"fruit: $fruit".asRight[Unit])) { baseUri => + testServer(in_query_out_string)((fruit: String) => pureResult(s"fruit: $fruit".asRight[Unit])) { (backend, baseUri) => basicRequest.get(uri"$baseUri?fruit=orange").send(backend).map(_.body shouldBe Right("fruit: orange")) }, - testServer(in_query_out_string, "with URL encoding")((fruit: String) => pureResult(s"fruit: $fruit".asRight[Unit])) { baseUri => - basicRequest.get(uri"$baseUri?fruit=red%20apple").send(backend).map(_.body shouldBe Right("fruit: red apple")) + testServer(in_query_out_string, "with URL encoding")((fruit: String) => pureResult(s"fruit: $fruit".asRight[Unit])) { + (backend, baseUri) => + basicRequest.get(uri"$baseUri?fruit=red%20apple").send(backend).map(_.body shouldBe Right("fruit: red apple")) }, testServer[String, Nothing, String](in_query_out_infallible_string)((fruit: String) => pureResult(s"fruit: $fruit".asRight[Nothing])) { - baseUri => + (backend, baseUri) => basicRequest.get(uri"$baseUri?fruit=kiwi").send(backend).map(_.body shouldBe Right("fruit: kiwi")) }, testServer(in_query_query_out_string) { case (fruit: String, amount: Option[Int]) => pureResult(s"$fruit $amount".asRight[Unit]) } { - baseUri => + (backend, baseUri) => basicRequest.get(uri"$baseUri?fruit=orange").send(backend).map(_.body shouldBe Right("orange None")) *> basicRequest.get(uri"$baseUri?fruit=orange&amount=10").send(backend).map(_.body shouldBe Right("orange Some(10)")) }, - testServer(in_header_out_string)((p1: String) => pureResult(s"$p1".asRight[Unit])) { baseUri => + testServer(in_header_out_string)((p1: String) => pureResult(s"$p1".asRight[Unit])) { (backend, baseUri) => basicRequest.get(uri"$baseUri").header("X-Role", "Admin").send(backend).map(_.body shouldBe Right("Admin")) }, - testServer(in_path_path_out_string) { case (fruit: String, amount: Int) => pureResult(s"$fruit $amount".asRight[Unit]) } { baseUri => - basicRequest.get(uri"$baseUri/fruit/orange/amount/20").send(backend).map(_.body shouldBe Right("orange 20")) + testServer(in_path_path_out_string) { case (fruit: String, amount: Int) => pureResult(s"$fruit $amount".asRight[Unit]) } { + (backend, baseUri) => + basicRequest.get(uri"$baseUri/fruit/orange/amount/20").send(backend).map(_.body shouldBe Right("orange 20")) }, testServer(in_path_path_out_string, "with URL encoding") { case (fruit: String, amount: Int) => pureResult(s"$fruit $amount".asRight[Unit]) - } { baseUri => + } { (backend, baseUri) => basicRequest.get(uri"$baseUri/fruit/apple%2Fred/amount/20").send(backend).map(_.body shouldBe Right("apple/red 20")) }, - testServer(in_path, "Empty path should not be passed to path capture decoding") { _ => pureResult(Right(())) } { baseUri => + testServer(in_path, "Empty path should not be passed to path capture decoding") { _ => pureResult(Right(())) } { (backend, baseUri) => basicRequest.get(uri"$baseUri/api/").send(backend).map(_.code shouldBe StatusCode.NotFound) }, testServer(in_two_path_capture, "capturing two path parameters with the same specification") { case (a: Int, b: Int) => pureResult(Right((a, b))) - } { baseUri => + } { (backend, baseUri) => basicRequest.get(uri"$baseUri/in/12/23").send(backend).map { response => response.header("a") shouldBe Some("12") response.header("b") shouldBe Some("23") } }, - testServer(in_string_out_string)((b: String) => pureResult(b.asRight[Unit])) { baseUri => + testServer(in_string_out_string)((b: String) => pureResult(b.asRight[Unit])) { (backend, baseUri) => basicRequest.post(uri"$baseUri/api/echo").body("Sweet").send(backend).map(_.body shouldBe Right("Sweet")) }, - testServer(in_string_out_string, "with get method")((b: String) => pureResult(b.asRight[Unit])) { baseUri => + testServer(in_string_out_string, "with get method")((b: String) => pureResult(b.asRight[Unit])) { (backend, baseUri) => basicRequest.get(uri"$baseUri/api/echo").body("Sweet").send(backend).map(_.body shouldBe Symbol("left")) }, - testServer(in_mapped_query_out_string)((fruit: List[Char]) => pureResult(s"fruit length: ${fruit.length}".asRight[Unit])) { baseUri => - basicRequest.get(uri"$baseUri?fruit=orange").send(backend).map(_.body shouldBe Right("fruit length: 6")) + testServer(in_mapped_query_out_string)((fruit: List[Char]) => pureResult(s"fruit length: ${fruit.length}".asRight[Unit])) { + (backend, baseUri) => + basicRequest.get(uri"$baseUri?fruit=orange").send(backend).map(_.body shouldBe Right("fruit length: 6")) }, - testServer(in_mapped_path_out_string)((fruit: Fruit) => pureResult(s"$fruit".asRight[Unit])) { baseUri => + testServer(in_mapped_path_out_string)((fruit: Fruit) => pureResult(s"$fruit".asRight[Unit])) { (backend, baseUri) => basicRequest.get(uri"$baseUri/fruit/kiwi").send(backend).map(_.body shouldBe Right("Fruit(kiwi)")) }, - testServer(in_mapped_path_path_out_string)((p1: FruitAmount) => pureResult(s"FA: $p1".asRight[Unit])) { baseUri => + testServer(in_mapped_path_path_out_string)((p1: FruitAmount) => pureResult(s"FA: $p1".asRight[Unit])) { (backend, baseUri) => basicRequest.get(uri"$baseUri/fruit/orange/amount/10").send(backend).map(_.body shouldBe Right("FA: FruitAmount(orange,10)")) }, testServer(in_query_mapped_path_path_out_string) { case (fa: FruitAmount, color: String) => pureResult(s"FA: $fa color: $color".asRight[Unit]) - } { baseUri => + } { (backend, baseUri) => basicRequest .get(uri"$baseUri/fruit/orange/amount/10?color=yellow") .send(backend) .map(_.body shouldBe Right("FA: FruitAmount(orange,10) color: yellow")) }, - testServer(in_query_out_mapped_string)((p1: String) => pureResult(p1.toList.asRight[Unit])) { baseUri => + testServer(in_query_out_mapped_string)((p1: String) => pureResult(p1.toList.asRight[Unit])) { (backend, baseUri) => basicRequest.get(uri"$baseUri?fruit=orange").send(backend).map(_.body shouldBe Right("orange")) }, - testServer(in_query_out_mapped_string_header)((p1: String) => pureResult(FruitAmount(p1, p1.length).asRight[Unit])) { baseUri => - basicRequest.get(uri"$baseUri?fruit=orange").send(backend).map { r => - r.body shouldBe Right("orange") - r.header("X-Role") shouldBe Some("6") - } + testServer(in_query_out_mapped_string_header)((p1: String) => pureResult(FruitAmount(p1, p1.length).asRight[Unit])) { + (backend, baseUri) => + basicRequest.get(uri"$baseUri?fruit=orange").send(backend).map { r => + r.body shouldBe Right("orange") + r.header("X-Role") shouldBe Some("6") + } }, testServer(in_header_before_path, "Header input before path capture input") { case (str: String, i: Int) => pureResult((i, str).asRight[Unit]) - } { baseUri => + } { (backend, baseUri) => basicRequest.get(uri"$baseUri/12").header("SomeHeader", "hello").send(backend).map { response => response.body shouldBe Right("hello") response.header("IntHeader") shouldBe Some("12") } }, testServer(in_json_out_json)((fa: FruitAmount) => pureResult(FruitAmount(fa.fruit + " banana", fa.amount * 2).asRight[Unit])) { - baseUri => + (backend, baseUri) => basicRequest .post(uri"$baseUri/api/echo") .body("""{"fruit":"orange","amount":11}""") .send(backend) .map(_.body shouldBe Right("""{"fruit":"orange banana","amount":22}""")) }, - testServer(in_json_out_json, "with accept header")((fa: FruitAmount) => pureResult(fa.asRight[Unit])) { baseUri => + testServer(in_json_out_json, "with accept header")((fa: FruitAmount) => pureResult(fa.asRight[Unit])) { (backend, baseUri) => basicRequest .post(uri"$baseUri/api/echo") .body("""{"fruit":"banana","amount":12}""") @@ -167,37 +167,30 @@ class ServerBasicTests[F[_], ROUTE, B]( .send(backend) .map(_.body shouldBe Right("""{"fruit":"banana","amount":12}""")) }, - testServer(in_json_out_json, "content type")((fa: FruitAmount) => pureResult(fa.asRight[Unit])) { baseUri => + testServer(in_json_out_json, "content type")((fa: FruitAmount) => pureResult(fa.asRight[Unit])) { (backend, baseUri) => basicRequest .post(uri"$baseUri/api/echo") .body("""{"fruit":"banana","amount":12}""") .send(backend) .map(_.contentType shouldBe Some(sttp.model.MediaType.ApplicationJson.toString)) }, - testServer(in_byte_array_out_byte_array)((b: Array[Byte]) => pureResult(b.asRight[Unit])) { baseUri => + testServer(in_byte_array_out_byte_array)((b: Array[Byte]) => pureResult(b.asRight[Unit])) { (backend, baseUri) => basicRequest.post(uri"$baseUri/api/echo").body("banana kiwi".getBytes).send(backend).map(_.body shouldBe Right("banana kiwi")) }, - testServer(in_byte_buffer_out_byte_buffer)((b: ByteBuffer) => pureResult(b.asRight[Unit])) { baseUri => + testServer(in_byte_buffer_out_byte_buffer)((b: ByteBuffer) => pureResult(b.asRight[Unit])) { (backend, baseUri) => basicRequest.post(uri"$baseUri/api/echo").body("mango").send(backend).map(_.body shouldBe Right("mango")) }, - testServer(in_unit_out_json_unit, "unit json mapper")((_: Unit) => pureResult(().asRight[Unit])) { baseUri => + testServer(in_unit_out_json_unit, "unit json mapper")((_: Unit) => pureResult(().asRight[Unit])) { (backend, baseUri) => basicRequest.get(uri"$baseUri/api/unit").send(backend).map(_.body shouldBe Right("{}")) }, - testServer(in_unit_out_string, "default status mapper")((_: Unit) => pureResult("".asRight[Unit])) { baseUri => + testServer(in_unit_out_string, "default status mapper")((_: Unit) => pureResult("".asRight[Unit])) { (backend, baseUri) => basicRequest.get(uri"$baseUri/not-existing-path").send(backend).map(_.code shouldBe StatusCode.NotFound) }, - testServer(in_unit_error_out_string, "default error status mapper")((_: Unit) => pureResult("".asLeft[Unit])) { baseUri => + testServer(in_unit_error_out_string, "default error status mapper")((_: Unit) => pureResult("".asLeft[Unit])) { (backend, baseUri) => basicRequest.get(uri"$baseUri/api").send(backend).map(_.code shouldBe StatusCode.BadRequest) }, - testServer(in_file_out_file)((file: File) => pureResult(file.asRight[Unit])) { baseUri => - basicRequest - .post(uri"$baseUri/api/echo") - .body("pen pineapple apple pen") - .send(backend) - .map(_.body shouldBe Right("pen pineapple apple pen")) - }, testServer(in_form_out_form)((fa: FruitAmount) => pureResult(fa.copy(fruit = fa.fruit.reverse, amount = fa.amount + 1).asRight[Unit])) { - baseUri => + (backend, baseUri) => basicRequest .post(uri"$baseUri/api/echo") .body(Map("fruit" -> "plum", "amount" -> "10")) @@ -206,7 +199,7 @@ class ServerBasicTests[F[_], ROUTE, B]( }, testServer(in_query_params_out_string)((mqp: QueryParams) => pureResult(mqp.toSeq.sortBy(_._1).map(p => s"${p._1}=${p._2}").mkString("&").asRight[Unit]) - ) { baseUri => + ) { (backend, baseUri) => val params = Map("name" -> "apple", "weight" -> "42", "kind" -> "very good") basicRequest .get(uri"$baseUri/api/echo/params?$params") @@ -215,81 +208,31 @@ class ServerBasicTests[F[_], ROUTE, B]( }, testServer(in_query_params_out_string, "should support value-less query param")((mqp: QueryParams) => pureResult(mqp.toMultiMap.map(data => s"${data._1}=${data._2}").mkString("&").asRight[Unit]) - ) { baseUri => + ) { (backend, baseUri) => basicRequest .get(uri"$baseUri/api/echo/params?flag") .send(backend) .map(_.body shouldBe Right("flag=List()")) }, testServer(in_headers_out_headers)((hs: List[Header]) => pureResult(hs.map(h => Header(h.name, h.value.reverse)).asRight[Unit])) { - baseUri => + (backend, baseUri) => basicRequest .get(uri"$baseUri/api/echo/headers") .headers(Header.unsafeApply("X-Fruit", "apple"), Header.unsafeApply("Y-Fruit", "Orange")) .send(backend) .map(_.headers should contain allOf (Header.unsafeApply("X-Fruit", "elppa"), Header.unsafeApply("Y-Fruit", "egnarO"))) }, - testServer(in_paths_out_string)((ps: Seq[String]) => pureResult(ps.mkString(" ").asRight[Unit])) { baseUri => + testServer(in_paths_out_string)((ps: Seq[String]) => pureResult(ps.mkString(" ").asRight[Unit])) { (backend, baseUri) => basicRequest.get(uri"$baseUri/hello/it/is/me/hal").send(backend).map(_.body shouldBe Right("hello it is me hal")) }, testServer(in_paths_out_string, "paths should match empty path")((ps: Seq[String]) => pureResult(ps.mkString(" ").asRight[Unit])) { - baseUri => basicRequest.get(uri"$baseUri").send(backend).map(_.body shouldBe Right("")) - }, - testServer(in_simple_multipart_out_multipart)((fa: FruitAmount) => - pureResult(FruitAmount(fa.fruit + " apple", fa.amount * 2).asRight[Unit]) - ) { baseUri => - basicStringRequest - .post(uri"$baseUri/api/echo/multipart") - .multipartBody(multipart("fruit", "pineapple"), multipart("amount", "120")) - .send(backend) - .map { r => - r.body should include regex "name=\"fruit\"[\\s\\S]*pineapple apple" - r.body should include regex "name=\"amount\"[\\s\\S]*240" - } + (backend, baseUri) => basicRequest.get(uri"$baseUri").send(backend).map(_.body shouldBe Right("")) }, - testServer(in_file_multipart_out_multipart)((fd: FruitData) => - pureResult( - FruitData( - Part("", writeToFile(Await.result(readFromFile(fd.data.body), 3.seconds).reverse), fd.data.otherDispositionParams, Nil) - .header("X-Auth", fd.data.headers.find(_.is("X-Auth")).map(_.value).toString) - ).asRight[Unit] - ) - ) { baseUri => - val file = writeToFile("peach mario") - basicStringRequest - .post(uri"$baseUri/api/echo/multipart") - .multipartBody(multipartFile("data", file).fileName("fruit-data.txt").header("X-Auth", "12Aa")) - .send(backend) - .map { r => - r.code shouldBe StatusCode.Ok - if (multipartInlineHeaderSupport) r.body should include regex "X-Auth: Some\\(12Aa\\)" - r.body should include regex "name=\"data\"[\\s\\S]*oiram hcaep" - } + testServer(in_query_out_string, "invalid query parameter")((fruit: String) => pureResult(s"fruit: $fruit".asRight[Unit])) { + (backend, baseUri) => + basicRequest.get(uri"$baseUri?fruit2=orange").send(backend).map(_.code shouldBe StatusCode.BadRequest) }, - testServer(in_raw_multipart_out_string)((parts: Seq[Part[Array[Byte]]]) => - pureResult( - parts.map(part => s"${part.name}:${new String(part.body)}").mkString("\n").asRight[Unit] - ) - ) { baseUri => - val file1 = writeToFile("peach mario") - val file2 = writeToFile("daisy luigi") - basicStringRequest - .post(uri"$baseUri/api/echo/multipart") - .multipartBody( - multipartFile("file1", file1).fileName("file1.txt"), - multipartFile("file2", file2).fileName("file2.txt") - ) - .send(backend) - .map { r => - r.code shouldBe StatusCode.Ok - r.body should include("file1:peach mario") - r.body should include("file2:daisy luigi") - } - }, - testServer(in_query_out_string, "invalid query parameter")((fruit: String) => pureResult(s"fruit: $fruit".asRight[Unit])) { baseUri => - basicRequest.get(uri"$baseUri?fruit2=orange").send(backend).map(_.code shouldBe StatusCode.BadRequest) - }, - testServer(in_query_list_out_header_list)((l: List[String]) => pureResult(("v0" :: l).reverse.asRight[Unit])) { baseUri => + testServer(in_query_list_out_header_list)((l: List[String]) => pureResult(("v0" :: l).reverse.asRight[Unit])) { (backend, baseUri) => basicRequest .get(uri"$baseUri/api/echo/param-to-header?qq=${List("v1", "v2", "v3")}") .send(backend) @@ -303,14 +246,14 @@ class ServerBasicTests[F[_], ROUTE, B]( }, testServer(in_cookies_out_cookies)((cs: List[sttp.model.headers.Cookie]) => pureResult(cs.map(c => CookieWithMeta.unsafeApply(c.name, c.value.reverse)).asRight[Unit]) - ) { baseUri => + ) { (backend, baseUri) => basicRequest.get(uri"$baseUri/api/echo/headers").cookies(("c1", "v1"), ("c2", "v2")).send(backend).map { r => r.unsafeCookies.map(c => (c.name, c.value)).toList shouldBe List(("c1", "1v"), ("c2", "2v")) } }, testServer(in_set_cookie_value_out_set_cookie_value)((c: CookieValueWithMeta) => pureResult(c.copy(value = c.value.reverse).asRight[Unit]) - ) { baseUri => + ) { (backend, baseUri) => basicRequest.get(uri"$baseUri/api/echo/headers").header("Set-Cookie", "c1=xy; HttpOnly; Path=/").send(backend).map { r => r.unsafeCookies.toList shouldBe List( CookieWithMeta.unsafeApply("c1", "yx", None, None, None, Some("/"), secure = false, httpOnly = true) @@ -318,25 +261,25 @@ class ServerBasicTests[F[_], ROUTE, B]( } }, testServer(in_string_out_content_type_string, "dynamic content type")((b: String) => pureResult((b, "image/png").asRight[Unit])) { - baseUri => + (backend, baseUri) => basicStringRequest.get(uri"$baseUri/api/echo").body("test").send(backend).map { r => r.contentType shouldBe Some("image/png") r.body shouldBe "test" } }, - testServer(in_content_type_out_string)((ct: String) => pureResult(ct.asRight[Unit])) { baseUri => + testServer(in_content_type_out_string)((ct: String) => pureResult(ct.asRight[Unit])) { (backend, baseUri) => basicRequest.get(uri"$baseUri/api/echo").contentType("application/dicom+json").send(backend).map { r => r.body shouldBe Right("application/dicom+json") } }, - testServer(in_content_type_fixed_header, "mismatch content-type")((_: Unit) => pureResult(().asRight[Unit])) { baseUri => + testServer(in_content_type_fixed_header, "mismatch content-type")((_: Unit) => pureResult(().asRight[Unit])) { (backend, baseUri) => basicRequest .post(uri"$baseUri/api/echo") .contentType(MediaType.ApplicationXml) .send(backend) .map(_.code shouldBe StatusCode.UnsupportedMediaType) }, - testServer(in_content_type_fixed_header, "missing content-type")((_: Unit) => pureResult(().asRight[Unit])) { baseUri => + testServer(in_content_type_fixed_header, "missing content-type")((_: Unit) => pureResult(().asRight[Unit])) { (backend, baseUri) => basicRequest .post(uri"$baseUri/api/echo") .send(backend) @@ -344,7 +287,7 @@ class ServerBasicTests[F[_], ROUTE, B]( }, testServer(in_content_type_header_with_custom_decode_results, "mismatch content-type")((_: MediaType) => pureResult(Either.right[Unit, Unit](())) - ) { baseUri => + ) { (backend, baseUri) => basicRequest .post(uri"$baseUri/api/echo") .contentType(MediaType.ApplicationXml) @@ -353,27 +296,27 @@ class ServerBasicTests[F[_], ROUTE, B]( }, testServer(in_content_type_header_with_custom_decode_results, "missing content-type")((_: MediaType) => pureResult(Either.right[Unit, Unit](())) - ) { baseUri => + ) { (backend, baseUri) => basicRequest .post(uri"$baseUri/api/echo") .send(backend) .map(_.code shouldBe StatusCode.BadRequest) }, - testServer(in_unit_out_html)(_ => pureResult("".asRight[Unit])) { baseUri => + testServer(in_unit_out_html)(_ => pureResult("".asRight[Unit])) { (backend, baseUri) => basicRequest.get(uri"$baseUri/api/echo").send(backend).map { r => r.contentType shouldBe Some("text/html; charset=UTF-8") } }, - testServer(in_unit_out_header_redirect)(_ => pureResult("http://new.com".asRight[Unit])) { baseUri => + testServer(in_unit_out_header_redirect)(_ => pureResult("http://new.com".asRight[Unit])) { (backend, baseUri) => basicRequest.followRedirects(false).get(uri"$baseUri").send(backend).map { r => r.code shouldBe StatusCode.PermanentRedirect r.header("Location") shouldBe Some("http://new.com") } }, - testServer(in_unit_out_fixed_header)(_ => pureResult(().asRight[Unit])) { baseUri => + testServer(in_unit_out_fixed_header)(_ => pureResult(().asRight[Unit])) { (backend, baseUri) => basicRequest.get(uri"$baseUri").send(backend).map { r => r.header("Location") shouldBe Some("Poland") } }, - testServer(in_optional_json_out_optional_json)((fa: Option[FruitAmount]) => pureResult(fa.asRight[Unit])) { baseUri => + testServer(in_optional_json_out_optional_json)((fa: Option[FruitAmount]) => pureResult(fa.asRight[Unit])) { (backend, baseUri) => basicRequest .post(uri"$baseUri/api/echo") .send(backend) @@ -388,30 +331,35 @@ class ServerBasicTests[F[_], ROUTE, B]( .map(_.body shouldBe Right("""{"fruit":"orange","amount":11}""")) }, // path matching - testServer(endpoint, "no path should match anything")((_: Unit) => pureResult(Either.right[Unit, Unit](()))) { baseUri => + testServer(endpoint, "no path should match anything")((_: Unit) => pureResult(Either.right[Unit, Unit](()))) { (backend, baseUri) => basicRequest.get(uri"$baseUri").send(backend).map(_.code shouldBe StatusCode.Ok) >> basicRequest.get(uri"$baseUri/").send(backend).map(_.code shouldBe StatusCode.Ok) >> basicRequest.get(uri"$baseUri/nonemptypath").send(backend).map(_.code shouldBe StatusCode.Ok) >> basicRequest.get(uri"$baseUri/nonemptypath/nonemptypath2").send(backend).map(_.code shouldBe StatusCode.Ok) }, - testServer(in_root_path, "root path should not match non-root path")((_: Unit) => pureResult(Either.right[Unit, Unit](()))) { baseUri => - basicRequest.get(uri"$baseUri/nonemptypath").send(backend).map(_.code shouldBe StatusCode.NotFound) + testServer(in_root_path, "root path should not match non-root path")((_: Unit) => pureResult(Either.right[Unit, Unit](()))) { + (backend, baseUri) => + basicRequest.get(uri"$baseUri/nonemptypath").send(backend).map(_.code shouldBe StatusCode.NotFound) }, - testServer(in_root_path, "root path should match empty path")((_: Unit) => pureResult(Either.right[Unit, Unit](()))) { baseUri => - basicRequest.get(uri"$baseUri").send(backend).map(_.code shouldBe StatusCode.Ok) + testServer(in_root_path, "root path should match empty path")((_: Unit) => pureResult(Either.right[Unit, Unit](()))) { + (backend, baseUri) => + basicRequest.get(uri"$baseUri").send(backend).map(_.code shouldBe StatusCode.Ok) }, - testServer(in_root_path, "root path should match root path")((_: Unit) => pureResult(Either.right[Unit, Unit](()))) { baseUri => - basicRequest.get(uri"$baseUri/").send(backend).map(_.code shouldBe StatusCode.Ok) + testServer(in_root_path, "root path should match root path")((_: Unit) => pureResult(Either.right[Unit, Unit](()))) { + (backend, baseUri) => + basicRequest.get(uri"$baseUri/").send(backend).map(_.code shouldBe StatusCode.Ok) }, - testServer(in_single_path, "single path should match single path")((_: Unit) => pureResult(Either.right[Unit, Unit](()))) { baseUri => - basicRequest.get(uri"$baseUri/api").send(backend).map(_.code shouldBe StatusCode.Ok) + testServer(in_single_path, "single path should match single path")((_: Unit) => pureResult(Either.right[Unit, Unit](()))) { + (backend, baseUri) => + basicRequest.get(uri"$baseUri/api").send(backend).map(_.code shouldBe StatusCode.Ok) }, - testServer(in_single_path, "single path should match single/ path")((_: Unit) => pureResult(Either.right[Unit, Unit](()))) { baseUri => - basicRequest.get(uri"$baseUri/api/").send(backend).map(_.code shouldBe StatusCode.Ok) + testServer(in_single_path, "single path should match single/ path")((_: Unit) => pureResult(Either.right[Unit, Unit](()))) { + (backend, baseUri) => + basicRequest.get(uri"$baseUri/api/").send(backend).map(_.code shouldBe StatusCode.Ok) }, testServer(in_path_paths_out_header_body, "Capturing paths after path capture") { case (i, paths) => pureResult(Right((i, paths.mkString(",")))) - } { baseUri => + } { (backend, baseUri) => basicRequest.get(uri"$baseUri/api/15/and/some/more/path").send(backend).map { r => r.code shouldBe StatusCode.Ok r.header("IntPath") shouldBe Some("15") @@ -420,47 +368,49 @@ class ServerBasicTests[F[_], ROUTE, B]( }, testServer(in_path_paths_out_header_body, "Capturing paths after path capture (when empty)") { case (i, paths) => pureResult(Right((i, paths.mkString(",")))) - } { baseUri => + } { (backend, baseUri) => basicRequest.get(uri"$baseUri/api/15/and/").send(backend).map { r => r.code shouldBe StatusCode.Ok r.header("IntPath") shouldBe Some("15") r.body shouldBe Right("") } }, - testServer(in_single_path, "single path should not match root path")((_: Unit) => pureResult(Either.right[Unit, Unit](()))) { baseUri => - basicRequest.get(uri"$baseUri").send(backend).map(_.code shouldBe StatusCode.NotFound) >> - basicRequest.get(uri"$baseUri/").send(backend).map(_.code shouldBe StatusCode.NotFound) + testServer(in_single_path, "single path should not match root path")((_: Unit) => pureResult(Either.right[Unit, Unit](()))) { + (backend, baseUri) => + basicRequest.get(uri"$baseUri").send(backend).map(_.code shouldBe StatusCode.NotFound) >> + basicRequest.get(uri"$baseUri/").send(backend).map(_.code shouldBe StatusCode.NotFound) }, testServer(in_single_path, "single path should not match larger path")((_: Unit) => pureResult(Either.right[Unit, Unit](()))) { - baseUri => + (backend, baseUri) => basicRequest.get(uri"$baseUri/api/echo/hello").send(backend).map(_.code shouldBe StatusCode.NotFound) >> basicRequest.get(uri"$baseUri/api/echo/").send(backend).map(_.code shouldBe StatusCode.NotFound) }, - testServer(in_string_out_status, "custom status code")((_: String) => pureResult(StatusCode(470).asRight[Unit])) { baseUri => + testServer(in_string_out_status, "custom status code")((_: String) => pureResult(StatusCode(470).asRight[Unit])) { (backend, baseUri) => basicRequest.get(uri"$baseUri?fruit=apple").send(backend).map(_.code shouldBe StatusCode(470)) }, testServer(in_string_out_status_from_string)((v: String) => pureResult((if (v == "apple") Right("x") else Left(10)).asRight[Unit])) { - baseUri => + (backend, baseUri) => basicRequest.get(uri"$baseUri?fruit=apple").send(backend).map(_.code shouldBe StatusCode.Ok) >> basicRequest.get(uri"$baseUri?fruit=orange").send(backend).map(_.code shouldBe StatusCode.Accepted) }, - testServer(in_int_out_value_form_exact_match)((num: Int) => pureResult(if (num % 2 == 0) Right("A") else Right("B"))) { baseUri => - basicRequest.get(uri"$baseUri/mapping?num=1").send(backend).map(_.code shouldBe StatusCode.Ok) >> - basicRequest.get(uri"$baseUri/mapping?num=2").send(backend).map(_.code shouldBe StatusCode.Accepted) + testServer(in_int_out_value_form_exact_match)((num: Int) => pureResult(if (num % 2 == 0) Right("A") else Right("B"))) { + (backend, baseUri) => + basicRequest.get(uri"$baseUri/mapping?num=1").send(backend).map(_.code shouldBe StatusCode.Ok) >> + basicRequest.get(uri"$baseUri/mapping?num=2").send(backend).map(_.code shouldBe StatusCode.Accepted) }, testServer(in_string_out_status_from_string_one_empty)((v: String) => pureResult((if (v == "apple") Right("x") else Left(())).asRight[Unit]) - ) { baseUri => + ) { (backend, baseUri) => basicRequest.get(uri"$baseUri?fruit=apple").send(backend).map(_.code shouldBe StatusCode.Ok) >> basicRequest.get(uri"$baseUri?fruit=orange").send(backend).map(_.code shouldBe StatusCode.Accepted) }, - testServer(in_extract_request_out_string)((v: String) => pureResult(v.asRight[Unit])) { baseUri => + testServer(in_extract_request_out_string)((v: String) => pureResult(v.asRight[Unit])) { (backend, baseUri) => basicStringRequest.get(uri"$baseUri").send(backend).map(_.body shouldBe "GET") >> basicStringRequest.post(uri"$baseUri").send(backend).map(_.body shouldBe "POST") }, testServer(in_string_out_status)((v: String) => pureResult((if (v == "apple") StatusCode.Accepted else StatusCode.NotFound).asRight[Unit]) - ) { baseUri => + ) { (backend, baseUri) => basicRequest.get(uri"$baseUri?fruit=apple").send(backend).map(_.code shouldBe StatusCode.Accepted) >> basicRequest.get(uri"$baseUri?fruit=orange").send(backend).map(_.code shouldBe StatusCode.NotFound) }, @@ -469,7 +419,7 @@ class ServerBasicTests[F[_], ROUTE, B]( in_path_fixed_capture_fixed_capture, "Returns 400 if path 'shape' matches, but failed to parse a path parameter", Some(decodeFailureHandlerBadRequestOnPathFailure) - )(_ => pureResult(Either.right[Unit, Unit](()))) { baseUri => + )(_ => pureResult(Either.right[Unit, Unit](()))) { (backend, baseUri) => basicRequest.get(uri"$baseUri/customer/asd/orders/2").send(backend).map { response => response.body shouldBe Left("Invalid value for: path parameter customer_id") response.code shouldBe StatusCode.BadRequest @@ -479,7 +429,7 @@ class ServerBasicTests[F[_], ROUTE, B]( in_path_fixed_capture_fixed_capture, "Returns 404 if path 'shape' doesn't match", Some(decodeFailureHandlerBadRequestOnPathFailure) - )(_ => pureResult(Either.right[Unit, Unit](()))) { baseUri => + )(_ => pureResult(Either.right[Unit, Unit](()))) { (backend, baseUri) => basicRequest.get(uri"$baseUri/customer").send(backend).map(response => response.code shouldBe StatusCode.NotFound) >> basicRequest.get(uri"$baseUri/customer/asd").send(backend).map(response => response.code shouldBe StatusCode.NotFound) >> basicRequest @@ -488,13 +438,13 @@ class ServerBasicTests[F[_], ROUTE, B]( .map(response => response.code shouldBe StatusCode.NotFound) }, // auth - testServer(in_auth_apikey_header_out_string)((s: String) => pureResult(s.asRight[Unit])) { baseUri => + testServer(in_auth_apikey_header_out_string)((s: String) => pureResult(s.asRight[Unit])) { (backend, baseUri) => basicStringRequest.get(uri"$baseUri/auth").header("X-Api-Key", "1234").send(backend).map(_.body shouldBe "1234") }, - testServer(in_auth_apikey_query_out_string)((s: String) => pureResult(s.asRight[Unit])) { baseUri => + testServer(in_auth_apikey_query_out_string)((s: String) => pureResult(s.asRight[Unit])) { (backend, baseUri) => basicStringRequest.get(uri"$baseUri/auth?api-key=1234").send(backend).map(_.body shouldBe "1234") }, - testServer(in_auth_basic_out_string)((up: UsernamePassword) => pureResult(up.toString.asRight[Unit])) { baseUri => + testServer(in_auth_basic_out_string)((up: UsernamePassword) => pureResult(up.toString.asRight[Unit])) { (backend, baseUri) => basicStringRequest .get(uri"$baseUri/auth") .auth @@ -502,7 +452,7 @@ class ServerBasicTests[F[_], ROUTE, B]( .send(backend) .map(_.body shouldBe "UsernamePassword(teddy,Some(bear))") }, - testServer(in_auth_bearer_out_string)((s: String) => pureResult(s.asRight[Unit])) { baseUri => + testServer(in_auth_bearer_out_string)((s: String) => pureResult(s.asRight[Unit])) { (backend, baseUri) => basicStringRequest.get(uri"$baseUri/auth").auth.bearer("1234").send(backend).map(_.body shouldBe "1234") }, // @@ -512,7 +462,7 @@ class ServerBasicTests[F[_], ROUTE, B]( route(endpoint.get.in("p1").out(stringBody).serverLogic((_: Unit) => pureResult("e1".asRight[Unit]))), route(endpoint.get.in("p1" / "p2").out(stringBody).serverLogic((_: Unit) => pureResult("e2".asRight[Unit]))) ) - ) { baseUri => + ) { (backend, baseUri) => basicStringRequest.get(uri"$baseUri/p1").send(backend).map(_.body shouldBe "e1") >> basicStringRequest.get(uri"$baseUri/p1/p2").send(backend).map(_.body shouldBe "e2") }, @@ -534,7 +484,7 @@ class ServerBasicTests[F[_], ROUTE, B]( .serverLogic((s: Array[Byte]) => pureResult(s"p2 ${s.length}".asRight[Unit])) ) ) - ) { baseUri => + ) { (backend, baseUri) => basicStringRequest .post(uri"$baseUri/p2") .body("a" * 1000000) @@ -547,7 +497,7 @@ class ServerBasicTests[F[_], ROUTE, B]( route(endpoint.get.in(query[String]("q1")).in("p1").serverLogic((_: String) => pureResult(().asRight[Unit]))), route(endpoint.get.in(query[String]("q2")).in("p2").serverLogic((_: String) => pureResult(().asRight[Unit]))) ) - ) { baseUri => + ) { (backend, baseUri) => basicRequest.get(uri"$baseUri/p1?q1=10").send(backend).map(_.code shouldBe StatusCode.Ok) >> basicRequest.get(uri"$baseUri/p1?q2=10").send(backend).map(_.code shouldBe StatusCode.BadRequest) >> basicRequest.get(uri"$baseUri/p2?q2=10").send(backend).map(_.code shouldBe StatusCode.Ok) >> @@ -559,7 +509,7 @@ class ServerBasicTests[F[_], ROUTE, B]( route(endpoint.get.in("p1").in(query[String]("q1")).out(stringBody).serverLogic((_: String) => pureResult("e1".asRight[Unit]))), route(endpoint.get.in("p1" / "p2").out(stringBody).serverLogic((_: Unit) => pureResult("e2".asRight[Unit]))) ) - ) { baseUri => basicStringRequest.get(uri"$baseUri/p1/p2").send(backend).map(_.body shouldBe "e2") }, + ) { (backend, baseUri) => basicStringRequest.get(uri"$baseUri/p1/p2").send(backend).map(_.body shouldBe "e2") }, testServer( "two endpoints with validation: should not try the second one if validation fails", NonEmptyList.of( @@ -568,12 +518,12 @@ class ServerBasicTests[F[_], ROUTE, B]( ), route(endpoint.get.in("p2").serverLogic((_: Unit) => pureResult(().asRight[Unit]))) ) - ) { baseUri => + ) { (backend, baseUri) => basicRequest.get(uri"$baseUri/p1/abcde").send(backend).map(_.code shouldBe StatusCode.Ok) >> basicRequest.get(uri"$baseUri/p1/ab").send(backend).map(_.code shouldBe StatusCode.BadRequest) >> basicRequest.get(uri"$baseUri/p2").send(backend).map(_.code shouldBe StatusCode.Ok) }, - testServer(in_header_out_header_unit_extended)(in => pureResult(in.asRight[Unit])) { baseUri => + testServer(in_header_out_header_unit_extended)(in => pureResult(in.asRight[Unit])) { (backend, baseUri) => basicRequest .get(uri"$baseUri") .header("A", "1") @@ -581,31 +531,31 @@ class ServerBasicTests[F[_], ROUTE, B]( .send(backend) .map(_.headers.map(h => h.name -> h.value).toSet should contain allOf ("Y" -> "3", "B" -> "2")) }, - testServer(in_4query_out_4header_extended)(in => pureResult(in.asRight[Unit])) { baseUri => + testServer(in_4query_out_4header_extended)(in => pureResult(in.asRight[Unit])) { (backend, baseUri) => basicRequest .get(uri"$baseUri?a=1&b=2&x=3&y=4") .send(backend) .map(_.headers.map(h => h.name -> h.value).toSet should contain allOf ("A" -> "1", "B" -> "2", "X" -> "3", "Y" -> "4")) }, - testServer(in_3query_out_3header_mapped_to_tuple)(in => pureResult(in.asRight[Unit])) { baseUri => + testServer(in_3query_out_3header_mapped_to_tuple)(in => pureResult(in.asRight[Unit])) { (backend, baseUri) => basicRequest .get(uri"$baseUri?p1=1&p2=2&p3=3") .send(backend) .map(_.headers.map(h => h.name -> h.value).toSet should contain allOf ("P1" -> "1", "P2" -> "2", "P3" -> "3")) }, - testServer(in_2query_out_2query_mapped_to_unit)(in => pureResult(in.asRight[Unit])) { baseUri => + testServer(in_2query_out_2query_mapped_to_unit)(in => pureResult(in.asRight[Unit])) { (backend, baseUri) => basicRequest .get(uri"$baseUri?p1=1&p2=2") .send(backend) .map(_.headers.map(h => h.name -> h.value).toSet should contain allOf ("P1" -> "DEFAULT_HEADER", "P2" -> "2")) }, - testServer(in_query_with_default_out_string)(in => pureResult(in.asRight[Unit])) { baseUri => + testServer(in_query_with_default_out_string)(in => pureResult(in.asRight[Unit])) { (backend, baseUri) => basicRequest.get(uri"$baseUri?p1=x").send(backend).map(_.body shouldBe Right("x")) >> basicRequest.get(uri"$baseUri").send(backend).map(_.body shouldBe Right("DEFAULT")) }, testServer(out_json_or_default_json)(entityType => pureResult((if (entityType == "person") Person("mary", 20) else Organization("work")).asRight[Unit]) - ) { baseUri => + ) { (backend, baseUri) => basicRequest.get(uri"$baseUri/entity/person").send(backend).map { r => r.code shouldBe StatusCode.Created r.body.right.get should include("mary") @@ -616,10 +566,10 @@ class ServerBasicTests[F[_], ROUTE, B]( } }, // - testServer(endpoint, "handle exceptions")(_ => throw new RuntimeException()) { baseUri => + testServer(endpoint, "handle exceptions")(_ => throw new RuntimeException()) { (backend, baseUri) => basicRequest.get(uri"$baseUri").send(backend).map(_.code shouldBe StatusCode.InternalServerError) }, - testServer(out_json_xml_text_common_schema)(_ => pureResult(Organization("sml").asRight[Unit])) { baseUri => + testServer(out_json_xml_text_common_schema)(_ => pureResult(Organization("sml").asRight[Unit])) { (backend, baseUri) => def ok(body: String) = (StatusCode.Ok, body.asRight[String]) def unsupportedMediaType() = (StatusCode.UnsupportedMediaType, "".asLeft[String]) def badRequest() = (StatusCode.BadRequest, "".asLeft[String]) @@ -662,15 +612,16 @@ class ServerBasicTests[F[_], ROUTE, B]( } }) }, - testServer(in_root_path, testNameSuffix = "accepts header without output body")(_ => pureResult(().asRight[Unit])) { baseUri => - basicRequest.header(HeaderNames.Accept, "text/plain").get(uri"$baseUri").send(backend).map(_.code shouldBe StatusCode.Ok) + testServer(in_root_path, testNameSuffix = "accepts header without output body")(_ => pureResult(().asRight[Unit])) { + (backend, baseUri) => + basicRequest.header(HeaderNames.Accept, "text/plain").get(uri"$baseUri").send(backend).map(_.code shouldBe StatusCode.Ok) }, testServer( "recover errors from exceptions", NonEmptyList.of( routeRecoverErrors(endpoint.in(query[String]("name")).errorOut(jsonBody[FruitError]).out(stringBody), throwFruits) ) - ) { baseUri => + ) { (backend, baseUri) => basicRequest.get(uri"$baseUri?name=apple").send(backend).map(_.body shouldBe Right("ok")) >> basicRequest.get(uri"$baseUri?name=banana").send(backend).map { r => r.code shouldBe StatusCode.BadRequest @@ -682,18 +633,18 @@ class ServerBasicTests[F[_], ROUTE, B]( } }, testServer(Validation.in_query_tagged, "support query validation with tagged type")((_: String) => pureResult(().asRight[Unit])) { - baseUri => + (backend, baseUri) => basicRequest.get(uri"$baseUri?fruit=apple").send(backend).map(_.code shouldBe StatusCode.Ok) >> basicRequest.get(uri"$baseUri?fruit=orange").send(backend).map(_.code shouldBe StatusCode.BadRequest) >> basicRequest.get(uri"$baseUri?fruit=banana").send(backend).map(_.code shouldBe StatusCode.Ok) }, - testServer(Validation.in_query, "support query validation")((_: Int) => pureResult(().asRight[Unit])) { baseUri => + testServer(Validation.in_query, "support query validation")((_: Int) => pureResult(().asRight[Unit])) { (backend, baseUri) => basicRequest.get(uri"$baseUri?amount=3").send(backend).map(_.code shouldBe StatusCode.Ok) >> basicRequest.get(uri"$baseUri?amount=-3").send(backend).map(_.code shouldBe StatusCode.BadRequest) }, testServer(Validation.in_valid_json, "support jsonBody validation with wrapped type")((_: ValidFruitAmount) => pureResult(().asRight[Unit]) - ) { baseUri => + ) { (backend, baseUri) => basicRequest.get(uri"$baseUri").body("""{"fruit":"orange","amount":11}""").send(backend).map(_.code shouldBe StatusCode.Ok) >> basicRequest .get(uri"$baseUri") @@ -703,14 +654,14 @@ class ServerBasicTests[F[_], ROUTE, B]( basicRequest.get(uri"$baseUri").body("""{"fruit":"orange","amount":1}""").send(backend).map(_.code shouldBe StatusCode.Ok) }, testServer(Validation.in_valid_query, "support query validation with wrapper type")((_: IntWrapper) => pureResult(().asRight[Unit])) { - baseUri => + (backend, baseUri) => basicRequest.get(uri"$baseUri?amount=11").send(backend).map(_.code shouldBe StatusCode.Ok) >> basicRequest.get(uri"$baseUri?amount=0").send(backend).map(_.code shouldBe StatusCode.BadRequest) >> basicRequest.get(uri"$baseUri?amount=1").send(backend).map(_.code shouldBe StatusCode.Ok) }, testServer(Validation.in_valid_json_collection, "support jsonBody validation with list of wrapped type")((_: BasketOfFruits) => pureResult(().asRight[Unit]) - ) { baseUri => + ) { (backend, baseUri) => basicRequest .get(uri"$baseUri") .body("""{"fruits":[{"fruit":"orange","amount":11}]}""") @@ -732,7 +683,7 @@ class ServerBasicTests[F[_], ROUTE, B]( .out(plainBody[Int]) .serverLogic { case (x, y) => pureResult((x * y.toInt).asRight[Unit]) }, "partial server logic - current, one part" - ) { baseUri => + ) { (backend, baseUri) => basicRequest.get(uri"$baseUri?x=2&y=3").send(backend).map(_.body shouldBe Right("6")) }, testServerLogic( @@ -745,7 +696,7 @@ class ServerBasicTests[F[_], ROUTE, B]( .out(plainBody[Long]) .serverLogic { case ((x, y), z) => pureResult((x * y * z.toLong).asRight[Unit]) }, "partial server logic - current, two parts" - ) { baseUri => + ) { (backend, baseUri) => basicRequest.get(uri"$baseUri?x=2&y=3&z=5").send(backend).map(_.body shouldBe Right("30")) }, testServerLogic( @@ -758,7 +709,7 @@ class ServerBasicTests[F[_], ROUTE, B]( .out(plainBody[Int]) .serverLogic { case (xy, (z, u)) => pureResult((xy * z.toInt * u.toInt).asRight[Unit]) }, "partial server logic - current, one part, multiple values" - ) { baseUri => + ) { (backend, baseUri) => basicRequest.get(uri"$baseUri?x=2&y=3&z=5&u=7").send(backend).map(_.body shouldBe Right("175")) }, testServerLogic( @@ -769,7 +720,7 @@ class ServerBasicTests[F[_], ROUTE, B]( .serverLogicPart((x: String) => pureResult(x.toInt.asRight[Unit])) .andThen { case (x, y) => pureResult((x * y.toInt).asRight[Unit]) }, "partial server logic - parts, one part" - ) { baseUri => + ) { (backend, baseUri) => basicRequest.get(uri"$baseUri?x=2&y=3").send(backend).map(_.body shouldBe Right("6")) }, testServerLogic( @@ -782,7 +733,7 @@ class ServerBasicTests[F[_], ROUTE, B]( .andThenPart((y: String) => pureResult(y.toLong.asRight[Unit])) .andThen { case ((x, y), z) => pureResult((x * y * z.toLong).asRight[Unit]) }, "partial server logic - parts, two parts" - ) { baseUri => + ) { (backend, baseUri) => basicRequest.get(uri"$baseUri?x=2&y=3&z=5").send(backend).map(_.body shouldBe Right("30")) }, testServerLogic( @@ -795,37 +746,17 @@ class ServerBasicTests[F[_], ROUTE, B]( .serverLogicPart { t: (String, String) => pureResult((t._1.toInt + t._2.toInt).asRight[Unit]) } .andThen { case (xy, (z, u)) => pureResult((xy * z.toInt * u.toInt).asRight[Unit]) }, "partial server logic - parts, one part, multiple values" - ) { baseUri => + ) { (backend, baseUri) => basicRequest.get(uri"$baseUri?x=2&y=3&z=5&u=7").send(backend).map(_.body shouldBe Right("175")) } ) - def multipartInlineHeaderTests(): List[Test] = List( - testServer(in_file_multipart_out_multipart, "with part content type header")((fd: FruitData) => - pureResult( - FruitData( - Part("", fd.data.body, fd.data.otherDispositionParams, fd.data.headers) - ).asRight[Unit] - ) - ) { baseUri => - val file = writeToFile("peach mario") - basicStringRequest - .post(uri"$baseUri/api/echo/multipart") - .multipartBody(multipartFile("data", file).contentType("text/html")) - .send(backend) - .map { r => - r.code shouldBe StatusCode.Ok - r.body.toLowerCase() should include("content-type: text/html") - } - } - ) - def inputStreamTests(): List[Test] = List( testServer(in_input_stream_out_input_stream)((is: InputStream) => pureResult((new ByteArrayInputStream(inputStreamToByteArray(is)): InputStream).asRight[Unit]) - ) { baseUri => basicRequest.post(uri"$baseUri/api/echo").body("mango").send(backend).map(_.body shouldBe Right("mango")) }, + ) { (backend, baseUri) => basicRequest.post(uri"$baseUri/api/echo").body("mango").send(backend).map(_.body shouldBe Right("mango")) }, testServer(in_string_out_stream_with_header)(_ => pureResult(Right((new ByteArrayInputStream(Array.fill[Byte](128)(0)), Some(128))))) { - baseUri => + (backend, baseUri) => basicRequest.post(uri"$baseUri/api/echo").body("test string body").response(asByteArray).send(backend).map { r => r.body.map(_.length) shouldBe Right(128) r.body.map(_.foreach(b => b shouldBe 0)) diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerFileMultipartTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerFileMultipartTests.scala new file mode 100644 index 0000000000..faa88fe7c0 --- /dev/null +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerFileMultipartTests.scala @@ -0,0 +1,117 @@ +package sttp.tapir.server.tests + +import cats.implicits._ +import org.scalatest.matchers.should.Matchers._ +import sttp.client3.{basicRequest, multipartFile, _} +import sttp.model.{Part, StatusCode} +import sttp.monad.MonadError +import sttp.tapir.tests.TestUtil.{readFromFile, writeToFile} +import sttp.tapir.tests.{ + FruitAmount, + FruitData, + Test, + in_file_multipart_out_multipart, + in_file_out_file, + in_raw_multipart_out_string, + in_simple_multipart_out_multipart +} + +import java.io.File +import scala.concurrent.Await +import scala.concurrent.duration.DurationInt + +class ServerFileMultipartTests[F[_], ROUTE, B]( + createServerTest: CreateServerTest[F, Any, ROUTE, B], + multipartInlineHeaderSupport: Boolean = true +)(implicit m: MonadError[F]) { + import createServerTest._ + + private val basicStringRequest = basicRequest.response(asStringAlways) + private def pureResult[T](t: T): F[T] = m.unit(t) + + def tests(): List[Test] = + basicTests() ++ (if (multipartInlineHeaderSupport) multipartInlineHeaderTests() else Nil) + + def basicTests(): List[Test] = { + List( + testServer(in_file_out_file)((file: File) => pureResult(file.asRight[Unit])) { (backend, baseUri) => + basicRequest + .post(uri"$baseUri/api/echo") + .body("pen pineapple apple pen") + .send(backend) + .map(_.body shouldBe Right("pen pineapple apple pen")) + }, + testServer(in_simple_multipart_out_multipart)((fa: FruitAmount) => + pureResult(FruitAmount(fa.fruit + " apple", fa.amount * 2).asRight[Unit]) + ) { (backend, baseUri) => + basicStringRequest + .post(uri"$baseUri/api/echo/multipart") + .multipartBody(multipart("fruit", "pineapple"), multipart("amount", "120")) + .send(backend) + .map { r => + r.body should include regex "name=\"fruit\"[\\s\\S]*pineapple apple" + r.body should include regex "name=\"amount\"[\\s\\S]*240" + } + }, + testServer(in_file_multipart_out_multipart)((fd: FruitData) => + pureResult( + FruitData( + Part("", writeToFile(Await.result(readFromFile(fd.data.body), 3.seconds).reverse), fd.data.otherDispositionParams, Nil) + .header("X-Auth", fd.data.headers.find(_.is("X-Auth")).map(_.value).toString) + ).asRight[Unit] + ) + ) { (backend, baseUri) => + val file = writeToFile("peach mario") + basicStringRequest + .post(uri"$baseUri/api/echo/multipart") + .multipartBody(multipartFile("data", file).fileName("fruit-data.txt").header("X-Auth", "12Aa")) + .send(backend) + .map { r => + r.code shouldBe StatusCode.Ok + if (multipartInlineHeaderSupport) r.body should include regex "X-Auth: Some\\(12Aa\\)" + r.body should include regex "name=\"data\"[\\s\\S]*oiram hcaep" + } + }, + testServer(in_raw_multipart_out_string)((parts: Seq[Part[Array[Byte]]]) => + pureResult( + parts.map(part => s"${part.name}:${new String(part.body)}").mkString("\n").asRight[Unit] + ) + ) { (backend, baseUri) => + val file1 = writeToFile("peach mario") + val file2 = writeToFile("daisy luigi") + basicStringRequest + .post(uri"$baseUri/api/echo/multipart") + .multipartBody( + multipartFile("file1", file1).fileName("file1.txt"), + multipartFile("file2", file2).fileName("file2.txt") + ) + .send(backend) + .map { r => + r.code shouldBe StatusCode.Ok + r.body should include("file1:peach mario") + r.body should include("file2:daisy luigi") + } + } + ) + } + + def multipartInlineHeaderTests(): List[Test] = List( + testServer(in_file_multipart_out_multipart, "with part content type header")((fd: FruitData) => + pureResult( + FruitData( + Part("", fd.data.body, fd.data.otherDispositionParams, fd.data.headers) + ).asRight[Unit] + ) + ) { (backend, baseUri) => + val file = writeToFile("peach mario") + basicStringRequest + .post(uri"$baseUri/api/echo/multipart") + .multipartBody(multipartFile("data", file).contentType("text/html")) + .send(backend) + .map { r => + r.code shouldBe StatusCode.Ok + r.body.toLowerCase() should include("content-type: text/html") + } + } + ) +} diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMetricsTest.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMetricsTest.scala index a5ee77ddbe..be8ad4aa96 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMetricsTest.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMetricsTest.scala @@ -1,10 +1,9 @@ package sttp.tapir.server.tests -import cats.effect.IO import cats.implicits._ import org.scalatest.concurrent.Eventually.eventually import org.scalatest.matchers.should.Matchers._ -import sttp.client3.{SttpBackend, _} +import sttp.client3._ import sttp.monad.MonadError import sttp.monad.syntax._ import sttp.tapir.metrics.{EndpointMetric, Metric} @@ -16,10 +15,7 @@ import sttp.tapir.tests.{Test, _} import java.io.{ByteArrayInputStream, InputStream} import java.util.concurrent.atomic.AtomicInteger -class ServerMetricsTest[F[_], ROUTE, B]( - backend: SttpBackend[IO, Any], - createServerTest: CreateServerTest[F, Any, ROUTE, B] -)(implicit m: MonadError[F]) { +class ServerMetricsTest[F[_], ROUTE, B](createServerTest: CreateServerTest[F, Any, ROUTE, B])(implicit m: MonadError[F]) { import createServerTest._ def tests(): List[Test] = List( @@ -30,7 +26,7 @@ class ServerMetricsTest[F[_], ROUTE, B]( testServer(in_json_out_json.name("metrics"), metricsInterceptor = metrics.some)(f => (if (f.fruit == "apple") Right(f) else Left(())).unit - ) { baseUri => + ) { (backend, baseUri) => basicRequest // onDecodeSuccess path .post(uri"$baseUri/api/echo") .body("""{"fruit":"apple","amount":1}""") @@ -56,7 +52,7 @@ class ServerMetricsTest[F[_], ROUTE, B]( testServer(in_input_stream_out_input_stream.name("metrics"), metricsInterceptor = metrics.some)(is => (new ByteArrayInputStream(inputStreamToByteArray(is)): InputStream).asRight[Unit].unit - ) { baseUri => + ) { (backend, baseUri) => basicRequest .post(uri"$baseUri/api/echo") .body("okoń") @@ -72,7 +68,7 @@ class ServerMetricsTest[F[_], ROUTE, B]( val resCounter = newResponseCounter[F] val metrics = new MetricsRequestInterceptor[F, B](List(resCounter), Seq.empty) - testServer(in_root_path.name("metrics"), metricsInterceptor = metrics.some)(_ => ().asRight[Unit].unit) { baseUri => + testServer(in_root_path.name("metrics"), metricsInterceptor = metrics.some)(_ => ().asRight[Unit].unit) { (backend, baseUri) => basicRequest .get(uri"$baseUri") .send(backend) @@ -91,7 +87,7 @@ class ServerMetricsTest[F[_], ROUTE, B]( testServer(in_root_path.name("metrics on exception"), metricsInterceptor = metrics.some) { _ => Thread.sleep(100) throw new RuntimeException("Ups") - } { baseUri => + } { (backend, baseUri) => basicRequest .get(uri"$baseUri") .send(backend) diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala index 0685e23632..65f1413be7 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala @@ -1,42 +1,45 @@ package sttp.tapir.server.tests -import cats.effect.IO -import sttp.capabilities.Streams -import sttp.client3._ -import sttp.tapir.tests.{Test, in_stream_out_stream, in_stream_out_stream_with_content_length} import cats.syntax.all._ -import sttp.monad.MonadError import org.scalatest.matchers.should.Matchers._ +import sttp.capabilities.Streams +import sttp.client3._ import sttp.model.{Header, HeaderNames} +import sttp.monad.MonadError +import sttp.tapir.tests.{Test, in_stream_out_stream, in_stream_out_stream_with_content_length} -class ServerStreamingTests[F[_], S, ROUTE, B](backend: SttpBackend[IO, Any], serverTests: CreateServerTest[F, S, ROUTE, B], streams: Streams[S])( - implicit m: MonadError[F] +class ServerStreamingTests[F[_], S, ROUTE, B](createServerTest: CreateServerTest[F, S, ROUTE, B], streams: Streams[S])(implicit + m: MonadError[F] ) { private def pureResult[T](t: T): F[T] = m.unit(t) def tests(): List[Test] = { - import serverTests._ + import createServerTest._ val penPineapple = "pen pineapple apple pen" List( - testServer(in_stream_out_stream(streams))((s: streams.BinaryStream) => pureResult(s.asRight[Unit])) { baseUri => + testServer(in_stream_out_stream(streams))((s: streams.BinaryStream) => pureResult(s.asRight[Unit])) { (backend, baseUri) => basicRequest.post(uri"$baseUri/api/echo").body(penPineapple).send(backend).map(_.body shouldBe Right(penPineapple)) }, testServer( in_stream_out_stream_with_content_length(streams) - )((in: (Long, streams.BinaryStream)) => pureResult(in.asRight[Unit])) { baseUri => + )((in: (Long, streams.BinaryStream)) => pureResult(in.asRight[Unit])) { (backend, baseUri) => { - basicRequest.post(uri"$baseUri/api/echo").contentLength(penPineapple.length.toLong).body(penPineapple).send(backend).map { - response => + basicRequest + .post(uri"$baseUri/api/echo") + .contentLength(penPineapple.length.toLong) + .body(penPineapple) + .send(backend) + .map { response => response.body shouldBe Right(penPineapple) if (response.headers.contains(Header(HeaderNames.TransferEncoding, "chunked"))) { response.contentLength shouldBe None } else { response.contentLength shouldBe Some(penPineapple.length) } - } + } } } ) diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala index 12ca4ca590..591bf92106 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala @@ -4,7 +4,6 @@ import cats.effect.IO import cats.syntax.all._ import io.circe.generic.auto._ import org.scalatest.matchers.should.Matchers._ -import sttp.capabilities.fs2.Fs2Streams import sttp.capabilities.{Streams, WebSockets} import sttp.client3._ import sttp.monad.MonadError @@ -17,7 +16,6 @@ import sttp.tapir.tests.{Fruit, Test} import sttp.ws.{WebSocket, WebSocketFrame} abstract class ServerWebSocketTests[F[_], S <: Streams[S], ROUTE, B]( - backend: SttpBackend[IO, Fs2Streams[IO] with WebSockets], createServerTest: CreateServerTest[F, S with WebSockets, ROUTE, B], val streams: S )(implicit @@ -35,7 +33,7 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], ROUTE, B]( testServer( endpoint.out(stringWs), "string client-terminated echo" - )((_: Unit) => pureResult(stringEcho.asRight[Unit])) { baseUri => + )((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) => basicRequest .response(asWebSocket { ws: WebSocket[IO] => for { @@ -56,7 +54,7 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], ROUTE, B]( testServer(endpoint.out(stringWs).name("metrics"), metricsInterceptor = metrics.some)((_: Unit) => pureResult(stringEcho.asRight[Unit]) - ) { baseUri => + ) { (backend, baseUri) => basicRequest .response(asWebSocket { ws: WebSocket[IO] => for { @@ -75,7 +73,7 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], ROUTE, B]( }, testServer(endpoint.out(webSocketBody[Fruit, CodecFormat.Json, Fruit, CodecFormat.Json](streams)), "json client-terminated echo")( (_: Unit) => pureResult(functionToPipe((f: Fruit) => Fruit(s"echo: ${f.f}")).asRight[Unit]) - ) { baseUri => + ) { (backend, baseUri) => basicRequest .response(asWebSocket { ws: WebSocket[IO] => for { @@ -97,7 +95,7 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], ROUTE, B]( case "end" => None case msg => Some(s"echo: $msg") }.asRight[Unit]) - ) { baseUri => + ) { (backend, baseUri) => basicRequest .response(asWebSocket { ws: WebSocket[IO] => for { @@ -123,7 +121,7 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], ROUTE, B]( .errorOut(stringBody) .out(stringWs), "non web-socket request" - )(isWS => if (isWS) pureResult(stringEcho.asRight) else pureResult("Not a WS!".asLeft)) { baseUri => + )(isWS => if (isWS) pureResult(stringEcho.asRight) else pureResult("Not a WS!".asLeft)) { (backend, baseUri) => basicRequest .response(asString) .get(baseUri.scheme("http")) diff --git a/server/vertx/src/test/scala/sttp/tapir/server/vertx/CatsVertxServerTest.scala b/server/vertx/src/test/scala/sttp/tapir/server/vertx/CatsVertxServerTest.scala index d215029ab9..ae90da4577 100644 --- a/server/vertx/src/test/scala/sttp/tapir/server/vertx/CatsVertxServerTest.scala +++ b/server/vertx/src/test/scala/sttp/tapir/server/vertx/CatsVertxServerTest.scala @@ -4,7 +4,7 @@ import cats.effect.{IO, Resource} import io.vertx.core.Vertx import sttp.capabilities.fs2.Fs2Streams import sttp.monad.MonadError -import sttp.tapir.server.tests.{ServerAuthenticationTests, ServerBasicTests, ServerStreamingTests, CreateServerTest, backendResource} +import sttp.tapir.server.tests.{DefaultCreateServerTest, ServerAuthenticationTests, ServerBasicTests, ServerFileMultipartTests, ServerStreamingTests, backendResource} import sttp.tapir.tests.{Test, TestSuite} class CatsVertxServerTest extends TestSuite { @@ -17,16 +17,15 @@ class CatsVertxServerTest extends TestSuite { vertxResource.map { implicit vertx => implicit val m: MonadError[IO] = VertxCatsServerInterpreter.monadError[IO] val interpreter = new CatsVertxTestServerInterpreter(vertx) - val createServerTest = new CreateServerTest(interpreter) + val createServerTest = new DefaultCreateServerTest(backend, interpreter) - new ServerBasicTests( - backend, - createServerTest, - interpreter, - multipartInlineHeaderSupport = false // README: doesn't seem supported but I may be wrong - ).tests() ++ - new ServerAuthenticationTests(backend, createServerTest).tests() ++ - new ServerStreamingTests(backend, createServerTest, Fs2Streams.apply[IO]).tests() + new ServerBasicTests(createServerTest, interpreter).tests() ++ + new ServerFileMultipartTests( + createServerTest, + multipartInlineHeaderSupport = false // README: doesn't seem supported but I may be wrong + ).tests() + new ServerAuthenticationTests(createServerTest).tests() ++ + new ServerStreamingTests(createServerTest, Fs2Streams.apply[IO]).tests() } } } diff --git a/server/vertx/src/test/scala/sttp/tapir/server/vertx/VertxBlockingServerTest.scala b/server/vertx/src/test/scala/sttp/tapir/server/vertx/VertxBlockingServerTest.scala index fc4e68742b..a42a88e556 100644 --- a/server/vertx/src/test/scala/sttp/tapir/server/vertx/VertxBlockingServerTest.scala +++ b/server/vertx/src/test/scala/sttp/tapir/server/vertx/VertxBlockingServerTest.scala @@ -3,7 +3,7 @@ package sttp.tapir.server.vertx import cats.effect.{IO, Resource} import io.vertx.core.Vertx import sttp.monad.FutureMonad -import sttp.tapir.server.tests.{ServerAuthenticationTests, ServerBasicTests, CreateServerTest, backendResource} +import sttp.tapir.server.tests.{DefaultCreateServerTest, ServerAuthenticationTests, ServerBasicTests, ServerFileMultipartTests, backendResource} import sttp.tapir.tests.{Test, TestSuite} import scala.concurrent.ExecutionContext @@ -16,14 +16,14 @@ class VertxBlockingServerTest extends TestSuite { vertxResource.map { implicit vertx => implicit val m: FutureMonad = new FutureMonad()(ExecutionContext.global) val interpreter = new VertxTestServerBlockingInterpreter(vertx) - val createServerTest = new CreateServerTest(interpreter) + val createServerTest = new DefaultCreateServerTest(backend, interpreter) - new ServerBasicTests( - backend, - createServerTest, - interpreter, - multipartInlineHeaderSupport = false // README: doesn't seem supported but I may be wrong - ).tests() ++ new ServerAuthenticationTests(backend, createServerTest).tests() + new ServerBasicTests(createServerTest, interpreter).tests() ++ + new ServerFileMultipartTests( + createServerTest, + multipartInlineHeaderSupport = false // README: doesn't seem supported but I may be wrong + ).tests() ++ + new ServerAuthenticationTests(createServerTest).tests() } } } diff --git a/server/vertx/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala b/server/vertx/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala index e290a152b1..23e135bc3e 100644 --- a/server/vertx/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala +++ b/server/vertx/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala @@ -3,7 +3,14 @@ package sttp.tapir.server.vertx import cats.effect.{IO, Resource} import io.vertx.core.Vertx import sttp.monad.FutureMonad -import sttp.tapir.server.tests.{CreateServerTest, ServerAuthenticationTests, ServerBasicTests, ServerMetricsTest, backendResource} +import sttp.tapir.server.tests.{ + DefaultCreateServerTest, + ServerAuthenticationTests, + ServerBasicTests, + ServerFileMultipartTests, + ServerMetricsTest, + backendResource +} import sttp.tapir.tests.{Test, TestSuite} import scala.concurrent.ExecutionContext @@ -17,15 +24,15 @@ class VertxServerTest extends TestSuite { implicit val m: FutureMonad = new FutureMonad()(ExecutionContext.global) val interpreter = new VertxTestServerInterpreter(vertx) - val createServerTest = new CreateServerTest(interpreter) + val createServerTest = new DefaultCreateServerTest(backend, interpreter) - new ServerBasicTests( - backend, - createServerTest, - interpreter, - multipartInlineHeaderSupport = false // README: doesn't seem supported but I may be wrong - ).tests() ++ new ServerAuthenticationTests(backend, createServerTest).tests() ++ - new ServerMetricsTest(backend, createServerTest).tests() + new ServerBasicTests(createServerTest, interpreter).tests() ++ + new ServerFileMultipartTests( + createServerTest, + multipartInlineHeaderSupport = false // README: doesn't seem supported but I may be wrong + ).tests() ++ + new ServerAuthenticationTests(createServerTest).tests() ++ + new ServerMetricsTest(createServerTest).tests() } } } diff --git a/server/vertx/src/test/scala/sttp/tapir/server/vertx/ZioVertxServerTest.scala b/server/vertx/src/test/scala/sttp/tapir/server/vertx/ZioVertxServerTest.scala index d479439ba5..71ba10c074 100644 --- a/server/vertx/src/test/scala/sttp/tapir/server/vertx/ZioVertxServerTest.scala +++ b/server/vertx/src/test/scala/sttp/tapir/server/vertx/ZioVertxServerTest.scala @@ -2,12 +2,20 @@ package sttp.tapir.server.vertx import cats.effect.{IO, Resource} import io.vertx.core.Vertx +import io.vertx.ext.web.{Route, Router, RoutingContext} import sttp.capabilities.zio.ZioStreams import sttp.monad.MonadError -import sttp.tapir.server.tests.{ServerAuthenticationTests, ServerBasicTests, ServerStreamingTests, CreateServerTest, backendResource} +import sttp.tapir.server.tests.{ + DefaultCreateServerTest, + ServerAuthenticationTests, + ServerBasicTests, + ServerFileMultipartTests, + ServerStreamingTests, + backendResource +} import sttp.tapir.tests.{Test, TestSuite} -import zio.interop.catz._ import zio.Task +import zio.interop.catz._ class ZioVertxServerTest extends TestSuite { import VertxZioServerInterpreter._ @@ -20,16 +28,16 @@ class ZioVertxServerTest extends TestSuite { vertxResource.map { implicit vertx => implicit val m: MonadError[Task] = VertxZioServerInterpreter.monadError val interpreter = new ZioVertxTestServerInterpreter(vertx) - val createServerTest = new CreateServerTest(interpreter) + val createServerTest = + new DefaultCreateServerTest(backend, interpreter).asInstanceOf[DefaultCreateServerTest[Task, ZioStreams, Router => Route, RoutingContext => Unit]] - new ServerBasicTests( - backend, - createServerTest, - interpreter, - multipartInlineHeaderSupport = false // README: doesn't seem supported but I may be wrong - ).tests() ++ - new ServerAuthenticationTests(backend, createServerTest).tests() ++ - new ServerStreamingTests(backend, createServerTest, ZioStreams).tests() + new ServerBasicTests(createServerTest, interpreter).tests() ++ + new ServerFileMultipartTests( + createServerTest, + multipartInlineHeaderSupport = false // README: doesn't seem supported but I may be wrong + ).tests() ++ + new ServerAuthenticationTests(createServerTest).tests() ++ + new ServerStreamingTests(createServerTest, ZioStreams).tests() } } } diff --git a/serverless/aws/examples/src/main/scala/sttp/tapir/serverless/aws/examples/LambdaApiExample.scala b/serverless/aws/examples/src/main/scala/sttp/tapir/serverless/aws/examples/LambdaApiExample.scala new file mode 100644 index 0000000000..a165cb25d4 --- /dev/null +++ b/serverless/aws/examples/src/main/scala/sttp/tapir/serverless/aws/examples/LambdaApiExample.scala @@ -0,0 +1,46 @@ +package sttp.tapir.serverless.aws.examples + +import cats.effect.IO +import cats.syntax.all._ +import com.amazonaws.services.lambda.runtime.{Context, RequestStreamHandler} +import io.circe.Printer +import io.circe.generic.auto._ +import io.circe.parser.decode +import io.circe.syntax._ +import sttp.model.StatusCode +import sttp.tapir._ +import sttp.tapir.server.ServerEndpoint +import sttp.tapir.serverless.aws.lambda._ + +import java.io.{BufferedWriter, InputStream, OutputStream, OutputStreamWriter} +import java.nio.charset.StandardCharsets.UTF_8 + +object LambdaApiExample extends RequestStreamHandler { + + val helloEndpoint: ServerEndpoint[Unit, Unit, String, Any, IO] = endpoint.get + .in("api" / "hello") + .out(stringBody) + .serverLogic { _ => IO.pure(s"Hello!".asRight[Unit]) } + + implicit val options: AwsServerOptions[IO] = AwsServerOptions.customInterceptors(encodeResponseBody = false) + + val route: Route[IO] = AwsCatsEffectServerInterpreter.toRoute(helloEndpoint) + + override def handleRequest(input: InputStream, output: OutputStream, context: Context): Unit = { + + /** Read input as string */ + val json = new String(input.readAllBytes(), UTF_8) + + /** Decode input to `AwsRequest` which is send by API Gateway */ + (decode[AwsRequest](json) match { + /** Process request using interpreted route */ + case Right(awsRequest) => route(awsRequest) + case Left(ex) => IO.pure(AwsResponse(Nil, isBase64Encoded = false, StatusCode.BadRequest.code, Map.empty, ex.getMessage)) + }).map { awsRes => + /** Write response to output */ + val writer = new BufferedWriter(new OutputStreamWriter(output, UTF_8)) + writer.write(Printer.noSpaces.print(awsRes.asJson)) + writer.flush() + }.unsafeRunSync() + } +} diff --git a/serverless/aws/examples/src/main/scala/sttp/tapir/serverless/aws/examples/SamTemplateExample.scala b/serverless/aws/examples/src/main/scala/sttp/tapir/serverless/aws/examples/SamTemplateExample.scala new file mode 100644 index 0000000000..9e800131ee --- /dev/null +++ b/serverless/aws/examples/src/main/scala/sttp/tapir/serverless/aws/examples/SamTemplateExample.scala @@ -0,0 +1,30 @@ +package sttp.tapir.serverless.aws.examples + +import sttp.tapir.serverless.aws.examples.LambdaApiExample.helloEndpoint +import sttp.tapir.serverless.aws.sam.{AwsSamInterpreter, AwsSamOptions, CodeSource} + +import java.nio.charset.StandardCharsets.UTF_8 +import java.nio.file.{Files, Paths} + +/** Before running the actual example we need to interpret our api as SAM template */ +object SamTemplateExample extends App { + + val jarPath = Paths.get("serverless/aws/examples/target/jvm-2.13/tapir-aws-examples.jar").toAbsolutePath.toString + + implicit val samOptions: AwsSamOptions = AwsSamOptions( + "PersonsApi", + source = + /** Specifying a fat jar build from example sources */ + CodeSource( + runtime = "java11", + codeUri = jarPath, + handler = "sttp.tapir.serverless.aws.examples.LambdaApiExample::handleRequest" + ) + ) + + val templateYaml = AwsSamInterpreter.toSamTemplate(helloEndpoint).toYaml + + /** Write template to file, it's required to run the example using sam local */ + Files.write(Paths.get("template.yaml"), templateYaml.getBytes(UTF_8)) + +} diff --git a/serverless/aws/examples/src/main/scala/sttp/tapir/serverless/aws/examples/TerraformConfigExample.scala b/serverless/aws/examples/src/main/scala/sttp/tapir/serverless/aws/examples/TerraformConfigExample.scala new file mode 100644 index 0000000000..ab00b2082f --- /dev/null +++ b/serverless/aws/examples/src/main/scala/sttp/tapir/serverless/aws/examples/TerraformConfigExample.scala @@ -0,0 +1,37 @@ +package sttp.tapir.serverless.aws.examples + +import io.circe.Printer +import io.circe.syntax._ +import sttp.tapir.serverless.aws.examples.LambdaApiExample.helloEndpoint +import sttp.tapir.serverless.aws.terraform.AwsTerraformEncoders._ +import sttp.tapir.serverless.aws.terraform.{AwsTerraformApiGateway, AwsTerraformInterpreter, AwsTerraformOptions, S3Source} + +import java.nio.charset.StandardCharsets.UTF_8 +import java.nio.file.{Files, Paths} +import scala.concurrent.duration.DurationInt + +/** Before running the actual example we need to interpret our api as Terraform resources */ +object TerraformConfigExample extends App { + + if (args.length != 3) sys.error("Usage: [aws region] [s3 bucket] [s3 key]") + + val region = args(0) + val bucket = args(1) + val key = args(2) + + implicit val terraformOptions: AwsTerraformOptions = AwsTerraformOptions( + region, + functionName = "PersonsFunction", + apiGatewayName = "PersonsApiGateway", + autoDeploy = true, + functionSource = S3Source(bucket, key, "java11", "sttp.tapir.serverless.aws.examples.LambdaApiExample::handleRequest"), + timeout = 30.seconds, + memorySize = 1024 + ) + + val apiGateway: AwsTerraformApiGateway = AwsTerraformInterpreter.toTerraformConfig(helloEndpoint) + + val apiGatewayConfig = Printer.spaces2.print(apiGateway.asJson) + + Files.write(Paths.get("api_gateway.tf.json"), apiGatewayConfig.getBytes(UTF_8)) +} diff --git a/serverless/aws/lambda-tests/src/main/scala/sttp/tapir/serverless/aws/lambda/tests/LambdaHandler.scala b/serverless/aws/lambda-tests/src/main/scala/sttp/tapir/serverless/aws/lambda/tests/LambdaHandler.scala new file mode 100644 index 0000000000..a714e16103 --- /dev/null +++ b/serverless/aws/lambda-tests/src/main/scala/sttp/tapir/serverless/aws/lambda/tests/LambdaHandler.scala @@ -0,0 +1,32 @@ +package sttp.tapir.serverless.aws.lambda.tests + +import cats.effect.IO +import com.amazonaws.services.lambda.runtime.{Context, RequestStreamHandler} +import io.circe.Printer +import io.circe.generic.auto._ +import io.circe.parser.decode +import io.circe.syntax._ +import sttp.model.StatusCode +import sttp.tapir.serverless.aws.lambda._ + +import java.io.{BufferedWriter, InputStream, OutputStream, OutputStreamWriter} +import java.nio.charset.StandardCharsets + +object LambdaHandler extends RequestStreamHandler { + override def handleRequest(input: InputStream, output: OutputStream, context: Context): Unit = { + + implicit val options: AwsServerOptions[IO] = AwsServerOptions.customInterceptors(encodeResponseBody = false) + + val route: Route[IO] = AwsCatsEffectServerInterpreter.toRoute(allEndpoints.toList) + val json = new String(input.readAllBytes(), StandardCharsets.UTF_8) + + (decode[AwsRequest](json) match { + case Right(awsRequest) => route(awsRequest) + case Left(_) => IO.pure(AwsResponse(Nil, isBase64Encoded = false, StatusCode.BadRequest.code, Map.empty, "")) + }).map { awsRes => + val writer = new BufferedWriter(new OutputStreamWriter(output, StandardCharsets.UTF_8)) + writer.write(Printer.noSpaces.print(awsRes.asJson)) + writer.flush() + }.unsafeRunSync() + } +} diff --git a/serverless/aws/lambda-tests/src/main/scala/sttp/tapir/serverless/aws/lambda/tests/LambdaSamTemplate.scala b/serverless/aws/lambda-tests/src/main/scala/sttp/tapir/serverless/aws/lambda/tests/LambdaSamTemplate.scala new file mode 100644 index 0000000000..61da8e63e8 --- /dev/null +++ b/serverless/aws/lambda-tests/src/main/scala/sttp/tapir/serverless/aws/lambda/tests/LambdaSamTemplate.scala @@ -0,0 +1,23 @@ +package sttp.tapir.serverless.aws.lambda.tests + +import sttp.tapir.serverless.aws.sam._ + +import java.nio.charset.StandardCharsets.UTF_8 +import java.nio.file.{Files, Paths} + +object LambdaSamTemplate extends App { + + val jarPath = Paths.get("serverless/aws/lambda-tests/target/jvm-2.13/tapir-aws-lambda-tests.jar").toAbsolutePath.toString + + implicit val samOptions: AwsSamOptions = AwsSamOptions( + "Tests", + source = CodeSource( + "java11", + jarPath, + "sttp.tapir.serverless.aws.lambda.tests.LambdaHandler::handleRequest" + ), + memorySize = 1024 + ) + val yaml = AwsSamInterpreter.toSamTemplate(allEndpoints.map(_.endpoint).toList).toYaml + Files.write(Paths.get("template.yaml"), yaml.getBytes(UTF_8)) +} diff --git a/serverless/aws/lambda-tests/src/main/scala/sttp/tapir/serverless/aws/lambda/tests/package.scala b/serverless/aws/lambda-tests/src/main/scala/sttp/tapir/serverless/aws/lambda/tests/package.scala new file mode 100644 index 0000000000..77cd5a631b --- /dev/null +++ b/serverless/aws/lambda-tests/src/main/scala/sttp/tapir/serverless/aws/lambda/tests/package.scala @@ -0,0 +1,43 @@ +package sttp.tapir.serverless.aws.lambda + +import cats.effect.IO +import cats.implicits._ +import com.softwaremill.macwire.wireSet +import sttp.model.Header +import sttp.tapir._ +import sttp.tapir.server.ServerEndpoint +import sttp.tapir.tests.TestUtil.inputStreamToByteArray +import sttp.tapir.tests._ + +import java.io.{ByteArrayInputStream, InputStream} + +package object tests { + + // this endpoint is used to wait until sam local starts up before running actual tests + val health_endpoint: ServerEndpoint[Unit, Unit, Unit, Any, IO] = endpoint.get.in("health").serverLogic(_ => IO.pure(().asRight[Unit])) + + val in_path_path_out_string_endpoint: ServerEndpoint[(String, Port), Unit, String, Any, IO] = in_path_path_out_string.serverLogic { + case (fruit: String, amount: Int) => IO.pure(s"$fruit $amount".asRight[Unit]) + } + + val in_string_out_string_endpoint: ServerEndpoint[String, Unit, String, Any, IO] = + in_string_out_string.in("string").serverLogic(s => IO.pure(s.asRight[Unit])) + + val in_json_out_json_endpoint: ServerEndpoint[FruitAmount, Unit, FruitAmount, Any, IO] = + in_json_out_json.in("json").serverLogic(fa => IO.pure(fa.asRight[Unit])) + + val in_headers_out_headers_endpoint: ServerEndpoint[List[Header], Unit, List[Header], Any, IO] = in_headers_out_headers.serverLogic { + headers => IO.pure(headers.asRight[Unit]) + } + + val in_input_stream_out_input_stream_endpoint: ServerEndpoint[InputStream, Unit, InputStream, Any, IO] = + in_input_stream_out_input_stream.in("is").serverLogic { is => + IO.pure((new ByteArrayInputStream(inputStreamToByteArray(is)): InputStream).asRight[Unit]) + } + + val in_4query_out_4header_extended_endpoint + : ServerEndpoint[((String, String), String, String), Unit, ((String, String), String, String), Any, IO] = + in_4query_out_4header_extended.in("echo" / "query").serverLogic { in => IO.pure(in.asRight[Unit]) } + + val allEndpoints: Set[ServerEndpoint[_, _, _, Any, IO]] = wireSet[ServerEndpoint[_, _, _, Any, IO]] +} diff --git a/serverless/aws/lambda-tests/src/test/scala/sttp/tapir/serverless/aws/lambda/tests/AwsLambdaSamLocalHttpTest.scala b/serverless/aws/lambda-tests/src/test/scala/sttp/tapir/serverless/aws/lambda/tests/AwsLambdaSamLocalHttpTest.scala new file mode 100644 index 0000000000..adde9a6cf5 --- /dev/null +++ b/serverless/aws/lambda-tests/src/test/scala/sttp/tapir/serverless/aws/lambda/tests/AwsLambdaSamLocalHttpTest.scala @@ -0,0 +1,70 @@ +package sttp.tapir.serverless.aws.lambda.tests + +import cats.effect.IO +import org.scalatest.Assertions +import org.scalatest.compatible.Assertion +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers._ +import sttp.capabilities.WebSockets +import sttp.capabilities.fs2.Fs2Streams +import sttp.client3.{basicRequest, _} +import sttp.model.{Header, Uri} +import sttp.tapir.server.ServerEndpoint +import sttp.tapir.server.tests.backendResource + +/** Requires running sam-local process with template generated by `LambdaSamTemplate`, + * it's automated in sbt test task but requires sam cli installed. + */ +class AwsLambdaSamLocalHttpTest extends AnyFunSuite { + + private val baseUri: Uri = uri"http://localhost:3000" + + testServer(in_path_path_out_string_endpoint) { backend => + basicRequest.get(uri"$baseUri/fruit/orange/amount/20").send(backend).map { req => + req.body + .map(_ shouldBe "orange 20") + .getOrElse(Assertions.fail()) + } + } + + testServer(in_string_out_string_endpoint) { backend => + basicRequest.post(uri"$baseUri/api/echo/string").body("Sweet").send(backend).map { req => + req.body.map(_ shouldBe "Sweet").getOrElse(Assertions.fail()) + } + } + + testServer(in_json_out_json_endpoint) { backend => + basicRequest + .post(uri"$baseUri/api/echo/json") + .body("""{"fruit":"orange","amount":11}""") + .send(backend) + .map { req => + req.body + .map(_ shouldBe """{"fruit":"orange","amount":11}""") + .getOrElse(Assertions.fail()) + } + } + + testServer(in_headers_out_headers_endpoint) { backend => + basicRequest + .get(uri"$baseUri/api/echo/headers") + .headers(Header.unsafeApply("X-Fruit", "apple"), Header.unsafeApply("Y-Fruit", "Orange")) + .send(backend) + .map(_.headers should contain allOf (Header.unsafeApply("X-Fruit", "apple"), Header.unsafeApply("Y-Fruit", "Orange"))) + } + + testServer(in_input_stream_out_input_stream_endpoint) { backend => + basicRequest.post(uri"$baseUri/api/echo/is").body("mango").send(backend).map(_.body shouldBe Right("mango")) + } + + testServer(in_4query_out_4header_extended_endpoint) { backend => + basicRequest + .get(uri"$baseUri/echo/query?a=1&b=2&x=3&y=4") + .send(backend) + .map(_.headers.map(h => h.name -> h.value).toSet should contain allOf ("A" -> "1", "B" -> "2", "X" -> "3", "Y" -> "4")) + } + + private def testServer(t: ServerEndpoint[_, _, _, Any, IO], suffix: String = "")( + f: SttpBackend[IO, Fs2Streams[IO] with WebSockets] => IO[Assertion] + ): Unit = test(s"${t.endpoint.showDetail} $suffix")(backendResource.use(f(_)).unsafeRunSync()) +} diff --git a/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsBodyListener.scala b/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsBodyListener.scala new file mode 100644 index 0000000000..2ab0bf64bb --- /dev/null +++ b/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsBodyListener.scala @@ -0,0 +1,11 @@ +package sttp.tapir.serverless.aws.lambda + +import sttp.monad.MonadError +import sttp.monad.syntax._ +import sttp.tapir.server.interpreter.BodyListener + +import scala.util.{Success, Try} + +private[lambda] class AwsBodyListener[F[_]: MonadError] extends BodyListener[F, String] { + override def onComplete(body: String)(cb: Try[Unit] => F[Unit]): F[String] = cb(Success(())).map(_ => body) +} diff --git a/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsCatsEffectServerInterpreter.scala b/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsCatsEffectServerInterpreter.scala new file mode 100644 index 0000000000..aa2e61429b --- /dev/null +++ b/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsCatsEffectServerInterpreter.scala @@ -0,0 +1,52 @@ +package sttp.tapir.serverless.aws.lambda + +import cats.effect.Sync +import sttp.model.StatusCode +import sttp.monad.syntax._ +import sttp.tapir.Endpoint +import sttp.tapir.integ.cats.CatsMonadError +import sttp.tapir.server.ServerEndpoint +import sttp.tapir.server.interpreter.{BodyListener, ServerInterpreter} + +import scala.reflect.ClassTag + +trait AwsCatsEffectServerInterpreter { + def toRoute[I, E, O, F[_]](e: Endpoint[I, E, O, Any])( + logic: I => F[Either[E, O]] + )(implicit serverOptions: AwsServerOptions[F], sync: Sync[F]): Route[F] = toRoute(e.serverLogic(logic)) + + def toRoute[I, E, O, F[_]](se: ServerEndpoint[I, E, O, Any, F])(implicit serverOptions: AwsServerOptions[F], sync: Sync[F]): Route[F] = + toRoute(List(se)) + + def toRouteRecoverErrors[I, E, O, F[_]](e: Endpoint[I, E, O, Any])( + logic: I => F[O] + )(implicit eIsThrowable: E <:< Throwable, eClassTag: ClassTag[E], serverOptions: AwsServerOptions[F], sync: Sync[F]): Route[F] = + toRoute(e.serverLogicRecoverErrors(logic)) + + def toRoute[F[_]](ses: List[ServerEndpoint[_, _, _, Any, F]])(implicit serverOptions: AwsServerOptions[F], sync: Sync[F]): Route[F] = { + implicit val monad: CatsMonadError[F] = new CatsMonadError[F] + implicit val bodyListener: BodyListener[F, String] = new AwsBodyListener[F] + + { request: AwsRequest => + implicit val monad: CatsMonadError[F] = new CatsMonadError[F] + implicit val bodyListener: BodyListener[F, String] = new AwsBodyListener[F] + val serverRequest = new AwsServerRequest(request) + val interpreter = new ServerInterpreter[Any, F, String, Nothing]( + new AwsRequestBody[F](request), + new AwsToResponseBody, + serverOptions.interceptors, + deleteFile = _ => ().unit // no file support + ) + + interpreter.apply(serverRequest, ses).map { + case None => AwsResponse(Nil, isBase64Encoded = serverOptions.encodeResponseBody, StatusCode.NotFound.code, Map.empty, "") + case Some(res) => + val cookies = res.cookies.collect { case Right(cookie) => cookie.value }.toList + val headers = res.headers.groupBy(_.name).map { case (n, v) => n -> v.map(_.value).mkString(",") } + AwsResponse(cookies, isBase64Encoded = serverOptions.encodeResponseBody, res.code.code, headers, res.body.getOrElse("")) + } + } + } +} + +object AwsCatsEffectServerInterpreter extends AwsCatsEffectServerInterpreter diff --git a/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsRequestBody.scala b/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsRequestBody.scala new file mode 100644 index 0000000000..23c5bfec05 --- /dev/null +++ b/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsRequestBody.scala @@ -0,0 +1,34 @@ +package sttp.tapir.serverless.aws.lambda + +import sttp.capabilities +import sttp.monad.MonadError +import sttp.monad.syntax._ +import sttp.tapir.RawBodyType +import sttp.tapir.internal.NoStreams +import sttp.tapir.server.interpreter.{RawValue, RequestBody} + +import java.io.ByteArrayInputStream +import java.nio.ByteBuffer +import java.util.Base64 + +private[lambda] class AwsRequestBody[F[_]: MonadError](request: AwsRequest) extends RequestBody[F, Nothing] { + override val streams: capabilities.Streams[Nothing] = NoStreams + + override def toRaw[R](bodyType: RawBodyType[R]): F[RawValue[R]] = { + val decoded = + if (request.isBase64Encoded) Left(Base64.getDecoder.decode(request.body.getOrElse(""))) else Right(request.body.getOrElse("")) + + def asByteArray: Array[Byte] = decoded.fold(identity[Array[Byte]], _.getBytes()) + + RawValue(bodyType match { + case RawBodyType.StringBody(charset) => decoded.fold(new String(_, charset), identity[String]) + case RawBodyType.ByteArrayBody => asByteArray + case RawBodyType.ByteBufferBody => ByteBuffer.wrap(asByteArray) + case RawBodyType.InputStreamBody => new ByteArrayInputStream(asByteArray) + case RawBodyType.FileBody => throw new UnsupportedOperationException + case _: RawBodyType.MultipartBody => throw new UnsupportedOperationException + }).asInstanceOf[RawValue[R]].unit + } + + override def toStream(): streams.BinaryStream = throw new UnsupportedOperationException +} diff --git a/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsServerOptions.scala b/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsServerOptions.scala new file mode 100644 index 0000000000..aa6c0a7573 --- /dev/null +++ b/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsServerOptions.scala @@ -0,0 +1,35 @@ +package sttp.tapir.serverless.aws.lambda + +import sttp.tapir.server.interceptor.Interceptor +import sttp.tapir.server.interceptor.content.UnsupportedMediaTypeInterceptor +import sttp.tapir.server.interceptor.decodefailure.{DecodeFailureHandler, DecodeFailureInterceptor, DefaultDecodeFailureHandler} +import sttp.tapir.server.interceptor.exception.{DefaultExceptionHandler, ExceptionHandler, ExceptionInterceptor} +import sttp.tapir.server.interceptor.log.ServerLogInterceptor +import sttp.tapir.server.interceptor.metrics.MetricsRequestInterceptor + +case class AwsServerOptions[F[_]](encodeResponseBody: Boolean = true, interceptors: List[Interceptor[F, String]]) { + def prependInterceptor(i: Interceptor[F, String]): AwsServerOptions[F] = copy(interceptors = i :: interceptors) + def appendInterceptor(i: Interceptor[F, String]): AwsServerOptions[F] = copy(interceptors = interceptors :+ i) +} + +object AwsServerOptions { + def customInterceptors[F[_], T]( + encodeResponseBody: Boolean = true, + metricsInterceptor: Option[MetricsRequestInterceptor[F, String]] = None, + exceptionHandler: Option[ExceptionHandler] = Some(DefaultExceptionHandler), + serverLogInterceptor: Option[ServerLogInterceptor[T, F, String]] = None, + additionalInterceptors: List[Interceptor[F, String]] = Nil, + unsupportedMediaTypeInterceptor: Option[UnsupportedMediaTypeInterceptor[F, String]] = Some( + new UnsupportedMediaTypeInterceptor[F, String]() + ), + decodeFailureHandler: DecodeFailureHandler = DefaultDecodeFailureHandler.handler + ): AwsServerOptions[F] = AwsServerOptions( + encodeResponseBody, + interceptors = metricsInterceptor.toList ++ + exceptionHandler.map(new ExceptionInterceptor[F, String](_)).toList ++ + serverLogInterceptor.toList ++ + additionalInterceptors ++ + unsupportedMediaTypeInterceptor.toList ++ + List(new DecodeFailureInterceptor[F, String](decodeFailureHandler)) + ) +} diff --git a/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsServerRequest.scala b/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsServerRequest.scala new file mode 100644 index 0000000000..06553facf1 --- /dev/null +++ b/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsServerRequest.scala @@ -0,0 +1,29 @@ +package sttp.tapir.serverless.aws.lambda + +import sttp.model.{Header, Method, QueryParams, Uri} +import sttp.tapir.model.{ConnectionInfo, ServerRequest} + +import java.net.{InetSocketAddress, URLDecoder} +import scala.collection.immutable.Seq + +private[lambda] class AwsServerRequest(request: AwsRequest) extends ServerRequest { + private val sttpUri: Uri = { + val queryString = if (request.rawQueryString.nonEmpty) "?" + request.rawQueryString else "" + Uri.unsafeParse(s"$protocol://${request.requestContext.domainName.getOrElse("")}${request.rawPath}$queryString") + } + + override def protocol: String = request.headers.getOrElse("x-forwarded-proto", "http") + override def connectionInfo: ConnectionInfo = + ConnectionInfo(None, Some(InetSocketAddress.createUnresolved(request.requestContext.http.sourceIp, 80)), None) + override def underlying: Any = request + override def pathSegments: List[String] = { + request.rawPath.dropWhile(_ == '/').split("/").toList.map(value => URLDecoder.decode(value, "UTF-8")) + } + override def queryParameters: QueryParams = sttpUri.params + override def method: Method = Method.unsafeApply(request.requestContext.http.method) + override def uri: Uri = sttpUri + override def headers: Seq[Header] = request.headers + .map { case (n, v) => Header(n, v) } + .toSeq + .asInstanceOf[scala.collection.immutable.Seq[Header]] +} diff --git a/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsToResponseBody.scala b/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsToResponseBody.scala new file mode 100644 index 0000000000..52eda6b654 --- /dev/null +++ b/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsToResponseBody.scala @@ -0,0 +1,59 @@ +package sttp.tapir.serverless.aws.lambda + +import sttp.capabilities +import sttp.model.HasHeaders +import sttp.tapir.internal.NoStreams +import sttp.tapir.server.interpreter.ToResponseBody +import sttp.tapir.{CodecFormat, RawBodyType, WebSocketBodyOutput} + +import java.io.InputStream +import java.nio.ByteBuffer +import java.nio.charset.Charset +import java.util.Base64 + +private[lambda] class AwsToResponseBody[F[_]](implicit options: AwsServerOptions[F]) extends ToResponseBody[String, Nothing] { + override val streams: capabilities.Streams[Nothing] = NoStreams + + override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): String = bodyType match { + case RawBodyType.StringBody(charset) => + val str = v.asInstanceOf[String] + if (options.encodeResponseBody) Base64.getEncoder.encodeToString(str.getBytes(charset)) else str + + case RawBodyType.ByteArrayBody => + val bytes = v.asInstanceOf[Array[Byte]] + if (options.encodeResponseBody) Base64.getEncoder.encodeToString(bytes) else new String(bytes) + + case RawBodyType.ByteBufferBody => + val byteBuffer = v.asInstanceOf[ByteBuffer] + if (options.encodeResponseBody) Base64.getEncoder.encodeToString(safeRead(byteBuffer)) else new String(safeRead(byteBuffer)) + + case RawBodyType.InputStreamBody => + val stream = v.asInstanceOf[InputStream] + if (options.encodeResponseBody) Base64.getEncoder.encodeToString(stream.readAllBytes()) else new String(stream.readAllBytes()) + + case RawBodyType.FileBody => throw new UnsupportedOperationException + case _: RawBodyType.MultipartBody => throw new UnsupportedOperationException + } + + private def safeRead(byteBuffer: ByteBuffer): Array[Byte] = { + if (byteBuffer.hasArray) { + if (byteBuffer.array().length != byteBuffer.limit()) { + val array = new Array[Byte](byteBuffer.limit()) + byteBuffer.get(array, 0, byteBuffer.limit()) + array + } else byteBuffer.array() + } else { + val array = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(array) + array + } + } + + override def fromStreamValue(v: streams.BinaryStream, headers: HasHeaders, format: CodecFormat, charset: Option[Charset]): String = + throw new UnsupportedOperationException + + override def fromWebSocketPipe[REQ, RESP]( + pipe: streams.Pipe[REQ, RESP], + o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, Nothing] + ): String = throw new UnsupportedOperationException +} diff --git a/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/model.scala b/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/model.scala new file mode 100644 index 0000000000..c0403f598e --- /dev/null +++ b/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/model.scala @@ -0,0 +1,14 @@ +package sttp.tapir.serverless.aws.lambda + +case class AwsRequest( + rawPath: String, + rawQueryString: String, + headers: Map[String, String], + requestContext: AwsRequestContext, + body: Option[String], + isBase64Encoded: Boolean +) +case class AwsRequestContext(domainName: Option[String], http: AwsHttp) +case class AwsHttp(method: String, path: String, protocol: String, sourceIp: String, userAgent: String) + +case class AwsResponse(cookies: List[String], isBase64Encoded: Boolean, statusCode: Int, headers: Map[String, String], body: String) diff --git a/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/package.scala b/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/package.scala new file mode 100644 index 0000000000..77c40a163a --- /dev/null +++ b/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/package.scala @@ -0,0 +1,5 @@ +package sttp.tapir.serverless.aws + +package object lambda { + type Route[F[_]] = AwsRequest => F[AwsResponse] +} diff --git a/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/runtime/AwsLambdaRuntime.scala b/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/runtime/AwsLambdaRuntime.scala new file mode 100644 index 0000000000..0b3375dd39 --- /dev/null +++ b/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/runtime/AwsLambdaRuntime.scala @@ -0,0 +1,27 @@ +package sttp.tapir.serverless.aws.lambda.runtime + +import cats.effect.{Blocker, ConcurrentEffect, ContextShift} +import cats.syntax.all._ +import com.typesafe.scalalogging.StrictLogging +import org.http4s.client.blaze.BlazeClientBuilder +import sttp.client3.http4s.Http4sBackend +import sttp.tapir.server.ServerEndpoint +import sttp.tapir.serverless.aws.lambda._ + +import scala.concurrent.ExecutionContext +import scala.concurrent.duration.DurationInt + +abstract class AwsLambdaRuntime[F[_]: ContextShift: ConcurrentEffect] extends StrictLogging { + def endpoints: Iterable[ServerEndpoint[_, _, _, Any, F]] + implicit def executionContext: ExecutionContext + implicit def serverOptions: AwsServerOptions[F] = AwsServerOptions.customInterceptors() + + def main(args: Array[String]): Unit = { + val backend = Http4sBackend.usingBlazeClientBuilder( + BlazeClientBuilder[F](executionContext).withConnectTimeout(0.seconds), + Blocker.liftExecutionContext(implicitly) + ) + val route: Route[F] = AwsCatsEffectServerInterpreter.toRoute(endpoints.toList) + ConcurrentEffect[F].toIO(AwsLambdaRuntimeLoop(route, sys.env("AWS_LAMBDA_RUNTIME_API"), backend)).foreverM.unsafeRunSync() + } +} diff --git a/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/runtime/AwsLambdaRuntimeLoop.scala b/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/runtime/AwsLambdaRuntimeLoop.scala new file mode 100644 index 0000000000..6cf0809a95 --- /dev/null +++ b/serverless/aws/lambda/src/main/scala/sttp/tapir/serverless/aws/lambda/runtime/AwsLambdaRuntimeLoop.scala @@ -0,0 +1,103 @@ +package sttp.tapir.serverless.aws.lambda.runtime + +import cats.effect.{ConcurrentEffect, ContextShift, Resource} +import cats.syntax.either._ +import com.typesafe.scalalogging.StrictLogging +import io.circe.Printer +import io.circe.generic.auto._ +import io.circe.parser.decode +import io.circe.syntax._ +import sttp.client3._ +import sttp.monad.MonadError +import sttp.monad.syntax._ +import sttp.tapir.integ.cats.CatsMonadError +import sttp.tapir.serverless.aws.lambda.{AwsRequest, AwsResponse, Route} + +import scala.concurrent.duration.DurationInt + +// loosely based on https://github.com/carpe/scalambda/blob/master/native/src/main/scala/io/carpe/scalambda/native/ScalambdaIO.scala +object AwsLambdaRuntimeLoop extends StrictLogging { + + def apply[F[_]: ContextShift: ConcurrentEffect]( + route: Route[F], + awsRuntimeApi: String, + backend: Resource[F, SttpBackend[F, Any]] + ): F[Either[Throwable, Unit]] = { + implicit val monad: MonadError[F] = new CatsMonadError[F] + + val runtimeApiInvocationUri = uri"http://${awsRuntimeApi}/2018-06-01/runtime/invocation" + + /** Make request (without a timeout as prescribed by the AWS Custom Lambda Runtime documentation). + * This is due to the possibility of the runtime being frozen between lambda function invocations. + */ + val nextEventRequest = basicRequest.get(uri"$runtimeApiInvocationUri/next").response(asStringAlways).readTimeout(0.seconds) + + val pollEvent: F[RequestEvent] = { + logger.info(s"Fetching request event") + backend + .use(nextEventRequest.send(_)) + .flatMap { response => + response.header("lambda-runtime-aws-request-id") match { + case Some(id) => RequestEvent(id, response.body).unit + case _ => + monad.error[RequestEvent](new RuntimeException(s"Missing lambda-runtime-aws-request-id header in request event $response")) + } + } + .handleError { case e => monad.error(new RuntimeException(s"Failed to fetch request event, ${e.getMessage}")) } + } + + val decodeEvent: RequestEvent => F[AwsRequest] = event => { + decode[AwsRequest](event.body) match { + case Right(awsRequest) => awsRequest.unit + case Left(e) => monad.error(new RuntimeException(s"Failed to decode request event ${event.requestId}, ${e.getMessage}")) + } + } + + val routeRequest: (RequestEvent, AwsRequest) => F[Either[Throwable, AwsResponse]] = (event, request) => + route(request).map(_.asRight[Throwable]).handleError { case e => + logger.error(s"Failed to process request event ${event.requestId}", e) + e.asLeft[AwsResponse].unit + } + + val sendResponse: (RequestEvent, AwsResponse) => F[Unit] = (event, response) => + backend + .use { b => + basicRequest + .post(uri"$runtimeApiInvocationUri/${event.requestId}/response") + .body(Printer.noSpaces.print(response.asJson)) + .send(b) + } + .map(_ => ()) + .handleError { case e => monad.error(new RuntimeException(s"Failed to send response for event ${event.requestId}")) } + + val sendError: (RequestEvent, Throwable) => F[Unit] = (event, e) => + backend + .use { b => + basicRequest.post(uri"$runtimeApiInvocationUri/${event.requestId}/error").body(e.getMessage).send(b) + } + .map(_ => ()) + .handleError { case e => monad.error(new RuntimeException(s"Failed to send error for event ${event.requestId}")) } + + val sendResult: (RequestEvent, Either[Throwable, AwsResponse]) => F[Unit] = (event, result) => + result match { + case Right(response) => + logger.info(s"Request event ${event.requestId} completed successfully") + sendResponse(event, response) + case Left(e) => + logger.error(s"Request event ${event.requestId} failed", e) + sendError(event, e) + } + + (for { + event <- pollEvent + request <- decodeEvent(event) + result <- routeRequest(event, request) + _ <- sendResult(event, result) + } yield ().asRight[Throwable]).handleError { case e => + logger.error(e.getMessage) + e.asLeft[Unit].unit + } + } +} + +case class RequestEvent(requestId: String, body: String) diff --git a/serverless/aws/lambda/src/test/scala/sttp/tapir/serverless/aws/lambda/AwsLambdaCreateServerStubTest.scala b/serverless/aws/lambda/src/test/scala/sttp/tapir/serverless/aws/lambda/AwsLambdaCreateServerStubTest.scala new file mode 100644 index 0000000000..5c8717be52 --- /dev/null +++ b/serverless/aws/lambda/src/test/scala/sttp/tapir/serverless/aws/lambda/AwsLambdaCreateServerStubTest.scala @@ -0,0 +1,111 @@ +package sttp.tapir.serverless.aws.lambda + +import cats.data.NonEmptyList +import cats.effect.IO +import org.scalatest.Assertion +import sttp.capabilities.WebSockets +import sttp.capabilities.fs2.Fs2Streams +import sttp.client3 +import sttp.client3.testing.SttpBackendStub +import sttp.client3.{ByteArrayBody, ByteBufferBody, InputStreamBody, NoBody, Request, Response, StringBody, SttpBackend, _} +import sttp.model.{Header, StatusCode, Uri} +import sttp.tapir.Endpoint +import sttp.tapir.integ.cats.CatsMonadError +import sttp.tapir.server.ServerEndpoint +import sttp.tapir.server.interceptor.decodefailure.{DecodeFailureHandler, DefaultDecodeFailureHandler} +import sttp.tapir.server.interceptor.metrics.MetricsRequestInterceptor +import sttp.tapir.server.tests.CreateServerTest +import sttp.tapir.serverless.aws.lambda.AwsLambdaCreateServerStubTest._ +import sttp.tapir.tests.Test + +class AwsLambdaCreateServerStubTest extends CreateServerTest[IO, Any, Route[IO], String] { + + override def testServer[I, E, O]( + e: Endpoint[I, E, O, Any], + testNameSuffix: String, + decodeFailureHandler: Option[DecodeFailureHandler], + metricsInterceptor: Option[MetricsRequestInterceptor[IO, String]] + )(fn: I => IO[Either[E, O]])(runTest: (SttpBackend[IO, Fs2Streams[IO] with WebSockets], Uri) => IO[Assertion]): Test = { + implicit val serverOptions: AwsServerOptions[IO] = AwsServerOptions.customInterceptors( + encodeResponseBody = false, + metricsInterceptor = metricsInterceptor, + decodeFailureHandler = decodeFailureHandler.getOrElse(DefaultDecodeFailureHandler.handler) + ) + val se: ServerEndpoint[I, E, O, Any, IO] = e.serverLogic(fn) + val route: Route[IO] = AwsCatsEffectServerInterpreter.toRoute(se) + val name = e.showDetail + (if (testNameSuffix == "") "" else " " + testNameSuffix) + Test(name)(runTest(stubBackend(route), uri"http://localhost:3000").unsafeRunSync()) + } + + override def testServerLogic[I, E, O](e: ServerEndpoint[I, E, O, Any, IO], testNameSuffix: String)( + runTest: (SttpBackend[IO, Fs2Streams[IO] with WebSockets], Uri) => IO[Assertion] + ): Test = { + implicit val serverOptions: AwsServerOptions[IO] = AwsServerOptions.customInterceptors(encodeResponseBody = false) + val route: Route[IO] = AwsCatsEffectServerInterpreter.toRoute(e) + val name = e.showDetail + (if (testNameSuffix == "") "" else " " + testNameSuffix) + Test(name)(runTest(stubBackend(route), uri"http://localhost:3000").unsafeRunSync()) + } + + override def testServer(name: String, rs: => NonEmptyList[Route[IO]])( + runTest: (SttpBackend[IO, Fs2Streams[IO] with WebSockets], Uri) => IO[Assertion] + ): Test = { + val backend = SttpBackendStub[IO, Fs2Streams[IO] with WebSockets](catsMonadIO).whenAnyRequest + .thenRespondF { request => + val responses: NonEmptyList[Response[String]] = rs.map { route => + route(sttpToAwsRequest(request)).map(awsToSttpResponse).unsafeRunSync() + } + IO.pure(responses.find(_.code != StatusCode.NotFound).getOrElse(Response("", StatusCode.NotFound))) + } + Test(name)(runTest(backend, uri"http://localhost:3000").unsafeRunSync()) + } + + private def stubBackend(route: Route[IO]): SttpBackend[IO, Fs2Streams[IO] with WebSockets] = + SttpBackendStub[IO, Fs2Streams[IO] with WebSockets](catsMonadIO).whenAnyRequest.thenRespondF { request => + route(sttpToAwsRequest(request)).map(awsToSttpResponse) + } +} + +object AwsLambdaCreateServerStubTest { + implicit val catsMonadIO: CatsMonadError[IO] = new CatsMonadError[IO] + + def sttpToAwsRequest(request: Request[_, _]): AwsRequest = { + AwsRequest( + rawPath = request.uri.pathSegments.toString, + rawQueryString = request.uri.params.toMultiSeq.foldLeft("") { case (q, (name, values)) => + s"${if (q == "") "" else s"$q&"}${if (values.isEmpty) name else values.map(v => s"$name=$v").mkString("&")}" + }, + headers = request.headers.map(h => h.name -> h.value).toMap, + requestContext = AwsRequestContext( + domainName = Some("localhost:3000"), + http = AwsHttp( + request.method.method, + request.uri.path.mkString("/"), + "http", + "127.0.0.1", + "Internet Explorer" + ) + ), + Some(request.body match { + case NoBody => "" + case StringBody(b, _, _) => new String(b) + case ByteArrayBody(b, _) => new String(b) + case ByteBufferBody(b, _) => new String(b.array()) + case InputStreamBody(b, _) => new String(b.readAllBytes()) + case _ => throw new UnsupportedOperationException + }), + isBase64Encoded = false + ) + } + + def awsToSttpResponse(response: AwsResponse): Response[String] = + client3.Response( + new String(response.body), + new StatusCode(response.statusCode), + "", + response.headers + .map { case (n, v) => v.split(",").map(Header(n, _)) } + .flatten + .toSeq + .asInstanceOf[scala.collection.immutable.Seq[Header]] + ) +} diff --git a/serverless/aws/lambda/src/test/scala/sttp/tapir/serverless/aws/lambda/AwsLambdaStubHttpTest.scala b/serverless/aws/lambda/src/test/scala/sttp/tapir/serverless/aws/lambda/AwsLambdaStubHttpTest.scala new file mode 100644 index 0000000000..618ee46089 --- /dev/null +++ b/serverless/aws/lambda/src/test/scala/sttp/tapir/serverless/aws/lambda/AwsLambdaStubHttpTest.scala @@ -0,0 +1,47 @@ +package sttp.tapir.serverless.aws.lambda + +import cats.data.NonEmptyList +import cats.effect.{IO, Resource} +import sttp.tapir.Endpoint +import sttp.tapir.server.ServerEndpoint +import sttp.tapir.server.interceptor.decodefailure.{DecodeFailureHandler, DefaultDecodeFailureHandler} +import sttp.tapir.server.interceptor.metrics.MetricsRequestInterceptor +import sttp.tapir.server.tests.{ServerBasicTests, ServerMetricsTest, TestServerInterpreter} +import sttp.tapir.serverless.aws.lambda.AwsLambdaCreateServerStubTest.catsMonadIO +import sttp.tapir.tests.{Port, Test, TestSuite} + +import scala.reflect.ClassTag + +class AwsLambdaStubHttpTest extends TestSuite { + override def tests: Resource[IO, List[Test]] = Resource.eval( + IO.pure { + val createTestServer = new AwsLambdaCreateServerStubTest + new ServerBasicTests(createTestServer, AwsLambdaStubHttpTest.testServerInterpreter)(catsMonadIO).tests() ++ + new ServerMetricsTest(createTestServer).tests() + } + ) +} + +object AwsLambdaStubHttpTest { + private val testServerInterpreter = new TestServerInterpreter[IO, Any, Route[IO], String] { + override def route[I, E, O]( + e: ServerEndpoint[I, E, O, Any, IO], + decodeFailureHandler: Option[DecodeFailureHandler], + metricsInterceptor: Option[MetricsRequestInterceptor[IO, String]] + ): Route[IO] = { + implicit val options: AwsServerOptions[IO] = AwsServerOptions.customInterceptors( + encodeResponseBody = false, + metricsInterceptor = metricsInterceptor, + decodeFailureHandler = decodeFailureHandler.getOrElse(DefaultDecodeFailureHandler.handler) + ) + AwsCatsEffectServerInterpreter.toRoute(e) + } + override def routeRecoverErrors[I, E <: Throwable, O](e: Endpoint[I, E, O, Any], fn: I => IO[O])(implicit + eClassTag: ClassTag[E] + ): Route[IO] = { + implicit val options: AwsServerOptions[IO] = AwsServerOptions.customInterceptors(encodeResponseBody = false) + AwsCatsEffectServerInterpreter.toRouteRecoverErrors(e)(fn) + } + override def server(routes: NonEmptyList[Route[IO]]): Resource[IO, Port] = ??? + } +} diff --git a/serverless/aws/lambda/src/test/scala/sttp/tapir/serverless/aws/lambda/runtime/AwsLambdaRuntimeLoopTest.scala b/serverless/aws/lambda/src/test/scala/sttp/tapir/serverless/aws/lambda/runtime/AwsLambdaRuntimeLoopTest.scala new file mode 100644 index 0000000000..b4813f4057 --- /dev/null +++ b/serverless/aws/lambda/src/test/scala/sttp/tapir/serverless/aws/lambda/runtime/AwsLambdaRuntimeLoopTest.scala @@ -0,0 +1,156 @@ +package sttp.tapir.serverless.aws.lambda.runtime + +import cats.effect.{ContextShift, IO, Resource} +import cats.syntax.all._ +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers +import sttp.client3._ +import sttp.client3.testing.SttpBackendStub +import sttp.model.{Header, StatusCode} +import sttp.tapir._ +import sttp.tapir.integ.cats.CatsMonadError +import sttp.tapir.serverless.aws.lambda.runtime.AwsLambdaRuntimeLoopTest._ +import sttp.tapir.serverless.aws.lambda.{AwsCatsEffectServerInterpreter, AwsServerOptions} + +import scala.concurrent.ExecutionContext.Implicits.global + +import scala.collection.immutable.Seq + +class AwsLambdaRuntimeLoopTest extends AnyFunSuite with Matchers { + + test("should process event") { + // given + var hello = "" + + val route = AwsCatsEffectServerInterpreter.toRoute(testEp.serverLogic { _ => + hello = "hello" + IO.pure(().asRight[Unit]) + }) + + val backend = SttpBackendStub(monadError) + .whenRequestMatches(_.uri == uri"http://aws/2018-06-01/runtime/invocation/next") + .thenRespondF(IO.pure(Response(awsRequest, StatusCode.Ok, "Ok", Seq(Header("lambda-runtime-aws-request-id", "43214"))))) + .whenAnyRequest + .thenRespondOk() + + // when + val result = AwsLambdaRuntimeLoop(route, "aws", Resource.eval(IO.pure(backend))).unsafeRunSync() + + // then + hello shouldBe "hello" + result shouldBe Right(()) + } + + test("should handle error while fetching event") { + // given + val route = AwsCatsEffectServerInterpreter.toRoute(testEp)(_ => IO(().asRight[Unit])) + + val backend = SttpBackendStub(monadError) + .whenRequestMatches(_.uri == uri"http://aws/2018-06-01/runtime/invocation/next") + .thenRespondF(_ => throw new RuntimeException) + + val loop = AwsLambdaRuntimeLoop(route, "aws", Resource.eval(IO.pure(backend))) + + // when + val result = AwsLambdaRuntimeLoop(route, "aws", Resource.eval(IO.pure(backend))).unsafeRunSync() + + // then + result.isLeft shouldBe true + } + + test("should handle decode failure") { + // given + val route = AwsCatsEffectServerInterpreter.toRoute(testEp)(_ => IO(().asRight[Unit])) + + val backend = SttpBackendStub(monadError) + .whenRequestMatches(_.uri == uri"http://aws/2018-06-01/runtime/invocation/next") + .thenRespondF(IO.pure(Response("???", StatusCode.Ok, "Ok", Seq(Header("lambda-runtime-aws-request-id", "43214"))))) + .whenAnyRequest + .thenRespondOk() + + // when + val result = AwsLambdaRuntimeLoop(route, "aws", Resource.eval(IO.pure(backend))).unsafeRunSync() + + // then + result.isLeft shouldBe true + } + + test("should handle missing lambda-runtime-aws-request-id header") { + // given + val route = AwsCatsEffectServerInterpreter.toRoute(testEp)(_ => IO(().asRight[Unit])) + + val backend = SttpBackendStub(monadError) + .whenRequestMatches(_.uri == uri"http://aws/2018-06-01/runtime/invocation/next") + .thenRespondF(IO.pure(Response(awsRequest, StatusCode.Ok))) + + // when + val result = AwsLambdaRuntimeLoop(route, "aws", Resource.eval(IO.pure(backend))).unsafeRunSync() + + // then + result.isLeft shouldBe true + } + + test("should handle error from server logic") { + // given + val route = AwsCatsEffectServerInterpreter.toRoute(testEp)(_ => throw new RuntimeException) + + val backend = SttpBackendStub(monadError) + .whenRequestMatches(_.uri == uri"http://aws/2018-06-01/runtime/invocation/next") + .thenRespondF(IO.pure(Response(awsRequest, StatusCode.Ok, "Ok", Seq(Header("lambda-runtime-aws-request-id", "43214"))))) + .whenAnyRequest + .thenRespondOk() + + // when + val result = AwsLambdaRuntimeLoop(route, "aws", Resource.eval(IO.pure(backend))).unsafeRunSync() + + // then + result shouldBe Right(()) + } + + test("should handle error when sending response to lambda") { + // given + val route = AwsCatsEffectServerInterpreter.toRoute(testEp)(_ => IO(().asRight[Unit])) + + val backend = SttpBackendStub(monadError) + .whenRequestMatches(_.uri == uri"http://aws/2018-06-01/runtime/invocation/next") + .thenRespondF(IO.pure(Response(awsRequest, StatusCode.Ok, "Ok", Seq(Header("lambda-runtime-aws-request-id", "43214"))))) + .whenAnyRequest + .thenRespondF(_ => throw new RuntimeException) + + // when + val result = AwsLambdaRuntimeLoop(route, "aws", Resource.eval(IO.pure(backend))).unsafeRunSync() + + // then + result.isLeft shouldBe true + } +} + +object AwsLambdaRuntimeLoopTest { + implicit val contextShift: ContextShift[IO] = IO.contextShift(global) + implicit val options: AwsServerOptions[IO] = AwsServerOptions.customInterceptors() + + val awsRequest: String = + """ + |{ + | "version": "2.0", + | "routeKey": "GET /api/hello", + | "rawPath": "/api/hello", + | "rawQueryString": "", + | "headers": {}, + | "requestContext": { + | "http": { + | "method": "GET", + | "path": "/api/hello", + | "protocol": "HTTP/1.1", + | "sourceIp": "188.146.66.23", + | "userAgent": "Chrome" + | } + | }, + | "isBase64Encoded": false + |} + |""".stripMargin + + val testEp: Endpoint[Unit, Unit, Unit, Any] = endpoint.get.in("api" / "hello") + + val monadError: CatsMonadError[IO] = new CatsMonadError[IO] +} diff --git a/serverless/aws/sam/src/main/scala/sttp/tapir/serverless/aws/sam/AwsSamInterpreter.scala b/serverless/aws/sam/src/main/scala/sttp/tapir/serverless/aws/sam/AwsSamInterpreter.scala new file mode 100644 index 0000000000..3a0bb467a1 --- /dev/null +++ b/serverless/aws/sam/src/main/scala/sttp/tapir/serverless/aws/sam/AwsSamInterpreter.scala @@ -0,0 +1,20 @@ +package sttp.tapir.serverless.aws.sam + +import sttp.tapir.Endpoint +import sttp.tapir.server.ServerEndpoint + +trait AwsSamInterpreter { + def toSamTemplate[I, E, O, S](e: Endpoint[I, E, O, S])(implicit options: AwsSamOptions): SamTemplate = EndpointsToSamTemplate(List(e)) + + def toSamTemplate(es: Iterable[Endpoint[_, _, _, _]])(implicit options: AwsSamOptions): SamTemplate = EndpointsToSamTemplate(es.toList) + + def toSamTemplate[I, E, O, S, F[_]](se: ServerEndpoint[I, E, O, S, F])(implicit options: AwsSamOptions): SamTemplate = + EndpointsToSamTemplate( + List(se.endpoint) + ) + + def serverEndpointsToSamTemplate[F[_]](ses: Iterable[ServerEndpoint[_, _, _, _, F]])(implicit options: AwsSamOptions): SamTemplate = + EndpointsToSamTemplate(ses.map(_.endpoint).toList) +} + +object AwsSamInterpreter extends AwsSamInterpreter \ No newline at end of file diff --git a/serverless/aws/sam/src/main/scala/sttp/tapir/serverless/aws/sam/AwsSamOptions.scala b/serverless/aws/sam/src/main/scala/sttp/tapir/serverless/aws/sam/AwsSamOptions.scala new file mode 100644 index 0000000000..312eed00e3 --- /dev/null +++ b/serverless/aws/sam/src/main/scala/sttp/tapir/serverless/aws/sam/AwsSamOptions.scala @@ -0,0 +1,14 @@ +package sttp.tapir.serverless.aws.sam + +import scala.concurrent.duration.{DurationInt, FiniteDuration} + +case class AwsSamOptions( + namePrefix: String, + source: FunctionSource, + timeout: FiniteDuration = 10.seconds, + memorySize: Int = 256 +) + +sealed trait FunctionSource +case class ImageSource(imageUri: String) extends FunctionSource +case class CodeSource(runtime: String, codeUri: String, handler: String) extends FunctionSource diff --git a/serverless/aws/sam/src/main/scala/sttp/tapir/serverless/aws/sam/AwsSamTemplateEncoders.scala b/serverless/aws/sam/src/main/scala/sttp/tapir/serverless/aws/sam/AwsSamTemplateEncoders.scala new file mode 100644 index 0000000000..a0096ec026 --- /dev/null +++ b/serverless/aws/sam/src/main/scala/sttp/tapir/serverless/aws/sam/AwsSamTemplateEncoders.scala @@ -0,0 +1,40 @@ +package sttp.tapir.serverless.aws.sam + +import io.circe.generic.semiauto.deriveEncoder +import io.circe.syntax.EncoderOps +import io.circe.{Encoder, Json} + +import scala.collection.immutable.ListMap + +object AwsSamTemplateEncoders { + implicit def encodeListMap[V: Encoder]: Encoder[ListMap[String, V]] = { case m: ListMap[String, V] => + val properties = m.view.map { case (k, v) => k -> implicitly[Encoder[V]].apply(v) }.toList + Json.obj(properties: _*) + } + + implicit val encoderOutput: Encoder[Output] = deriveEncoder[Output] + implicit val encoderFunctionHttpApiEventProperties: Encoder[FunctionHttpApiEventProperties] = + deriveEncoder[FunctionHttpApiEventProperties] + implicit val encoderFunctionHttpApiEvent: Encoder[FunctionHttpApiEvent] = { + val encoder = deriveEncoder[FunctionHttpApiEvent] + e => Json.fromJsonObject(encoder(e).asJson.asObject.get.add("Type", Json.fromString("HttpApi"))) + } + + implicit val encoderHttpProperties: Encoder[HttpProperties] = deriveEncoder[HttpProperties] + implicit val encoderFunctionImageProperties: Encoder[FunctionImageProperties] = deriveEncoder[FunctionImageProperties] + implicit val encoderFunctionCodeProperties: Encoder[FunctionCodeProperties] = deriveEncoder[FunctionCodeProperties] + implicit val encoderProperties: Encoder[Properties] = { + case v: HttpProperties => v.asJson + case v: FunctionImageProperties => v.asJson + case v: FunctionCodeProperties => v.asJson + } + + implicit val encoderHttpResource: Encoder[HttpResource] = deriveEncoder[HttpResource] + implicit val encoderFunctionResource: Encoder[FunctionResource] = deriveEncoder[FunctionResource] + implicit val encoderResource: Encoder[Resource] = { + case v: HttpResource => Json.fromJsonObject(v.asJson.asObject.get.add("Type", Json.fromString("AWS::Serverless::HttpApi"))) + case v: FunctionResource => Json.fromJsonObject(v.asJson.asObject.get.add("Type", Json.fromString("AWS::Serverless::Function"))) + } + + implicit val encoderSamTemplate: Encoder[SamTemplate] = deriveEncoder[SamTemplate] +} diff --git a/serverless/aws/sam/src/main/scala/sttp/tapir/serverless/aws/sam/EndpointsToSamTemplate.scala b/serverless/aws/sam/src/main/scala/sttp/tapir/serverless/aws/sam/EndpointsToSamTemplate.scala new file mode 100644 index 0000000000..cae49d7632 --- /dev/null +++ b/serverless/aws/sam/src/main/scala/sttp/tapir/serverless/aws/sam/EndpointsToSamTemplate.scala @@ -0,0 +1,73 @@ +package sttp.tapir.serverless.aws.sam + +import sttp.model.Method +import sttp.tapir.internal._ +import sttp.tapir.{Endpoint, EndpointInput} + +private[sam] object EndpointsToSamTemplate { + def apply(es: List[Endpoint[_, _, _, _]])(implicit options: AwsSamOptions): SamTemplate = { + val functionName = options.namePrefix + "Function" + val httpApiName = options.namePrefix + "HttpApi" + + val apiEvents = es + .map(endpointNameMethodAndPath) + .map { case (name, method, path) => + name -> FunctionHttpApiEvent( + FunctionHttpApiEventProperties(s"!Ref $httpApiName", method.map(_.method).getOrElse("ANY"), path, options.timeout.toMillis) + ) + } + .toMap + + SamTemplate( + Resources = Map( + functionName -> FunctionResource( + options.source match { + case ImageSource(imageUri) => + FunctionImageProperties(options.timeout.toSeconds, options.memorySize, apiEvents, imageUri) + case cs @ CodeSource(_, _, _) => + FunctionCodeProperties( + options.timeout.toSeconds, + options.memorySize, + apiEvents, + cs.runtime, + cs.codeUri, + cs.handler + ) + } + ), + httpApiName -> HttpResource(HttpProperties("$default")) + ), + Outputs = Map( + (options.namePrefix + "Url") -> Output( + "Base URL of your endpoints", + Map("Fn::Sub" -> ("https://${" + httpApiName + "}.execute-api.${AWS::Region}.${AWS::URLSuffix}")) + ) + ) + ) + } + + private def endpointNameMethodAndPath(e: Endpoint[_, _, _, _]): (String, Option[Method], String) = { + val pathComponents = e.input + .asVectorOfBasicInputs() + .foldLeft((Vector.empty[Either[String, String]], 0)) { case ((acc, c), input) => + input match { + case EndpointInput.PathCapture(name, _, _) => (acc :+ Left(name.getOrElse(s"param$c")), if (name.isEmpty) c + 1 else c) + case EndpointInput.FixedPath(p, _, _) => (acc :+ Right(p), c) + case _ => (acc, c) + } + } + ._1 + + val method = e.httpMethod + + val nameComponents = if (pathComponents.isEmpty) Vector("root") else pathComponents.map(_.fold(identity, identity)) + val name = (method.map(_.method.toLowerCase).getOrElse("any").capitalize +: nameComponents.map(_.toLowerCase.capitalize)).mkString + + val idComponents = pathComponents.map { + case Left(s) => s"{$s}" + case Right(s) => s + } + + (name, method, "/" + idComponents.mkString("/")) + } +} diff --git a/serverless/aws/sam/src/main/scala/sttp/tapir/serverless/aws/sam/Printer.scala b/serverless/aws/sam/src/main/scala/sttp/tapir/serverless/aws/sam/Printer.scala new file mode 100644 index 0000000000..4f32fa3705 --- /dev/null +++ b/serverless/aws/sam/src/main/scala/sttp/tapir/serverless/aws/sam/Printer.scala @@ -0,0 +1,167 @@ +package sttp.tapir.serverless.aws.sam + +import Printer._ +import io.circe._ +import java.io.StringWriter +import org.yaml.snakeyaml.DumperOptions +import org.yaml.snakeyaml.emitter.Emitter +import org.yaml.snakeyaml.nodes._ +import org.yaml.snakeyaml.resolver.Resolver +import org.yaml.snakeyaml.serializer.Serializer +import scala.collection.JavaConverters._ + +// modified stringNode to handle !Ref tags +final case class Printer( + preserveOrder: Boolean = false, + dropNullKeys: Boolean = false, + indent: Int = 2, + maxScalarWidth: Int = 80, + splitLines: Boolean = true, + indicatorIndent: Int = 0, + tags: Map[String, String] = Map.empty, + sequenceStyle: FlowStyle = FlowStyle.Block, + mappingStyle: FlowStyle = FlowStyle.Block, + stringStyle: StringStyle = StringStyle.Plain, + lineBreak: LineBreak = LineBreak.Unix, + explicitStart: Boolean = false, + explicitEnd: Boolean = false, + version: YamlVersion = YamlVersion.Auto +) { + + def pretty(json: Json): String = { + val rootTag = yamlTag(json) + val writer = new StringWriter() + val serializer = new Serializer(new Emitter(writer, options), new Resolver, options, rootTag) + serializer.open() + serializer.serialize(jsonToYaml(json)) + serializer.close() + writer.toString + } + + private lazy val options = { + val options = new DumperOptions() + options.setIndent(indent) + options.setWidth(maxScalarWidth) + options.setSplitLines(splitLines) + options.setIndicatorIndent(indicatorIndent) + options.setTags(tags.asJava) + options.setDefaultScalarStyle(StringStyle.toScalarStyle(stringStyle)) + options.setLineBreak(lineBreak match { + case LineBreak.Unix => DumperOptions.LineBreak.UNIX + case LineBreak.Windows => DumperOptions.LineBreak.WIN + case LineBreak.Mac => DumperOptions.LineBreak.MAC + }) + options.setVersion(version match { + case YamlVersion.Auto => null + case YamlVersion.Yaml1_0 => DumperOptions.Version.V1_0 + case YamlVersion.Yaml1_1 => DumperOptions.Version.V1_1 + }) + options.setExplicitStart(explicitStart) + options.setExplicitEnd(explicitEnd) + options + } + + private def isBad(s: String): Boolean = s.indexOf('\u0085') >= 0 || s.indexOf('\ufeff') >= 0 + + private def scalarStyle(value: String): DumperOptions.ScalarStyle = + if (isBad(value)) DumperOptions.ScalarStyle.DOUBLE_QUOTED else DumperOptions.ScalarStyle.PLAIN + + private def stringScalarStyle(value: String): DumperOptions.ScalarStyle = + if (isBad(value)) DumperOptions.ScalarStyle.DOUBLE_QUOTED else StringStyle.toScalarStyle(stringStyle) + + private def scalarNode(tag: Tag, value: String) = new ScalarNode(tag, value, null, null, scalarStyle(value)) + private def stringNode(value: String) = if (value.startsWith("!Ref ")) { + new ScalarNode(new Tag("!Ref"), value.substring(5), null, null, stringScalarStyle(value)) + } else { + new ScalarNode(Tag.STR, value, null, null, stringScalarStyle(value)) + } + + private def keyNode(value: String) = new ScalarNode(Tag.STR, value, null, null, scalarStyle(value)) + + private def jsonToYaml(json: Json): Node = { + + def convertObject(obj: JsonObject) = { + val fields = if (preserveOrder) obj.keys else obj.keys.toSet + val m = obj.toMap + val childNodes = fields.flatMap { key => + val value = m(key) + if (!dropNullKeys || !value.isNull) Some(new NodeTuple(keyNode(key), jsonToYaml(value))) + else None + } + new MappingNode( + Tag.MAP, + childNodes.toList.asJava, + if (mappingStyle == FlowStyle.Flow) DumperOptions.FlowStyle.FLOW else DumperOptions.FlowStyle.BLOCK + ) + } + + json.fold( + scalarNode(Tag.NULL, "null"), + bool => scalarNode(Tag.BOOL, bool.toString), + number => scalarNode(numberTag(number), number.toString), + str => stringNode(str), + arr => + new SequenceNode( + Tag.SEQ, + arr.map(jsonToYaml).asJava, + if (sequenceStyle == FlowStyle.Flow) DumperOptions.FlowStyle.FLOW else DumperOptions.FlowStyle.BLOCK + ), + obj => convertObject(obj) + ) + } +} + +object Printer { + + val spaces2 = Printer() + val spaces4 = Printer(indent = 4) + + sealed trait FlowStyle + object FlowStyle { + case object Flow extends FlowStyle + case object Block extends FlowStyle + } + + sealed trait StringStyle + object StringStyle { + case object Plain extends StringStyle + case object DoubleQuoted extends StringStyle + case object SingleQuoted extends StringStyle + case object Literal extends StringStyle + case object Folded extends StringStyle + + def toScalarStyle(style: StringStyle): DumperOptions.ScalarStyle = style match { + case StringStyle.Plain => DumperOptions.ScalarStyle.PLAIN + case StringStyle.DoubleQuoted => DumperOptions.ScalarStyle.DOUBLE_QUOTED + case StringStyle.SingleQuoted => DumperOptions.ScalarStyle.SINGLE_QUOTED + case StringStyle.Literal => DumperOptions.ScalarStyle.LITERAL + case StringStyle.Folded => DumperOptions.ScalarStyle.FOLDED + } + } + + sealed trait LineBreak + object LineBreak { + case object Unix extends LineBreak + case object Windows extends LineBreak + case object Mac extends LineBreak + } + + sealed trait YamlVersion + object YamlVersion { + case object Yaml1_0 extends YamlVersion + case object Yaml1_1 extends YamlVersion + case object Auto extends YamlVersion + } + + private def yamlTag(json: Json) = json.fold( + Tag.NULL, + _ => Tag.BOOL, + number => numberTag(number), + _ => Tag.STR, + _ => Tag.SEQ, + _ => Tag.MAP + ) + + private def numberTag(number: JsonNumber): Tag = + if (number.toString.contains(".")) Tag.FLOAT else Tag.INT +} diff --git a/serverless/aws/sam/src/main/scala/sttp/tapir/serverless/aws/sam/model.scala b/serverless/aws/sam/src/main/scala/sttp/tapir/serverless/aws/sam/model.scala new file mode 100644 index 0000000000..fb04b6d0c3 --- /dev/null +++ b/serverless/aws/sam/src/main/scala/sttp/tapir/serverless/aws/sam/model.scala @@ -0,0 +1,60 @@ +package sttp.tapir.serverless.aws.sam + +import io.circe.syntax._ +import sttp.tapir.serverless.aws.sam.AwsSamTemplateEncoders._ + +case class SamTemplate( + AWSTemplateFormatVersion: String = "2010-09-09", + Transform: String = "AWS::Serverless-2016-10-31", + Resources: Map[String, Resource], + Outputs: Map[String, Output] +) { + def toYaml: String = Printer(dropNullKeys = true, preserveOrder = true, stringStyle = Printer.StringStyle.Plain).pretty(this.asJson) +} + +sealed trait Resource { + def Properties: Properties +} +case class FunctionResource(Properties: Properties) extends Resource +case class HttpResource(Properties: HttpProperties) extends Resource + +sealed trait Properties + +sealed trait FunctionProperties { + val Timeout: Long + val MemorySize: Int + val Events: Map[String, FunctionHttpApiEvent] +} + +case class FunctionImageProperties( + Timeout: Long, + MemorySize: Int, + Events: Map[String, FunctionHttpApiEvent], + ImageUri: String, + PackageType: String = "Image" +) extends Properties + with FunctionProperties + +case class FunctionCodeProperties( + Timeout: Long, + MemorySize: Int, + Events: Map[String, FunctionHttpApiEvent], + Runtime: String, + CodeUri: String, + Handler: String +) extends Properties + with FunctionProperties + +case class HttpProperties(StageName: String) extends Properties + +case class FunctionHttpApiEvent(Properties: FunctionHttpApiEventProperties) + +case class FunctionHttpApiEventProperties( + ApiId: String, + Method: String, + Path: String, + TimeoutInMillis: Long, + PayloadFormatVersion: String = "2.0" +) + +case class Output(Description: String, Value: Map[String, String]) diff --git a/serverless/aws/sam/src/test/resources/code_source_template.yaml b/serverless/aws/sam/src/test/resources/code_source_template.yaml new file mode 100644 index 0000000000..ee3ce3b8e9 --- /dev/null +++ b/serverless/aws/sam/src/test/resources/code_source_template.yaml @@ -0,0 +1,37 @@ +AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 +Resources: + PetApiFunction: + Properties: + Timeout: 10 + MemorySize: 1024 + Events: + GetApiPetsId: + Properties: + ApiId: !Ref 'PetApiHttpApi' + Method: GET + Path: /api/pets/{id} + TimeoutInMillis: 10000 + PayloadFormatVersion: '2.0' + Type: HttpApi + PostApiPets: + Properties: + ApiId: !Ref 'PetApiHttpApi' + Method: POST + Path: /api/pets + TimeoutInMillis: 10000 + PayloadFormatVersion: '2.0' + Type: HttpApi + Runtime: java11 + CodeUri: /somewhere/pet-api.jar + Handler: pet.api.Handler::handleRequest + Type: AWS::Serverless::Function + PetApiHttpApi: + Properties: + StageName: $default + Type: AWS::Serverless::HttpApi +Outputs: + PetApiUrl: + Description: Base URL of your endpoints + Value: + Fn::Sub: https://${PetApiHttpApi}.execute-api.${AWS::Region}.${AWS::URLSuffix} diff --git a/serverless/aws/sam/src/test/resources/image_source_template.yaml b/serverless/aws/sam/src/test/resources/image_source_template.yaml new file mode 100644 index 0000000000..e683ad4bd6 --- /dev/null +++ b/serverless/aws/sam/src/test/resources/image_source_template.yaml @@ -0,0 +1,36 @@ +AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 +Resources: + PetApiFunction: + Properties: + Timeout: 10 + MemorySize: 1024 + Events: + GetApiPetsId: + Properties: + ApiId: !Ref 'PetApiHttpApi' + Method: GET + Path: /api/pets/{id} + TimeoutInMillis: 10000 + PayloadFormatVersion: '2.0' + Type: HttpApi + PostApiPets: + Properties: + ApiId: !Ref 'PetApiHttpApi' + Method: POST + Path: /api/pets + TimeoutInMillis: 10000 + PayloadFormatVersion: '2.0' + Type: HttpApi + ImageUri: image.repository:pet-api + PackageType: Image + Type: AWS::Serverless::Function + PetApiHttpApi: + Properties: + StageName: $default + Type: AWS::Serverless::HttpApi +Outputs: + PetApiUrl: + Description: Base URL of your endpoints + Value: + Fn::Sub: https://${PetApiHttpApi}.execute-api.${AWS::Region}.${AWS::URLSuffix} \ No newline at end of file diff --git a/serverless/aws/sam/src/test/scala/sttp/tapir/serverless/aws/sam/VerifySamTemplateTest.scala b/serverless/aws/sam/src/test/scala/sttp/tapir/serverless/aws/sam/VerifySamTemplateTest.scala new file mode 100644 index 0000000000..145fbe76cf --- /dev/null +++ b/serverless/aws/sam/src/test/scala/sttp/tapir/serverless/aws/sam/VerifySamTemplateTest.scala @@ -0,0 +1,62 @@ +package sttp.tapir.serverless.aws.sam + +import io.circe.generic.auto._ +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers +import sttp.tapir.generic.auto._ +import sttp.tapir.json.circe._ +import sttp.tapir.serverless.aws.sam.VerifySamTemplateTest._ +import sttp.tapir.{Endpoint, endpoint, path, _} + +import scala.io.Source + +class VerifySamTemplateTest extends AnyFunSuite with Matchers { + + test("should match the expected yaml with image source") { + val expectedYaml = load("image_source_template.yaml") + + implicit val samOptions: AwsSamOptions = AwsSamOptions( + "PetApi", + source = ImageSource("image.repository:pet-api"), + memorySize = 1024 + ) + + val actualYaml = AwsSamInterpreter.toSamTemplate(List(getPetEndpoint, addPetEndpoint)).toYaml + + expectedYaml shouldBe noIndentation(actualYaml) + } + + test("should match the expected yaml with code source") { + val expectedYaml = load("code_source_template.yaml") + + implicit val samOptions: AwsSamOptions = AwsSamOptions( + "PetApi", + source = CodeSource(runtime = "java11", codeUri = "/somewhere/pet-api.jar", "pet.api.Handler::handleRequest"), + memorySize = 1024 + ) + + val actualYaml = AwsSamInterpreter.toSamTemplate(List(getPetEndpoint, addPetEndpoint)).toYaml + + expectedYaml shouldBe noIndentation(actualYaml) + } + +} + +object VerifySamTemplateTest { + + case class Pet(name: String, species: String) + + val getPetEndpoint: Endpoint[Int, Unit, Pet, Any] = endpoint.get + .in("api" / "pets" / path[Int]("id")) + .out(jsonBody[Pet]) + + val addPetEndpoint: Endpoint[Pet, Unit, Unit, Any] = endpoint.post + .in("api" / "pets") + .in(jsonBody[Pet]) + + def load(fileName: String): String = { + noIndentation(Source.fromInputStream(getClass.getResourceAsStream(s"/$fileName")).getLines().mkString("\n")) + } + + def noIndentation(s: String): String = s.replaceAll("[ \t]", "").trim +} diff --git a/serverless/aws/terraform/src/main/scala/sttp/tapir/serverless/aws/terraform/AwsTerraformEncoders.scala b/serverless/aws/terraform/src/main/scala/sttp/tapir/serverless/aws/terraform/AwsTerraformEncoders.scala new file mode 100644 index 0000000000..17b644231f --- /dev/null +++ b/serverless/aws/terraform/src/main/scala/sttp/tapir/serverless/aws/terraform/AwsTerraformEncoders.scala @@ -0,0 +1,49 @@ +package sttp.tapir.serverless.aws.terraform + +import io.circe.{Encoder, Json} +import sttp.tapir.serverless.aws.terraform.TerraformResource.TapirApiGateway + +object AwsTerraformEncoders { + + implicit def encoderAwsTerraformApiGateway(implicit options: AwsTerraformOptions): Encoder[AwsTerraformApiGateway] = + gateway => { + + val provider = + Json.fromFields(Seq("aws" -> Json.fromValues(Seq(Json.fromFields(Seq("region" -> Json.fromString(options.awsRegion))))))) + val terraform = Json.fromFields( + Seq("required_providers" -> Json.fromFields(Seq("aws" -> Json.fromFields(Seq("source" -> Json.fromString("hashicorp/aws")))))) + ) + + val integrations: Seq[TerraformResource] = gateway.routes.flatMap { m => + val integration = AwsApiGatewayV2Integration(m.name) + val route = AwsApiGatewayV2Route(m.name, s"${m.httpMethod.method} /${m.path}", m.name) + Seq(integration, route) + } + + val output = Json.fromFields( + Seq( + "base_url" -> Json.fromFields( + Seq("value" -> Json.fromString(s"$${aws_apigatewayv2_api.$TapirApiGateway.api_endpoint}")) + ) + ) + ) + + Json.fromFields( + Seq( + "terraform" -> terraform, + "provider" -> provider, + "resource" -> Json.fromValues( + Seq( + AwsLambdaFunction(options.functionName, options.timeout, options.memorySize, options.functionSource).json(), + AwsIamRole(options.assumeRolePolicy.noSpaces).json(), + AwsLambdaPermission.json(), + AwsApiGatewayV2Api(options.apiGatewayName, options.apiGatewayDescription).json(), + AwsApiGatewayV2Deployment(integrations.collect { case i @ AwsApiGatewayV2Integration(_) => i.name }).json(), + AwsApiGatewayV2Stage(options.apiGatewayStage, options.autoDeploy).json() + ) ++ integrations.map(_.json()) + ), + "output" -> output + ) + ) + } +} diff --git a/serverless/aws/terraform/src/main/scala/sttp/tapir/serverless/aws/terraform/AwsTerraformInterpreter.scala b/serverless/aws/terraform/src/main/scala/sttp/tapir/serverless/aws/terraform/AwsTerraformInterpreter.scala new file mode 100644 index 0000000000..81b77f6d72 --- /dev/null +++ b/serverless/aws/terraform/src/main/scala/sttp/tapir/serverless/aws/terraform/AwsTerraformInterpreter.scala @@ -0,0 +1,26 @@ +package sttp.tapir.serverless.aws.terraform + +import sttp.tapir.Endpoint +import sttp.tapir.server.ServerEndpoint + +trait AwsTerraformInterpreter { + def toTerraformConfig[I, E, O, S](e: Endpoint[I, E, O, S])(implicit options: AwsTerraformOptions): AwsTerraformApiGateway = + EndpointsToTerraformConfig(List(e)) + + def toTerraformConfig(es: Iterable[Endpoint[_, _, _, _]])(implicit options: AwsTerraformOptions): AwsTerraformApiGateway = + EndpointsToTerraformConfig(es.toList) + + def toTerraformConfig[I, E, O, S, F[_]](se: ServerEndpoint[I, E, O, S, F])(implicit + options: AwsTerraformOptions + ): AwsTerraformApiGateway = + EndpointsToTerraformConfig( + List(se.endpoint) + ) + + def serverEndpointsToTerraformConfig[F[_]](ses: Iterable[ServerEndpoint[_, _, _, _, F]])(implicit + options: AwsTerraformOptions + ): AwsTerraformApiGateway = + EndpointsToTerraformConfig(ses.map(_.endpoint).toList) +} + +object AwsTerraformInterpreter extends AwsTerraformInterpreter diff --git a/serverless/aws/terraform/src/main/scala/sttp/tapir/serverless/aws/terraform/AwsTerraformOptions.scala b/serverless/aws/terraform/src/main/scala/sttp/tapir/serverless/aws/terraform/AwsTerraformOptions.scala new file mode 100644 index 0000000000..6541eb2190 --- /dev/null +++ b/serverless/aws/terraform/src/main/scala/sttp/tapir/serverless/aws/terraform/AwsTerraformOptions.scala @@ -0,0 +1,45 @@ +package sttp.tapir.serverless.aws.terraform + +import io.circe.Json +import io.circe.literal._ +import sttp.tapir.serverless.aws.terraform.AwsTerraformOptions.lambdaDefaultAssumeRolePolicy + +import scala.concurrent.duration.{DurationInt, FiniteDuration} + +case class AwsTerraformOptions( + awsRegion: String, + functionName: String, + apiGatewayName: String, + apiGatewayDescription: String = "Serverless Application", + apiGatewayStage: String = "$default", + autoDeploy: Boolean = false, + assumeRolePolicy: Json = lambdaDefaultAssumeRolePolicy, + functionSource: FunctionSource, + timeout: FiniteDuration = 10.seconds, + memorySize: Int = 256 +) + +object AwsTerraformOptions { + // grants no policies for lambda function - it cannot access any other AWS services + private val lambdaDefaultAssumeRolePolicy = + json""" + { + "Version": "2012-10-17", + "Statement": [ + { + "Action": "sts:AssumeRole", + "Principal": { + "Service": "lambda.amazonaws.com" + }, + "Effect": "Allow", + "Sid": "" + } + ] + } + """ +} + +sealed trait FunctionSource +case class S3Source(bucket: String, key: String, runtime: String, handler: String) extends FunctionSource +case class ImageSource(imageUri: String) extends FunctionSource +case class CodeSource(fileName: String, runtime: String, handler: String) extends FunctionSource diff --git a/serverless/aws/terraform/src/main/scala/sttp/tapir/serverless/aws/terraform/EndpointsToTerraformConfig.scala b/serverless/aws/terraform/src/main/scala/sttp/tapir/serverless/aws/terraform/EndpointsToTerraformConfig.scala new file mode 100644 index 0000000000..b8ef23710c --- /dev/null +++ b/serverless/aws/terraform/src/main/scala/sttp/tapir/serverless/aws/terraform/EndpointsToTerraformConfig.scala @@ -0,0 +1,41 @@ +package sttp.tapir.serverless.aws.terraform + +import sttp.model.Method +import sttp.tapir.internal._ +import sttp.tapir.{Endpoint, EndpointInput} + +private[terraform] object EndpointsToTerraformConfig { + def apply(eps: List[Endpoint[_, _, _, _]])(implicit options: AwsTerraformOptions): AwsTerraformApiGateway = { + + val routes: Seq[AwsApiGatewayRoute] = eps.map { endpoint => + val method = endpoint.httpMethod.getOrElse(Method("ANY")) + + val basicInputs = endpoint.input.asVectorOfBasicInputs() + + val pathComponents: Seq[(Either[EndpointInput.FixedPath[_], EndpointInput.PathCapture[_]], String)] = basicInputs + .foldLeft((Seq.empty[(Either[EndpointInput.FixedPath[_], EndpointInput.PathCapture[_]], String)], 0)) { case ((acc, c), input) => + input match { + case fp @ EndpointInput.FixedPath(p, _, _) => (acc :+ Left(fp) -> p, c) + case pc @ EndpointInput.PathCapture(name, _, _) => + (acc :+ Right(pc) -> name.getOrElse(s"param$c"), if (name.isEmpty) c + 1 else c) + case _ => (acc, c) + } + } + ._1 + + val path = pathComponents + .map { + case (Left(_), p) => p + case (Right(_), p) => s"{$p}" + } + .mkString("/") + + val nameComponents = if (pathComponents.isEmpty) Vector("root") else pathComponents.map { case (_, name) => name } + val name = s"${method.method.toLowerCase.capitalize}${nameComponents.map(_.toLowerCase.capitalize).mkString}" + + AwsApiGatewayRoute(name, path, method) + } + + AwsTerraformApiGateway(routes) + } +} diff --git a/serverless/aws/terraform/src/main/scala/sttp/tapir/serverless/aws/terraform/TerraformResource.scala b/serverless/aws/terraform/src/main/scala/sttp/tapir/serverless/aws/terraform/TerraformResource.scala new file mode 100644 index 0000000000..fcaee84b30 --- /dev/null +++ b/serverless/aws/terraform/src/main/scala/sttp/tapir/serverless/aws/terraform/TerraformResource.scala @@ -0,0 +1,161 @@ +package sttp.tapir.serverless.aws.terraform + +import io.circe.Json +import sttp.tapir.serverless.aws.terraform.TerraformResource.{TapirApiGateway, terraformResource} + +import scala.concurrent.duration.FiniteDuration + +sealed trait TerraformResource { + def json(): Json +} + +private[terraform] object TerraformResource { + val TapirApiGateway = "TapirApiGateway" // main resource name + + def terraformResource[R](`type`: String, name: String, encoded: Json): Json = + Json.fromFields(Seq(`type` -> Json.fromFields(Seq(name -> encoded)))) +} + +case class AwsLambdaFunction(name: String, timeout: FiniteDuration, memorySize: Int, source: FunctionSource) extends TerraformResource { + override def json(): Json = { + val functionSource: Seq[(String, Json)] = source match { + case s3: S3Source => + Seq( + "s3_bucket" -> Json.fromString(s3.bucket), + "s3_key" -> Json.fromString(s3.key), + "runtime" -> Json.fromString(s3.runtime), + "handler" -> Json.fromString(s3.handler) + ) + case image: ImageSource => Seq("image_uri" -> Json.fromString(image.imageUri)) + case code: CodeSource => + Seq( + "filename" -> Json.fromString(code.fileName), + "runtime" -> Json.fromString(code.runtime), + "handler" -> Json.fromString(code.handler) + ) + } + + val lambdaFunction = Json.fromFields( + Seq( + "function_name" -> Json.fromString(name), + "role" -> Json.fromString("${aws_iam_role.lambda_exec.arn}"), + "timeout" -> Json.fromLong(timeout.toSeconds), + "memory_size" -> Json.fromInt(memorySize) + ) ++ functionSource + ) + + terraformResource("aws_lambda_function", "lambda", lambdaFunction) + } +} + +case class AwsIamRole(assumeRolePolicy: String) extends TerraformResource { + override def json(): Json = terraformResource( + "aws_iam_role", + "lambda_exec", + Json.fromFields( + Seq( + "name" -> Json.fromString("lambda_exec_role"), + "assume_role_policy" -> Json.fromString(assumeRolePolicy) + ) + ) + ) +} + +case object AwsLambdaPermission extends TerraformResource { + override def json(): Json = + terraformResource( + "aws_lambda_permission", + "api_gateway_permission", + Json.fromFields( + Seq( + "statement_id" -> Json.fromString("AllowAPIGatewayInvoke"), + "action" -> Json.fromString("lambda:InvokeFunction"), + "function_name" -> Json.fromString(s"$${aws_lambda_function.lambda.function_name}"), + "principal" -> Json.fromString("apigateway.amazonaws.com"), + "source_arn" -> Json.fromString(s"$${aws_apigatewayv2_api.$TapirApiGateway.execution_arn}/*/*") + ) + ) + ) +} + +case class AwsApiGatewayV2Api(name: String, description: String) extends TerraformResource { + override def json(): Json = { + terraformResource( + "aws_apigatewayv2_api", + TapirApiGateway, + Json.fromFields( + Seq( + "name" -> Json.fromString(name), + "description" -> Json.fromString(description), + "protocol_type" -> Json.fromString("HTTP") + ) + ) + ) + } +} + +case class AwsApiGatewayV2Route( + name: String, + routeKey: String, // "METHOD PATH" + integration: String +) extends TerraformResource { + override def json(): Json = terraformResource( + "aws_apigatewayv2_route", + name, + Json.fromFields( + Seq( + "api_id" -> Json.fromString(s"$${aws_apigatewayv2_api.$TapirApiGateway.id}"), + "route_key" -> Json.fromString(routeKey), + "authorization_type" -> Json.fromString("NONE"), + "target" -> Json.fromString(s"integrations/$${aws_apigatewayv2_integration.$integration.id}") + ) + ) + ) +} + +case class AwsApiGatewayV2Integration(name: String) extends TerraformResource { + override def json(): Json = terraformResource( + "aws_apigatewayv2_integration", + name, + Json.fromFields( + Seq( + "api_id" -> Json.fromString(s"$${aws_apigatewayv2_api.$TapirApiGateway.id}"), + "integration_type" -> Json.fromString("AWS_PROXY"), + "integration_method" -> Json.fromString("POST"), + "integration_uri" -> Json.fromString(s"$${aws_lambda_function.lambda.invoke_arn}"), + "payload_format_version" -> Json.fromString("2.0") + ) + ) + ) +} + +case class AwsApiGatewayV2Deployment(dependsOn: Seq[String]) extends TerraformResource { + override def json(): Json = + terraformResource( + "aws_apigatewayv2_deployment", + TapirApiGateway, + Json.fromFields( + Seq( + "depends_on" -> Json.fromValues( + dependsOn.map { d => Json.fromString(s"aws_apigatewayv2_route.$d") } + ), + "api_id" -> Json.fromString(s"$${aws_apigatewayv2_api.$TapirApiGateway.id}") + ) + ) + ) +} + +case class AwsApiGatewayV2Stage(stage: String, autoDeploy: Boolean) extends TerraformResource { + override def json(): Json = + terraformResource( + "aws_apigatewayv2_stage", + TapirApiGateway, + Json.fromFields( + Seq( + "api_id" -> Json.fromString(s"$${aws_apigatewayv2_api.$TapirApiGateway.id}"), + "name" -> Json.fromString(stage), + "auto_deploy" -> Json.fromBoolean(autoDeploy) + ) + ) + ) +} diff --git a/serverless/aws/terraform/src/main/scala/sttp/tapir/serverless/aws/terraform/model.scala b/serverless/aws/terraform/src/main/scala/sttp/tapir/serverless/aws/terraform/model.scala new file mode 100644 index 0000000000..e118a545ca --- /dev/null +++ b/serverless/aws/terraform/src/main/scala/sttp/tapir/serverless/aws/terraform/model.scala @@ -0,0 +1,15 @@ +package sttp.tapir.serverless.aws.terraform +import io.circe.Printer +import io.circe.syntax._ +import sttp.model.Method +import sttp.tapir.serverless.aws.terraform.AwsTerraformEncoders._ + +case class AwsTerraformApiGateway(routes: Seq[AwsApiGatewayRoute]) { + def toJson()(implicit options: AwsTerraformOptions): String = Printer.spaces2.print(this.asJson) +} + +case class AwsApiGatewayRoute( + name: String, + path: String, + httpMethod: Method +) diff --git a/serverless/aws/terraform/src/test/resources/endpoint_with_params.json b/serverless/aws/terraform/src/test/resources/endpoint_with_params.json new file mode 100644 index 0000000000..80edc5d20d --- /dev/null +++ b/serverless/aws/terraform/src/test/resources/endpoint_with_params.json @@ -0,0 +1,105 @@ +{ + "terraform" : { + "required_providers" : { + "aws" : { + "source" : "hashicorp/aws" + } + } + }, + "provider" : { + "aws" : [ + { + "region" : "eu-central-1" + } + ] + }, + "resource" : [ + { + "aws_lambda_function" : { + "lambda" : { + "function_name" : "Tapir", + "role" : "${aws_iam_role.lambda_exec.arn}", + "timeout" : 10, + "memory_size" : 256, + "s3_bucket" : "bucket", + "s3_key" : "key", + "runtime" : "java11", + "handler" : "Handler::handleRequest" + } + } + }, + { + "aws_iam_role" : { + "lambda_exec" : { + "name" : "lambda_exec_role", + "assume_role_policy" : "{\"Version\":\"2012-10-17\",\"Statement\":[{\"Action\":\"sts:AssumeRole\",\"Principal\":{\"Service\":\"lambda.amazonaws.com\"},\"Effect\":\"Allow\",\"Sid\":\"\"}]}" + } + } + }, + { + "aws_lambda_permission" : { + "api_gateway_permission" : { + "statement_id" : "AllowAPIGatewayInvoke", + "action" : "lambda:InvokeFunction", + "function_name" : "${aws_lambda_function.lambda.function_name}", + "principal" : "apigateway.amazonaws.com", + "source_arn" : "${aws_apigatewayv2_api.TapirApiGateway.execution_arn}/*/*" + } + } + }, + { + "aws_apigatewayv2_api" : { + "TapirApiGateway" : { + "name" : "TapirApiGateway", + "description" : "Serverless Application", + "protocol_type" : "HTTP" + } + } + }, + { + "aws_apigatewayv2_deployment" : { + "TapirApiGateway" : { + "depends_on" : [ + "aws_apigatewayv2_route.GetAccountsIdHistory" + ], + "api_id" : "${aws_apigatewayv2_api.TapirApiGateway.id}" + } + } + }, + { + "aws_apigatewayv2_stage" : { + "TapirApiGateway" : { + "api_id" : "${aws_apigatewayv2_api.TapirApiGateway.id}", + "name" : "$default", + "auto_deploy" : false + } + } + }, + { + "aws_apigatewayv2_integration" : { + "GetAccountsIdHistory" : { + "api_id" : "${aws_apigatewayv2_api.TapirApiGateway.id}", + "integration_type" : "AWS_PROXY", + "integration_method" : "POST", + "integration_uri" : "${aws_lambda_function.lambda.invoke_arn}", + "payload_format_version" : "2.0" + } + } + }, + { + "aws_apigatewayv2_route" : { + "GetAccountsIdHistory" : { + "api_id" : "${aws_apigatewayv2_api.TapirApiGateway.id}", + "route_key" : "GET /accounts/{id}/history", + "authorization_type" : "NONE", + "target" : "integrations/${aws_apigatewayv2_integration.GetAccountsIdHistory.id}" + } + } + } + ], + "output" : { + "base_url" : { + "value" : "${aws_apigatewayv2_api.TapirApiGateway.api_endpoint}" + } + } +} \ No newline at end of file diff --git a/serverless/aws/terraform/src/test/resources/endpoints_common_paths.json b/serverless/aws/terraform/src/test/resources/endpoints_common_paths.json new file mode 100644 index 0000000000..19ee4ae1ce --- /dev/null +++ b/serverless/aws/terraform/src/test/resources/endpoints_common_paths.json @@ -0,0 +1,171 @@ +{ + "terraform" : { + "required_providers" : { + "aws" : { + "source" : "hashicorp/aws" + } + } + }, + "provider" : { + "aws" : [ + { + "region" : "eu-central-1" + } + ] + }, + "resource" : [ + { + "aws_lambda_function" : { + "lambda" : { + "function_name" : "Tapir", + "role" : "${aws_iam_role.lambda_exec.arn}", + "timeout" : 10, + "memory_size" : 256, + "s3_bucket" : "bucket", + "s3_key" : "key", + "runtime" : "java11", + "handler" : "Handler::handleRequest" + } + } + }, + { + "aws_iam_role" : { + "lambda_exec" : { + "name" : "lambda_exec_role", + "assume_role_policy" : "{\"Version\":\"2012-10-17\",\"Statement\":[{\"Action\":\"sts:AssumeRole\",\"Principal\":{\"Service\":\"lambda.amazonaws.com\"},\"Effect\":\"Allow\",\"Sid\":\"\"}]}" + } + } + }, + { + "aws_lambda_permission" : { + "api_gateway_permission" : { + "statement_id" : "AllowAPIGatewayInvoke", + "action" : "lambda:InvokeFunction", + "function_name" : "${aws_lambda_function.lambda.function_name}", + "principal" : "apigateway.amazonaws.com", + "source_arn" : "${aws_apigatewayv2_api.TapirApiGateway.execution_arn}/*/*" + } + } + }, + { + "aws_apigatewayv2_api" : { + "TapirApiGateway" : { + "name" : "TapirApiGateway", + "description" : "Serverless Application", + "protocol_type" : "HTTP" + } + } + }, + { + "aws_apigatewayv2_deployment" : { + "TapirApiGateway" : { + "depends_on" : [ + "aws_apigatewayv2_route.GetAccountsId", + "aws_apigatewayv2_route.PostAccounts", + "aws_apigatewayv2_route.GetAccountsIdTransactions", + "aws_apigatewayv2_route.PostAccountsIdTransactions" + ], + "api_id" : "${aws_apigatewayv2_api.TapirApiGateway.id}" + } + } + }, + { + "aws_apigatewayv2_stage" : { + "TapirApiGateway" : { + "api_id" : "${aws_apigatewayv2_api.TapirApiGateway.id}", + "name" : "$default", + "auto_deploy" : false + } + } + }, + { + "aws_apigatewayv2_integration" : { + "GetAccountsId" : { + "api_id" : "${aws_apigatewayv2_api.TapirApiGateway.id}", + "integration_type" : "AWS_PROXY", + "integration_method" : "POST", + "integration_uri" : "${aws_lambda_function.lambda.invoke_arn}", + "payload_format_version" : "2.0" + } + } + }, + { + "aws_apigatewayv2_route" : { + "GetAccountsId" : { + "api_id" : "${aws_apigatewayv2_api.TapirApiGateway.id}", + "route_key" : "GET /accounts/{id}", + "authorization_type" : "NONE", + "target" : "integrations/${aws_apigatewayv2_integration.GetAccountsId.id}" + } + } + }, + { + "aws_apigatewayv2_integration" : { + "PostAccounts" : { + "api_id" : "${aws_apigatewayv2_api.TapirApiGateway.id}", + "integration_type" : "AWS_PROXY", + "integration_method" : "POST", + "integration_uri" : "${aws_lambda_function.lambda.invoke_arn}", + "payload_format_version" : "2.0" + } + } + }, + { + "aws_apigatewayv2_route" : { + "PostAccounts" : { + "api_id" : "${aws_apigatewayv2_api.TapirApiGateway.id}", + "route_key" : "POST /accounts", + "authorization_type" : "NONE", + "target" : "integrations/${aws_apigatewayv2_integration.PostAccounts.id}" + } + } + }, + { + "aws_apigatewayv2_integration" : { + "GetAccountsIdTransactions" : { + "api_id" : "${aws_apigatewayv2_api.TapirApiGateway.id}", + "integration_type" : "AWS_PROXY", + "integration_method" : "POST", + "integration_uri" : "${aws_lambda_function.lambda.invoke_arn}", + "payload_format_version" : "2.0" + } + } + }, + { + "aws_apigatewayv2_route" : { + "GetAccountsIdTransactions" : { + "api_id" : "${aws_apigatewayv2_api.TapirApiGateway.id}", + "route_key" : "GET /accounts/{id}/transactions", + "authorization_type" : "NONE", + "target" : "integrations/${aws_apigatewayv2_integration.GetAccountsIdTransactions.id}" + } + } + }, + { + "aws_apigatewayv2_integration" : { + "PostAccountsIdTransactions" : { + "api_id" : "${aws_apigatewayv2_api.TapirApiGateway.id}", + "integration_type" : "AWS_PROXY", + "integration_method" : "POST", + "integration_uri" : "${aws_lambda_function.lambda.invoke_arn}", + "payload_format_version" : "2.0" + } + } + }, + { + "aws_apigatewayv2_route" : { + "PostAccountsIdTransactions" : { + "api_id" : "${aws_apigatewayv2_api.TapirApiGateway.id}", + "route_key" : "POST /accounts/{id}/transactions", + "authorization_type" : "NONE", + "target" : "integrations/${aws_apigatewayv2_integration.PostAccountsIdTransactions.id}" + } + } + } + ], + "output" : { + "base_url" : { + "value" : "${aws_apigatewayv2_api.TapirApiGateway.api_endpoint}" + } + } +} \ No newline at end of file diff --git a/serverless/aws/terraform/src/test/resources/root_endpoint.json b/serverless/aws/terraform/src/test/resources/root_endpoint.json new file mode 100644 index 0000000000..f65843e805 --- /dev/null +++ b/serverless/aws/terraform/src/test/resources/root_endpoint.json @@ -0,0 +1,105 @@ +{ + "terraform" : { + "required_providers" : { + "aws" : { + "source" : "hashicorp/aws" + } + } + }, + "provider" : { + "aws" : [ + { + "region" : "eu-central-1" + } + ] + }, + "resource" : [ + { + "aws_lambda_function" : { + "lambda" : { + "function_name" : "Tapir", + "role" : "${aws_iam_role.lambda_exec.arn}", + "timeout" : 10, + "memory_size" : 256, + "s3_bucket" : "bucket", + "s3_key" : "key", + "runtime" : "java11", + "handler" : "Handler::handleRequest" + } + } + }, + { + "aws_iam_role" : { + "lambda_exec" : { + "name" : "lambda_exec_role", + "assume_role_policy" : "{\"Version\":\"2012-10-17\",\"Statement\":[{\"Action\":\"sts:AssumeRole\",\"Principal\":{\"Service\":\"lambda.amazonaws.com\"},\"Effect\":\"Allow\",\"Sid\":\"\"}]}" + } + } + }, + { + "aws_lambda_permission" : { + "api_gateway_permission" : { + "statement_id" : "AllowAPIGatewayInvoke", + "action" : "lambda:InvokeFunction", + "function_name" : "${aws_lambda_function.lambda.function_name}", + "principal" : "apigateway.amazonaws.com", + "source_arn" : "${aws_apigatewayv2_api.TapirApiGateway.execution_arn}/*/*" + } + } + }, + { + "aws_apigatewayv2_api" : { + "TapirApiGateway" : { + "name" : "TapirApiGateway", + "description" : "Serverless Application", + "protocol_type" : "HTTP" + } + } + }, + { + "aws_apigatewayv2_deployment" : { + "TapirApiGateway" : { + "depends_on" : [ + "aws_apigatewayv2_route.AnyRoot" + ], + "api_id" : "${aws_apigatewayv2_api.TapirApiGateway.id}" + } + } + }, + { + "aws_apigatewayv2_stage" : { + "TapirApiGateway" : { + "api_id" : "${aws_apigatewayv2_api.TapirApiGateway.id}", + "name" : "$default", + "auto_deploy" : false + } + } + }, + { + "aws_apigatewayv2_integration" : { + "AnyRoot" : { + "api_id" : "${aws_apigatewayv2_api.TapirApiGateway.id}", + "integration_type" : "AWS_PROXY", + "integration_method" : "POST", + "integration_uri" : "${aws_lambda_function.lambda.invoke_arn}", + "payload_format_version" : "2.0" + } + } + }, + { + "aws_apigatewayv2_route" : { + "AnyRoot" : { + "api_id" : "${aws_apigatewayv2_api.TapirApiGateway.id}", + "route_key" : "ANY /", + "authorization_type" : "NONE", + "target" : "integrations/${aws_apigatewayv2_integration.AnyRoot.id}" + } + } + } + ], + "output" : { + "base_url" : { + "value" : "${aws_apigatewayv2_api.TapirApiGateway.api_endpoint}" + } + } +} \ No newline at end of file diff --git a/serverless/aws/terraform/src/test/resources/simple_endpoint.json b/serverless/aws/terraform/src/test/resources/simple_endpoint.json new file mode 100644 index 0000000000..8f8515fd1a --- /dev/null +++ b/serverless/aws/terraform/src/test/resources/simple_endpoint.json @@ -0,0 +1,105 @@ +{ + "terraform" : { + "required_providers" : { + "aws" : { + "source" : "hashicorp/aws" + } + } + }, + "provider" : { + "aws" : [ + { + "region" : "eu-central-1" + } + ] + }, + "resource" : [ + { + "aws_lambda_function" : { + "lambda" : { + "function_name" : "Tapir", + "role" : "${aws_iam_role.lambda_exec.arn}", + "timeout" : 10, + "memory_size" : 256, + "s3_bucket" : "bucket", + "s3_key" : "key", + "runtime" : "java11", + "handler" : "Handler::handleRequest" + } + } + }, + { + "aws_iam_role" : { + "lambda_exec" : { + "name" : "lambda_exec_role", + "assume_role_policy" : "{\"Version\":\"2012-10-17\",\"Statement\":[{\"Action\":\"sts:AssumeRole\",\"Principal\":{\"Service\":\"lambda.amazonaws.com\"},\"Effect\":\"Allow\",\"Sid\":\"\"}]}" + } + } + }, + { + "aws_lambda_permission" : { + "api_gateway_permission" : { + "statement_id" : "AllowAPIGatewayInvoke", + "action" : "lambda:InvokeFunction", + "function_name" : "${aws_lambda_function.lambda.function_name}", + "principal" : "apigateway.amazonaws.com", + "source_arn" : "${aws_apigatewayv2_api.TapirApiGateway.execution_arn}/*/*" + } + } + }, + { + "aws_apigatewayv2_api" : { + "TapirApiGateway" : { + "name" : "TapirApiGateway", + "description" : "Serverless Application", + "protocol_type" : "HTTP" + } + } + }, + { + "aws_apigatewayv2_deployment" : { + "TapirApiGateway" : { + "depends_on" : [ + "aws_apigatewayv2_route.GetHelloWorld" + ], + "api_id" : "${aws_apigatewayv2_api.TapirApiGateway.id}" + } + } + }, + { + "aws_apigatewayv2_stage" : { + "TapirApiGateway" : { + "api_id" : "${aws_apigatewayv2_api.TapirApiGateway.id}", + "name" : "$default", + "auto_deploy" : false + } + } + }, + { + "aws_apigatewayv2_integration" : { + "GetHelloWorld" : { + "api_id" : "${aws_apigatewayv2_api.TapirApiGateway.id}", + "integration_type" : "AWS_PROXY", + "integration_method" : "POST", + "integration_uri" : "${aws_lambda_function.lambda.invoke_arn}", + "payload_format_version" : "2.0" + } + } + }, + { + "aws_apigatewayv2_route" : { + "GetHelloWorld" : { + "api_id" : "${aws_apigatewayv2_api.TapirApiGateway.id}", + "route_key" : "GET /hello/world", + "authorization_type" : "NONE", + "target" : "integrations/${aws_apigatewayv2_integration.GetHelloWorld.id}" + } + } + } + ], + "output" : { + "base_url" : { + "value" : "${aws_apigatewayv2_api.TapirApiGateway.api_endpoint}" + } + } +} \ No newline at end of file diff --git a/serverless/aws/terraform/src/test/scala/sttp/tapir/serverless/aws/terraform/VerifyTerraformTemplateTest.scala b/serverless/aws/terraform/src/test/scala/sttp/tapir/serverless/aws/terraform/VerifyTerraformTemplateTest.scala new file mode 100644 index 0000000000..73c0983f42 --- /dev/null +++ b/serverless/aws/terraform/src/test/scala/sttp/tapir/serverless/aws/terraform/VerifyTerraformTemplateTest.scala @@ -0,0 +1,76 @@ +package sttp.tapir.serverless.aws.terraform + +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers +import sttp.tapir._ +import sttp.tapir.serverless.aws.terraform.VerifyTerraformTemplateTest.{load, noIndentation} + +import scala.io.Source + +class VerifyTerraformTemplateTest extends AnyFunSuite with Matchers { + + private implicit val options: AwsTerraformOptions = AwsTerraformOptions( + awsRegion = "eu-central-1", + functionName = "Tapir", + apiGatewayName = "TapirApiGateway", + functionSource = S3Source("bucket", "key", "java11", "Handler::handleRequest") + ) + + test("should handle empty endpoint list") { + AwsTerraformInterpreter.toTerraformConfig(List.empty).toJson() + } + + test("should match expected json root endpoint") { + val ep = endpoint + + val expectedJson = load("root_endpoint.json") + val actualJson = AwsTerraformInterpreter.toTerraformConfig(List(ep)).toJson() + + expectedJson shouldBe noIndentation(actualJson) + } + + test("should match expected json simple endpoint") { + val ep = endpoint.get.in("hello" / "world") + + val expectedJson = load("simple_endpoint.json") + val actualJson = AwsTerraformInterpreter.toTerraformConfig(List(ep)).toJson() + + expectedJson shouldBe noIndentation(actualJson) + } + + test("should match expected json endpoint with params") { + val ep = endpoint.get + .in("accounts" / path[String]("id") / "history") + .in(query[Int]("limit")) + .in(header[String]("X-Account")) + .in(header[String]("X-Secret")) + + val expectedJson = load("endpoint_with_params.json") + val actualJson = AwsTerraformInterpreter.toTerraformConfig(List(ep)).toJson() + + expectedJson shouldBe noIndentation(actualJson) + } + + test("should match expected json endpoints with common path") { + val eps = List( + endpoint.get.in("accounts" / path[String]("id")), + endpoint.post.in("accounts"), + endpoint.get.in("accounts" / path[String]("id") / "transactions"), + endpoint.post.in("accounts" / path[String]("id") / "transactions") + ) + + val expectedJson = load("endpoints_common_paths.json") + val actualJson = AwsTerraformInterpreter.toTerraformConfig(eps).toJson() + + expectedJson shouldBe noIndentation(actualJson) + } +} + +object VerifyTerraformTemplateTest { + + def load(fileName: String): String = { + noIndentation(Source.fromInputStream(getClass.getResourceAsStream(s"/$fileName")).getLines().mkString("\n")) + } + + def noIndentation(s: String): String = s.replaceAll("[ \t]", "").trim +}