Skip to content

Commit

Permalink
feat(coercer): handle list sequentially/concurrently coerced at decor…
Browse files Browse the repository at this point in the history
…ators level
  • Loading branch information
Maximilien-R committed Nov 11, 2020
1 parent fcdefb3 commit 2a1142a
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 6 deletions.
18 changes: 15 additions & 3 deletions tartiflette/coercers/outputs/compute.py
@@ -1,18 +1,25 @@
from functools import partial
from typing import Callable

from tartiflette.coercers.outputs.list_coercer import list_coercer
from tartiflette.coercers.outputs.list_coercer import (
list_coercer_concurrently,
list_coercer_sequentially,
)
from tartiflette.coercers.outputs.non_null_coercer import non_null_coercer

__all__ = ("get_output_coercer",)


def get_output_coercer(graphql_type: "GraphQLType") -> Callable:
def get_output_coercer(
graphql_type: "GraphQLType", concurrently: bool
) -> Callable:
"""
Computes and returns the output coercer to use for the filled in schema
type.
:param graphql_type: the schema type for which compute the coercer
:param concurrently: whether list should be coerced concurrently
:type graphql_type: GraphQLType
:type concurrently: bool
:return: the computed coercer wrap with directives if defined
:rtype: Callable
"""
Expand All @@ -22,7 +29,12 @@ def get_output_coercer(graphql_type: "GraphQLType") -> Callable:
wrapped_type = inner_type.wrapped_type
if inner_type.is_list_type:
wrapper_coercers.append(
partial(list_coercer, item_type=wrapped_type)
partial(
list_coercer_concurrently
if concurrently
else list_coercer_sequentially,
item_type=wrapped_type,
)
)
elif inner_type.is_non_null_type:
wrapper_coercers.append(non_null_coercer)
Expand Down
65 changes: 63 additions & 2 deletions tartiflette/coercers/outputs/list_coercer.py
@@ -1,15 +1,17 @@
import asyncio

from typing import Any, Callable, List

from tartiflette.coercers.common import Path
from tartiflette.coercers.outputs.null_coercer import null_coercer_wrapper
from tartiflette.resolver.factory import complete_value_catching_error
from tartiflette.utils.errors import extract_exceptions_from_results

__all__ = ("list_coercer",)
__all__ = ("list_coercer_sequentially", "list_coercer_concurrently")


@null_coercer_wrapper
async def list_coercer(
async def list_coercer_sequentially(
result: Any,
info: "ResolveInfo",
execution_context: "ExecutionContext",
Expand Down Expand Up @@ -65,3 +67,62 @@ async def list_coercer(
raise exceptions

return results


@null_coercer_wrapper
async def list_coercer_concurrently(
result: Any,
info: "ResolveInfo",
execution_context: "ExecutionContext",
field_nodes: List["FieldNode"],
path: "Path",
item_type: "GraphQLOutputType",
inner_coercer: Callable,
) -> List[Any]:
"""
Computes the value of a list.
:param result: resolved value
:param info: information related to the execution and the resolved field
:param execution_context: instance of the query execution context
:param field_nodes: AST nodes related to the resolved field
:param path: the path traveled until this resolver
:param item_type: GraphQLType of list items
:param inner_coercer: the pre-computed coercer to use on the result
:type result: Any
:type info: ResolveInfo
:type execution_context: ExecutionContext
:type field_nodes: List[FieldNode]
:type path: Path
:type item_type: GraphQLOutputType
:type inner_coercer: Callable
:return: the computed value
:rtype: List[Any]
"""
# pylint: disable=too-many-locals
if not isinstance(result, list):
raise TypeError(
"Expected Iterable, but did not find one for field "
f"{info.parent_type.name}.{info.field_name}."
)

results = await asyncio.gather(
*[
complete_value_catching_error(
item,
info,
execution_context,
field_nodes,
Path(path, index),
item_type,
inner_coercer,
)
for index, item in enumerate(result)
],
return_exceptions=True,
)

exceptions = extract_exceptions_from_results(results)
if exceptions:
raise exceptions

return results
5 changes: 5 additions & 0 deletions tartiflette/resolver/resolver.py
Expand Up @@ -38,23 +38,27 @@ def __init__(
schema_name: str = "default",
type_resolver: Optional[Callable] = None,
arguments_coercer: Optional[Callable] = None,
concurrently: Optional[bool] = None,
) -> None:
"""
:param name: name of the field to wrap
:param schema_name: name of the schema to which link the resolver
:param type_resolver: callable to use to resolve the type of an
abstract type
:param arguments_coercer: the callable to use to coerce field arguments
:param concurrently: whether or not list will be coerced concurrently
:type name: str
:type schema_name: str
:type type_resolver: Optional[Callable]
:type arguments_coercer: Optional[Callable]
:type concurrently: Optional[bool]
"""
self.name = name
self._type_resolver = type_resolver
self._implementation = None
self._schema_name = schema_name
self._arguments_coercer = arguments_coercer
self._concurrently = concurrently

def bake(self, schema: "GraphQLSchema") -> None:
"""
Expand All @@ -71,6 +75,7 @@ def bake(self, schema: "GraphQLSchema") -> None:
field = schema.get_field_by_name(self.name)
field.raw_resolver = self._implementation
field.query_arguments_coercer = self._arguments_coercer
field.query_concurrently = self._concurrently

field_wrapped_type = get_wrapped_type(
get_graphql_type(schema, field.gql_type)
Expand Down
5 changes: 5 additions & 0 deletions tartiflette/subscription/subscription.py
Expand Up @@ -40,19 +40,23 @@ def __init__(
name: str,
schema_name: str = "default",
arguments_coercer: Optional[Callable] = None,
concurrently: Optional[bool] = None,
) -> None:
"""
:param name: name of the subscription field
:param schema_name: name of the schema to which link the subscription
:param arguments_coercer: callable to use to coerce field arguments
:param concurrently: whether list should be coerced concurrently
:type name: str
:type schema_name: str
:type arguments_coercer: Optional[Callable]
:type concurrently: Optional[bool]
"""
self.name = name
self._implementation = None
self._schema_name = schema_name
self._arguments_coercer = arguments_coercer
self._concurrently = concurrently

def bake(self, schema: "GraphQLSchema") -> None:
"""
Expand Down Expand Up @@ -81,6 +85,7 @@ def bake(self, schema: "GraphQLSchema") -> None:

field.subscribe = self._implementation
field.subscription_arguments_coercer = self._arguments_coercer
field.subscription_concurrently = self._concurrently

def __call__(self, implementation: Callable) -> Callable:
"""
Expand Down
16 changes: 15 additions & 1 deletion tartiflette/types/field.py
Expand Up @@ -64,6 +64,11 @@ def __init__(
self.query_arguments_coercer: Optional[Callable] = None
self.subscription_arguments_coercer: Optional[Callable] = None

# Concurrently
self.concurrently: Optional[bool] = None
self.query_concurrently: Optional[bool] = None
self.subscription_concurrently: Optional[bool] = None

# Introspection attributes
self.isDeprecated: bool = False # pylint: disable=invalid-name
self.args: List["GraphQLArgument"] = []
Expand Down Expand Up @@ -160,6 +165,13 @@ def bake(
else:
self.arguments_coercer = schema.default_arguments_coercer

if self.subscription_concurrently is not None:
self.concurrently = self.subscription_concurrently
elif self.query_concurrently is not None:
self.concurrently = self.query_concurrently
else: # TODO: handle a default value at schema level
self.concurrently = True

# Directives
directives_definition = compute_directive_nodes(
schema, self.directives
Expand Down Expand Up @@ -192,7 +204,9 @@ def bake(
is_resolver=True,
with_default=True,
),
output_coercer=get_output_coercer(self.graphql_type),
output_coercer=get_output_coercer(
self.graphql_type, self.concurrently
),
)

for argument in self.arguments.values():
Expand Down
72 changes: 72 additions & 0 deletions tests/functional/regressions/issue278/test_issue457.py
@@ -0,0 +1,72 @@
import asyncio
import random

import pytest

from tartiflette import Resolver, create_engine

_BOOKS = [
{"id": 1, "title": "Book #1"},
{"id": 2, "title": "Book #2"},
{"id": 3, "title": "Book #3"},
{"id": 4, "title": "Book #4"},
{"id": 5, "title": "Book #5"},
{"id": 6, "title": "Book #6"},
{"id": 7, "title": "Book #7"},
{"id": 8, "title": "Book #8"},
]

_SDL = """
type Book {
id: Int!
title: String!
}
type Query {
books: [Book!]
}
"""


@pytest.mark.asyncio
async def test_issue_457_sequentially(random_schema_name):
@Resolver(
"Query.books", concurrently=False, schema_name=random_schema_name
)
async def test_query_books(parent, args, ctx, info):
return _BOOKS

books_parsing_order = []

@Resolver("Book.id", schema_name=random_schema_name)
async def test_book_id(parent, args, ctx, info):
await asyncio.sleep(random.randint(1, 10) / 10)
books_parsing_order.append(parent["id"])
return parent["id"]

engine = await create_engine(_SDL, schema_name=random_schema_name)
assert await engine.execute("{ books { id title } }") == {
"data": {"books": _BOOKS}
}
assert books_parsing_order == [book["id"] for book in _BOOKS]


@pytest.mark.asyncio
async def test_issue_457_concurrently(random_schema_name):
@Resolver("Query.books", concurrently=True, schema_name=random_schema_name)
async def test_query_books(parent, args, ctx, info):
return _BOOKS

books_parsing_order = []

@Resolver("Book.id", schema_name=random_schema_name)
async def test_book_id(parent, args, ctx, info):
await asyncio.sleep(random.randint(1, 10) / 10)
books_parsing_order.append(parent["id"])
return parent["id"]

engine = await create_engine(_SDL, schema_name=random_schema_name)
assert await engine.execute("{ books { id title } }") == {
"data": {"books": _BOOKS}
}
assert books_parsing_order != [book["id"] for book in _BOOKS]

0 comments on commit 2a1142a

Please sign in to comment.