Skip to content

Commit

Permalink
Merge pull request #274 from unfiltered/deprecate_qparams
Browse files Browse the repository at this point in the history
Deprecate `QParams` validation so that we can later remove it.
  • Loading branch information
Nathan Hamblen committed Nov 12, 2014
2 parents e674351 + b4b9e8d commit 8ec6b89
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 206 deletions.
4 changes: 2 additions & 2 deletions directives/src/main/scala/directives/Directive.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ class FilterDirective[-T, +R, +A](
run: HttpRequest[T] => Result[R, A],
onEmpty: HttpRequest[T] => Result[R, A]
) extends Directive[T,R,A](run) {
def filter(filt: A => Boolean): Directive[T, R, A] = withFilter(filt)
def withFilter(filt: A => Boolean): Directive[T, R, A] =
def filter(filt: A => Boolean): FilterDirective[T, R, A] = withFilter(filt)
def withFilter(filt: A => Boolean): FilterDirective[T, R, A] =
new FilterDirective({ req =>
run(req).flatMap { a =>
if (filt(a)) Result.Success(a)
Expand Down
4 changes: 2 additions & 2 deletions directives/src/main/scala/directives/Directives.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ trait Directives {
* is required to satisfy an intent's interface but all requests are acceptable */
def success[A](value:A) = result[Nothing, A](Success(value))

def failure[R](r:ResponseFunction[R]) = result[ResponseFunction[R], Nothing](Failure(r))
def failure[R](r:R) = result[R, Nothing](Failure(r))

def error[R](r:ResponseFunction[R]) = result[ResponseFunction[R], Nothing](Error(r))
def error[R](r:R) = result[R, Nothing](Error(r))

object commit extends Directive[Any, Nothing, Unit](_ => Success(())){
override def flatMap[T, R, A](f:Unit => Directive[T, R, A]):Directive[T, R, A] =
Expand Down
1 change: 1 addition & 0 deletions library/src/main/scala/request/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ object QueryParams {

/** Fined-grained error reporting for arbitrarily many failing parameters.
* Import QParams._ to use; see ParamsSpec for examples. */
@deprecated("This validation scheme is deprecated, use Directives instead", since="0.8.3")
object QParams {
type Log[E] = List[Fail[E]]
type QueryFn[E,A] = (Params.Map, Option[String], Log[E]) =>
Expand Down
3 changes: 3 additions & 0 deletions library/src/test/scala/ParamsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@ package unfiltered.request

import org.specs2.mutable._

@deprecated("Deprecated until we remove its references to QParams", since="0.8.3")
object ParamsSpecJetty
extends Specification
with unfiltered.specs2.jetty.Planned
with ParamsSpec

@deprecated("Deprecated until we remove its references to QParams", since="0.8.3")
object ParamsSpecNetty
extends Specification
with unfiltered.specs2.netty.Planned
with ParamsSpec

@deprecated("Deprecated until we remove its references to QParams", since="0.8.3")
trait ParamsSpec extends Specification with unfiltered.specs2.Hosted {
import unfiltered.response._
import unfiltered.request.{Path => UFPath}
Expand Down
23 changes: 9 additions & 14 deletions mac/src/main/scala/mac.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,26 @@ object MacAuthorization {
val MacKey = "mac"

object MacHeader {
import QParams._
val NonceFormat = """^(\d+)[:](\S+)$""".r
val KeyVal = """(\w+)="([\w|=|:|\/|.|%|-|+]+)" """.trim.r
val keys = Id :: Nonce :: BodyHash :: Ext :: MacKey :: Nil
val headerSpace = "MAC" + " "

def unapply(hvals: List[String]) = hvals match {
case x :: xs if x startsWith headerSpace =>
val map = Map(hvals map { _.replace(headerSpace, "") } flatMap {
case KeyVal(k, v) if(keys.contains(k)) => Seq((k -> Seq(v)))
val headers = Map(hvals map { _.replace(headerSpace, "") } flatMap {
case KeyVal(k, v) if(keys.contains(k)) => Seq((k -> v))
case e =>
Nil
}: _*)
val expect = for {
id <- lookup(Id) is nonempty("id is empty") is required("id is required")
nonce <- lookup(Nonce) is nonempty("nonce is empty") is required("nonce is required") is
pred({NonceFormat.findFirstIn(_).isDefined}, _ + " is an invalid format")
bodyhash <- lookup(BodyHash) is optional[String, String]
ext <- lookup(Ext) is optional[String, String]
mac <- lookup(MacKey) is nonempty("mac is nempty") is required("mac is required")
for {
id <- headers.get(Id) if !id.isEmpty
nonce <- headers.get(Nonce) if NonceFormat.findFirstIn(nonce).isDefined
bodyhash <- Some(headers.get(BodyHash))
ext <- Some(headers.get(Ext))
mac <- headers.get(MacKey) if !mac.isEmpty
} yield {
Some(id.get, nonce.get, bodyhash.get, ext.get, mac.get)
}
expect(map) orFail { f =>
None
(id, nonce, bodyhash, ext, mac)
}
case _ => None
}
Expand Down
188 changes: 96 additions & 92 deletions oauth/src/main/scala/oauth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package unfiltered.oauth
import unfiltered.request._
import unfiltered.response._
import unfiltered.request.{HttpRequest => Req}
import unfiltered.directives._, Directives._

object OAuth {
val ConsumerKey = "oauth_consumer_key"
Expand Down Expand Up @@ -54,6 +55,34 @@ trait DefaultOAuthPaths extends OAuthPaths {
trait Messages {
def blankMsg(param: String): String
def requiredMsg(param: String): String

def required[T] = data.Requiring[T].fail(name =>
BadParam(requiredMsg(name))
)
val nonEmptyString = data.as.String.nonEmpty.fail { (k, v) =>
BadParam(blankMsg(k))
}

case class BadParam(msg: String) extends ResponseJoiner(msg)( messages =>
BadRequest ~> ResponseString(messages.mkString(". "))
)

/** Combined header and parameter input */
case class Inputs(request: HttpRequest[Any]) {
val headers = Authorization(request) match {
case Some(a) => OAuth.Header(a.split(","))
case _ => Map.empty[String, Seq[String]]
}
val Params(params) = request
val inputs = headers ++ params
def named(name: String) = inputs.get(name).flatMap(_.headOption)
def requiredNamed(name: String) =
(nonEmptyString ~> required).named(name, named(name))

val version =
(data.as.String.named(OAuth.Version, named(OAuth.Version).toSeq) orElse
failure(BadParam("invalid oauth version"))) filter (_.forall(_ == "1.0"))
}
}

trait DefaultMessages extends Messages {
Expand All @@ -64,34 +93,32 @@ trait DefaultMessages extends Messages {
trait Protected extends OAuthProvider with unfiltered.filter.Plan {
self: OAuthStores with Messages =>
import unfiltered.filter.request.ContextPath
import QParams._
import OAuth._

def intent = {
case Params(params) & request =>
val headers = Authorization(request) match {
case Some(a) => OAuth.Header(a.split(","))
case _ => Map.empty[String, Seq[String]]
}
val expected = for {
oauth_consumer_key <- lookup(ConsumerKey) is
nonempty(blankMsg(ConsumerKey)) is required(requiredMsg(ConsumerKey))
oauth_signature_method <- lookup(SignatureMethod) is
nonempty(blankMsg(SignatureMethod)) is required(requiredMsg(SignatureMethod))
timestamp <- lookup(Timestamp) is
nonempty(blankMsg(Timestamp)) is required(requiredMsg(Timestamp))
nonce <- lookup(Nonce) is
nonempty(blankMsg(Nonce)) is required(requiredMsg(Nonce))
token <- lookup(TokenKey) is
nonempty(blankMsg(TokenKey)) is required(requiredMsg(TokenKey))
signature <- lookup(Sig) is
nonempty(blankMsg(Sig)) is required(requiredMsg(Sig))
version <- lookup(Version) is
pred ( _ == "1.0", "invalid oauth version " + _ ) is
optional[String,String]
realm <- lookup("realm") is optional[String, String]
def intent = Directive.Intent {
case request =>
val in = Inputs(request)

for {
( oauth_consumer_key &
oauth_signature_method &
timestamp &
nonce &
token &
signature &
version &
realm
) <-
in.requiredNamed(ConsumerKey) &
in.requiredNamed(SignatureMethod) &
in.requiredNamed(Timestamp) &
in.requiredNamed(Nonce) &
in.requiredNamed(TokenKey) &
in.requiredNamed(Sig) &
in.version &
(nonEmptyString named "realm")
} yield {
protect(request.method, request.underlying.getRequestURL.toString, params ++ headers) match {
protect(request.method, request.underlying.getRequestURL.toString, in.inputs) match {
case Failure(_, _) =>
Unauthorized ~> WWWAuthenticate("OAuth realm=\"%s\"" format(realm match {
case Some(value) => value
Expand All @@ -102,60 +129,47 @@ trait Protected extends OAuthProvider with unfiltered.filter.Plan {
Pass
}
}

expected(params ++ headers) orFail { errors =>
BadRequest ~> ResponseString(errors.map { _.error } mkString(". "))
}

}
}

trait OAuthed extends OAuthProvider with unfiltered.filter.Plan {
self: OAuthStores with Messages with OAuthPaths =>
import unfiltered.filter.request.ContextPath
import QParams._
import OAuth._

def intent = {
def intent = Directive.Intent {
case POST(ContextPath(_, RequestTokenPath) & Params(params)) & request =>
val headers = Authorization(request) match {
case Some(a) => OAuth.Header(a.split(","))
case _ => Map.empty[String, Seq[String]]
}
val expected = for {
consumer_key <- lookup(ConsumerKey) is
nonempty(blankMsg(ConsumerKey)) is required(requiredMsg(ConsumerKey))
oauth_signature_method <- lookup(SignatureMethod) is
nonempty(blankMsg(SignatureMethod)) is required(requiredMsg(SignatureMethod))
timestamp <- lookup(Timestamp) is
nonempty(blankMsg(Timestamp)) is required(requiredMsg(Timestamp))
nonce <- lookup(Nonce) is
nonempty(blankMsg(Nonce)) is required(requiredMsg(Nonce))
callback <- lookup(Callback) is
nonempty(blankMsg(Callback)) is required(requiredMsg(Callback))
signature <- lookup(Sig) is
nonempty(blankMsg(Sig)) is required(requiredMsg(Sig))
version <- lookup(Version) is
pred ( _ == "1.0", "invalid oauth version " + _ ) is
optional[String,String]
val in = Inputs(request)

for {
( consumer_key &
oauth_signature_method &
timestamp &
nonce &
callback &
signature &
version
) <-
in.requiredNamed(ConsumerKey) &
in.requiredNamed(SignatureMethod) &
in.requiredNamed(Timestamp) &
in.requiredNamed(Nonce) &
in.requiredNamed(Callback) &
in.requiredNamed(Sig) &
in.version
} yield {
// TODO how to extract the full url and not rely on underlying
requestToken(request.method, request.underlying.getRequestURL.toString, params ++ headers) match {
requestToken(request.method, request.underlying.getRequestURL.toString, in.inputs) match {
case Failure(status, msg) => fail(status, msg)
case resp: OAuthResponseWriter => resp ~> FormEncodedContent
}
}

expected(params ++ headers) orFail { errors =>
BadRequest ~> ResponseString(errors.map { _.error } mkString(". "))
}

case ContextPath(_, AuthorizationPath) & Params(params) & request =>
val expected = for {
token <- lookup(TokenKey) is
nonempty(blankMsg(TokenKey)) is required(requiredMsg(TokenKey))
for {
token <- nonEmptyString ~> required named TokenKey
} yield {
authorize(token.get, request) match {
authorize(token, request) match {
case Failure(code, msg) => fail(code, msg)
case HostResponse(resp) => Ok ~> (resp.asInstanceOf[ResponseFunction[Any]])
case AuthorizeResponse(callback, token, verifier) => callback match {
Expand All @@ -166,44 +180,34 @@ trait OAuthed extends OAuthProvider with unfiltered.filter.Plan {
}
}

expected(params) orFail { errors =>
BadRequest ~> ResponseString(errors.map { _.error } mkString(". "))
}

case request @ POST(ContextPath(_, AccessTokenPath) & Params(params)) =>
val headers = Authorization(request) match {
case Some(a) => OAuth.Header(a.split(","))
case _ => Map.empty[String, Seq[String]]
}
val expected = for {
consumer_key <- lookup(ConsumerKey) is
nonempty(blankMsg(ConsumerKey)) is required(requiredMsg(ConsumerKey))
oauth_signature_method <- lookup(SignatureMethod) is
nonempty(blankMsg(SignatureMethod)) is required(requiredMsg(SignatureMethod))
timestamp <- lookup(Timestamp) is
nonempty(blankMsg(Timestamp)) is required(requiredMsg(Timestamp))
nonce <- lookup(Nonce) is
nonempty(blankMsg(Nonce)) is required(requiredMsg(Nonce))
token <- lookup(TokenKey) is
nonempty(blankMsg(TokenKey)) is required(requiredMsg(TokenKey))
verifier <- lookup(Verifier) is
nonempty(blankMsg(Verifier)) is required(requiredMsg(Verifier))
signature <- lookup(Sig) is
nonempty(blankMsg(Sig)) is required(requiredMsg(Sig))
version <- lookup(Version) is
pred ( _ == "1.0", "invalid oauth version " + _ ) is
optional[String,String]
val in = Inputs(request)

for {
( consumer_key &
oauth_signature_method &
timestamp &
nonce &
token &
verifier &
signature &
version
) <-
in.requiredNamed(ConsumerKey) &
in.requiredNamed(SignatureMethod) &
in.requiredNamed(Timestamp) &
in.requiredNamed(Nonce) &
in.requiredNamed(TokenKey) &
in.requiredNamed(Verifier) &
in.requiredNamed(Sig) &
in.version
} yield {
accessToken(request.method, request.underlying.getRequestURL.toString, params ++ headers) match {
accessToken(request.method, request.underlying.getRequestURL.toString, in.inputs) match {
case Failure(code, msg) => fail(code, msg)
case resp@AccessResponse(_, _) =>
resp ~> FormEncodedContent
}
}

expected(params ++ headers) orFail { fails =>
BadRequest ~> ResponseString(fails.map { _.error } mkString(". "))
}
}

def fail(status: Int, msg: String) =
Expand Down

0 comments on commit 8ec6b89

Please sign in to comment.