From f84a96a19d2c7fb84e46fcc46deb1011cdda6262 Mon Sep 17 00:00:00 2001 From: Nathan Hamblen Date: Sun, 19 Oct 2014 14:48:21 -0400 Subject: [PATCH 1/9] Deprecate `QParams` validation so that we can later remove it. This is used in the oauth and mac modules. It might be better to port those over to Directives in the same PR, as they generate a lot of deprecation warnings. Let's discuss before merging. --- library/src/main/scala/request/params.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/library/src/main/scala/request/params.scala b/library/src/main/scala/request/params.scala index 3aa6779ee..2c58a8f2c 100644 --- a/library/src/main/scala/request/params.scala +++ b/library/src/main/scala/request/params.scala @@ -88,6 +88,7 @@ object QueryParams { /** Fined-grained error reporting for arbitrarily many failing parameters. * Import QParams._ to use; see ParamsSpec for examples. */ +@deprecated("This validation scheme is deprecated, use Directives instead", since="0.8.3") object QParams { type Log[E] = List[Fail[E]] type QueryFn[E,A] = (Params.Map, Option[String], Log[E]) => From c4323e36ee76f40a86bb54d26e6485173174b09d Mon Sep 17 00:00:00 2001 From: Nathan Hamblen Date: Sun, 19 Oct 2014 15:42:03 -0400 Subject: [PATCH 2/9] Remove use of deprecated `QParams` in `mac`. Since the error messages generated were just being thrown away, we can instead just `flatMap` through options. --- mac/src/main/scala/mac.scala | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/mac/src/main/scala/mac.scala b/mac/src/main/scala/mac.scala index 1e20a10e6..53409e1eb 100644 --- a/mac/src/main/scala/mac.scala +++ b/mac/src/main/scala/mac.scala @@ -23,7 +23,6 @@ object MacAuthorization { val MacKey = "mac" object MacHeader { - import QParams._ val NonceFormat = """^(\d+)[:](\S+)$""".r val KeyVal = """(\w+)="([\w|=|:|\/|.|%|-|+]+)" """.trim.r val keys = Id :: Nonce :: BodyHash :: Ext :: MacKey :: Nil @@ -31,23 +30,19 @@ object MacAuthorization { def unapply(hvals: List[String]) = hvals match { case x :: xs if x startsWith headerSpace => - val map = Map(hvals map { _.replace(headerSpace, "") } flatMap { - case KeyVal(k, v) if(keys.contains(k)) => Seq((k -> Seq(v))) + val headers = Map(hvals map { _.replace(headerSpace, "") } flatMap { + case KeyVal(k, v) if(keys.contains(k)) => Seq((k -> v)) case e => Nil }: _*) - val expect = for { - id <- lookup(Id) is nonempty("id is empty") is required("id is required") - nonce <- lookup(Nonce) is nonempty("nonce is empty") is required("nonce is required") is - pred({NonceFormat.findFirstIn(_).isDefined}, _ + " is an invalid format") - bodyhash <- lookup(BodyHash) is optional[String, String] - ext <- lookup(Ext) is optional[String, String] - mac <- lookup(MacKey) is nonempty("mac is nempty") is required("mac is required") + for { + id <- headers.get(Id) if !id.isEmpty + nonce <- headers.get(Nonce) if NonceFormat.findFirstIn(nonce).isDefined + bodyhash <- Some(headers.get(BodyHash)) + ext <- Some(headers.get(Ext)) + mac <- headers.get(MacKey) if !mac.isEmpty } yield { - Some(id.get, nonce.get, bodyhash.get, ext.get, mac.get) - } - expect(map) orFail { f => - None + (id, nonce, bodyhash, ext, mac) } case _ => None } From bde94366106a4fc51ffe75bf87183d08d12b7f04 Mon Sep 17 00:00:00 2001 From: Nathan Hamblen Date: Sun, 19 Oct 2014 22:55:20 -0400 Subject: [PATCH 3/9] Adapt `oauth` module from QParams to Directives. --- oauth/src/main/scala/oauth.scala | 140 ++++++++++++++----------------- project/build.scala | 2 +- 2 files changed, 65 insertions(+), 77 deletions(-) diff --git a/oauth/src/main/scala/oauth.scala b/oauth/src/main/scala/oauth.scala index 1a18f0fbb..6be8a690c 100644 --- a/oauth/src/main/scala/oauth.scala +++ b/oauth/src/main/scala/oauth.scala @@ -3,6 +3,7 @@ package unfiltered.oauth import unfiltered.request._ import unfiltered.response._ import unfiltered.request.{HttpRequest => Req} +import unfiltered.directives._, Directives._ object OAuth { val ConsumerKey = "oauth_consumer_key" @@ -54,6 +55,17 @@ trait DefaultOAuthPaths extends OAuthPaths { trait Messages { def blankMsg(param: String): String def requiredMsg(param: String): String + + def required[T] = data.Requiring[T].fail(name => + BadParam(requiredMsg(name)) + ) + val nonEmptyString = data.as.String.nonEmpty.fail { (k, v) => + BadParam(blankMsg(k)) + } + + case class BadParam(msg: String) extends ResponseJoiner(msg)( messages => + BadRequest ~> ResponseString(messages.mkString(". ")) + ) } trait DefaultMessages extends Messages { @@ -64,32 +76,30 @@ trait DefaultMessages extends Messages { trait Protected extends OAuthProvider with unfiltered.filter.Plan { self: OAuthStores with Messages => import unfiltered.filter.request.ContextPath - import QParams._ import OAuth._ - def intent = { + def intent = Directive.Intent { case Params(params) & request => val headers = Authorization(request) match { case Some(a) => OAuth.Header(a.split(",")) case _ => Map.empty[String, Seq[String]] } - val expected = for { - oauth_consumer_key <- lookup(ConsumerKey) is - nonempty(blankMsg(ConsumerKey)) is required(requiredMsg(ConsumerKey)) - oauth_signature_method <- lookup(SignatureMethod) is - nonempty(blankMsg(SignatureMethod)) is required(requiredMsg(SignatureMethod)) - timestamp <- lookup(Timestamp) is - nonempty(blankMsg(Timestamp)) is required(requiredMsg(Timestamp)) - nonce <- lookup(Nonce) is - nonempty(blankMsg(Nonce)) is required(requiredMsg(Nonce)) - token <- lookup(TokenKey) is - nonempty(blankMsg(TokenKey)) is required(requiredMsg(TokenKey)) - signature <- lookup(Sig) is - nonempty(blankMsg(Sig)) is required(requiredMsg(Sig)) - version <- lookup(Version) is - pred ( _ == "1.0", "invalid oauth version " + _ ) is - optional[String,String] - realm <- lookup("realm") is optional[String, String] + val inputs = params ++ headers + def input(name: String) = inputs.get(name).flatMap(_.headOption) + + def lookupRequired(name: String) = + nonEmptyString ~> required named (name, input(name)) + + for { + oauth_consumer_key <- lookupRequired(ConsumerKey) + oauth_signature_method <- lookupRequired(SignatureMethod) + timestamp <- lookupRequired(Timestamp) + nonce <- lookupRequired(Nonce) + token <- lookupRequired(TokenKey) + signature <- lookupRequired(Sig) + version <- (data.as.String named(Version, input(Version).toSeq) orElse + failure(BadParam("invalid oauth version"))) if version.forall(_ == "1.0") + realm <- nonEmptyString named "realm" } yield { protect(request.method, request.underlying.getRequestURL.toString, params ++ headers) match { case Failure(_, _) => @@ -102,42 +112,35 @@ trait Protected extends OAuthProvider with unfiltered.filter.Plan { Pass } } - - expected(params ++ headers) orFail { errors => - BadRequest ~> ResponseString(errors.map { _.error } mkString(". ")) - } - } } trait OAuthed extends OAuthProvider with unfiltered.filter.Plan { self: OAuthStores with Messages with OAuthPaths => import unfiltered.filter.request.ContextPath - import QParams._ import OAuth._ - def intent = { + def intent = Directive.Intent { case POST(ContextPath(_, RequestTokenPath) & Params(params)) & request => val headers = Authorization(request) match { case Some(a) => OAuth.Header(a.split(",")) case _ => Map.empty[String, Seq[String]] } - val expected = for { - consumer_key <- lookup(ConsumerKey) is - nonempty(blankMsg(ConsumerKey)) is required(requiredMsg(ConsumerKey)) - oauth_signature_method <- lookup(SignatureMethod) is - nonempty(blankMsg(SignatureMethod)) is required(requiredMsg(SignatureMethod)) - timestamp <- lookup(Timestamp) is - nonempty(blankMsg(Timestamp)) is required(requiredMsg(Timestamp)) - nonce <- lookup(Nonce) is - nonempty(blankMsg(Nonce)) is required(requiredMsg(Nonce)) - callback <- lookup(Callback) is - nonempty(blankMsg(Callback)) is required(requiredMsg(Callback)) - signature <- lookup(Sig) is - nonempty(blankMsg(Sig)) is required(requiredMsg(Sig)) - version <- lookup(Version) is - pred ( _ == "1.0", "invalid oauth version " + _ ) is - optional[String,String] + val inputs = params ++ headers + def input(name: String) = inputs.get(name).flatMap(_.headOption) + + def lookupRequired(name: String) = + nonEmptyString ~> required named (name, input(name)) + + for { + consumer_key <- lookupRequired(ConsumerKey) + oauth_signature_method <- lookupRequired(SignatureMethod) + timestamp <- lookupRequired(Timestamp) + nonce <- lookupRequired(Nonce) + callback <- lookupRequired(Callback) + signature <- lookupRequired(Sig) + version <- (data.as.String named(Version, input(Version).toSeq) orElse + failure(BadParam("invalid oauth version"))) if version.forall(_ == "1.0") } yield { // TODO how to extract the full url and not rely on underlying requestToken(request.method, request.underlying.getRequestURL.toString, params ++ headers) match { @@ -146,16 +149,11 @@ trait OAuthed extends OAuthProvider with unfiltered.filter.Plan { } } - expected(params ++ headers) orFail { errors => - BadRequest ~> ResponseString(errors.map { _.error } mkString(". ")) - } - case ContextPath(_, AuthorizationPath) & Params(params) & request => - val expected = for { - token <- lookup(TokenKey) is - nonempty(blankMsg(TokenKey)) is required(requiredMsg(TokenKey)) + for { + token <- nonEmptyString ~> required named TokenKey } yield { - authorize(token.get, request) match { + authorize(token, request) match { case Failure(code, msg) => fail(code, msg) case HostResponse(resp) => Ok ~> (resp.asInstanceOf[ResponseFunction[Any]]) case AuthorizeResponse(callback, token, verifier) => callback match { @@ -166,33 +164,27 @@ trait OAuthed extends OAuthProvider with unfiltered.filter.Plan { } } - expected(params) orFail { errors => - BadRequest ~> ResponseString(errors.map { _.error } mkString(". ")) - } - case request @ POST(ContextPath(_, AccessTokenPath) & Params(params)) => val headers = Authorization(request) match { case Some(a) => OAuth.Header(a.split(",")) case _ => Map.empty[String, Seq[String]] } - val expected = for { - consumer_key <- lookup(ConsumerKey) is - nonempty(blankMsg(ConsumerKey)) is required(requiredMsg(ConsumerKey)) - oauth_signature_method <- lookup(SignatureMethod) is - nonempty(blankMsg(SignatureMethod)) is required(requiredMsg(SignatureMethod)) - timestamp <- lookup(Timestamp) is - nonempty(blankMsg(Timestamp)) is required(requiredMsg(Timestamp)) - nonce <- lookup(Nonce) is - nonempty(blankMsg(Nonce)) is required(requiredMsg(Nonce)) - token <- lookup(TokenKey) is - nonempty(blankMsg(TokenKey)) is required(requiredMsg(TokenKey)) - verifier <- lookup(Verifier) is - nonempty(blankMsg(Verifier)) is required(requiredMsg(Verifier)) - signature <- lookup(Sig) is - nonempty(blankMsg(Sig)) is required(requiredMsg(Sig)) - version <- lookup(Version) is - pred ( _ == "1.0", "invalid oauth version " + _ ) is - optional[String,String] + val inputs = params ++ headers + def input(name: String) = inputs.get(name).flatMap(_.headOption) + + def lookupRequired(name: String) = + nonEmptyString ~> required named (name, input(name)) + + for { + consumer_key <- lookupRequired(ConsumerKey) + oauth_signature_method <- lookupRequired(SignatureMethod) + timestamp <- lookupRequired(Timestamp) + nonce <- lookupRequired(Nonce) + token <- lookupRequired(TokenKey) + verifier <- lookupRequired(Verifier) + signature <- lookupRequired(Sig) + version <- (data.as.String named(Version, input(Version).toSeq) orElse + failure(BadParam("invalid oauth version"))) if version.forall(_ == "1.0") } yield { accessToken(request.method, request.underlying.getRequestURL.toString, params ++ headers) match { case Failure(code, msg) => fail(code, msg) @@ -200,10 +192,6 @@ trait OAuthed extends OAuthProvider with unfiltered.filter.Plan { resp ~> FormEncodedContent } } - - expected(params ++ headers) orFail { fails => - BadRequest ~> ResponseString(fails.map { _.error } mkString(". ")) - } } def fail(status: Int, msg: String) = diff --git a/project/build.scala b/project/build.scala index acd457d52..40e0c7db8 100644 --- a/project/build.scala +++ b/project/build.scala @@ -104,7 +104,7 @@ object Unfiltered extends Build { lazy val websockets = module("netty-websockets")().dependsOn(nettyServer) - lazy val oauth = module("oauth")().dependsOn(jetty, filters) + lazy val oauth = module("oauth")().dependsOn(jetty, filters, directives) lazy val mac = module("mac")().dependsOn(library) From 3a5cf59a950d896b602ccc18a50a229083a5773d Mon Sep 17 00:00:00 2001 From: Nathan Hamblen Date: Sun, 19 Oct 2014 23:48:27 -0400 Subject: [PATCH 4/9] Factor out the input (header + params) processing. --- oauth/src/main/scala/oauth.scala | 102 ++++++++++++++----------------- 1 file changed, 46 insertions(+), 56 deletions(-) diff --git a/oauth/src/main/scala/oauth.scala b/oauth/src/main/scala/oauth.scala index 6be8a690c..495f23dfa 100644 --- a/oauth/src/main/scala/oauth.scala +++ b/oauth/src/main/scala/oauth.scala @@ -66,6 +66,23 @@ trait Messages { case class BadParam(msg: String) extends ResponseJoiner(msg)( messages => BadRequest ~> ResponseString(messages.mkString(". ")) ) + + /** Combined header and parameter input */ + case class Inputs(request: HttpRequest[Any]) { + val headers = Authorization(request) match { + case Some(a) => OAuth.Header(a.split(",")) + case _ => Map.empty[String, Seq[String]] + } + val Params(params) = request + val inputs = headers ++ params + def named(name: String) = inputs.get(name).flatMap(_.headOption) + def requiredNamed(name: String) = + (nonEmptyString ~> required).named(name, named(name)) + + val version = + (data.as.String.named(OAuth.Version, named(OAuth.Version).toSeq) orElse + failure(BadParam("invalid oauth version"))) filter (_.forall(_ == "1.0")) + } } trait DefaultMessages extends Messages { @@ -79,29 +96,20 @@ trait Protected extends OAuthProvider with unfiltered.filter.Plan { import OAuth._ def intent = Directive.Intent { - case Params(params) & request => - val headers = Authorization(request) match { - case Some(a) => OAuth.Header(a.split(",")) - case _ => Map.empty[String, Seq[String]] - } - val inputs = params ++ headers - def input(name: String) = inputs.get(name).flatMap(_.headOption) - - def lookupRequired(name: String) = - nonEmptyString ~> required named (name, input(name)) + case request => + val in = Inputs(request) for { - oauth_consumer_key <- lookupRequired(ConsumerKey) - oauth_signature_method <- lookupRequired(SignatureMethod) - timestamp <- lookupRequired(Timestamp) - nonce <- lookupRequired(Nonce) - token <- lookupRequired(TokenKey) - signature <- lookupRequired(Sig) - version <- (data.as.String named(Version, input(Version).toSeq) orElse - failure(BadParam("invalid oauth version"))) if version.forall(_ == "1.0") + oauth_consumer_key <- in.requiredNamed(ConsumerKey) + oauth_signature_method <- in.requiredNamed(SignatureMethod) + timestamp <- in.requiredNamed(Timestamp) + nonce <- in.requiredNamed(Nonce) + token <- in.requiredNamed(TokenKey) + signature <- in.requiredNamed(Sig) + version <- in.version realm <- nonEmptyString named "realm" } yield { - protect(request.method, request.underlying.getRequestURL.toString, params ++ headers) match { + protect(request.method, request.underlying.getRequestURL.toString, in.inputs) match { case Failure(_, _) => Unauthorized ~> WWWAuthenticate("OAuth realm=\"%s\"" format(realm match { case Some(value) => value @@ -122,28 +130,19 @@ trait OAuthed extends OAuthProvider with unfiltered.filter.Plan { def intent = Directive.Intent { case POST(ContextPath(_, RequestTokenPath) & Params(params)) & request => - val headers = Authorization(request) match { - case Some(a) => OAuth.Header(a.split(",")) - case _ => Map.empty[String, Seq[String]] - } - val inputs = params ++ headers - def input(name: String) = inputs.get(name).flatMap(_.headOption) - - def lookupRequired(name: String) = - nonEmptyString ~> required named (name, input(name)) + val in = Inputs(request) for { - consumer_key <- lookupRequired(ConsumerKey) - oauth_signature_method <- lookupRequired(SignatureMethod) - timestamp <- lookupRequired(Timestamp) - nonce <- lookupRequired(Nonce) - callback <- lookupRequired(Callback) - signature <- lookupRequired(Sig) - version <- (data.as.String named(Version, input(Version).toSeq) orElse - failure(BadParam("invalid oauth version"))) if version.forall(_ == "1.0") + consumer_key <- in.requiredNamed(ConsumerKey) + oauth_signature_method <- in.requiredNamed(SignatureMethod) + timestamp <- in.requiredNamed(Timestamp) + nonce <- in.requiredNamed(Nonce) + callback <- in.requiredNamed(Callback) + signature <- in.requiredNamed(Sig) + version <- in.version } yield { // TODO how to extract the full url and not rely on underlying - requestToken(request.method, request.underlying.getRequestURL.toString, params ++ headers) match { + requestToken(request.method, request.underlying.getRequestURL.toString, in.inputs) match { case Failure(status, msg) => fail(status, msg) case resp: OAuthResponseWriter => resp ~> FormEncodedContent } @@ -165,28 +164,19 @@ trait OAuthed extends OAuthProvider with unfiltered.filter.Plan { } case request @ POST(ContextPath(_, AccessTokenPath) & Params(params)) => - val headers = Authorization(request) match { - case Some(a) => OAuth.Header(a.split(",")) - case _ => Map.empty[String, Seq[String]] - } - val inputs = params ++ headers - def input(name: String) = inputs.get(name).flatMap(_.headOption) - - def lookupRequired(name: String) = - nonEmptyString ~> required named (name, input(name)) + val in = Inputs(request) for { - consumer_key <- lookupRequired(ConsumerKey) - oauth_signature_method <- lookupRequired(SignatureMethod) - timestamp <- lookupRequired(Timestamp) - nonce <- lookupRequired(Nonce) - token <- lookupRequired(TokenKey) - verifier <- lookupRequired(Verifier) - signature <- lookupRequired(Sig) - version <- (data.as.String named(Version, input(Version).toSeq) orElse - failure(BadParam("invalid oauth version"))) if version.forall(_ == "1.0") + consumer_key <- in.requiredNamed(ConsumerKey) + oauth_signature_method <- in.requiredNamed(SignatureMethod) + timestamp <- in.requiredNamed(Timestamp) + nonce <- in.requiredNamed(Nonce) + token <- in.requiredNamed(TokenKey) + verifier <- in.requiredNamed(Verifier) + signature <- in.requiredNamed(Sig) + version <- in.version } yield { - accessToken(request.method, request.underlying.getRequestURL.toString, params ++ headers) match { + accessToken(request.method, request.underlying.getRequestURL.toString, in.inputs) match { case Failure(code, msg) => fail(code, msg) case resp@AccessResponse(_, _) => resp ~> FormEncodedContent From 4447b86944bb836085e79eed59504570d8cb6e79 Mon Sep 17 00:00:00 2001 From: Nathan Hamblen Date: Mon, 20 Oct 2014 00:05:30 -0400 Subject: [PATCH 5/9] Deprecate `ParamsSpec` so we won't be warned about its QParams references. Later we'll remove the references and leave the few plain Params tests that it has. --- library/src/test/scala/ParamsSpec.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/library/src/test/scala/ParamsSpec.scala b/library/src/test/scala/ParamsSpec.scala index 1acc5eb07..36c264f24 100644 --- a/library/src/test/scala/ParamsSpec.scala +++ b/library/src/test/scala/ParamsSpec.scala @@ -3,16 +3,19 @@ package unfiltered.request import org.specs2.mutable._ +@deprecated("Deprecated until we remove its references to QParams", since="0.8.3") object ParamsSpecJetty extends Specification with unfiltered.specs2.jetty.Planned with ParamsSpec +@deprecated("Deprecated until we remove its references to QParams", since="0.8.3") object ParamsSpecNetty extends Specification with unfiltered.specs2.netty.Planned with ParamsSpec +@deprecated("Deprecated until we remove its references to QParams", since="0.8.3") trait ParamsSpec extends Specification with unfiltered.specs2.Hosted { import unfiltered.response._ import unfiltered.request.{Path => UFPath} From 56fcfc598660538669c52418efeda673cefdd6eb Mon Sep 17 00:00:00 2001 From: Nathan Hamblen Date: Mon, 20 Oct 2014 22:22:59 -0400 Subject: [PATCH 6/9] Preserve FilterDirective type after filtering. --- directives/src/main/scala/directives/Directive.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/directives/src/main/scala/directives/Directive.scala b/directives/src/main/scala/directives/Directive.scala index 4b7a74db3..2be4d873f 100644 --- a/directives/src/main/scala/directives/Directive.scala +++ b/directives/src/main/scala/directives/Directive.scala @@ -104,8 +104,8 @@ class FilterDirective[-T, +R, +A]( run: HttpRequest[T] => Result[R, A], onEmpty: HttpRequest[T] => Result[R, A] ) extends Directive[T,R,A](run) { - def filter(filt: A => Boolean): Directive[T, R, A] = withFilter(filt) - def withFilter(filt: A => Boolean): Directive[T, R, A] = + def filter(filt: A => Boolean): FilterDirective[T, R, A] = withFilter(filt) + def withFilter(filt: A => Boolean): FilterDirective[T, R, A] = new FilterDirective({ req => run(req).flatMap { a => if (filt(a)) Result.Success(a) From 2ab33d2b0d6298934a08aabfe22ea1d17d55afae Mon Sep 17 00:00:00 2001 From: Nathan Hamblen Date: Mon, 20 Oct 2014 22:23:37 -0400 Subject: [PATCH 7/9] Preserve type of response (e.g., JoiningResponseFunction). --- directives/src/main/scala/directives/Directives.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/directives/src/main/scala/directives/Directives.scala b/directives/src/main/scala/directives/Directives.scala index 1112b73ff..fc2382ad9 100644 --- a/directives/src/main/scala/directives/Directives.scala +++ b/directives/src/main/scala/directives/Directives.scala @@ -14,9 +14,9 @@ trait Directives { * is required to satisfy an intent's interface but all requests are acceptable */ def success[A](value:A) = result[Nothing, A](Success(value)) - def failure[R](r:ResponseFunction[R]) = result[ResponseFunction[R], Nothing](Failure(r)) + def failure[R](r:R) = result[R, Nothing](Failure(r)) - def error[R](r:ResponseFunction[R]) = result[ResponseFunction[R], Nothing](Error(r)) + def error[R](r:R) = result[R, Nothing](Error(r)) object commit extends Directive[Any, Nothing, Unit](_ => Success(())){ override def flatMap[T, R, A](f:Unit => Directive[T, R, A]):Directive[T, R, A] = From 43614f023dfcd8c0ef8de838eaa0d9c85140add7 Mon Sep 17 00:00:00 2001 From: Nathan Hamblen Date: Mon, 20 Oct 2014 22:24:57 -0400 Subject: [PATCH 8/9] Use joined directives to report multiple errors. --- oauth/src/main/scala/oauth.scala | 72 ++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 23 deletions(-) diff --git a/oauth/src/main/scala/oauth.scala b/oauth/src/main/scala/oauth.scala index 495f23dfa..17e9b4c8c 100644 --- a/oauth/src/main/scala/oauth.scala +++ b/oauth/src/main/scala/oauth.scala @@ -100,14 +100,23 @@ trait Protected extends OAuthProvider with unfiltered.filter.Plan { val in = Inputs(request) for { - oauth_consumer_key <- in.requiredNamed(ConsumerKey) - oauth_signature_method <- in.requiredNamed(SignatureMethod) - timestamp <- in.requiredNamed(Timestamp) - nonce <- in.requiredNamed(Nonce) - token <- in.requiredNamed(TokenKey) - signature <- in.requiredNamed(Sig) - version <- in.version - realm <- nonEmptyString named "realm" + ( oauth_consumer_key & + oauth_signature_method & + timestamp & + nonce & + token & + signature & + version & + realm + ) <- + in.requiredNamed(ConsumerKey) & + in.requiredNamed(SignatureMethod) & + in.requiredNamed(Timestamp) & + in.requiredNamed(Nonce) & + in.requiredNamed(TokenKey) & + in.requiredNamed(Sig) & + in.version & + (nonEmptyString named "realm") } yield { protect(request.method, request.underlying.getRequestURL.toString, in.inputs) match { case Failure(_, _) => @@ -133,13 +142,21 @@ trait OAuthed extends OAuthProvider with unfiltered.filter.Plan { val in = Inputs(request) for { - consumer_key <- in.requiredNamed(ConsumerKey) - oauth_signature_method <- in.requiredNamed(SignatureMethod) - timestamp <- in.requiredNamed(Timestamp) - nonce <- in.requiredNamed(Nonce) - callback <- in.requiredNamed(Callback) - signature <- in.requiredNamed(Sig) - version <- in.version + ( consumer_key & + oauth_signature_method & + timestamp & + nonce & + callback & + signature & + version + ) <- + in.requiredNamed(ConsumerKey) & + in.requiredNamed(SignatureMethod) & + in.requiredNamed(Timestamp) & + in.requiredNamed(Nonce) & + in.requiredNamed(Callback) & + in.requiredNamed(Sig) & + in.version } yield { // TODO how to extract the full url and not rely on underlying requestToken(request.method, request.underlying.getRequestURL.toString, in.inputs) match { @@ -167,14 +184,23 @@ trait OAuthed extends OAuthProvider with unfiltered.filter.Plan { val in = Inputs(request) for { - consumer_key <- in.requiredNamed(ConsumerKey) - oauth_signature_method <- in.requiredNamed(SignatureMethod) - timestamp <- in.requiredNamed(Timestamp) - nonce <- in.requiredNamed(Nonce) - token <- in.requiredNamed(TokenKey) - verifier <- in.requiredNamed(Verifier) - signature <- in.requiredNamed(Sig) - version <- in.version + ( consumer_key & + oauth_signature_method & + timestamp & + nonce & + token & + verifier & + signature & + version + ) <- + in.requiredNamed(ConsumerKey) & + in.requiredNamed(SignatureMethod) & + in.requiredNamed(Timestamp) & + in.requiredNamed(Nonce) & + in.requiredNamed(TokenKey) & + in.requiredNamed(Verifier) & + in.requiredNamed(Sig) & + in.version } yield { accessToken(request.method, request.underlying.getRequestURL.toString, in.inputs) match { case Failure(code, msg) => fail(code, msg) From b4b9e8d83b33500ad58075e05364764b4ce22e75 Mon Sep 17 00:00:00 2001 From: Nathan Hamblen Date: Sun, 26 Oct 2014 21:43:53 -0400 Subject: [PATCH 9/9] Replace use of QParams with parameter directives. This is a blunt translation of QParams code to directives, it could probably benefit from a comprehensive rewrite at some point. My aim was to preserve the exiting behavior. The tests pass. --- oauth2/src/main/scala/authorizations.scala | 209 ++++++++++++--------- project/build.scala | 2 +- 2 files changed, 116 insertions(+), 95 deletions(-) diff --git a/oauth2/src/main/scala/authorizations.scala b/oauth2/src/main/scala/authorizations.scala index 717db742d..91110b2e1 100644 --- a/oauth2/src/main/scala/authorizations.scala +++ b/oauth2/src/main/scala/authorizations.scala @@ -4,6 +4,7 @@ import unfiltered.request._ import unfiltered.response._ import unfiltered.request.{ HttpRequest => Req } import unfiltered.filter.request.ContextPath // work on removing this dep +import unfiltered.directives._, Directives._ import scala.language.reflectiveCalls import scala.language.implicitConversions @@ -97,7 +98,6 @@ trait Authorized extends AuthorizationProvider with AuthorizationEndpoints with Formatting with ValidationMessages with Flows with unfiltered.filter.Plan { - import QParams._ import OAuthorization._ /** Syntactic sugar for appending query strings to paths */ @@ -139,6 +139,10 @@ trait Authorized extends AuthorizationProvider protected def spaceEncoder(scopes: Seq[String]) = scopes.mkString("+") + val spaceDecoded = data.Interpreter[Option[String],Option[Seq[String]]]( + _.map(spaceDecoder) + ) + def onAuthCode( req: HttpRequest[Any], responseType: Seq[String], clientId: String, redirectUri: String, scope: Seq[String], state: Option[String]) = @@ -257,32 +261,16 @@ trait Authorized extends AuthorizationProvider errorResponder(error, desc, euri, state) } - def intent = { - case req @ ContextPath(_, AuthorizePath) & Params(params) => - val expected = for { - responseType <- lookup(ResponseType) is required(requiredMsg(ResponseType)) is - watch(_.map(spaceDecoder), e => "") - clientId <- lookup(ClientId) is required(requiredMsg(ClientId)) - redirectURI <- lookup(RedirectURI) is required(requiredMsg(RedirectURI)) - scope <- lookup(Scope) is watch(_.map(spaceDecoder), e => "") - state <- lookup(State) is optional[String, String] - } yield { - (redirectURI.get, responseType.get) match { - case (ruri, rtx) if(rtx contains(Code)) => - onAuthCode(req, rtx, clientId.get, ruri, scope.getOrElse(Nil), state.get) - case (ruri, rtx) if(rtx contains(TokenKey)) => - onToken(req, rtx, clientId.get, ruri, scope.getOrElse(Nil), state.get) - case (ruri, unsupported) => - onUnsupportedAuth(req, unsupported, clientId.get, ruri, scope.getOrElse(Nil), state.get) - } - } - - expected(params) orFail { errs => + def intent = Directive.Intent { + case req @ ContextPath(_, AuthorizePath) => + case class BadParam(req: HttpRequest[Any], msg: String) + extends ResponseJoiner(msg)({ errs => + val Params(params) = req params(RedirectURI) match { case Seq(uri) => val qs = qstr(Map( Error -> InvalidRequest, - ErrorDescription -> errs.map { _.error }.mkString(", ") + ErrorDescription -> errs.mkString(", ") )) params(ResponseType) match { case Seq(TokenKey) => @@ -293,75 +281,38 @@ trait Authorized extends AuthorizationProvider } case _ => auth.mismatchedRedirectUri(req) } - } - - case req @ POST(ContextPath(_, TokenPath)) & Params(params) => - val expected = for { - grantType <- lookup(GrantType) is required(requiredMsg(GrantType)) - code <- lookup(Code) is optional[String, String] - clientId <- lookup(ClientId) is required(requiredMsg(ClientId)) - redirectURI <- lookup(RedirectURI) is optional[String, String] - // clientSecret is not recommended to be passed as a parameter by instead - // encoded in a basic auth header http://tools.ietf.org/html/draft-ietf-oauth-v2-16#section-3.1 - clientSecret <- lookup(ClientSecret) is required(requiredMsg(ClientSecret)) - refreshToken <- lookup(RefreshToken) is optional[String, String] - scope <- lookup(Scope) is watch(_.map(spaceDecoder), e => "") - userName <- lookup(Username) is optional[String, String] - password <- lookup(Password) is optional[String, String] + }) + + def required[T](req: HttpRequest[Any]) = data.Requiring[T].fail(name => + BadParam(req, requiredMsg(name)) + ) + + for { + responseType <- spaceDecoded ~> required(req) named ResponseType + clientId <- required(req) named ClientId + redirectURI <- required(req) named RedirectURI + scope <- spaceDecoded named Scope + state <- data.as.Option[String] named State } yield { - - grantType.get match { - - case ClientCredentials => - onClientCredentials(clientId.get, clientSecret.get, scope.getOrElse(Nil)) - - case Password => - (userName.get, password.get) match { - case (Some(u), Some(pw)) => - onPassword(u, pw, clientId.get, clientSecret.get, scope.getOrElse(Nil)) - case _ => - errorResponder( - InvalidRequest, - (requiredMsg(Username) :: requiredMsg(Password) :: Nil).mkString(" and "), - auth.errUri(InvalidRequest), None - ) - } - - case RefreshToken => - refreshToken.get match { - case Some(rtoken) => - onRefresh(rtoken, clientId.get, clientSecret.get, scope.getOrElse(Nil)) - case _ => errorResponder(InvalidRequest, requiredMsg(RefreshToken), None, None) - } - - case AuthorizationCode => - (code.get, redirectURI.get) match { - case (Some(c), Some(r)) => - onGrantAuthCode(c, r, clientId.get, clientSecret.get) - case _ => - errorResponder( - InvalidRequest, - (requiredMsg(Code) :: requiredMsg(RedirectURI) :: Nil).mkString(" and "), - auth.errUri(InvalidRequest), None - ) - } - case unsupported => - // note the oauth2 spec does allow for extension grant types, - // this implementation currently does not - errorResponder( - UnsupportedGrantType, "%s is unsupported" format unsupported, - auth.errUri(UnsupportedGrantType), None) + (redirectURI, responseType) match { + case (ruri, rtx) if(rtx contains(Code)) => + onAuthCode(req, rtx, clientId, ruri, scope.getOrElse(Nil), state) + case (ruri, rtx) if(rtx contains(TokenKey)) => + onToken(req, rtx, clientId, ruri, scope.getOrElse(Nil), state) + case (ruri, unsupported) => + onUnsupportedAuth(req, unsupported, clientId, ruri, scope.getOrElse(Nil), state) } } - // here, we are combining requests parameters with basic authentication headers - // the preferred way of providing client credentials is through - // basic auth but this is not required. The following folds basic auth data - // into the params ensuring there is no conflict in transports - val combinedParams = ( - (Right(params): Either[String, Map[String, Seq[String]]]) /: BasicAuth(req) - )((a,e) => e match { - case (clientId, clientSecret) => + case req @ POST(ContextPath(_, TokenPath)) & Params(params) => + // here, we are combining requests parameters with basic authentication headers + // the preferred way of providing client credentials is through + // basic auth but this is not required. The following folds basic auth data + // into the params ensuring there is no conflict in transports + val combinedParams = ( + (Right(params): Either[String, Map[String, Seq[String]]]) /: BasicAuth(req) + )((a,e) => e match { + case (clientId, clientSecret) => val preferred = Right( a.right.get ++ Map(ClientId -> Seq(clientId), ClientSecret-> Seq(clientSecret)) ) @@ -370,15 +321,85 @@ trait Authorized extends AuthorizationProvider if(id == clientId) preferred else Left("client ids did not match") case _ => preferred } - case _ => a - }) + case _ => a + }) combinedParams fold({ err => - errorResponder(InvalidRequest, err, None, None) + failure(errorResponder(InvalidRequest, err, None, None)) }, { mixed => - expected(mixed) orFail { errs => - errorResponder(InvalidRequest, errs.map { _.error }.mkString(", "), None, None) - } + case class BadParam(req: HttpRequest[Any], msg: String) + extends ResponseJoiner(msg)({ errs => + errorResponder(InvalidRequest, errs.mkString(", "), None, None) + }) + + def required[T](req: HttpRequest[Any]) = data.Requiring[T].fail(name => + BadParam(req, requiredMsg(name)) + ) + def named(name: String) = mixed.get(name).flatMap(_.headOption) + + def requiredNamed(name: String) = + data.as.String ~> required(req) named (name, named(name).toSeq) + + def optionNamed(name: String) = + data.as.String.nonEmpty.named(name, named(name)) + + for { + grantType <- requiredNamed(GrantType) + code <- optionNamed(Code) + clientId <- requiredNamed(ClientId) + redirectURI <- optionNamed(RedirectURI) + // clientSecret is not recommended to be passed as a parameter by instead + // encoded in a basic auth header http://tools.ietf.org/html/draft-ietf-oauth-v2-16#section-3.1 + clientSecret <- requiredNamed(ClientSecret) + refreshToken <- optionNamed(RefreshToken) + scope <- spaceDecoded.named(Scope, named(Scope)) + userName <- optionNamed(Username) + password <- optionNamed(Password) + } yield { + + grantType match { + + case ClientCredentials => + onClientCredentials(clientId, clientSecret, scope.getOrElse(Nil)) + + case Password => + (userName, password) match { + case (Some(u), Some(pw)) => + onPassword(u, pw, clientId, clientSecret, scope.getOrElse(Nil)) + case _ => + errorResponder( + InvalidRequest, + (requiredMsg(Username) :: requiredMsg(Password) :: Nil).mkString(" and "), + auth.errUri(InvalidRequest), None + ) + } + + case RefreshToken => + refreshToken match { + case Some(rtoken) => + onRefresh(rtoken, clientId, clientSecret, scope.getOrElse(Nil)) + case _ => errorResponder(InvalidRequest, requiredMsg(RefreshToken), None, None) + } + + case AuthorizationCode => + (code, redirectURI) match { + case (Some(c), Some(r)) => + onGrantAuthCode(c, r, clientId, clientSecret) + case _ => + errorResponder( + InvalidRequest, + (requiredMsg(Code) :: requiredMsg(RedirectURI) :: Nil).mkString(" and "), + auth.errUri(InvalidRequest), None + ) + } + case unsupported => + // note the oauth2 spec does allow for extension grant types, + // this implementation currently does not + errorResponder( + UnsupportedGrantType, "%s is unsupported" format unsupported, + auth.errUri(UnsupportedGrantType), None) + } + } }) } } diff --git a/project/build.scala b/project/build.scala index 2a0901a28..cc27ac51c 100644 --- a/project/build.scala +++ b/project/build.scala @@ -110,7 +110,7 @@ object Unfiltered extends Build { lazy val mac = module("mac")().dependsOn(library) - lazy val oauth2 = module("oauth2")().dependsOn(jetty, filters, mac) + lazy val oauth2 = module("oauth2")().dependsOn(jetty, filters, mac, directives) lazy val nettyUploads = module("netty-uploads")().dependsOn(nettyServer, uploads) }