Skip to content

Commit

Permalink
Added prepare phase to execution in order to help with multiple que…
Browse files Browse the repository at this point in the history
…ry executions (e.g. helpful for subscriptions queries handled from outside of the `execute`)
  • Loading branch information
OlegIlyenko committed Mar 6, 2016
1 parent 91fdd6d commit 7bffbe2
Show file tree
Hide file tree
Showing 16 changed files with 277 additions and 185 deletions.
257 changes: 173 additions & 84 deletions src/main/scala/sangria/execution/Executor.scala
Expand Up @@ -13,8 +13,6 @@ import scala.util.{Success, Failure, Try}

case class Executor[Ctx, Root](
schema: Schema[Ctx, Root],
root: Root = (),
userContext: Ctx = (),
queryValidator: QueryValidator = QueryValidator.default,
deferredResolver: DeferredResolver[Ctx] = DeferredResolver.empty,
exceptionHandler: Executor.ExceptionHandler = PartialFunction.empty,
Expand All @@ -23,8 +21,62 @@ case class Executor[Ctx, Root](
maxQueryDepth: Option[Int] = None,
queryReducers: List[QueryReducer[Ctx, _]] = Nil)(implicit executionContext: ExecutionContext) {

def prepare[Input](
queryAst: ast.Document,
userContext: Ctx,
root: Root,
operationName: Option[String] = None,
variables: Input = emptyMapVars)(implicit um: InputUnmarshaller[Input]): Future[PreparedQuery[Ctx, Root, Input]] = {
val violations = queryValidator.validateQuery(schema, queryAst)

if (violations.nonEmpty)
Future.failed(ValidationError(violations, exceptionHandler))
else {
val valueCollector = new ValueCollector[Ctx, Input](schema, variables, queryAst.sourceMapper, deprecationTracker, userContext, exceptionHandler)(um)

val executionResult = for {
operation getOperation(queryAst, operationName)
unmarshalledVariables valueCollector.getVariableValues(operation.variables)
fieldCollector = new FieldCollector[Ctx, Root](schema, queryAst, unmarshalledVariables, queryAst.sourceMapper, valueCollector, exceptionHandler)
tpe getOperationRootType(operation, queryAst.sourceMapper)
fields fieldCollector.collectFields(Vector.empty, tpe, Vector(operation))
} yield {
val preparedFields = fields.toVector.flatMap {
case (_, (astField, Success(_)))
val allFields = tpe.getField(schema, astField.name).asInstanceOf[Vector[Field[Ctx, Root]]]
val field = allFields.head
val args = valueCollector.getFieldArgumentValues(Vector(astField.name), field.arguments, astField.arguments, unmarshalledVariables)

args.toOption.map(PreparedField(field, _))
case _ None
}

reduceQuerySafe(fieldCollector, valueCollector, unmarshalledVariables, tpe, fields, userContext) match {
case fut: Future[Ctx]
fut.map(newCtx
new PreparedQuery[Ctx, Root, Input](queryAst, operation, tpe, newCtx, root, preparedFields,
(c: Ctx, r: Root, m: ResultMarshaller)
executeOperation(queryAst, operationName, variables, um, operation, queryAst.sourceMapper, valueCollector,
fieldCollector, m, unmarshalledVariables, tpe, fields, c, r)))
case newCtx
Future.successful(new PreparedQuery[Ctx, Root, Input](queryAst, operation, tpe, newCtx.asInstanceOf[Ctx], root, preparedFields,
(c: Ctx, r: Root, m: ResultMarshaller)
executeOperation(queryAst, operationName, variables, um, operation, queryAst.sourceMapper, valueCollector,
fieldCollector, m, unmarshalledVariables, tpe, fields, c, r)))
}
}

executionResult match {
case Success(future) future
case Failure(error) Future.failed(error)
}
}
}

def execute[Input](
queryAst: ast.Document,
userContext: Ctx,
root: Root,
operationName: Option[String] = None,
variables: Input = emptyMapVars)(implicit marshaller: ResultMarshaller, um: InputUnmarshaller[Input]): Future[marshaller.Node] = {
val violations = queryValidator.validateQuery(schema, queryAst)
Expand All @@ -38,18 +90,16 @@ case class Executor[Ctx, Root](
operation getOperation(queryAst, operationName)
unmarshalledVariables valueCollector.getVariableValues(operation.variables)
fieldCollector = new FieldCollector[Ctx, Root](schema, queryAst, unmarshalledVariables, queryAst.sourceMapper, valueCollector, exceptionHandler)
res executeOperation(
queryAst,
operationName,
variables,
um,
operation,
queryAst.sourceMapper,
valueCollector,
fieldCollector,
marshaller,
unmarshalledVariables)
} yield res
tpe getOperationRootType(operation, queryAst.sourceMapper)
fields fieldCollector.collectFields(Vector.empty, tpe, Vector(operation))
} yield reduceQuerySafe(fieldCollector, valueCollector, unmarshalledVariables, tpe, fields, userContext) match {
case fut: Future[Ctx]
fut.flatMap(executeOperation(queryAst, operationName, variables, um, operation, queryAst.sourceMapper, valueCollector,
fieldCollector, marshaller, unmarshalledVariables, tpe, fields, _, root))
case ctx: Ctx @unchecked
executeOperation(queryAst, operationName, variables, um, operation, queryAst.sourceMapper, valueCollector,
fieldCollector, marshaller, unmarshalledVariables, tpe, fields, ctx, root)
}

executionResult match {
case Success(future) future
Expand All @@ -68,72 +118,59 @@ case class Executor[Ctx, Root](
}

def executeOperation[Input](
queryAst: ast.Document,
operationName: Option[String] = None,
inputVariables: Input,
inputUnmarshaller: InputUnmarshaller[Input],
operation: ast.OperationDefinition,
sourceMapper: Option[SourceMapper],
valueCollector: ValueCollector[Ctx, _],
fieldCollector: FieldCollector[Ctx, Root],
marshaller: ResultMarshaller,
variables: Map[String, VariableValue]): Try[Future[marshaller.Node]] =
for {
tpe getOperationRootType(operation, sourceMapper)
fields fieldCollector.collectFields(Vector.empty, tpe, Vector(operation))
} yield {
def doExecute(ctx: Ctx) = {
val middlewareCtx = MiddlewareQueryContext(ctx, this, queryAst, operationName, inputVariables, inputUnmarshaller)

try {
val middlewareVal = middleware map (m m.beforeQuery(middlewareCtx) m)

val resolver = new Resolver[Ctx](
marshaller,
middlewareCtx,
schema,
valueCollector,
variables,
fieldCollector,
ctx,
exceptionHandler,
deferredResolver,
sourceMapper,
deprecationTracker,
middlewareVal,
maxQueryDepth)

val result =
operation.operationType match {
case ast.OperationType.Query resolver.resolveFieldsPar(tpe, root, fields).asInstanceOf[Future[marshaller.Node]]
case ast.OperationType.Subscription resolver.resolveFieldsPar(tpe, root, fields).asInstanceOf[Future[marshaller.Node]]
case ast.OperationType.Mutation resolver.resolveFieldsSeq(tpe, root, fields).asInstanceOf[Future[marshaller.Node]]
}

if (middlewareVal.nonEmpty) {
def onAfter() =
middlewareVal foreach { case (v, m) m.afterQuery(v.asInstanceOf[m.QueryVal], middlewareCtx)}

result
.map { x onAfter(); x}
.recover { case e onAfter(); throw e}
} else result
} catch {
case NonFatal(error)
Future.failed(error)
}
}
queryAst: ast.Document,
operationName: Option[String] = None,
inputVariables: Input,
inputUnmarshaller: InputUnmarshaller[Input],
operation: ast.OperationDefinition,
sourceMapper: Option[SourceMapper],
valueCollector: ValueCollector[Ctx, _],
fieldCollector: FieldCollector[Ctx, Root],
marshaller: ResultMarshaller,
variables: Map[String, VariableValue],
tpe: ObjectType[Ctx, Root],
fields: Map[String, (ast.Field, Try[Vector[ast.Field]])],
ctx: Ctx,
root: Root): Future[marshaller.Node] = {
val middlewareCtx = MiddlewareQueryContext(ctx, this, queryAst, operationName, inputVariables, inputUnmarshaller)

try {
val middlewareVal = middleware map (m m.beforeQuery(middlewareCtx) m)

val resolver = new Resolver[Ctx](
marshaller,
middlewareCtx,
schema,
valueCollector,
variables,
fieldCollector,
ctx,
exceptionHandler,
deferredResolver,
sourceMapper,
deprecationTracker,
middlewareVal,
maxQueryDepth)

val result =
operation.operationType match {
case ast.OperationType.Query resolver.resolveFieldsPar(tpe, root, fields).asInstanceOf[Future[marshaller.Node]]
case ast.OperationType.Subscription resolver.resolveFieldsPar(tpe, root, fields).asInstanceOf[Future[marshaller.Node]]
case ast.OperationType.Mutation resolver.resolveFieldsSeq(tpe, root, fields).asInstanceOf[Future[marshaller.Node]]
}

if (queryReducers.nonEmpty)
reduceQuery(fieldCollector, valueCollector, variables, tpe, fields, queryReducers.toVector) match {
case future: Future[Ctx]
future
.map (newCtx doExecute(newCtx))
.recover {case error: Throwable throw QueryReducingError(error, exceptionHandler)}
.flatMap (identity)
case newCtx: Ctx @unchecked doExecute(newCtx)
}
else doExecute(userContext)
if (middlewareVal.nonEmpty) {
def onAfter() =
middlewareVal foreach { case (v, m) m.afterQuery(v.asInstanceOf[m.QueryVal], middlewareCtx)}

result
.map { x onAfter(); x}
.recover { case e onAfter(); throw e}
} else result
} catch {
case NonFatal(error)
Future.failed(error)
}
}

def getOperationRootType(operation: ast.OperationDefinition, sourceMapper: Option[SourceMapper]) = operation.operationType match {
Expand All @@ -147,13 +184,32 @@ case class Executor[Ctx, Root](
Failure(OperationSelectionError("Schema is not configured for subscriptions", exceptionHandler, sourceMapper, operation.position.toList))
}

// returns either new Ctx or future of it
private def reduceQuerySafe[Val](
fieldCollector: FieldCollector[Ctx, Root],
valueCollector: ValueCollector[Ctx, _],
variables: Map[String, VariableValue],
rootTpe: ObjectType[Ctx, Root],
fields: Map[String, (ast.Field, Try[Vector[ast.Field]])],
userContext: Ctx): Any =
if (queryReducers.nonEmpty)
reduceQuery(fieldCollector, valueCollector, variables, rootTpe, fields, queryReducers.toVector, userContext) match {
case future: Future[Ctx]
future
.map (newCtx newCtx)
.recover {case error: Throwable throw QueryReducingError(error, exceptionHandler)}
case newCtx: Ctx @unchecked Future.successful(newCtx)
}
else userContext

private def reduceQuery[Val](
fieldCollector: FieldCollector[Ctx, Val],
valueCollector: ValueCollector[Ctx, _],
variables: Map[String, VariableValue],
rootTpe: ObjectType[_, _],
fields: Map[String, (ast.Field, Try[Vector[ast.Field]])],
reducers: Vector[QueryReducer[Ctx, _]]): Any = {
reducers: Vector[QueryReducer[Ctx, _]],
userContext: Ctx): Any = {
// Using mutability here locally in order to reduce footprint
import scala.collection.mutable.ListBuffer

Expand Down Expand Up @@ -239,7 +295,7 @@ case class Executor[Ctx, Root](
case TryValue(value) Future.fromTry(value)
})

case (acc: Ctx@unchecked, (reducer, idx))
case (acc: Ctx @unchecked, (reducer, idx))
reducer.reduceCtx(reduced(idx).asInstanceOf[reducer.Acc], acc) match {
case FutureValue(future) future
case Value(value) value
Expand All @@ -258,10 +314,10 @@ object Executor {
def execute[Ctx, Root, Input](
schema: Schema[Ctx, Root],
queryAst: ast.Document,
userContext: Ctx = (),
root: Root = (),
operationName: Option[String] = None,
variables: Input = emptyMapVars,
root: Root = (),
userContext: Ctx = (),
queryValidator: QueryValidator = QueryValidator.default,
deferredResolver: DeferredResolver[Ctx] = DeferredResolver.empty,
exceptionHandler: Executor.ExceptionHandler = PartialFunction.empty,
Expand All @@ -270,7 +326,40 @@ object Executor {
maxQueryDepth: Option[Int] = None,
queryReducers: List[QueryReducer[Ctx, _]] = Nil
)(implicit executionContext: ExecutionContext, marshaller: ResultMarshaller, um: InputUnmarshaller[Input]): Future[marshaller.Node] =
Executor(schema, root, userContext, queryValidator, deferredResolver, exceptionHandler, deprecationTracker, middleware, maxQueryDepth, queryReducers).execute(queryAst, operationName, variables)
Executor(schema, queryValidator, deferredResolver, exceptionHandler, deprecationTracker, middleware, maxQueryDepth, queryReducers)
.execute(queryAst, userContext, root, operationName, variables)

def prepare[Ctx, Root, Input](
schema: Schema[Ctx, Root],
queryAst: ast.Document,
userContext: Ctx = (),
root: Root = (),
operationName: Option[String] = None,
variables: Input = emptyMapVars,
queryValidator: QueryValidator = QueryValidator.default,
deferredResolver: DeferredResolver[Ctx] = DeferredResolver.empty,
exceptionHandler: Executor.ExceptionHandler = PartialFunction.empty,
deprecationTracker: DeprecationTracker = DeprecationTracker.empty,
middleware: List[Middleware[Ctx]] = Nil,
maxQueryDepth: Option[Int] = None,
queryReducers: List[QueryReducer[Ctx, _]] = Nil
)(implicit executionContext: ExecutionContext, um: InputUnmarshaller[Input]): Future[PreparedQuery[Ctx, Root, Input]] =
Executor(schema, queryValidator, deferredResolver, exceptionHandler, deprecationTracker, middleware, maxQueryDepth, queryReducers)
.prepare(queryAst, userContext, root, operationName, variables)
}

case class HandledException(message: String, additionalFields: Map[String, ResultMarshaller#Node] = Map.empty)

class PreparedQuery[Ctx, Root, Input] private[execution] (
val queryAst: ast.Document,
val operation: ast.OperationDefinition,
val tpe: ObjectType[Ctx, Root],
val userContext: Ctx,
val root: Root,
val fields: Seq[PreparedField[Ctx, Root]],
execFn: (Ctx, Root, ResultMarshaller) Future[ResultMarshaller#Node]) {
def execute(userContext: Ctx = userContext, root: Root = root)(implicit marshaller: ResultMarshaller): Future[marshaller.Node] =
execFn(userContext, root, marshaller).asInstanceOf[Future[marshaller.Node]]
}

case class PreparedField[Ctx, Root](field: Field[Ctx, Root], args: Args)
11 changes: 9 additions & 2 deletions src/main/scala/sangria/execution/Resolver.scala
Expand Up @@ -88,7 +88,7 @@ class Resolver[Ctx](

def resolveError(e: Throwable) = {
try {
newUc map (_.onError(e))
newUc foreach (_.onError(e))
} catch {
case NonFatal(ee) ee.printStackTrace()
}
Expand Down Expand Up @@ -430,7 +430,14 @@ class Resolver[Ctx](
Result(errorReg, if (canceled) None else Some(marshaller.arrayNode(listBuilder.result())))
}

def resolveField(userCtx: Ctx, tpe: ObjectType[Ctx, _], path: Vector[String], value: Any, errors: ErrorRegistry, name: String, astFields: Vector[ast.Field]): (ErrorRegistry, Option[LeafAction[Ctx, Any]], Option[MappedCtxUpdate[Ctx, Any, Any]]) = {
def resolveField(
userCtx: Ctx,
tpe: ObjectType[Ctx, _],
path: Vector[String],
value: Any,
errors: ErrorRegistry,
name: String,
astFields: Vector[ast.Field]): (ErrorRegistry, Option[LeafAction[Ctx, Any]], Option[MappedCtxUpdate[Ctx, Any, Any]]) = {
val astField = astFields.head
val allFields = tpe.getField(schema, astField.name).asInstanceOf[Vector[Field[Ctx, Any]]]
val field = allFields.head
Expand Down

0 comments on commit 7bffbe2

Please sign in to comment.