Skip to content

Commit

Permalink
#46: initial implementation of auth inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw committed Feb 28, 2019
1 parent bb945db commit 7c7712f
Show file tree
Hide file tree
Showing 20 changed files with 338 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ class EndpointToSttpClient(clientOptions: SttpClientOptions) {
val headers = paramsAsArgs.paramAt(params, paramIndex).asInstanceOf[Seq[(String, String)]]
val req2 = headers.foldLeft(req) { case (r, (k, v)) => r.header(k, v) }
setInputParams(tail, params, paramsAsArgs, paramIndex + 1, uri, req2)
case (a: EndpointInput.Auth[_]) +: tail =>
setInputParams(a.input +: tail, params, paramsAsArgs, paramIndex, uri, req)
case EndpointInput.Mapped(wrapped, _, g, wrappedParamsAsArgs) +: tail =>
handleMapped(wrapped, g, wrappedParamsAsArgs, tail)
case EndpointIO.Mapped(wrapped, _, g, wrappedParamsAsArgs) +: tail =>
Expand Down
14 changes: 13 additions & 1 deletion client/tests/src/main/scala/tapir/client/tests/ClientTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import tapir.tests._
import tapir.typelevel.ParamsAsArgs
import TestUtil._
import org.http4s.multipart
import tapir.model.MultiQueryParams
import tapir.model.{MultiQueryParams, UsernamePassword}

import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext, Future}
Expand Down Expand Up @@ -63,6 +63,13 @@ trait ClientTests[S] extends FunSuite with Matchers with BeforeAndAfterAll {
testClient(in_paths_out_string, Seq("fruit", "apple", "amount", "50"), Right("apple 50 None"))
testClient(in_query_list_out_header_list, List("plum", "watermelon", "apple"), Right(List("apple", "watermelon", "plum")))
testClient(in_simple_multipart_out_string, FruitAmount("melon", 10), Right("melon=10"))
// TODO: test root path
testClient(in_auth_apikey_header_out_string, "1234", Right("Authorization=None; X-Api-Key=Some(1234); Query=None"))
testClient(in_auth_apikey_query_out_string, "1234", Right("Authorization=None; X-Api-Key=None; Query=Some(1234)"))
testClient(in_auth_basic_out_string,
UsernamePassword("teddy", Some("bear")),
Right("Authorization=Some(Basic dGVkZHk6YmVhcg==); X-Api-Key=None; Query=None"))
testClient(in_auth_bearer_out_string, "1234", Right("Authorization=Some(Bearer 1234); X-Api-Key=None; Query=None"))

//

Expand Down Expand Up @@ -91,6 +98,7 @@ trait ClientTests[S] extends FunSuite with Matchers with BeforeAndAfterAll {
private object fruitParam extends QueryParamDecoderMatcher[String]("fruit")
private object amountOptParam extends OptionalQueryParamDecoderMatcher[String]("amount")
private object colorOptParam extends OptionalQueryParamDecoderMatcher[String]("color")
private object apiKeyOptParam extends OptionalQueryParamDecoderMatcher[String]("api-key")

private val service = HttpRoutes.of[IO] {
case GET -> Root :? fruitParam(f) +& amountOptParam(amount) =>
Expand All @@ -117,6 +125,10 @@ trait ClientTests[S] extends FunSuite with Matchers with BeforeAndAfterAll {
case None => Ok()
case Some(h) => Ok("Role: " + h.value)
}
case r @ GET -> Root / "auth" :? apiKeyOptParam(ak) =>
val authHeader = r.headers.get(CaseInsensitiveString("Authorization")).map(_.value)
val xApiKey = r.headers.get(CaseInsensitiveString("X-Api-Key")).map(_.value)
Ok(s"Authorization=$authHeader; X-Api-Key=$xApiKey; Query=$ak")
}

private val app: HttpApp[IO] = Router("/" -> service).orNotFound
Expand Down
38 changes: 33 additions & 5 deletions core/src/main/scala/tapir/EndpointIO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ sealed trait EndpointInput[I] {
case m: EndpointIO.Multiple[_] => m.ios
}

private[tapir] def asVectorOfBasic: Vector[EndpointInput.Basic[_]] = this match {
private[tapir] def asVectorOfBasic(includeAuth: Boolean = true): Vector[EndpointInput.Basic[_]] = this match {
case b: EndpointInput.Basic[_] => Vector(b)
case EndpointInput.Mapped(wrapped, _, _, _) => wrapped.asVectorOfBasic
case EndpointIO.Mapped(wrapped, _, _, _) => wrapped.asVectorOfBasic
case EndpointInput.Multiple(inputs) => inputs.flatMap(_.asVectorOfBasic)
case EndpointIO.Multiple(ios) => ios.flatMap(_.asVectorOfBasic)
case EndpointInput.Mapped(wrapped, _, _, _) => wrapped.asVectorOfBasic(includeAuth)
case EndpointIO.Mapped(wrapped, _, _, _) => wrapped.asVectorOfBasic(includeAuth)
case EndpointInput.Multiple(inputs) => inputs.flatMap(_.asVectorOfBasic(includeAuth))
case EndpointIO.Multiple(ios) => ios.flatMap(_.asVectorOfBasic(includeAuth))
case a: EndpointInput.Auth[_] => if (includeAuth) a.input.asVectorOfBasic(includeAuth) else Vector.empty
}

private[tapir] def bodyType: Option[RawValueType[_]] = this match {
Expand All @@ -40,6 +41,7 @@ sealed trait EndpointInput[I] {
case EndpointIO.Multiple(inputs) => inputs.flatMap(_.bodyType).headOption
case EndpointInput.Mapped(wrapped, _, _, _) => wrapped.bodyType
case EndpointIO.Mapped(wrapped, _, _, _) => wrapped.bodyType
case a: EndpointInput.Auth[_] => a.input.bodyType
case _ => None
}

Expand All @@ -51,6 +53,17 @@ sealed trait EndpointInput[I] {
case EndpointIO.Mapped(wrapped, _, _, _) => wrapped.method
case _ => None
}

private[tapir] def auths: Vector[EndpointInput.Auth[_]] = this match {
case a: EndpointInput.Auth[_] => Vector(a)
case EndpointInput.Multiple(inputs) => inputs.flatMap(_.auths)
case EndpointIO.Multiple(inputs) => inputs.flatMap(_.auths)
case EndpointInput.Mapped(wrapped, _, _, _) => wrapped.auths
case EndpointIO.Mapped(wrapped, _, _, _) => wrapped.auths
case _ => Vector.empty
}

// TODO: add generic traverse
}

object EndpointInput {
Expand Down Expand Up @@ -100,6 +113,21 @@ object EndpointInput {

//

trait Auth[T] extends EndpointInput.Single[T] {
def input: EndpointInput.Single[T]
}

object Auth {
case class ApiKey[T](input: EndpointInput.Single[T]) extends Auth[T] {
def show = s"auth(api key, via ${input.show})"
}
case class Http[T](scheme: String, input: EndpointInput.Single[T]) extends Auth[T] {
def show = s"auth($scheme http, via ${input.show})"
}
}

//

case class Mapped[I, T](wrapped: EndpointInput[I], f: I => T, g: T => I, paramsAsArgs: ParamsAsArgs[I]) extends Single[T] {
override def show: String = s"map(${wrapped.show})"
}
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/tapir/Tapir.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ trait Tapir {
def streamBody[S](schema: Schema, mediaType: MediaType): StreamingEndpointIO.Body[S, mediaType.type] =
StreamingEndpointIO.Body(schema, mediaType, EndpointIO.Info.empty)

def auth: TapirAuth.type = TapirAuth

def schemaFor[T: SchemaFor]: Schema = implicitly[SchemaFor[T]].schema

val endpoint: Endpoint[Unit, Unit, Unit, Nothing] =
Expand Down
47 changes: 47 additions & 0 deletions core/src/main/scala/tapir/TapirAuth.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package tapir
import java.util.Base64

import tapir.Codec.PlainCodec
import tapir.model.UsernamePassword

object TapirAuth {
private val BasicAuthType = "Basic"
private val BearerAuthType = "Bearer"

def apiKey[T](input: EndpointInput.Single[T]): EndpointInput.Auth.ApiKey[T] = EndpointInput.Auth.ApiKey[T](input)
val basic: EndpointInput.Auth.Http[UsernamePassword] = httpAuth(BasicAuthType, usernamePasswordCodec(credentialsCodec(BasicAuthType)))
val bearer: EndpointInput.Auth.Http[String] = httpAuth(BearerAuthType, credentialsCodec(BearerAuthType))

private def httpAuth[T](authType: String, codec: PlainCodec[T]): EndpointInput.Auth.Http[T] =
EndpointInput.Auth.Http(authType, header[T]("Authorization")(CodecForMany.fromCodec(codec)))

private def usernamePasswordCodec(baseCodec: PlainCodec[String]): PlainCodec[UsernamePassword] = {
def decode(s: String): DecodeResult[UsernamePassword] =
try {
val s2 = new String(Base64.getDecoder.decode(s))
val up = s2.split(":", 2) match {
case Array() => UsernamePassword("", None)
case Array(u) => UsernamePassword(u, None)
case Array(u, "") => UsernamePassword(u, None)
case Array(u, p) => UsernamePassword(u, Some(p))
}
DecodeResult.Value(up)
} catch {
case e: Exception => DecodeResult.Error(s, e)
}

def encode(up: UsernamePassword): String =
Base64.getEncoder.encodeToString(s"${up.username}:${up.password.getOrElse("")}".getBytes("UTF-8"))

baseCodec.mapDecode(decode)(encode)
}

private def credentialsCodec(authType: String): PlainCodec[String] = {
val authTypeWithSpace = authType + " "
val prefixLength = authTypeWithSpace.length
def removeAuthType(v: String): DecodeResult[String] =
if (v.startsWith(authType)) DecodeResult.Value(v.substring(prefixLength))
else DecodeResult.Error(v, new IllegalArgumentException(s"The given value doesn't start with $authType"))
Codec.stringPlainCodecUtf8.mapDecode(removeAuthType)(v => s"$authType $v")
}
}
8 changes: 4 additions & 4 deletions core/src/main/scala/tapir/internal/server/DecodeInputs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ import scala.annotation.tailrec

trait DecodeInputsResult
object DecodeInputsResult {
case class Values(values: Map[EndpointInput.Single[_], Any], bodyInput: Option[EndpointIO.Body[_, _, _]]) extends DecodeInputsResult {
def value(i: EndpointInput.Single[_], v: Any): Values = copy(values = values + (i -> v))
case class Values(values: Map[EndpointInput.Basic[_], Any], bodyInput: Option[EndpointIO.Body[_, _, _]]) extends DecodeInputsResult {
def value(i: EndpointInput.Basic[_], v: Any): Values = copy(values = values + (i -> v))
}
case class Failure(input: EndpointInput.Single[_], failure: DecodeFailure) extends DecodeInputsResult
case class Failure(input: EndpointInput.Basic[_], failure: DecodeFailure) extends DecodeInputsResult
}

trait DecodeInputsContext {
Expand Down Expand Up @@ -40,7 +40,7 @@ object DecodeInputs {
*/
def apply(input: EndpointInput[_], ctx: DecodeInputsContext): DecodeInputsResult = {
// the first decoding failure is returned. We decode in the following order: method, path, query, headers, body
val inputs = input.asVectorOfBasic.sortBy {
val inputs = input.asVectorOfBasic().sortBy {
case _: EndpointInput.RequestMethod => 0
case _: EndpointInput.PathSegment => 1
case _: EndpointInput.PathCapture[_] => 1
Expand Down
10 changes: 6 additions & 4 deletions core/src/main/scala/tapir/internal/server/InputValues.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ object InputValues {
* Returns the values of the inputs in the order specified by `input`, and mapped if necessary using defined mapping
* functions.
*/
def apply(input: EndpointInput[_], values: Map[EndpointInput.Single[_], Any]): List[Any] = apply(input.asVectorOfSingle, values)
def apply(input: EndpointInput[_], values: Map[EndpointInput.Basic[_], Any]): List[Any] = apply(input.asVectorOfSingle, values)

private def apply(inputs: Vector[EndpointInput.Single[_]], values: Map[EndpointInput.Single[_], Any]): List[Any] = {
private def apply(inputs: Vector[EndpointInput.Single[_]], values: Map[EndpointInput.Basic[_], Any]): List[Any] = {
inputs match {
case Vector() => Nil
case (_: EndpointInput.RequestMethod) +: inputsTail =>
Expand All @@ -21,15 +21,17 @@ object InputValues {
handleMapped(wrapped, f, inputsTail, values)
case EndpointIO.Mapped(wrapped, f, _, _) +: inputsTail =>
handleMapped(wrapped, f, inputsTail, values)
case (input: EndpointInput.Single[_]) +: inputsTail =>
case (auth: EndpointInput.Auth[_]) +: inputsTail =>
apply(auth.input +: inputsTail, values)
case (input: EndpointInput.Basic[_]) +: inputsTail =>
values(input) :: apply(inputsTail, values)
}
}

private def handleMapped[II, T](wrapped: EndpointInput[II],
f: II => T,
inputsTail: Vector[EndpointInput.Single[_]],
values: Map[EndpointInput.Single[_], Any]): List[Any] = {
values: Map[EndpointInput.Basic[_], Any]): List[Any] = {
val wrappedValue = apply(wrapped.asVectorOfSingle, values)
f.asInstanceOf[Any => Any].apply(SeqToParams(wrappedValue)) :: apply(inputsTail, values)
}
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/scala/tapir/model/UsernamePassword.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package tapir.model

case class UsernamePassword(username: String, password: Option[String])
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ object EndpointToOpenAPIDocs {
def toOpenAPI(api: Info, es: Iterable[Endpoint[_, _, _, _]], options: OpenAPIDocsOptions): OpenAPI = {
val es2 = es.map(nameAllPathCapturesInEndpoint)
val objectSchemas = ObjectSchemasForEndpoints(es2)
val pathCreator = new EndpointToOpenApiPaths(objectSchemas, options)
val componentsCreator = new EndpointToOpenApiComponents(objectSchemas)
val securitySchemes = SecuritySchemesForEndpoints(es2)
val pathCreator = new EndpointToOpenApiPaths(objectSchemas, securitySchemes, options)
val componentsCreator = new EndpointToOpenApiComponents(objectSchemas, securitySchemes)

val base = apiToOpenApi(api, componentsCreator)

Expand All @@ -24,7 +25,8 @@ object EndpointToOpenAPIDocs {
info = info,
servers = List.empty,
paths = Map.empty,
components = componentsCreator.components
components = componentsCreator.components,
security = List.empty
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package tapir.docs.openapi
import tapir.docs.openapi.schema.ObjectSchemas
import tapir.openapi.Components

private[openapi] class EndpointToOpenApiComponents(objectSchemas: ObjectSchemas) {
private[openapi] class EndpointToOpenApiComponents(objectSchemas: ObjectSchemas, securitySchemes: SecuritySchemes) {
def components: Option[Components] = {
val keyToSchema = objectSchemas.keyToOSchema
if (keyToSchema.nonEmpty) Some(Components(keyToSchema))
if (keyToSchema.nonEmpty || securitySchemes.nonEmpty) Some(Components(keyToSchema, securitySchemes.values.toMap.mapValues(Right(_))))
else None
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ import tapir.openapi.OpenAPI.ReferenceOr
import tapir.openapi.{MediaType => OMediaType, _}
import tapir.{EndpointInput, MediaType => SMediaType, Schema => SSchema, _}

private[openapi] class EndpointToOpenApiPaths(objectSchemas: ObjectSchemas, options: OpenAPIDocsOptions) {
private[openapi] class EndpointToOpenApiPaths(objectSchemas: ObjectSchemas, securitySchemes: SecuritySchemes, options: OpenAPIDocsOptions) {

def pathItem(e: Endpoint[_, _, _, _]): (String, PathItem) = {
import model.Method._

val inputs = e.input.asVectorOfBasic
val inputs = e.input.asVectorOfBasic(includeAuth = false)
val pathComponents = namedPathComponents(inputs)
val method = e.input.method.getOrElse(Method.GET)

Expand All @@ -23,7 +23,7 @@ private[openapi] class EndpointToOpenApiPaths(objectSchemas: ObjectSchemas, opti
case Right(p) => p
}

val operation = Some(endpointToOperation(defaultId, e))
val operation = Some(endpointToOperation(defaultId, e, inputs))
val pathItem = PathItem(
None,
None,
Expand All @@ -42,11 +42,15 @@ private[openapi] class EndpointToOpenApiPaths(objectSchemas: ObjectSchemas, opti
("/" + pathComponentForPath.mkString("/"), pathItem)
}

private def endpointToOperation(defaultId: String, e: Endpoint[_, _, _, _]): Operation = {
val inputs = e.input.asVectorOfBasic
private def endpointToOperation(defaultId: String, e: Endpoint[_, _, _, _], inputs: Vector[EndpointInput.Basic[_]]): Operation = {
val parameters = operationParameters(inputs)
val body: Vector[ReferenceOr[RequestBody]] = operationInputBody(inputs)
val responses: Map[ResponsesKey, ReferenceOr[Response]] = operationResponse(e)

val authNames = e.input.auths.flatMap(auth => securitySchemes.get(auth).map(_._1))
// for now, all auths have empty scope
val securityRequirement = authNames.map(_ -> Vector.empty).toMap

Operation(
e.info.tags.toList,
e.info.summary,
Expand All @@ -56,6 +60,7 @@ private[openapi] class EndpointToOpenApiPaths(objectSchemas: ObjectSchemas, opti
body.headOption,
responses,
None,
if (securityRequirement.isEmpty) List.empty else List(securityRequirement),
List.empty
)
}
Expand Down Expand Up @@ -104,7 +109,7 @@ private[openapi] class EndpointToOpenApiPaths(objectSchemas: ObjectSchemas, opti
}

private def outputToResponse(io: EndpointIO[_]): Option[Response] = {
val ios = io.asVectorOfBasic
val ios = io.asVectorOfBasic()

val headers = ios.collect {
case EndpointIO.Header(name, codec, info) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package tapir.docs.openapi
import tapir.openapi.SecurityScheme
import tapir.{Endpoint, EndpointIO, EndpointInput}

import scala.annotation.tailrec

private[openapi] object SecuritySchemesForEndpoints {

def apply(es: Iterable[Endpoint[_, _, _, _]]): SecuritySchemes = {
val auths = es.flatMap(e => e.input.auths)
val authSecuritySchemes = auths.map(a => (a, authToSecurityScheme(a)))
val securitySchemes = authSecuritySchemes.map(_._2).toSet
val namedSecuritySchemes = nameSecuritySchemes(securitySchemes.toVector, Set(), Map())

authSecuritySchemes.map { case (a, s) => a -> ((namedSecuritySchemes(s), s)) }.toMap
}

@tailrec
private def nameSecuritySchemes(schemes: Vector[SecurityScheme],
takenNames: Set[SchemeName],
acc: Map[SecurityScheme, SchemeName]): Map[SecurityScheme, SchemeName] = {
schemes match {
case Vector() => acc
case scheme +: tail =>
val baseName = scheme.`type` + "Auth"
val name = uniqueName(baseName, !takenNames.contains(_))
nameSecuritySchemes(tail, takenNames + name, acc + (scheme -> name))
}
}

private def authToSecurityScheme(a: EndpointInput.Auth[_]): SecurityScheme = a match {
case EndpointInput.Auth.ApiKey(input) =>
val (name, in) = apiKeyInputNameAndIn(input.asVectorOfBasic())
SecurityScheme("apiKey", None, Some(name), Some(in), None, None, None, None)
case EndpointInput.Auth.Http(scheme, _) =>
SecurityScheme("http", None, None, None, Some(scheme.toLowerCase()), None, None, None)
}

private def apiKeyInputNameAndIn(input: Vector[EndpointInput.Basic[_]]) = input match {
case Vector(EndpointIO.Header(name, _, _)) => (name, "header")
case Vector(EndpointInput.Query(name, _, _)) => (name, "query")
// TODO cookie
case _ => throw new IllegalArgumentException(s"Api key authentication can only be read from headers, queries or cookies, not: $input")
}
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
package tapir.docs

package object openapi extends OpenAPIDocs
import tapir.EndpointInput
import tapir.openapi.SecurityScheme

package object openapi extends OpenAPIDocs {
private[openapi] type SchemeName = String
private[openapi] type SecuritySchemes = Map[EndpointInput.Auth[_], (SchemeName, SecurityScheme)]

private[openapi] def uniqueName(base: String, isUnique: String => Boolean): String = {
var i = 0
var result = base
while (!isUnique(result)) {
i += 1
result = base + i
}
result
}
}

0 comments on commit 7c7712f

Please sign in to comment.