Skip to content

Commit

Permalink
Executor.execute now returns Future with failure if error happene…
Browse files Browse the repository at this point in the history
…d before query execution. Closes #109
  • Loading branch information
OlegIlyenko committed Mar 5, 2016
1 parent 93bb579 commit f4fe3ff
Show file tree
Hide file tree
Showing 33 changed files with 292 additions and 184 deletions.
48 changes: 47 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,52 @@
## Upcoming

* `Executor.execute` now returns `Future` with failure if error happened before query execution (#109). It can be extremely helpful when you need to take some action or produce different result in case of error. Typical example is returning different HTTP status code.

**CAUTION: breaking change and action needed!** Since things like validation errors and errors in query reducers are now explicitly returned in `Future` failure and not as a successful result, you need to take some action to handle them. In order to migrate, all you need to do is following:

```scala
Executor.execute(schema, query).recover {
case error: ErrorWithResolver error.resolveError
}
```

`recover` function will make sure that all of the errors, that were previously handled internally in `Executor`, are now properly handled. **Code above will produce exactly the same result as before.** `resolveError` produces a valid GraphQL response JSON and will use custom exception handler, if you have provided one.

This new approach to error handling gives you much more flexibility. For example in most cases it makes a lot of sense to return 400 HTTP status code if query validation failed. It was not really possible to do this before. Now you able to do something like this (using playframefork in this particular example):

```scala
executor.execute(query, ...)
.map(Ok(_))
.recover {
case error: QueryAnalysisError BadRequest(error.resolveError)
case error: ErrorWithResolver InternalServerError(error.resolveError)
}
```

This code will produce status code 400 in case of any error caused by client (query validation, invalid operation name, etc.).

Errors that happened in a query reducer would be wrapped in `QueryReducingError`. Here is an example of returning custom status code in case of error in the query reducer:

```scala
val authReducer = QueryReducer.collectTags[MyContext, String] {
case Permission(name) name
} { (permissionNames, ctx)
if (ctx.isUserAuthorized(permissionNames)) ctx
else throw AuthorizationException("User is not authorized!")
}

Executor.execute(schema, queryAst, userContext = new MyContext, queryReducers = authReducer :: Nil)
.map(Ok(_))
.recover {
case QueryReducingError(error: AuthorizationException) Unauthorized(error.getMessage)
case error: QueryAnalysisError BadRequest(error.resolveError)
case error: ErrorWithResolver InternalServerError(error.resolveError)
}
```

HTTP status code would be 401 for unauthorized users.
* Detect name collisions with incompatible types during schema definition (#117)
* Introduced a type alias `Executor.ExceptionHandler` for exception handler partial function

## v0.5.2 (2016-02-28)

Expand Down Expand Up @@ -302,7 +348,7 @@
throw new IllegalArgumentException(s"Too complex query: max allowed complexity is 1000.0, but got $c")
else ()

val exceptionHandler: PartialFunction[(ResultMarshaller, Throwable), HandledException] = {
val exceptionHandler: Executor.ExceptionHandler = {
case (m, e: IllegalArgumentException) HandledException(e.getMessage)
}

Expand Down
33 changes: 26 additions & 7 deletions src/main/scala/sangria/execution/ExecutionError.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sangria.execution

import org.parboiled2.Position
import sangria.marshalling.ResultMarshaller
import sangria.parser.SourceMapper
import sangria.validation.{AstNodeLocation, Violation}

Expand All @@ -12,16 +13,34 @@ trait WithViolations extends UserFacingError{
def violations: Vector[Violation]
}

class ExecutionError(message: String, val sourceMapper: Option[SourceMapper] = None, val positions: List[Position] = Nil) extends Exception(message) with AstNodeLocation with UserFacingError {
trait ErrorWithResolver {
this: Throwable

def exceptionHandler: Executor.ExceptionHandler

def resolveError(implicit marshaller: ResultMarshaller): marshaller.Node =
new ResultResolver(marshaller, exceptionHandler).resolveError(this).asInstanceOf[marshaller.Node]
}

class ExecutionError(message: String, val exceptionHandler: Executor.ExceptionHandler, val sourceMapper: Option[SourceMapper] = None, val positions: List[Position] = Nil) extends Exception(message) with AstNodeLocation with UserFacingError with ErrorWithResolver {
override def simpleErrorMessage = super.getMessage
override def getMessage = super.getMessage + astLocation
}

case class VariableCoercionError(violations: Vector[Violation]) extends ExecutionError(
s"Error during variable coercion. Violations:\n\n${violations map (_.errorMessage) mkString "\n\n"}") with WithViolations
case class VariableCoercionError(violations: Vector[Violation], eh: Executor.ExceptionHandler) extends ExecutionError(
s"Error during variable coercion. Violations:\n\n${violations map (_.errorMessage) mkString "\n\n"}", eh) with WithViolations

case class AttributeCoercionError(violations: Vector[Violation], eh: Executor.ExceptionHandler) extends ExecutionError(
s"Error during attribute coercion. Violations:\n\n${violations map (_.errorMessage) mkString "\n\n"}", eh) with WithViolations

trait QueryAnalysisError extends ErrorWithResolver {
this: Throwable
}

case class ValidationError(violations: Vector[Violation], eh: Executor.ExceptionHandler) extends ExecutionError(
s"Query does not pass validation. Violations:\n\n${violations map (_.errorMessage) mkString "\n\n"}", eh) with WithViolations with QueryAnalysisError

case class ValidationError(violations: Vector[Violation]) extends ExecutionError(
s"Query does not pass validation. Violations:\n\n${violations map (_.errorMessage) mkString "\n\n"}") with WithViolations
case class QueryReducingError(cause: Throwable, exceptionHandler: Executor.ExceptionHandler) extends Exception(s"Query reducing error: ${cause.getMessage}", cause) with QueryAnalysisError

case class AttributeCoercionError(violations: Vector[Violation]) extends ExecutionError(
s"Error during attribute coercion. Violations:\n\n${violations map (_.errorMessage) mkString "\n\n"}") with WithViolations
case class OperationSelectionError(message: String, eh: Executor.ExceptionHandler, sm: Option[SourceMapper] = None, pos: List[Position] = Nil)
extends ExecutionError(message, eh, sm, pos) with QueryAnalysisError
143 changes: 79 additions & 64 deletions src/main/scala/sangria/execution/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import sangria.validation.QueryValidator
import InputUnmarshaller.emptyMapVars

import scala.concurrent.{ExecutionContext, Future}
import scala.util.control.NonFatal
import scala.util.{Success, Failure, Try}

case class Executor[Ctx, Root](
Expand All @@ -16,7 +17,7 @@ case class Executor[Ctx, Root](
userContext: Ctx = (),
queryValidator: QueryValidator = QueryValidator.default,
deferredResolver: DeferredResolver[Ctx] = DeferredResolver.empty,
exceptionHandler: PartialFunction[(ResultMarshaller, Throwable), HandledException] = PartialFunction.empty,
exceptionHandler: Executor.ExceptionHandler = PartialFunction.empty,
deprecationTracker: DeprecationTracker = DeprecationTracker.empty,
middleware: List[Middleware[Ctx]] = Nil,
maxQueryDepth: Option[Int] = None,
Expand All @@ -29,14 +30,14 @@ case class Executor[Ctx, Root](
val violations = queryValidator.validateQuery(schema, queryAst)

if (violations.nonEmpty)
Future.successful(new ResultResolver(marshaller, exceptionHandler).resolveError(ValidationError(violations)).asInstanceOf[marshaller.Node])
else {
val valueCollector = new ValueCollector[Ctx, Input](schema, variables, queryAst.sourceMapper, deprecationTracker, userContext)(um)
Future.failed(ValidationError(violations, exceptionHandler))
else {
val valueCollector = new ValueCollector[Ctx, Input](schema, variables, queryAst.sourceMapper, deprecationTracker, userContext, exceptionHandler)(um)

val executionResult = for {
operation getOperation(queryAst, operationName)
unmarshalledVariables valueCollector.getVariableValues(operation.variables)
fieldCollector = new FieldCollector[Ctx, Root](schema, queryAst, unmarshalledVariables, queryAst.sourceMapper, valueCollector)
fieldCollector = new FieldCollector[Ctx, Root](schema, queryAst, unmarshalledVariables, queryAst.sourceMapper, valueCollector, exceptionHandler)
res executeOperation(
queryAst,
operationName,
Expand All @@ -52,18 +53,18 @@ case class Executor[Ctx, Root](

executionResult match {
case Success(future) future
case Failure(error) Future.successful(new ResultResolver(marshaller, exceptionHandler).resolveError(error).asInstanceOf[marshaller.Node])
case Failure(error) Future.failed(error)
}
}
}

def getOperation(document: ast.Document, operationName: Option[String]): Try[ast.OperationDefinition] =
if (document.operations.size != 1 && operationName.isEmpty)
Failure(new ExecutionError("Must provide operation name if query contains multiple operations"))
Failure(OperationSelectionError("Must provide operation name if query contains multiple operations", exceptionHandler))
else {
val operation = operationName flatMap (opName document.operations get Some(opName)) orElse document.operations.values.headOption

operation map (Success(_)) getOrElse Failure(new ExecutionError(s"Unknown operation name: ${operationName.get}"))
operation map (Success(_)) getOrElse Failure(OperationSelectionError(s"Unknown operation name: ${operationName.get}", exceptionHandler))
}

def executeOperation[Input](
Expand All @@ -84,58 +85,66 @@ case class Executor[Ctx, Root](
def doExecute(ctx: Ctx) = {
val middlewareCtx = MiddlewareQueryContext(ctx, this, queryAst, operationName, inputVariables, inputUnmarshaller)

val middlewareVal = middleware map (m m.beforeQuery(middlewareCtx) m)

val resolver = new Resolver[Ctx](
marshaller,
middlewareCtx,
schema,
valueCollector,
variables,
fieldCollector,
ctx,
exceptionHandler,
deferredResolver,
sourceMapper,
deprecationTracker,
middlewareVal,
maxQueryDepth)

val result =
operation.operationType match {
case ast.OperationType.Query resolver.resolveFieldsPar(tpe, root, fields).asInstanceOf[Future[marshaller.Node]]
case ast.OperationType.Subscription resolver.resolveFieldsPar(tpe, root, fields).asInstanceOf[Future[marshaller.Node]]
case ast.OperationType.Mutation resolver.resolveFieldsSeq(tpe, root, fields).asInstanceOf[Future[marshaller.Node]]
}

if (middlewareVal.nonEmpty) {
def onAfter() =
middlewareVal foreach { case (v, m) m.afterQuery(v.asInstanceOf[m.QueryVal], middlewareCtx)}

result
.map { x onAfter(); x}
.recover { case e onAfter(); throw e}
} else result
try {
val middlewareVal = middleware map (m m.beforeQuery(middlewareCtx) m)

val resolver = new Resolver[Ctx](
marshaller,
middlewareCtx,
schema,
valueCollector,
variables,
fieldCollector,
ctx,
exceptionHandler,
deferredResolver,
sourceMapper,
deprecationTracker,
middlewareVal,
maxQueryDepth)

val result =
operation.operationType match {
case ast.OperationType.Query resolver.resolveFieldsPar(tpe, root, fields).asInstanceOf[Future[marshaller.Node]]
case ast.OperationType.Subscription resolver.resolveFieldsPar(tpe, root, fields).asInstanceOf[Future[marshaller.Node]]
case ast.OperationType.Mutation resolver.resolveFieldsSeq(tpe, root, fields).asInstanceOf[Future[marshaller.Node]]
}

if (middlewareVal.nonEmpty) {
def onAfter() =
middlewareVal foreach { case (v, m) m.afterQuery(v.asInstanceOf[m.QueryVal], middlewareCtx)}

result
.map { x onAfter(); x}
.recover { case e onAfter(); throw e}
} else result
} catch {
case NonFatal(error)
Future.failed(error)
}
}

if (queryReducers.nonEmpty)
reduceQuery(fieldCollector, valueCollector, variables, tpe, fields, queryReducers.toVector) match {
case future: Future[Ctx]
future
.map (newCtx doExecute(newCtx))
.recover {case error: Throwable Future.successful(new ResultResolver(marshaller, exceptionHandler).resolveError(error).asInstanceOf[marshaller.Node])}
.recover {case error: Throwable throw QueryReducingError(error, exceptionHandler)}
.flatMap (identity)
case newCtx: Ctx @unchecked doExecute(newCtx)
}
else doExecute(userContext)
}

def getOperationRootType(operation: ast.OperationDefinition, sourceMapper: Option[SourceMapper]) = operation.operationType match {
case ast.OperationType.Query Success(schema.query)
case ast.OperationType.Mutation schema.mutation map (Success(_)) getOrElse
Failure(new ExecutionError("Schema is not configured for mutations", sourceMapper, operation.position.toList))
case ast.OperationType.Subscription schema.subscription map (Success(_)) getOrElse
Failure(new ExecutionError("Schema is not configured for subscriptions", sourceMapper, operation.position.toList))
case ast.OperationType.Query
Success(schema.query)
case ast.OperationType.Mutation
schema.mutation map (Success(_)) getOrElse
Failure(OperationSelectionError("Schema is not configured for mutations", exceptionHandler, sourceMapper, operation.position.toList))
case ast.OperationType.Subscription
schema.subscription map (Success(_)) getOrElse
Failure(OperationSelectionError("Schema is not configured for subscriptions", exceptionHandler, sourceMapper, operation.position.toList))
}

private def reduceQuery[Val](
Expand Down Expand Up @@ -167,7 +176,7 @@ case class Executor[Ctx, Root](
val newPath = path :+ astField.outputName
val childReduced = loop(newPath, field.fieldType, fields)

for (i 0 until reducers.size) {
for (i reducers.indices) {
val reducer = reducers(i)

acc(i) = reducer.reduceField[Any](
Expand Down Expand Up @@ -205,7 +214,7 @@ case class Executor[Ctx, Root](
val path = Vector(astField.outputName)
val childReduced = loop(path, field.fieldType, astFields)

for (i 0 until reducers.size) {
for (i reducers.indices) {
val reducer = reducers(i)

acc(i) = reducer.reduceField(
Expand All @@ -220,26 +229,32 @@ case class Executor[Ctx, Root](
case (acc, _) acc
}

// Unsafe part to avoid addition boxing in order to reduce the footprint
reducers.zipWithIndex.foldLeft(userContext: Any) {
case (acc: Future[Ctx], (reducer, idx))
acc.flatMap(a reducer.reduceCtx(reduced(idx).asInstanceOf[reducer.Acc], a) match {
case FutureValue(future) future
case Value(value) Future.successful(value)
case TryValue(value) Future.fromTry(value)
})

case (acc: Ctx @unchecked, (reducer, idx))
reducer.reduceCtx(reduced(idx).asInstanceOf[reducer.Acc], acc) match {
case FutureValue(future) future
case Value(value) value
case TryValue(value) value.get
}
try {
// Unsafe part to avoid addition boxing in order to reduce the footprint
reducers.zipWithIndex.foldLeft(userContext: Any) {
case (acc: Future[Ctx], (reducer, idx))
acc.flatMap(a reducer.reduceCtx(reduced(idx).asInstanceOf[reducer.Acc], a) match {
case FutureValue(future) future
case Value(value) Future.successful(value)
case TryValue(value) Future.fromTry(value)
})

case (acc: Ctx@unchecked, (reducer, idx))
reducer.reduceCtx(reduced(idx).asInstanceOf[reducer.Acc], acc) match {
case FutureValue(future) future
case Value(value) value
case TryValue(value) value.get
}
}
} catch {
case NonFatal(error) Future.failed(error)
}
}
}

object Executor {
type ExceptionHandler = PartialFunction[(ResultMarshaller, Throwable), HandledException]

def execute[Ctx, Root, Input](
schema: Schema[Ctx, Root],
queryAst: ast.Document,
Expand All @@ -249,7 +264,7 @@ object Executor {
userContext: Ctx = (),
queryValidator: QueryValidator = QueryValidator.default,
deferredResolver: DeferredResolver[Ctx] = DeferredResolver.empty,
exceptionHandler: PartialFunction[(ResultMarshaller, Throwable), HandledException] = PartialFunction.empty,
exceptionHandler: Executor.ExceptionHandler = PartialFunction.empty,
deprecationTracker: DeprecationTracker = DeprecationTracker.empty,
middleware: List[Middleware[Ctx]] = Nil,
maxQueryDepth: Option[Int] = None,
Expand Down
Loading

0 comments on commit f4fe3ff

Please sign in to comment.