Skip to content

Commit

Permalink
Implemented more generic way to reduce the query. This allows better …
Browse files Browse the repository at this point in the history
…separation between execution engine and features like complexity analysis. This also makes it possible to easily define other kinds of query reducers, like tag collector (#88).
  • Loading branch information
OlegIlyenko committed Oct 14, 2015
1 parent ee84644 commit 1cd155c
Show file tree
Hide file tree
Showing 7 changed files with 420 additions and 107 deletions.
178 changes: 157 additions & 21 deletions src/main/scala/sangria/execution/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ case class Executor[Ctx, Root](
deprecationTracker: DeprecationTracker = DeprecationTracker.empty,
middleware: List[Middleware] = Nil,
maxQueryDepth: Option[Int] = None,
measureComplexity: Option[Double Unit] = None)(implicit executionContext: ExecutionContext) {
queryReducers: List[QueryReducer] = Nil)(implicit executionContext: ExecutionContext) {

def execute[Input](
queryAst: ast.Document,
Expand All @@ -37,7 +37,17 @@ case class Executor[Ctx, Root](
operation getOperation(queryAst, operationName)
unmarshalledVariables valueCollector.getVariableValues(operation.variables)
fieldCollector = new FieldCollector[Ctx, Root](schema, queryAst, unmarshalledVariables, queryAst.sourceMapper, valueCollector)
res executeOperation(MiddlewareQueryContext(this, queryAst, operationName, variables, um), operation, queryAst.sourceMapper, valueCollector, fieldCollector, marshaller, unmarshalledVariables)
res executeOperation(
queryAst,
operationName,
variables,
um,
operation,
queryAst.sourceMapper,
valueCollector,
fieldCollector,
marshaller,
unmarshalledVariables)
} yield res

executionResult match {
Expand All @@ -56,8 +66,11 @@ case class Executor[Ctx, Root](
operation map (Success(_)) getOrElse Failure(new ExecutionError(s"Unknown operation name: ${operationName.get}"))
}

def executeOperation(
middlewareCtx: MiddlewareQueryContext[Ctx, _],
def executeOperation[Input](
queryAst: ast.Document,
operationName: Option[String] = None,
inputVariables: Input,
inputUnmarshaller: InputUnmarshaller[Input],
operation: ast.OperationDefinition,
sourceMapper: Option[SourceMapper],
valueCollector: ValueCollector[Ctx, _],
Expand All @@ -67,31 +80,154 @@ case class Executor[Ctx, Root](
for {
tpe getOperationRootType(operation, sourceMapper)
fields fieldCollector.collectFields(Nil, tpe, operation :: Nil)
middlewareVal = middleware map (m m.beforeQuery(middlewareCtx) m)
resolver = new Resolver[Ctx](marshaller, middlewareCtx, schema, valueCollector, variables, fieldCollector, userContext, exceptionHandler, deferredResolver, sourceMapper, deprecationTracker, middlewareVal, maxQueryDepth)
_ measureComplexity.fold(Success(()): Try[Unit])(fn Try(fn(resolver.estimateComplexity(tpe, fields))))
} yield {
val result =
operation.operationType match {
case ast.OperationType.Query resolver.resolveFieldsPar(tpe, root, fields).asInstanceOf[Future[marshaller.Node]]
case ast.OperationType.Mutation resolver.resolveFieldsSeq(tpe, root, fields).asInstanceOf[Future[marshaller.Node]]
}
def doExecute(ctx: Ctx) = {
val middlewareCtx = MiddlewareQueryContext(ctx, this, queryAst, operationName, inputVariables, inputUnmarshaller)

val middlewareVal = middleware map (m m.beforeQuery(middlewareCtx) m)

val resolver = new Resolver[Ctx](
marshaller,
middlewareCtx,
schema,
valueCollector,
variables,
fieldCollector,
ctx,
exceptionHandler,
deferredResolver,
sourceMapper,
deprecationTracker,
middlewareVal,
maxQueryDepth)

val result =
operation.operationType match {
case ast.OperationType.Query resolver.resolveFieldsPar(tpe, root, fields).asInstanceOf[Future[marshaller.Node]]
case ast.OperationType.Mutation resolver.resolveFieldsSeq(tpe, root, fields).asInstanceOf[Future[marshaller.Node]]
}

if (middlewareVal.nonEmpty) {
def onAfter() =
middlewareVal foreach { case (v, m) m.afterQuery(v.asInstanceOf[m.QueryVal], middlewareCtx)}

if (middlewareVal.nonEmpty) {
def onAfter() =
middlewareVal foreach {case (v, m) m.afterQuery(v.asInstanceOf[m.QueryVal], middlewareCtx)}
result
.map { x onAfter(); x}
.recover { case e onAfter(); throw e}
} else result
}

result
.map {x onAfter(); x}
.recover {case e onAfter(); throw e}
} else result
if (queryReducers.nonEmpty)
reduceQuery(fieldCollector, valueCollector, variables, tpe, fields, queryReducers.toVector) match {
case future: Future[Ctx] future.flatMap(newCtx doExecute(newCtx))
case newCtx: Ctx @unchecked doExecute(newCtx)
}
else doExecute(userContext)
}

def getOperationRootType(operation: ast.OperationDefinition, sourceMapper: Option[SourceMapper]) = operation.operationType match {
case ast.OperationType.Query Success(schema.query)
case ast.OperationType.Mutation schema.mutation map (Success(_)) getOrElse
Failure(new ExecutionError("Schema is not configured for mutations", sourceMapper, operation.position.toList))
}

private def reduceQuery[Val](
fieldCollector: FieldCollector[Ctx, Val],
valueCollector: ValueCollector[Ctx, _],
variables: Map[String, Any],
rootTpe: ObjectType[_, _],
fields: Map[String, (ast.Field, Try[List[ast.Field]])],
reducers: Vector[QueryReducer]): Any = {
// Using mutability here locally in order to reduce footprint
import scala.collection.mutable.ListBuffer

val argumentValuesFn = (path: List[String], argumentDefs: List[Argument[_]], argumentAsts: List[ast.Argument])
valueCollector.getFieldArgumentValues(path, argumentDefs, argumentAsts, variables)

val initialValues: Vector[Any] = reducers map (_.initial)

def loop(path: List[String], tpe: OutputType[_], astFields: List[ast.Field]): Seq[Any] =
tpe match {
case OptionType(ofType) loop(path, ofType, astFields)
case ListType(ofType) loop(path, ofType, astFields)
case objTpe: ObjectType[Ctx, _]
fieldCollector.collectFields(path, objTpe, astFields) match {
case Success(ff)
ff.values.toVector.foldLeft(ListBuffer(initialValues: _*)) {
case (acc, (_, Success(fields))) if objTpe.getField(schema, fields.head.name).nonEmpty
val astField = fields.head
val field = objTpe.getField(schema, astField.name).head
val newPath = path :+ astField.outputName
val childReduced = loop(newPath, field.fieldType, fields)

for (i 0 until reducers.size) {
val reducer = reducers(i)

acc(i) = reducer.reduceField(
acc(i).asInstanceOf[reducer.Acc],
childReduced(i).asInstanceOf[reducer.Acc],
newPath, userContext, fields,
objTpe.asInstanceOf[ObjectType[Ctx, Any]],
field.asInstanceOf[Field[Ctx, Any]], argumentValuesFn)
}

acc
case (acc, _) acc
}
case Failure(_) initialValues
}
case abst: AbstractType
schema.possibleTypes
.get (abst.name)
.map (types
types.map(loop(path, _, astFields)).transpose.zipWithIndex.map{
case (values, idx)
val reducer = reducers(idx)
reducer.reduceAlternatives(values.asInstanceOf[Seq[reducer.Acc]])
})
.getOrElse (initialValues)
case s: ScalarType[_] reducers map (_.reduceScalar(path, userContext, s))
case e: EnumType[_] reducers map (_.reduceEnum(path, userContext, e))
case _ initialValues
}

val reduced = fields.values.toVector.foldLeft(ListBuffer(initialValues: _*)) {
case (acc, (_, Success(astFields))) if rootTpe.getField(schema, astFields.head.name).nonEmpty =>
val astField = astFields.head
val field = rootTpe.getField(schema, astField.name).head
val path = astField.outputName :: Nil
val childReduced = loop(path, field.fieldType, astFields)

for (i 0 until reducers.size) {
val reducer = reducers(i)

acc(i) = reducer.reduceField(
acc(i).asInstanceOf[reducer.Acc],
childReduced(i).asInstanceOf[reducer.Acc],
path, userContext, astFields,
rootTpe.asInstanceOf[ObjectType[Any, Any]],
field.asInstanceOf[Field[Any, Any]], argumentValuesFn)
}

acc
case (acc, _) acc
}

// Unsafe part to avoid addition boxing in order to reduce the footprint
reducers.zipWithIndex.foldLeft(userContext: Any) {
case (acc: Future[_], (reducer, idx))
acc.flatMap(a reducer.reduceCtx(reduced(idx).asInstanceOf[reducer.Acc], a) match {
case Left(future) future
case Right(value) Future.successful(value)
})

case (acc, (reducer, idx))
reducer.reduceCtx(reduced(idx).asInstanceOf[reducer.Acc], acc) match {
case Left(future) future
case Right(value) value
}
}.asInstanceOf[Ctx]
}
}

object Executor {
Expand All @@ -108,9 +244,9 @@ object Executor {
deprecationTracker: DeprecationTracker = DeprecationTracker.empty,
middleware: List[Middleware] = Nil,
maxQueryDepth: Option[Int] = None,
measureComplexity: Option[Double Unit] = None
queryReducers: List[QueryReducer] = Nil
)(implicit executionContext: ExecutionContext, marshaller: ResultMarshaller, um: InputUnmarshaller[Input]): Future[marshaller.Node] =
Executor(schema, root, userContext, queryValidator, deferredResolver, exceptionHandler, deprecationTracker, middleware, maxQueryDepth, measureComplexity).execute(queryAst, operationName, variables)
Executor(schema, root, userContext, queryValidator, deferredResolver, exceptionHandler, deprecationTracker, middleware, maxQueryDepth, queryReducers).execute(queryAst, operationName, variables)
}

case class HandledException(message: String, additionalFields: Map[String, ResultMarshaller#Node] = Map.empty)
2 changes: 1 addition & 1 deletion src/main/scala/sangria/execution/FieldCollector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class FieldCollector[Ctx, Val](

possibleDirs.collect{case Failure(error) error}.headOption map (Failure(_)) getOrElse {
val validDirs = possibleDirs collect {case Success(v) v}
val should = validDirs.forall { case (dir, args) dir.shouldInclude(DirectiveContext(selection, dir, Args(args))) }
val should = validDirs.forall { case (dir, args) dir.shouldInclude(DirectiveContext(selection, dir, args)) }

Success(should)
}
Expand Down
129 changes: 129 additions & 0 deletions src/main/scala/sangria/execution/QueryReducer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package sangria.execution

import sangria.ast
import sangria.schema._

import scala.concurrent.Future
import scala.util.{Try, Failure, Success}

trait QueryReducer {
type Acc

def initial: Acc

def reduceAlternatives(alterntives: Seq[Acc]): Acc

def reduceField[Ctx, Val](
fieldAcc: Acc,
childrenAcc: Acc,
path: List[String],
ctx: Ctx,
astFields: List[ast.Field],
parentType: ObjectType[Ctx, Val],
field: Field[Ctx, Val],
argumentValuesFn: (List[String], List[Argument[_]], List[ast.Argument]) Try[Args]): Acc

def reduceScalar[Ctx, T](
path: List[String],
ctx: Ctx,
tpe: ScalarType[T]): Acc

def reduceEnum[Ctx, T](
path: List[String],
ctx: Ctx,
tpe: EnumType[T]): Acc

def reduceCtx[Ctx](acc: Acc, ctx: Ctx): Either[Future[Ctx], Ctx]
}

object QueryReducer {
def measureComplexity[Ctx](fn: (Double, Ctx) Either[Future[Ctx], Ctx]): QueryReducer =
new MeasureComplexity[Ctx](fn)

def rejectComplexQueries(complexityThreshold: Double, error: Double Throwable): QueryReducer =
new MeasureComplexity[Any]((c, ctx)
if (c >= complexityThreshold) throw error(c) else Right(ctx))

def collectTags[Ctx, T](tagMatcher: PartialFunction[FieldTag, T])(fn: (Seq[T], Ctx) Either[Future[Ctx], Ctx]): QueryReducer =
new TagCollector[Ctx, T](tagMatcher, fn)
}

class MeasureComplexity[Ctx](action: (Double, Ctx) Either[Future[Ctx], Ctx]) extends QueryReducer {
type Acc = Double

import MeasureComplexity.DefaultComplexity

val initial = 0.0D

def reduceAlternatives(alterntives: Seq[Acc]) = alterntives.max

def reduceField[C, Val](
fieldAcc: Acc,
childrenAcc: Acc,
path: List[String],
ctx: C,
astFields: List[ast.Field],
parentType: ObjectType[C, Val],
field: Field[C, Val],
argumentValuesFn: (List[String], List[Argument[_]], List[ast.Argument]) Try[Args]): Acc = {
val estimate = field.complexity match {
case Some(fn)
argumentValuesFn(path, field.arguments, astFields.head.arguments) match {
case Success(args) fn(ctx, args, childrenAcc)
case Failure(_) DefaultComplexity + childrenAcc
}
case None DefaultComplexity + childrenAcc
}

fieldAcc + estimate
}

def reduceScalar[C, T](
path: List[String],
ctx: C,
tpe: ScalarType[T]): Acc = tpe.complexity

def reduceEnum[C, T](
path: List[String],
ctx: C,
tpe: EnumType[T]): Acc = initial

def reduceCtx[C](acc: Acc, ctx: C) =
action(acc, ctx.asInstanceOf[Ctx]).asInstanceOf[Either[Future[C], C]]
}

object MeasureComplexity {
val DefaultComplexity = 1.0D
}

class TagCollector[Ctx, T](tagMatcher: PartialFunction[FieldTag, T], action: (Seq[T], Ctx) Either[Future[Ctx], Ctx]) extends QueryReducer {
type Acc = Vector[T]

val initial = Vector.empty

def reduceAlternatives(alterntives: Seq[Acc]) = alterntives.toVector.flatten

def reduceField[C, Val](
fieldAcc: Acc,
childrenAcc: Acc,
path: List[String],
ctx: C,
astFields: List[ast.Field],
parentType: ObjectType[C, Val],
field: Field[C, Val],
argumentValuesFn: (List[String], List[Argument[_]], List[ast.Argument]) Try[Args]): Acc =
fieldAcc ++ childrenAcc ++ field.tags.collect {case t if tagMatcher.isDefinedAt(t) tagMatcher(t)}

def reduceScalar[C, ST](
path: List[String],
ctx: C,
tpe: ScalarType[ST]): Acc = initial

def reduceEnum[C, ET](
path: List[String],
ctx: C,
tpe: EnumType[ET]): Acc = initial

def reduceCtx[C](acc: Acc, ctx: C) =
action(acc, ctx.asInstanceOf[Ctx]).asInstanceOf[Either[Future[C], C]]
}
Loading

0 comments on commit 1cd155c

Please sign in to comment.