From 9cb6aec2cb9e74863e0ec2aade3bd5d0b5828da3 Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Tue, 6 Jun 2023 15:35:08 -0300 Subject: [PATCH] feat: new input_mutation field (#2580) * feat: new input_mutation field * refactor: use future annotations * fix: fix mypy issues * refactor: correct example in the release file * refactor: use field extensions instead of a StrawberryField subclass * Update docs/general/mutations.md Co-authored-by: Erik Wrede * refactor: we don't need field_class argument anymore * Apply suggestions from code review Co-authored-by: Jonathan Kim * refactor: modify arguments directly on the field * refactor: retrieve directives directly from the argument * refactor: remove input_mutation and move InputMutationExtension to a new module * Apply suggestions from code review Co-authored-by: Patrick Arminio * refactor: use capitalize_first --------- Co-authored-by: Erik Wrede Co-authored-by: Jonathan Kim Co-authored-by: Patrick Arminio --- RELEASE.md | 60 +++++++ docs/general/mutations.md | 50 ++++++ strawberry/field_extensions/__init__.py | 5 + strawberry/field_extensions/input_mutation.py | 95 +++++++++++ strawberry/schema/schema_converter.py | 17 +- tests/http/test_input_mutation.py | 149 ++++++++++++++++++ 6 files changed, 375 insertions(+), 1 deletion(-) create mode 100644 RELEASE.md create mode 100644 strawberry/field_extensions/__init__.py create mode 100644 strawberry/field_extensions/input_mutation.py create mode 100644 tests/http/test_input_mutation.py diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..8e11f3b7f9 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,60 @@ +Release type: minor + +This release adds a new field extension called `InputMutationExtension`, which makes +it easier to create mutations that receive a single input type called `input`, +while still being able to define the arguments of that input on the resolver itself. + +The following example: + +```python +import strawberry +from strawberry.field_extensions import InputMutationExtension + + +@strawberry.type +class Fruit: + id: strawberry.ID + name: str + weight: float + + +@strawberry.type +class Mutation: + @strawberry.mutation(extensions=[InputMutationExtension()]) + def update_fruit_weight( + self, + info: Info, + id: strawberry.ID, + weight: Annotated[ + float, + strawberry.argument(description="The fruit's new weight in grams"), + ], + ) -> Fruit: + fruit = ... # retrieve the fruit with the given ID + fruit.weight = weight + ... # maybe save the fruit in the database + return fruit +``` + +Would generate a schema like this: + +```graphql +input UpdateFruitInput { + id: ID! + + """ + The fruit's new weight in grams + """ + weight: Float! +} + +type Fruit { + id: ID! + name: String! + weight: Float! +} + +type Mutation { + updateFruitWeight(input: UpdateFruitInput!): Fruit! +} +``` diff --git a/docs/general/mutations.md b/docs/general/mutations.md index df4149838d..cb9058dc81 100644 --- a/docs/general/mutations.md +++ b/docs/general/mutations.md @@ -84,3 +84,53 @@ type Mutation { Mutations with void-result go against [GQL best practices](https://graphql-rules.com/rules/mutation-payload) + +### The Input Mutation Extension + +It is usually useful to use a pattern of defining a mutation that receives a single +[input type](./input-types) argument called `input`. + +Strawberry provides a helper to create a mutation that automatically +creates an input type for you, whose attributes are the same as the args in the resolver. + +For example, suppose we want the mutation defined in the section above to be an +input mutation. We can add the `InputMutationExtension` to the field like this: + +```python +from strawberry.field_extensions import InputMutationExtension + + +@strawberry.type +class Mutation: + @strawberry.mutation(extensions=[InputMutationExtension()]) + def update_fruit_weight( + self, + info: Info, + id: strawberry.ID, + weight: Annotated[ + float, + strawberry.argument(description="The fruit's new weight in grams"), + ], + ) -> Fruit: + fruit = ... # retrieve the fruit with the given ID + fruit.weight = weight + ... # maybe save the fruit in the database + return fruit +``` + +That would generate a schema like this: + +```graphql +input UpdateFruitWeightInput { + id: ID! + + """ + The fruit's new weight in grams + """ + weight: Float! +} + +type Mutation { + updateFruitWeight(input: UpdateFruitWeightInput!): Fruit! +} +``` diff --git a/strawberry/field_extensions/__init__.py b/strawberry/field_extensions/__init__.py new file mode 100644 index 0000000000..80c5fa97cc --- /dev/null +++ b/strawberry/field_extensions/__init__.py @@ -0,0 +1,5 @@ +from .input_mutation import InputMutationExtension + +__all__ = [ + "InputMutationExtension", +] diff --git a/strawberry/field_extensions/input_mutation.py b/strawberry/field_extensions/input_mutation.py new file mode 100644 index 0000000000..9bb1b4fb70 --- /dev/null +++ b/strawberry/field_extensions/input_mutation.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, + Dict, + TypeVar, +) + +import strawberry +from strawberry.annotation import StrawberryAnnotation +from strawberry.arguments import StrawberryArgument +from strawberry.extensions.field_extension import ( + AsyncExtensionResolver, + FieldExtension, + SyncExtensionResolver, +) +from strawberry.field import StrawberryField +from strawberry.utils.str_converters import capitalize_first, to_camel_case + +if TYPE_CHECKING: + from strawberry.types.info import Info + +_T = TypeVar("_T") + + +class InputMutationExtension(FieldExtension): + def apply(self, field: StrawberryField) -> None: + resolver = field.base_resolver + assert resolver + + name = field.graphql_name or to_camel_case(resolver.name) + type_dict: Dict[str, Any] = { + "__doc__": f"Input data for `{name}` mutation", + "__annotations__": {}, + } + annotations = resolver.wrapped_func.__annotations__ + + for arg in resolver.arguments: + arg_field = StrawberryField( + python_name=arg.python_name, + graphql_name=arg.graphql_name, + description=arg.description, + default=arg.default, + type_annotation=arg.type_annotation, + directives=tuple(arg.directives), + ) + type_dict[arg_field.python_name] = arg_field + type_dict["__annotations__"][arg_field.python_name] = annotations[ + arg.python_name + ] + + caps_name = capitalize_first(name) + new_type = strawberry.input(type(f"{caps_name}Input", (), type_dict)) + field.arguments = [ + StrawberryArgument( + python_name="input", + graphql_name=None, + type_annotation=StrawberryAnnotation( + new_type, + namespace=resolver._namespace, + ), + description=type_dict["__doc__"], + ) + ] + + def resolve( + self, + next_: SyncExtensionResolver, + source: Any, + info: Info, + **kwargs: Any, + ) -> Any: + input_args = kwargs.pop("input") + return next_( + source, + info, + **kwargs, + **vars(input_args), + ) + + async def resolve_async( + self, + next_: AsyncExtensionResolver, + source: Any, + info: Info, + **kwargs: Any, + ) -> Any: + input_args = kwargs.pop("input") + return await next_( + source, + info, + **kwargs, + **vars(input_args), + ) diff --git a/strawberry/schema/schema_converter.py b/strawberry/schema/schema_converter.py index 7554c9a3d4..87a646a778 100644 --- a/strawberry/schema/schema_converter.py +++ b/strawberry/schema/schema_converter.py @@ -471,9 +471,24 @@ def _get_arguments( info: Info, kwargs: Any, ) -> Tuple[List[Any], Dict[str, Any]]: + # FIXME: An extension might have changed the resolver arguments, + # but we need them here since we are calling it. + # This is a bit of a hack, but it's the easiest way to get the arguments + # This happens in mutation.InputMutationExtension + field_arguments = field.arguments[:] + if field.base_resolver: + existing = {arg.python_name for arg in field_arguments} + field_arguments.extend( + [ + arg + for arg in field.base_resolver.arguments + if arg.python_name not in existing + ] + ) + kwargs = convert_arguments( kwargs, - field.arguments, + field_arguments, scalar_registry=self.scalar_registry, config=self.config, ) diff --git a/tests/http/test_input_mutation.py b/tests/http/test_input_mutation.py new file mode 100644 index 0000000000..d62cdcf963 --- /dev/null +++ b/tests/http/test_input_mutation.py @@ -0,0 +1,149 @@ +import textwrap +from typing_extensions import Annotated + +import strawberry +from strawberry.field_extensions import InputMutationExtension +from strawberry.schema_directive import Location, schema_directive +from strawberry.types import Info + + +@schema_directive( + locations=[Location.FIELD_DEFINITION], + name="some_directive", +) +class SomeDirective: + some: str + directive: str + + +@strawberry.type +class Fruit: + name: str + color: str + + +@strawberry.type +class Query: + @strawberry.mutation(extensions=[InputMutationExtension()]) + def create_fruit( + self, + info: Info, + name: str, + color: Annotated[ + str, + strawberry.argument( + description="The color of the fruit", + directives=[SomeDirective(some="foo", directive="bar")], + ), + ], + ) -> Fruit: + return Fruit( + name=name, + color=color, + ) + + @strawberry.mutation(extensions=[InputMutationExtension()]) + async def create_fruit_async( + self, + info: Info, + name: str, + color: Annotated[str, object()], + ) -> Fruit: + return Fruit( + name=name, + color=color, + ) + + +schema = strawberry.Schema(query=Query) + + +def test_schema(): + expected = ''' + directive @some_directive(some: String!, directive: String!) on FIELD_DEFINITION + + input CreateFruitAsyncInput { + name: String! + color: String! + } + + input CreateFruitInput { + name: String! + + """The color of the fruit""" + color: String! @some_directive(some: "foo", directive: "bar") + } + + type Fruit { + name: String! + color: String! + } + + type Query { + createFruit( + """Input data for `createFruit` mutation""" + input: CreateFruitInput! + ): Fruit! + createFruitAsync( + """Input data for `createFruitAsync` mutation""" + input: CreateFruitAsyncInput! + ): Fruit! + } + ''' + assert str(schema).strip() == textwrap.dedent(expected).strip() + + +def test_input_mutation(): + result = schema.execute_sync( + """ + query TestQuery ($input: CreateFruitInput!) { + createFruit (input: $input) { + ... on Fruit { + name + color + } + } + } + """, + variable_values={ + "input": { + "name": "Dragonfruit", + "color": "red", + } + }, + ) + assert result.errors is None + assert result.data == { + "createFruit": { + "name": "Dragonfruit", + "color": "red", + }, + } + + +async def test_input_mutation_async(): + result = await schema.execute( + """ + query TestQuery ($input: CreateFruitAsyncInput!) { + createFruitAsync (input: $input) { + ... on Fruit { + name + color + } + } + } + """, + variable_values={ + "input": { + "name": "Dragonfruit", + "color": "red", + } + }, + ) + assert result.errors is None + assert result.data == { + "createFruitAsync": { + "name": "Dragonfruit", + "color": "red", + }, + }