Permalink
Browse files

Merge lucastorri's before/after filters with matchers. [closes GH-30]

  • Loading branch information...
2 parents 61e294d + 5a003c8 commit 2485aef20ea1fed2d39b3ec5388868cc09f9f8a2 @rossabaker rossabaker committed May 13, 2011
View
@@ -20,6 +20,7 @@
- [Paul Lambert](http://paulitex.com/)
- [Ted Nyman](http://github.com/tnm)
- [Erik Rozendaal](http://github.com/erikrozendaal)
+- [Lucas Torri](http://github.com/lucastorri)
- [Ivan Willig](http://github.com/iwillig)
- [Phil Wills](http://github.com/philwills)
@@ -31,7 +31,7 @@ trait CsrfTokenSupport { self: ScalatraKernel =>
protected def csrfKey = CsrfTokenSupport.DefaultKey
protected def csrfToken = session(csrfKey).asInstanceOf[String]
- before {
+ beforeAll {
if (isForged) {
handleForgery()
}
@@ -54,6 +54,8 @@ trait ScalatraKernel extends Handler with Initializable
HttpMethod.methods foreach { x: HttpMethod => map += ((x, List[Route]())) }
map
}
+ protected val beforeFilters = ListBuffer[Route]()
+ protected val afterFilters = ListBuffer[Route]()
def contentType = response.getContentType
def contentType_=(value: String): Unit = response.setContentType(value)
@@ -130,14 +132,14 @@ trait ScalatraKernel extends Handler with Initializable
_response.withValue(response) {
_multiParams.withValue(Map() ++ realMultiParams) {
val result = try {
- beforeFilters foreach { _() }
+ beforeFilters.toStream.foreach { _(requestPath) }
routes(effectiveMethod).toStream.flatMap { _(requestPath) }.headOption.getOrElse(doNotFound())
}
catch {
case e => handleError(e)
}
finally {
- afterFilters foreach { _() }
+ afterFilters.toStream.foreach { _(requestPath) }
}
renderResponse(result)
}
@@ -152,13 +154,25 @@ trait ScalatraKernel extends Handler with Initializable
}
def requestPath: String
+
+ def beforeAll(fun: => Any) = addBefore(List(string2RouteMatcher("/*")), fun)
+
+ def before(routeMatchers: RouteMatcher*)(fun: => Any) = addBefore(routeMatchers, fun)
+
+ protected def addBefore(routeMatchers: Iterable[RouteMatcher], fun: => Any): Unit = {
+ val route = new Route(routeMatchers, () => fun)
+ beforeFilters += route
+ }
- protected val beforeFilters = new ListBuffer[() => Any]
- def before(fun: => Any) = beforeFilters += { () => fun }
-
- protected val afterFilters = new ListBuffer[() => Any]
- def after(fun: => Any) = afterFilters += { () => fun }
-
+ def afterAll(fun: => Any) = addAfter(List(string2RouteMatcher("/*")), fun)
+
+ def after(routeMatchers: RouteMatcher*)(fun: => Any) = addAfter(routeMatchers, fun)
+
+ protected def addAfter(routeMatchers: Iterable[RouteMatcher], fun: => Any): Unit = {
+ val route = new Route(routeMatchers, () => fun)
+ afterFilters += route
+ }
+
protected var doNotFound: Action
def notFound(fun: => Any) = doNotFound = { () => fun }
@@ -0,0 +1,47 @@
+package org.scalatra
+
+import org.scalatest.matchers.ShouldMatchers
+import javax.servlet.http.HttpServletResponse
+import test.scalatest.ScalatraFunSuite
+
+class AfterTestServlet extends ScalatraServlet {
+
+ afterAll {
+ response.setStatus(204)
+ }
+
+ after("/some/path") {
+ response.setStatus(202)
+ }
+
+ after("/other/path") {
+ response.setStatus(206)
+ }
+
+ get("/some/path") { }
+
+ get("/other/path") { }
+
+ get("/third/path") { }
+
+}
+
+class AfterTest extends ScalatraFunSuite with ShouldMatchers {
+ addServlet(classOf[AfterTestServlet], "/*")
+
+ test("afterAll is applied to all paths") {
+ get("/third/path") {
+ status should equal(204)
+ }
+ }
+
+ test("after only applies to a given path") {
+ get("/some/path") {
+ status should equal(202)
+ }
+ get("/other/path") {
+ status should equal(206)
+ }
+ }
+
+}
@@ -0,0 +1,47 @@
+package org.scalatra
+
+import org.scalatest.matchers.ShouldMatchers
+import javax.servlet.http.HttpServletResponse
+import test.scalatest.ScalatraFunSuite
+
+class BeforeTestServlet extends ScalatraServlet {
+
+ beforeAll {
+ response.setStatus(204)
+ }
+
+ before("/some/path") {
+ response.setStatus(202)
+ }
+
+ before("/other/path") {
+ response.setStatus(206)
+ }
+
+ get("/some/path") { }
+
+ get("/other/path") { }
+
+ get("/third/path") { }
+
+}
+
+class BeforeTest extends ScalatraFunSuite with ShouldMatchers {
+ addServlet(classOf[BeforeTestServlet], "/*")
+
+ test("beforeAll is applied to all paths") {
+ get("/third/path") {
+ status should equal(204)
+ }
+ }
+
+ test("before only applies to a given path") {
+ get("/some/path") {
+ status should equal(202)
+ }
+ get("/other/path") {
+ status should equal(206)
+ }
+ }
+
+}
@@ -8,15 +8,15 @@ class FilterTestServlet extends ScalatraServlet {
var beforeCount = 0
var afterCount = 0
- before {
+ beforeAll {
beforeCount += 1
params.get("before") match {
case Some(x) => response.getWriter.write(x)
case None =>
}
}
- after {
+ afterAll {
afterCount += 1
params.get("after") match {
case Some(x) => response.getWriter.write(x)
@@ -45,7 +45,7 @@ class FilterTestServlet extends ScalatraServlet {
class FilterTestFilter extends ScalatraFilter {
var beforeCount = 0
- before {
+ beforeAll {
beforeCount += 1
response.setHeader("filterBeforeCount", beforeCount.toString)
}
@@ -57,23 +57,23 @@ class FilterTestFilter extends ScalatraFilter {
}
class MultipleFilterTestServlet extends ScalatraServlet {
- before {
+ beforeAll {
response.getWriter.print("one\n")
}
- before {
+ beforeAll {
response.getWriter.print("two\n")
}
get("/") {
response.getWriter.print("three\n")
}
- after {
+ afterAll {
response.getWriter.print("four\n")
}
- after {
+ afterAll {
response.getWriter.print("five\n")
}
}
@@ -4,11 +4,11 @@ import org.scalatest.matchers.ShouldMatchers
import test.scalatest.ScalatraFunSuite
class GetResponseStatusSupportTestServlet extends ScalatraServlet with GetResponseStatusSupport {
- before {
+ beforeAll {
session // Establish a session before we commit the response
}
- after {
+ afterAll {
session("status") = status.toString
}
@@ -22,13 +22,13 @@ class HaltTestServlet extends ScalatraServlet {
"this content must not be returned"
}
- before {
+ beforeAll {
if (params.isDefinedAt("haltBefore")) {
halt(503)
}
}
- after {
+ afterAll {
response.setHeader("After-Block-Ran", "true")
}
}
@@ -4,7 +4,7 @@ import org.scalatest.matchers.ShouldMatchers
import test.scalatest.ScalatraFunSuite
class ScalatraSuiteTestServlet extends ScalatraServlet {
- before {
+ beforeAll {
contentType = "text/html; charset=utf-8"
}
@@ -41,7 +41,7 @@ class TemplateExample extends ScalatraServlet with UrlSupport /*with FileUploadS
}
}
- before {
+ beforeAll {
contentType = "text/html"
}

0 comments on commit 2485aef

Please sign in to comment.