Skip to content

Commit

Permalink
More type safety for QueryReducer + introduced ReduceAction (#88).
Browse files Browse the repository at this point in the history
  • Loading branch information
OlegIlyenko committed Oct 15, 2015
1 parent 6bef0fb commit 4293436
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 61 deletions.
28 changes: 15 additions & 13 deletions src/main/scala/sangria/execution/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ case class Executor[Ctx, Root](
deprecationTracker: DeprecationTracker = DeprecationTracker.empty,
middleware: List[Middleware[Ctx]] = Nil,
maxQueryDepth: Option[Int] = None,
queryReducers: List[QueryReducer] = Nil)(implicit executionContext: ExecutionContext) {
queryReducers: List[QueryReducer[Ctx, _]] = Nil)(implicit executionContext: ExecutionContext) {

def execute[Input](
queryAst: ast.Document,
Expand Down Expand Up @@ -137,7 +137,7 @@ case class Executor[Ctx, Root](
variables: Map[String, Any],
rootTpe: ObjectType[_, _],
fields: Map[String, (ast.Field, Try[List[ast.Field]])],
reducers: Vector[QueryReducer]): Any = {
reducers: Vector[QueryReducer[Ctx, _]]): Any = {
// Using mutability here locally in order to reduce footprint
import scala.collection.mutable.ListBuffer

Expand All @@ -163,11 +163,11 @@ case class Executor[Ctx, Root](
for (i 0 until reducers.size) {
val reducer = reducers(i)

acc(i) = reducer.reduceField(
acc(i) = reducer.reduceField[Any](
acc(i).asInstanceOf[reducer.Acc],
childReduced(i).asInstanceOf[reducer.Acc],
newPath, userContext, fields,
objTpe.asInstanceOf[ObjectType[Ctx, Any]],
objTpe.asInstanceOf[ObjectType[Any, Any]],
field.asInstanceOf[Field[Ctx, Any]], argumentValuesFn)
}

Expand Down Expand Up @@ -206,7 +206,7 @@ case class Executor[Ctx, Root](
childReduced(i).asInstanceOf[reducer.Acc],
path, userContext, astFields,
rootTpe.asInstanceOf[ObjectType[Any, Any]],
field.asInstanceOf[Field[Any, Any]], argumentValuesFn)
field.asInstanceOf[Field[Ctx, Any]], argumentValuesFn)
}

acc
Expand All @@ -215,18 +215,20 @@ case class Executor[Ctx, Root](

// Unsafe part to avoid addition boxing in order to reduce the footprint
reducers.zipWithIndex.foldLeft(userContext: Any) {
case (acc: Future[_], (reducer, idx))
case (acc: Future[Ctx], (reducer, idx))
acc.flatMap(a reducer.reduceCtx(reduced(idx).asInstanceOf[reducer.Acc], a) match {
case Left(future) future
case Right(value) Future.successful(value)
case FutureValue(future) future
case Value(value) Future.successful(value)
case TryValue(value) Future.fromTry(value)
})

case (acc, (reducer, idx))
case (acc: Ctx @unchecked, (reducer, idx))
reducer.reduceCtx(reduced(idx).asInstanceOf[reducer.Acc], acc) match {
case Left(future) future
case Right(value) value
case FutureValue(future) future
case Value(value) value
case TryValue(value) value.get
}
}.asInstanceOf[Ctx]
}
}
}

Expand All @@ -244,7 +246,7 @@ object Executor {
deprecationTracker: DeprecationTracker = DeprecationTracker.empty,
middleware: List[Middleware[Ctx]] = Nil,
maxQueryDepth: Option[Int] = None,
queryReducers: List[QueryReducer] = Nil
queryReducers: List[QueryReducer[Ctx, _]] = Nil
)(implicit executionContext: ExecutionContext, marshaller: ResultMarshaller, um: InputUnmarshaller[Input]): Future[marshaller.Node] =
Executor(schema, root, userContext, queryValidator, deferredResolver, exceptionHandler, deprecationTracker, middleware, maxQueryDepth, queryReducers).execute(queryAst, operationName, variables)
}
Expand Down
69 changes: 35 additions & 34 deletions src/main/scala/sangria/execution/QueryReducer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,53 @@ package sangria.execution
import sangria.ast
import sangria.schema._

import scala.annotation.unchecked.uncheckedVariance
import scala.concurrent.Future
import scala.util.{Try, Failure, Success}

trait QueryReducer {
trait QueryReducer[-Ctx, +Out] {
type Acc

def initial: Acc

def reduceAlternatives(alterntives: Seq[Acc]): Acc

def reduceField[Ctx, Val](
def reduceField[Val](
fieldAcc: Acc,
childrenAcc: Acc,
path: List[String],
ctx: Ctx,
astFields: List[ast.Field],
parentType: ObjectType[Ctx, Val],
field: Field[Ctx, Val],
parentType: ObjectType[Out, Val] @uncheckedVariance,
field: Field[Ctx, Val] @uncheckedVariance,
argumentValuesFn: (List[String], List[Argument[_]], List[ast.Argument]) Try[Args]): Acc

def reduceScalar[Ctx, T](
def reduceScalar[T](
path: List[String],
ctx: Ctx,
tpe: ScalarType[T]): Acc

def reduceEnum[Ctx, T](
def reduceEnum[T](
path: List[String],
ctx: Ctx,
tpe: EnumType[T]): Acc

def reduceCtx[Ctx](acc: Acc, ctx: Ctx): Either[Future[Ctx], Ctx]
def reduceCtx(acc: Acc, ctx: Ctx): ReduceAction[Out, Out]
}

object QueryReducer {
def measureComplexity[Ctx](fn: (Double, Ctx) Either[Future[Ctx], Ctx]): QueryReducer =
def measureComplexity[Ctx](fn: (Double, Ctx) ReduceAction[Ctx, Ctx]): QueryReducer[Ctx, Ctx] =
new MeasureComplexity[Ctx](fn)

def rejectComplexQueries(complexityThreshold: Double, error: Double Throwable): QueryReducer =
new MeasureComplexity[Any]((c, ctx)
if (c >= complexityThreshold) throw error(c) else Right(ctx))
def rejectComplexQueries[Ctx](complexityThreshold: Double, error: (Double, Ctx) Throwable): QueryReducer[Ctx, Ctx] =
new MeasureComplexity[Ctx]((c, ctx)
if (c >= complexityThreshold) throw error(c, ctx) else ctx)

def collectTags[Ctx, T](tagMatcher: PartialFunction[FieldTag, T])(fn: (Seq[T], Ctx) Either[Future[Ctx], Ctx]): QueryReducer =
def collectTags[Ctx, T](tagMatcher: PartialFunction[FieldTag, T])(fn: (Seq[T], Ctx) ReduceAction[Ctx, Ctx]): QueryReducer[Ctx, Ctx] =
new TagCollector[Ctx, T](tagMatcher, fn)
}

class MeasureComplexity[Ctx](action: (Double, Ctx) Either[Future[Ctx], Ctx]) extends QueryReducer {
class MeasureComplexity[Ctx](action: (Double, Ctx) ReduceAction[Ctx, Ctx]) extends QueryReducer[Ctx, Ctx] {
type Acc = Double

import MeasureComplexity.DefaultComplexity
Expand All @@ -57,14 +58,14 @@ class MeasureComplexity[Ctx](action: (Double, Ctx) ⇒ Either[Future[Ctx], Ctx])

def reduceAlternatives(alterntives: Seq[Acc]) = alterntives.max

def reduceField[C, Val](
def reduceField[Val](
fieldAcc: Acc,
childrenAcc: Acc,
path: List[String],
ctx: C,
ctx: Ctx,
astFields: List[ast.Field],
parentType: ObjectType[C, Val],
field: Field[C, Val],
parentType: ObjectType[Ctx, Val],
field: Field[Ctx, Val],
argumentValuesFn: (List[String], List[Argument[_]], List[ast.Argument]) Try[Args]): Acc = {
val estimate = field.complexity match {
case Some(fn)
Expand All @@ -78,52 +79,52 @@ class MeasureComplexity[Ctx](action: (Double, Ctx) ⇒ Either[Future[Ctx], Ctx])
fieldAcc + estimate
}

def reduceScalar[C, T](
def reduceScalar[T](
path: List[String],
ctx: C,
ctx: Ctx,
tpe: ScalarType[T]): Acc = tpe.complexity

def reduceEnum[C, T](
def reduceEnum[T](
path: List[String],
ctx: C,
ctx: Ctx,
tpe: EnumType[T]): Acc = initial

def reduceCtx[C](acc: Acc, ctx: C) =
action(acc, ctx.asInstanceOf[Ctx]).asInstanceOf[Either[Future[C], C]]
def reduceCtx(acc: Acc, ctx: Ctx) =
action(acc, ctx)
}

object MeasureComplexity {
val DefaultComplexity = 1.0D
}

class TagCollector[Ctx, T](tagMatcher: PartialFunction[FieldTag, T], action: (Seq[T], Ctx) Either[Future[Ctx], Ctx]) extends QueryReducer {
class TagCollector[Ctx, T](tagMatcher: PartialFunction[FieldTag, T], action: (Seq[T], Ctx) ReduceAction[Ctx, Ctx]) extends QueryReducer[Ctx, Ctx] {
type Acc = Vector[T]

val initial = Vector.empty

def reduceAlternatives(alterntives: Seq[Acc]) = alterntives.toVector.flatten

def reduceField[C, Val](
def reduceField[Val](
fieldAcc: Acc,
childrenAcc: Acc,
path: List[String],
ctx: C,
ctx: Ctx,
astFields: List[ast.Field],
parentType: ObjectType[C, Val],
field: Field[C, Val],
parentType: ObjectType[Ctx, Val],
field: Field[Ctx, Val],
argumentValuesFn: (List[String], List[Argument[_]], List[ast.Argument]) Try[Args]): Acc =
fieldAcc ++ childrenAcc ++ field.tags.collect {case t if tagMatcher.isDefinedAt(t) tagMatcher(t)}

def reduceScalar[C, ST](
def reduceScalar[ST](
path: List[String],
ctx: C,
ctx: Ctx,
tpe: ScalarType[ST]): Acc = initial

def reduceEnum[C, ET](
def reduceEnum[ET](
path: List[String],
ctx: C,
ctx: Ctx,
tpe: EnumType[ET]): Acc = initial

def reduceCtx[C](acc: Acc, ctx: C) =
action(acc, ctx.asInstanceOf[Ctx]).asInstanceOf[Either[Future[C], C]]
def reduceCtx(acc: Acc, ctx: Ctx) =
action(acc, ctx)
}
18 changes: 14 additions & 4 deletions src/main/scala/sangria/schema/Context.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,36 @@ sealed trait Action[+Ctx, +Val] {
sealed trait LeafAction[+Ctx, +Val] extends Action[Ctx, Val] {
def map[NewVal](fn: Val NewVal)(implicit ec: ExecutionContext): LeafAction[Ctx, NewVal]
}
sealed trait ReduceAction[+Ctx, +Val] extends Action[Ctx, Val] {
def map[NewVal](fn: Val NewVal)(implicit ec: ExecutionContext): LeafAction[Ctx, NewVal]
}

object ReduceAction {
implicit def futureAction[Ctx, Val](value: Future[Val]): ReduceAction[Ctx, Val] = FutureValue(value)
implicit def tryAction[Ctx, Val](value: Try[Val]): ReduceAction[Ctx, Val] = TryValue(value)
implicit def defaultAction[Ctx, Val](value: Val): ReduceAction[Ctx, Val] = Value(value)
}

object Action {
implicit def futureAction[Ctx, Val](value: Future[Val]): LeafAction[Ctx, Val] = FutureValue(value)
implicit def deferredAction[Ctx, Val](value: Deferred[Val]): LeafAction[Ctx, Val] = DeferredValue(value)
implicit def deferredFutureAction[Ctx, Val, D <: Deferred[Val]](value: Future[D])(implicit ev: D <:< Deferred[Val]): LeafAction[Ctx, Val] = DeferredFutureValue(value)

implicit def futureAction[Ctx, Val](value: Future[Val]): LeafAction[Ctx, Val] = FutureValue(value)
implicit def tryAction[Ctx, Val](value: Try[Val]): LeafAction[Ctx, Val] = TryValue(value)
implicit def defaultAction[Ctx, Val](value: Val): LeafAction[Ctx, Val] = Value(value)
}

case class Value[Ctx, Val](value: Val) extends LeafAction[Ctx, Val] {
case class Value[Ctx, Val](value: Val) extends LeafAction[Ctx, Val] with ReduceAction[Ctx, Val] {
override def map[NewVal](fn: Val NewVal)(implicit ec: ExecutionContext): Value[Ctx, NewVal] =
Value(fn(value))
}

case class TryValue[Ctx, Val](value: Try[Val]) extends LeafAction[Ctx, Val] {
case class TryValue[Ctx, Val](value: Try[Val]) extends LeafAction[Ctx, Val] with ReduceAction[Ctx, Val] {
override def map[NewVal](fn: Val NewVal)(implicit ec: ExecutionContext): TryValue[Ctx, NewVal] =
TryValue(value map fn)
}

case class FutureValue[Ctx, Val](value: Future[Val]) extends LeafAction[Ctx, Val] {
case class FutureValue[Ctx, Val](value: Future[Val]) extends LeafAction[Ctx, Val] with ReduceAction[Ctx, Val] {
override def map[NewVal](fn: Val NewVal)(implicit ec: ExecutionContext): FutureValue[Ctx, NewVal] =
FutureValue(value map fn)
}
Expand Down
20 changes: 10 additions & 10 deletions src/test/scala/sangria/execution/QueryReducerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class QueryReducerSpec extends WordSpec with Matchers with AwaitSupport {

val complReducer = QueryReducer.measureComplexity[Info] { (c, ctx)
complexity = c
Right(ctx)
ctx
}

Executor.execute(schema, query, userContext = Info(Nil), queryReducers = complReducer :: Nil).await should be (
Expand Down Expand Up @@ -170,7 +170,7 @@ class QueryReducerSpec extends WordSpec with Matchers with AwaitSupport {

val complReducer = QueryReducer.measureComplexity[Info] { (c, ctx)
complexity = c
Right(ctx)
ctx
}

Executor.execute(schema, query, userContext = Info(Nil), queryReducers = complReducer :: Nil).await should be (
Expand Down Expand Up @@ -219,7 +219,7 @@ class QueryReducerSpec extends WordSpec with Matchers with AwaitSupport {

val complReducer = QueryReducer.measureComplexity[Info] { (c, ctx)
complexity = c
Right(ctx)
ctx
}

Executor.execute(schema, query, userContext = Info(Nil), queryReducers = complReducer :: Nil).await should be (
Expand Down Expand Up @@ -276,7 +276,7 @@ class QueryReducerSpec extends WordSpec with Matchers with AwaitSupport {

val complReducer = QueryReducer.measureComplexity[Info] { (c, ctx)
complexity = c
Right(ctx)
ctx
}

Executor.execute(schema, query, userContext = Info(Nil), queryReducers = complReducer :: Nil).await should be (
Expand Down Expand Up @@ -318,7 +318,7 @@ class QueryReducerSpec extends WordSpec with Matchers with AwaitSupport {

val complReducer = QueryReducer.measureComplexity[Info] { (c, ctx)
complexity = c
Right(ctx)
ctx
}

Executor.execute(schema, query, userContext = Info(Nil), queryReducers = complReducer :: Nil).await should be (
Expand Down Expand Up @@ -346,7 +346,7 @@ class QueryReducerSpec extends WordSpec with Matchers with AwaitSupport {
case (m, e: IllegalArgumentException) HandledException(e.getMessage)
}

val rejectComplexQuery = QueryReducer.rejectComplexQueries(14, c
val rejectComplexQuery = QueryReducer.rejectComplexQueries[Info](14, (c, _)
new IllegalArgumentException(s"Too complex query: max allowed complexity is 14.0, but got $c"))

Executor.execute(schema, query,
Expand Down Expand Up @@ -376,11 +376,11 @@ class QueryReducerSpec extends WordSpec with Matchers with AwaitSupport {

val complReducer = QueryReducer.measureComplexity[Info] { (c, ctx)
complexity = c
Right(ctx)
ctx
}

val tagColl = QueryReducer.collectTags[Info, Int] {case ATag(num) num} ((nums, ctx)
Right(ctx.copy(nums = nums)))
ctx.copy(nums = nums))

Executor.execute(schema, query,
userContext = Info(Nil),
Expand Down Expand Up @@ -411,11 +411,11 @@ class QueryReducerSpec extends WordSpec with Matchers with AwaitSupport {

val complReducer = QueryReducer.measureComplexity[Info] { (c, ctx)
complexity = c
Right(ctx)
ctx
}

val tagColl = QueryReducer.collectTags[Info, Int] {case ATag(num) num} ((nums, ctx)
Right(ctx.copy(nums = nums)))
ctx.copy(nums = nums))

Executor.execute(schema, query,
userContext = Info(Nil),
Expand Down

0 comments on commit 4293436

Please sign in to comment.