diff --git a/src/main/scala/sangria/execution/Resolver.scala b/src/main/scala/sangria/execution/Resolver.scala index de6a7e8b..489f4588 100644 --- a/src/main/scala/sangria/execution/Resolver.scala +++ b/src/main/scala/sangria/execution/Resolver.scala @@ -43,8 +43,16 @@ class Resolver[Ctx]( def resolveFieldsPar(tpe: ObjectType[Ctx, _], value: Any, fields: CollectedFields)(scheme: ExecutionScheme): scheme.Result[Ctx, marshaller.Node] = { val actions = collectActionsPar(ExecutionPath.empty, tpe, value, fields, ErrorRegistry.empty, userContext) + val resolvedActionPar = resolveActionsPar(ExecutionPath.empty, tpe, actions, userContext, fields.namesOrdered) - handleScheme(processFinalResolve(resolveActionsPar(ExecutionPath.empty, tpe, actions, userContext, fields.namesOrdered)) map (_ → userContext), scheme) + val newCtx = resolvedActionPar match { + case Result(_, _, Some(ctx)) => ctx + case _ => userContext + } + + val finalResolve = processFinalResolve(resolvedActionPar) map (_ -> newCtx) + + handleScheme(finalResolve, scheme) } def resolveFieldsSeq(tpe: ObjectType[Ctx, _], value: Any, fields: CollectedFields)(scheme: ExecutionScheme): scheme.Result[Ctx, marshaller.Node] = { @@ -727,7 +735,7 @@ class Resolver[Ctx]( val simpleRes = resolvedValues.collect {case (af, r: Result) ⇒ af → r} val resSoFar = simpleRes.foldLeft(Result(errors, Some(marshaller.emptyMapNode(fieldsNamesOrdered)))) { - case (res, (astField, other)) ⇒ res addToMap (other, astField.outputName, isOptional(tpe, astField.name), path.add(astField, tpe), astField.location, res.errors) + case (res, (astField, other)) ⇒ res addToMap (other, astField.outputName, isOptional(tpe, astField.name), path.add(astField, tpe), astField.location, res.errors, other.userContext) } val complexRes = resolvedValues.collect{case (af, r: DeferredResult) ⇒ af → r} @@ -737,7 +745,7 @@ class Resolver[Ctx]( val allDeferred = complexRes.flatMap(_._2.deferred) val finalValue = Future.sequence(complexRes.map {case (astField, DeferredResult(_, future)) ⇒ future map (astField → _)}) map { results ⇒ results.foldLeft(resSoFar) { - case (res, (astField, other)) ⇒ res addToMap (other, astField.outputName, isOptional(tpe, astField.name), path.add(astField, tpe), astField.location, res.errors) + case (res, (astField, other)) ⇒ res addToMap (other, astField.outputName, isOptional(tpe, astField.name), path.add(astField, tpe), astField.location, res.errors, other.userContext) }.buildValue } @@ -868,7 +876,7 @@ class Resolver[Ctx]( Some(marshalScalarValue(coercedWithMiddleware, marshaller, scalar.name, scalar.scalarInfo)) } - }) + },Some(userCtx)) } catch { case NonFatal(e) ⇒ Result(ErrorRegistry(path, e), None) } @@ -886,7 +894,7 @@ class Resolver[Ctx]( None else Some(marshalEnumValue(coerced, marshaller, enum.name)) - }) + }, Some(userCtx)) } catch { case NonFatal(e) ⇒ Result(ErrorRegistry(path, e), None) } @@ -1113,6 +1121,18 @@ class Resolver[Ctx]( res.nextCtx, if (mAfter.nonEmpty) doAfterMiddlewareWithMap(res.mapFn) else res.mapFn, if (mError.nonEmpty) doErrorMiddleware else identity))) + + case res: MappedActionAndUpdateCtx[Ctx, Any @unchecked, Any @unchecked] => { + StandardFieldResolution( + errors, + res.action, + Some(MappedCtxUpdate( + res.nextCtx, + if (mAfter.nonEmpty) doAfterMiddlewareWithMap(res.mapFn) else res.mapFn, + if (mError.nonEmpty) doErrorMiddleware else identity) + )) + } + } res match { @@ -1195,7 +1215,7 @@ class Resolver[Ctx]( case class Defer(promise: Promise[(ChildDeferredContext, Any, Vector[Throwable])], deferred: Deferred[Any], complexity: Double, field: Field[_, _], astFields: Vector[ast.Field], args: Args) extends DeferredWithInfo case class Result(errors: ErrorRegistry, value: Option[Any /* Either marshaller.Node or marshaller.MapBuilder */], userContext: Option[Ctx] = None) extends Resolve { - def addToMap(other: Result, key: String, optional: Boolean, path: ExecutionPath, position: Option[AstLocation], updatedErrors: ErrorRegistry) = + def addToMap(other: Result, key: String, optional: Boolean, path: ExecutionPath, position: Option[AstLocation], updatedErrors: ErrorRegistry, updatedContext: Option[Ctx] = None) = copy( errors = if (!optional && other.value.isEmpty && other.errors.isEmpty) @@ -1206,7 +1226,9 @@ class Resolver[Ctx]( if (optional && other.value.isEmpty) value map (v ⇒ marshaller.addMapNodeElem(v.asInstanceOf[marshaller.MapBuilder], key, marshaller.nullNode, optional = false)) else - for {myVal ← value; otherVal ← other.value} yield marshaller.addMapNodeElem(myVal.asInstanceOf[marshaller.MapBuilder], key, otherVal.asInstanceOf[marshaller.Node], optional = false)) + for {myVal ← value; otherVal ← other.value} yield marshaller.addMapNodeElem(myVal.asInstanceOf[marshaller.MapBuilder], key, otherVal.asInstanceOf[marshaller.Node], optional = false), + userContext = if(updatedContext.isDefined) updatedContext else userContext + ) def nodeValue = value.asInstanceOf[Option[marshaller.Node]] def builderValue = value.asInstanceOf[Option[marshaller.MapBuilder]] @@ -1303,4 +1325,4 @@ trait DeferredWithInfo { def field: Field[_, _] def astFields: Vector[ast.Field] def args: Args -} \ No newline at end of file +} diff --git a/src/main/scala/sangria/schema/Context.scala b/src/main/scala/sangria/schema/Context.scala index 21d0ea94..a5ea1784 100644 --- a/src/main/scala/sangria/schema/Context.scala +++ b/src/main/scala/sangria/schema/Context.scala @@ -123,13 +123,26 @@ class UpdateCtx[Ctx, Val](val action: LeafAction[Ctx, Val], val nextCtx: Val ⇒ class MappedUpdateCtx[Ctx, Val, NewVal](val action: LeafAction[Ctx, Val], val nextCtx: Val ⇒ Ctx, val mapFn: Val ⇒ NewVal) extends Action[Ctx, NewVal] { override def map[NewNewVal](fn: NewVal ⇒ NewNewVal)(implicit ec: ExecutionContext): MappedUpdateCtx[Ctx, Val, NewNewVal] = - new MappedUpdateCtx[Ctx, Val, NewNewVal](action, nextCtx, v ⇒ fn(mapFn(v))) + new MappedUpdateCtx[Ctx, Val, NewNewVal](action, nextCtx, (v: Val) ⇒ fn(mapFn(v))) } object UpdateCtx { def apply[Ctx, Val](action: LeafAction[Ctx, Val])(newCtx: Val ⇒ Ctx): UpdateCtx[Ctx, Val] = new UpdateCtx(action, newCtx) } +class MappedActionAndUpdateCtx[Ctx, Val, NewVal](val action: LeafAction[Ctx, Val], val nextCtx: Val ⇒ Ctx, val mapFn: Val ⇒ NewVal) extends Action[Ctx, NewVal] { + override def map[NewNewVal](fn: NewVal ⇒ NewNewVal)(implicit ec: ExecutionContext): MappedActionAndUpdateCtx[Ctx, Val, NewNewVal] = + new MappedActionAndUpdateCtx[Ctx, Val, NewNewVal](action, nextCtx, (v: Val) ⇒ fn(mapFn(v))) +} + +object ActionAndUpdateCtx { + def apply[Ctx, Val](action: LeafAction[Ctx, Val])(newCtx: Val ⇒ Ctx): MappedActionAndUpdateCtx[Ctx, Val, Val] = + new MappedActionAndUpdateCtx(action, newCtx, identity) + + def apply[Ctx, Val, NewVal](action: LeafAction[Ctx, Val], convert: Val => NewVal)(newCtx: Val => Ctx): MappedActionAndUpdateCtx[Ctx, Val, NewVal] = + new MappedActionAndUpdateCtx(action, newCtx, convert) +} + private[sangria] case class SubscriptionValue[Ctx, Val, S[_]](source: Val, stream: SubscriptionStream[S]) extends LeafAction[Ctx, Val] { override def map[NewVal](fn: Val ⇒ NewVal)(implicit ec: ExecutionContext): SubscriptionValue[Ctx, NewVal, S] = throw new IllegalStateException("`map` is not supported subscription actions. Action is only intended for internal use.") diff --git a/src/test/scala/sangria/execution/ExecutorSpec.scala b/src/test/scala/sangria/execution/ExecutorSpec.scala index 02bb8206..1acfde6d 100644 --- a/src/test/scala/sangria/execution/ExecutorSpec.scala +++ b/src/test/scala/sangria/execution/ExecutorSpec.scala @@ -1067,5 +1067,98 @@ class ExecutorSpec extends WordSpec with Matchers with FutureResultSupport { "locations" → Vector(Map("line" → 1, "column" → (5 + offset))))))) } } + + "support for ActionAndUpdateCtx in queries" in { + import ExecutionScheme.Extended + + case class MyCtx(info: String) + + case class ComplexModel(name: String, surname: String) + case class BasicModel(name: String) + case class WrappedModel[T](value: T, meta: String) + + val BasicModelType = ObjectType[Unit, BasicModel]( + "BasicModel", + fields[Unit, BasicModel]( + Field("name", StringType, resolve = _.value.name) + ) + ) + + val ctx = MyCtx("EMPTY") + + def wrapped: WrappedModel[ComplexModel] = WrappedModel(ComplexModel("John", "Doe"), "META INFO") + + val QueryType = ObjectType("Query", fields[MyCtx, Unit]( + Field("basic", BasicModelType, + resolve = ctx => ActionAndUpdateCtx[MyCtx, WrappedModel[ComplexModel], BasicModel](wrapped, cm => BasicModel(s"${cm.value.name} ${cm.value.surname}")){ + wrapped => + ctx.ctx.copy(info = wrapped.meta) + }))) + + val schema = Schema(QueryType) + + val exceptionHandler = ExceptionHandler { + case (m, e: IllegalStateException) ⇒ HandledException(e.getMessage) + } + + val query = + graphql""" + query { + q1: basic { + name + } + } + """ + + val result = Executor.execute(schema, query, ctx, + exceptionHandler = exceptionHandler).await + + result.result.asInstanceOf[Map[String, Any]]("data") should be ( + Map("q1" → Map("name" -> "John Doe"))) + + result.ctx.info shouldEqual "META INFO" + } + + "support for ActionAndUpdateCtx in mutations" in { + import ExecutionScheme.Extended + + case class MyCtx(acc: String) + + val ctx = MyCtx("") + + case class BasicModel(name: String) + + val QueryType = ObjectType("Query", fields[MyCtx, Unit]( + Field("hello", StringType, resolve = _ ⇒ "world"))) + + val MutationType = ObjectType("Mutation", fields[MyCtx, Unit]( + Field("addModel", StringType, + arguments = Argument("name", StringType) :: Nil, + resolve = c => ActionAndUpdateCtx[MyCtx, BasicModel, String](Future(BasicModel(c.ctx.acc + c.arg[String]("name"))),_.name ){ + bm => + c.ctx.copy(bm.name+" ") + } + ))) + + val schema = Schema(QueryType, Some(MutationType)) + + val query = + graphql""" + mutation { + a1: addModel(name: "One") + a2: addModel(name: "Two") + a3: addModel(name: "Three") + } + """ + + val result = Executor.execute(schema, query, ctx).await + + result.result.asInstanceOf[Map[String, Any]]("data") should be ( + Map("a1" → "One", "a2" → "One Two", "a3" → "One Two Three")) + + result.ctx.acc shouldEqual "One Two Three " + } + + } }