Permalink
Browse files

Adds the full scalatra dsl

  • Loading branch information...
1 parent 730e064 commit 385836557e14a7c42bf497f5477f4aa4599be68d @casualjim casualjim committed Sep 11, 2011
Showing with 298 additions and 13 deletions.
  1. +4 −2 build.sbt
  2. +294 −11 src/main/scala/Example.scala
View
@@ -7,9 +7,11 @@ version := "0.1.0-SNAPSHOT"
libraryDependencies ++= Seq(
"net.databinder" %% "unfiltered-filter" % "0.4.1",
"net.databinder" %% "unfiltered-jetty" % "0.4.1",
- "org.clapper" %% "avsl" % "0.3.1"
+ "org.clapper" %% "avsl" % "0.3.1",
+ "org.scalatra" %% "scalatra" % "2.1.0-SNAPSHOT"
)
resolvers ++= Seq(
- "java m2" at "http://download.java.net/maven/2"
+ "java m2" at "http://download.java.net/maven/2",
+ "sonatype oss snapshots" at "https://oss.sonatype.org/content/repositories/snapshots/"
)
@@ -1,10 +1,152 @@
package com.example
+import _root_.org.scalatra.{RouteMatcher, Route}
import unfiltered.request._
import unfiltered.response._
import org.clapper.avsl.Logger
import util._
import javax.servlet.http.{HttpServletRequest, HttpServletResponse}
+import matching.Regex
+import org.scalatra._
+
+import java.lang.{Integer => JInteger}
+import javax.servlet.ServletContext
+import javax.servlet.http.{HttpServletRequest, HttpServletResponse, HttpSession}
+import scala.annotation.tailrec
+import scala.collection.JavaConversions._
+import scala.collection.mutable.ConcurrentMap
+import java.util.concurrent.ConcurrentHashMap
+import ScalatraKernel.{Action, MultiParams}
+import util.{MapWithIndifferentAccess, MultiMapHeadView, MultiMap}
+
+trait CoreDsl[Res] {
+
+ type UnfilteredErrorHandler = PartialFunction[scala.Throwable, ResponseFunction[Res]]
+ def params: Map[String, String]
+
+ def multiParams: MultiParams
+
+ def redirect(uri: String) = throw new RuntimeException("TODO") // TODO: provide redirect etc methods
+
+ def before(routeMatchers: RouteMatcher*)(block: => ResponseFunction[Res]): Unit
+
+ def after(routeMatchers: RouteMatcher*)(block: => ResponseFunction[Res]): Unit
+
+ def notFound(block: => ResponseFunction[Res]): Unit
+
+ def methodNotAllowed(block: Set[HttpMethod] => ResponseFunction[Res]): Unit
+
+ def error(handler: UnfilteredErrorHandler): Unit
+
+ def get(routeMatchers: RouteMatcher*)(block: => ResponseFunction[Res]): UnfilteredRoute[Res]
+
+ def post(routeMatchers: RouteMatcher*)(block: => ResponseFunction[Res]): UnfilteredRoute[Res]
+
+ def put(routeMatchers: RouteMatcher*)(block: => ResponseFunction[Res]): UnfilteredRoute[Res]
+
+ def delete(routeMatchers: RouteMatcher*)(block: => ResponseFunction[Res]): UnfilteredRoute[Res]
+
+ def options(routeMatchers: RouteMatcher*)(block: => ResponseFunction[Res]): UnfilteredRoute[Res]
+
+ def patch(routeMatchers: RouteMatcher*)(block: => ResponseFunction[Res]): UnfilteredRoute[Res]
+
+ def halt(status: JInteger = null,
+ body: String = "",
+ headers: Map[String, String] = Map.empty,
+ reason: String = null): Nothing
+
+ def pass()
+}
+
+
+case class UnfilteredRoute[T](
+ routeMatchers: Iterable[RouteMatcher],
+ action: () => ResponseFunction[T],
+ contextPath: () => String = () => ""
+)
+{
+ def apply(): Option[MatchedUnfilteredRoute[T]] = {
+ routeMatchers.foldLeft(Option(MultiMap())) {
+ (acc: Option[MultiParams], routeMatcher: RouteMatcher) => for {
+ routeParams <- acc
+ matcherParams <- routeMatcher()
+ } yield routeParams ++ matcherParams
+ } map { routeParams => MatchedUnfilteredRoute(action, routeParams) }
+ }
+
+ lazy val reversibleMatcher: Option[RouteMatcher] =
+ routeMatchers find (_.isInstanceOf[ReversibleRouteMatcher])
+
+ lazy val isReversible: Boolean = !reversibleMatcher.isEmpty
+
+ override def toString: String = routeMatchers mkString " "
+}
+
+case class MatchedUnfilteredRoute[T](action: () => ResponseFunction[T], multiParams: MultiParams)
+
+
+class UnfilteredRouteRegistry[T] {
+ private val methodRoutes: ConcurrentMap[HttpMethod, Seq[UnfilteredRoute[T]]] =
+ new ConcurrentHashMap[HttpMethod, Seq[UnfilteredRoute[T]]]
+
+ private var _beforeFilters: Seq[UnfilteredRoute[T]] = Vector.empty
+ private var _afterFilters: Seq[UnfilteredRoute[T]] = Vector.empty
+
+ def apply(method: HttpMethod): Seq[UnfilteredRoute[T]] =
+ method match {
+ case Head => methodRoutes.getOrElse(Get, Vector.empty)
+ case m => methodRoutes.getOrElse(m, Vector.empty)
+ }
+
+ def matchingMethods: Set[HttpMethod] = matchingMethodsExcept { _ => false }
+
+ def matchingMethodsExcept(method: HttpMethod): Set[HttpMethod] = {
+ val p: HttpMethod => Boolean = method match {
+ case Get | Head => { m => m == Get || m == Head }
+ case _ => { _ == method }
+ }
+ matchingMethodsExcept(p)
+ }
+
+ private def matchingMethodsExcept(p: HttpMethod => Boolean) = {
+ var methods = (methodRoutes filter { case (method, routes) =>
+ !p(method) && (routes exists { _().isDefined })
+ }).keys.toSet
+ if (methods.contains(Get))
+ methods += Head
+ methods
+ }
+
+ def prependRoute(method: HttpMethod, route: UnfilteredRoute[T]): Unit =
+ modifyRoutes(method, route +: _)
+
+ def removeRoute(method: HttpMethod, route: UnfilteredRoute[T]): Unit =
+ modifyRoutes(method, _ filterNot (_ == route))
+
+ def beforeFilters: Seq[UnfilteredRoute[T]] = _beforeFilters
+
+ def appendBeforeFilter(route: UnfilteredRoute[T]): Unit = _beforeFilters :+= route
+
+ def afterFilters: Seq[UnfilteredRoute[T]] = _afterFilters
+
+ def appendAfterFilter(route: UnfilteredRoute[T]): Unit = _afterFilters :+= route
+
+ @tailrec private def modifyRoutes(method: HttpMethod, f: (Seq[UnfilteredRoute[T]] => Seq[UnfilteredRoute[T]])): Unit = {
+ if (methodRoutes.putIfAbsent(method, f(Vector.empty)).isDefined) {
+ val oldRoutes = methodRoutes(method)
+ if (!methodRoutes.replace(method, oldRoutes, f(oldRoutes)))
+ modifyRoutes(method, f)
+ }
+ }
+
+ def entryPoints: Seq[String] =
+ (for {
+ (method, routes) <- methodRoutes
+ route <- routes
+ } yield method + " " + route).toSeq sortWith (_ < _)
+
+ override def toString: String = entryPoints mkString ", "
+}
trait ImplicitResponses {
@@ -14,32 +156,170 @@ trait ImplicitResponses {
Html(xml)
}
-trait Scalatra[Req,Res] {
+trait Scalatra[Req,Res] extends CoreDsl[Res] {
- //this is used for all request methods for now
- private lazy val handlers = collection.mutable.Map[String,Function0[ResponseFunction[Res]]]()
+ protected lazy val routes = new UnfilteredRouteRegistry[Res]()
private lazy val _request = new DynamicVariable[HttpRequest[_]](null)
+ private lazy val _multiParams = new DynamicVariable[MultiMap](null)
+ def multiParams: MultiParams = _multiParams.value.withDefaultValue(Seq.empty)
- def get(r:String)( f: => ResponseFunction[Res]) = {
- val p = () => f
- handlers += (r -> p)
+ protected val _params = new MultiMapHeadView[String, String] with MapWithIndifferentAccess[String] {
+ protected def multiMap = multiParams
}
+
+ def params = _params
+
+
implicit def request = _request value
- protected def executeRoutes(req: HttpRequest[_]):ResponseFunction[Res] = {
- //TODO:proper matching logic should come here, for now it's matching all request methods from right to left
- val handler = handlers.keys.filter(req.uri.startsWith(_))
- handler.lastOption map(handlers(_)()) getOrElse ( NotFound ~> ResponseString("could not find handler"))
+ protected implicit def string2RouteMatcher(path: String): RouteMatcher =
+ new SinatraRouteMatcher(path, request.uri)
+
+ protected implicit def pathPatternParser2RouteMatcher(pattern: PathPattern): RouteMatcher =
+ new PathPatternRouteMatcher(pattern, request.uri)
+
+ protected implicit def regex2RouteMatcher(regex: Regex): RouteMatcher =
+ new RegexRouteMatcher(regex, request.uri)
+
+ protected implicit def booleanBlock2RouteMatcher(block: => Boolean): RouteMatcher =
+ new BooleanBlockRouteMatcher(block)
+
+ protected case class HaltException(
+ status: Option[Int],
+ reason: Option[String],
+ headers: Map[String, String],
+ body: String)
+ extends RuntimeException
+
+ protected def executeRoutes:ResponseFunction[Res] = {
+ val result = try {
+ runFilters(routes.beforeFilters)
+ val actionResult = runRoutes(routes(HttpMethod(request.method))).headOption
+ actionResult orElse matchOtherMethods() getOrElse doNotFound()
+ }
+ catch {
+ case e: HaltException => {
+ Status(e.status getOrElse 500) ~> ResponseString(e.body.toString)
+ }
+ case e => errorHandler(e)
+ }
+ finally {
+ runFilters(routes.afterFilters)
+ }
+ result
+ }
+
+ protected def runFilters(filters: Traversable[UnfilteredRoute[Res]]) =
+ for {
+ route <- filters
+ matchedRoute <- route()
+ } invoke(matchedRoute)
+
+ protected def runRoutes(routes: Traversable[UnfilteredRoute[Res]]) =
+ for {
+ route <- routes.toStream // toStream makes it lazy so we stop after match
+ matchedRoute <- route()
+ actionResult <- invoke(matchedRoute)
+ } yield actionResult
+
+ protected def invoke(matchedRoute: MatchedUnfilteredRoute[Res]) =
+ withRouteMultiParams(Some(matchedRoute)) {
+ try {
+ Some(matchedRoute.action())
+ }
+ catch {
+ case e: PassException => None
+ }
+ }
+
+ protected def withRouteMultiParams[S](matchedRoute: Option[MatchedUnfilteredRoute[Res]])(thunk: => S): S = {
+ val originalParams = multiParams
+ _multiParams.withValue(originalParams ++ matchedRoute.map(_.multiParams).getOrElse(Map.empty))(thunk)
+ }
+
+
+ def pass() = throw new PassException
+
+ protected class PassException extends RuntimeException
+
+ def get(routeMatchers: RouteMatcher*)(action: => ResponseFunction[Res]) = addRoute(Get, routeMatchers, action)
+
+ def post(routeMatchers: RouteMatcher*)(action: => ResponseFunction[Res]) = addRoute(Post, routeMatchers, action)
+
+ def put(routeMatchers: RouteMatcher*)(action: => ResponseFunction[Res]) = addRoute(Put, routeMatchers, action)
+
+ def delete(routeMatchers: RouteMatcher*)(action: => ResponseFunction[Res]) = addRoute(Delete, routeMatchers, action)
+
+ def options(routeMatchers: RouteMatcher*)(action: => ResponseFunction[Res]) = addRoute(Options, routeMatchers, action)
+
+ def patch(routeMatchers: RouteMatcher*)(action: => ResponseFunction[Res]) = addRoute(Patch, routeMatchers, action)
+
+ protected def addRoute(method: HttpMethod, routeMatchers: Iterable[RouteMatcher], action: => ResponseFunction[Res]): UnfilteredRoute[Res] = {
+ val route = UnfilteredRoute(routeMatchers, () => action, () => request.uri)
+ routes.prependRoute(method, route)
+ route
}
+
+ protected def removeRoute(method: HttpMethod, route: UnfilteredRoute[Res]): Unit =
+ routes.removeRoute(method, route)
+
+ protected def removeRoute(method: String, route: UnfilteredRoute[Res]): Unit =
+ removeRoute(HttpMethod(method), route)
+
+
val logger = Logger(classOf[App])
+ def halt(status: JInteger = null,
+ body: String,
+ headers: Map[String, String] = Map.empty,
+ reason: String = null): Nothing = {
+ val statusOpt = if (status == null) None else Some(status.intValue)
+ throw new HaltException(statusOpt, Some(reason), headers, body)
+ }
+
//capture all requests
def intent: unfiltered.Cycle.Intent[Req,Res] = {
case req @ _ => _request.withValue(req) {
- executeRoutes(req)
+ val multiMap = request.parameterNames.foldLeft(MultiMap()){ (acc, name) =>
+ acc + (name -> request.parameterValues(name))
+ }
+ _multiParams.withValue(multiMap) {
+ executeRoutes
+ }
}
}
+
+ private def matchOtherMethods(): Option[ResponseFunction[Res]] = {
+ val allow = routes.matchingMethodsExcept(HttpMethod(request.method))
+ if (allow.isEmpty) None else Some(doMethodNotAllowed(allow))
+ }
+
+ def before(routeMatchers: RouteMatcher*)(fun: => ResponseFunction[Res]) =
+ addBefore(routeMatchers, fun)
+
+ private def addBefore(routeMatchers: Iterable[RouteMatcher], fun: => ResponseFunction[Res]) =
+ routes.appendBeforeFilter(UnfilteredRoute(routeMatchers, () => fun))
+
+ def after(routeMatchers: RouteMatcher*)(fun: => ResponseFunction[Res]) =
+ addAfter(routeMatchers, fun)
+
+ private def addAfter(routeMatchers: Iterable[RouteMatcher], fun: => ResponseFunction[Res]) =
+ routes.appendAfterFilter(UnfilteredRoute(routeMatchers, () => fun))
+
+ protected var doMethodNotAllowed: (Set[HttpMethod] => ResponseFunction[Res]) = { allow =>
+ MethodNotAllowed ~> ResponseString(allow mkString ", ")
+ }
+ def methodNotAllowed(f: Set[HttpMethod] => ResponseFunction[Res]) = doMethodNotAllowed = f
+
+ protected var doNotFound: () => ResponseFunction[Res] = () => NotFound ~> ResponseString("could not find handler")
+ def notFound(block: => ResponseFunction[Res]) {
+ doNotFound = () => block
+ }
+
+
+ protected var errorHandler: UnfilteredErrorHandler = { case t => throw t }
+ def error(handler: UnfilteredErrorHandler) = errorHandler = handler orElse errorHandler
}
/**
@@ -64,6 +344,9 @@ with ImplicitResponses {
"hello index page!"
}
+ get("/param/:value") {
+ "hello %s" format params('value)
+ }
}

0 comments on commit 3858365

Please sign in to comment.