Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Strawberry config to a typed dict #3365

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: major

Make schema config parameter a TypedDict
4 changes: 1 addition & 3 deletions docs/types/schema-configurations.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@ example below:
```python
import strawberry

from strawberry.schema.config import StrawberryConfig


@strawberry.type
class Query:
example_field: str


schema = strawberry.Schema(query=Query, config=StrawberryConfig(auto_camel_case=False))
schema = strawberry.Schema(query=Query, config={"auto_camel_case": False})
```

In this case we are disabling the auto camel casing feature, so your output schema
Expand Down
20 changes: 20 additions & 0 deletions strawberry/codemods/schema_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import annotations

import libcst as cst
import libcst.matchers as m
from libcst._nodes.expression import BaseExpression # noqa: TCH002
from libcst.codemod import VisitorBasedCodemodCommand
from libcst.codemod.visitors import RemoveImportsVisitor


class ConvertStrawberryConfigToDict(VisitorBasedCodemodCommand):
DESCRIPTION: str = "Converts StrawberryConfig(...) to dict"

@m.leave(m.Call(func=m.Name("StrawberryConfig")))
def leave_strawberry_config_call(
self, original_node: cst.Call, _: cst.Call
) -> BaseExpression:
RemoveImportsVisitor.remove_unused_import(
self.context, "strawberry.schema.config", "StrawberryConfig"
)
return cst.Call(func=cst.Name(value="dict"), args=original_node.args)
4 changes: 2 additions & 2 deletions strawberry/federation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from strawberry.enum import EnumDefinition
from strawberry.extensions import SchemaExtension
from strawberry.federation.schema_directives import ComposeDirective
from strawberry.schema.config import StrawberryConfig
from strawberry.schema.config import StrawberryConfigDict
from strawberry.schema.types.concrete_type import TypeMap
from strawberry.schema_directive import StrawberrySchemaDirective
from strawberry.union import StrawberryUnion
Expand All @@ -60,7 +60,7 @@ def __init__(
types: Iterable[Type] = (),
extensions: Iterable[Union[Type["SchemaExtension"], "SchemaExtension"]] = (),
execution_context_class: Optional[Type["GraphQLExecutionContext"]] = None,
config: Optional["StrawberryConfig"] = None,
config: Optional["StrawberryConfigDict"] = None,
scalar_overrides: Optional[
Dict[object, Union[Type, "ScalarWrapper", "ScalarDefinition"]]
] = None,
Expand Down
15 changes: 14 additions & 1 deletion strawberry/schema/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from __future__ import annotations

from dataclasses import InitVar, dataclass, field
from typing import Any, Callable
from typing import Any, Callable, TypedDict

from .name_converter import NameConverter


class StrawberryConfigDict(TypedDict, total=False):
auto_camel_case: bool
name_converter: NameConverter
default_resolver: Callable[[Any, str], object]
relay_max_results: int


@dataclass
class StrawberryConfig:
auto_camel_case: InitVar[bool] = None # pyright: reportGeneralTypeIssues=false
Expand All @@ -19,3 +26,9 @@ def __post_init__(
):
if auto_camel_case is not None:
self.name_converter.auto_camel_case = auto_camel_case

@classmethod
def from_dict(cls, data: StrawberryConfigDict | None) -> StrawberryConfig:
if not data:
return cls()
return cls(**data)
5 changes: 3 additions & 2 deletions strawberry/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from strawberry.enum import EnumDefinition
from strawberry.extensions import SchemaExtension
from strawberry.field import StrawberryField
from strawberry.schema.config import StrawberryConfigDict
from strawberry.type import StrawberryType
from strawberry.types import ExecutionResult
from strawberry.union import StrawberryUnion
Expand All @@ -77,7 +78,7 @@ def __init__(
types: Iterable[Union[Type, StrawberryType]] = (),
extensions: Iterable[Union[Type[SchemaExtension], SchemaExtension]] = (),
execution_context_class: Optional[Type[GraphQLExecutionContext]] = None,
config: Optional[StrawberryConfig] = None,
config: Optional[StrawberryConfigDict] = None,
scalar_overrides: Optional[
Dict[object, Union[Type, ScalarWrapper, ScalarDefinition]]
] = None,
Expand All @@ -89,7 +90,7 @@ def __init__(

self.extensions = extensions
self.execution_context_class = execution_context_class
self.config = config or StrawberryConfig()
self.config = StrawberryConfig.from_dict(config)

SCALAR_OVERRIDES_DICT_TYPE = Dict[
object, Union["ScalarWrapper", "ScalarDefinition"]
Expand Down
123 changes: 123 additions & 0 deletions tests/codemods/test_schema_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from libcst.codemod import CodemodTest

from strawberry.codemods.schema_config import ConvertStrawberryConfigToDict


class TestConvertConstantCommand(CodemodTest):
TRANSFORM = ConvertStrawberryConfigToDict

def test_update_config(self) -> None:
before = """
import strawberry
from strawberry.schema.config import StrawberryConfig

schema = strawberry.Schema(
query=Query, config=StrawberryConfig(auto_camel_case=False)
)
"""

after = """
import strawberry

schema = strawberry.Schema(
query=Query, config=dict(auto_camel_case=False)
)
"""

self.assertCodemod(
before,
after,
)

def test_update_config_default(self) -> None:
before = """
import strawberry
from strawberry.schema.config import StrawberryConfig

schema = strawberry.Schema(
query=Query, config=StrawberryConfig()
)
"""

after = """
import strawberry

schema = strawberry.Schema(
query=Query, config=dict()
)
"""

self.assertCodemod(
before,
after,
)

def test_update_config_with_two_args(self) -> None:
before = """
import strawberry
from strawberry.schema.config import StrawberryConfig

schema = strawberry.Schema(
query=Query,
config=StrawberryConfig(auto_camel_case=True, default_resolver=getitem)
)
"""

after = """
import strawberry

schema = strawberry.Schema(
query=Query,
config=dict(auto_camel_case=True, default_resolver=getitem)
)
"""

self.assertCodemod(
before,
after,
)

def test_update_config_declared_outside(self) -> None:
before = """
import strawberry
from strawberry.schema.config import StrawberryConfig

config = StrawberryConfig(auto_camel_case=True, default_resolver=getitem)

schema = strawberry.Schema(
query=Query,
config=config
)
"""

after = """
import strawberry

config = dict(auto_camel_case=True, default_resolver=getitem)

schema = strawberry.Schema(
query=Query,
config=config
)
"""

self.assertCodemod(
before,
after,
)

def test_update_config_declared_outside_not_used_in_module(self) -> None:
before = """
from strawberry.schema.config import StrawberryConfig

config = StrawberryConfig(auto_camel_case=True, default_resolver=getitem)
"""

after = """
config = dict(auto_camel_case=True, default_resolver=getitem)
"""

self.assertCodemod(
before,
after,
)
25 changes: 6 additions & 19 deletions tests/schema/test_camel_casing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import strawberry
from strawberry.schema.config import StrawberryConfig


def test_camel_case_is_on_by_default():
Expand Down Expand Up @@ -30,9 +29,7 @@ def test_can_set_camel_casing():
class Query:
example_field: str = "Example"

schema = strawberry.Schema(
query=Query, config=StrawberryConfig(auto_camel_case=True)
)
schema = strawberry.Schema(query=Query, config={"auto_camel_case": True})

query = """
{
Expand All @@ -55,9 +52,7 @@ def test_can_set_camel_casing_to_false():
class Query:
example_field: str = "Example"

schema = strawberry.Schema(
query=Query, config=StrawberryConfig(auto_camel_case=False)
)
schema = strawberry.Schema(query=Query, config={"auto_camel_case": False})

query = """
{
Expand All @@ -80,9 +75,7 @@ def test_can_set_camel_casing_to_false_uses_name():
class Query:
example_field: str = strawberry.field(name="exampleField")

schema = strawberry.Schema(
query=Query, config=StrawberryConfig(auto_camel_case=False)
)
schema = strawberry.Schema(query=Query, config={"auto_camel_case": False})

query = """
{
Expand All @@ -107,9 +100,7 @@ class Query:
def example_field(self) -> str:
return "ABC"

schema = strawberry.Schema(
query=Query, config=StrawberryConfig(auto_camel_case=False)
)
schema = strawberry.Schema(query=Query, config={"auto_camel_case": False})

query = """
{
Expand Down Expand Up @@ -162,9 +153,7 @@ class Query:
def example_field(self, example_input: str) -> str:
return example_input

schema = strawberry.Schema(
query=Query, config=StrawberryConfig(auto_camel_case=False)
)
schema = strawberry.Schema(query=Query, config={"auto_camel_case": False})

query = """
{
Expand Down Expand Up @@ -192,9 +181,7 @@ class Query:
def example_field(self, example_input: str) -> str:
return example_input

schema = strawberry.Schema(
query=Query, config=StrawberryConfig(auto_camel_case=False)
)
schema = strawberry.Schema(query=Query, config={"auto_camel_case": False})

query = """
{
Expand Down
3 changes: 1 addition & 2 deletions tests/schema/test_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import strawberry
from strawberry.directive import DirectiveLocation, DirectiveValue
from strawberry.extensions import SchemaExtension
from strawberry.schema.config import StrawberryConfig
from strawberry.type import get_object_definition
from strawberry.types.info import Info
from strawberry.utils.await_maybe import await_maybe
Expand Down Expand Up @@ -224,7 +223,7 @@ def replace(value: str, old: str, new: str):
schema = strawberry.Schema(
query=Query,
directives=[turn_uppercase, replace],
config=StrawberryConfig(auto_camel_case=False),
config={"auto_camel_case": False},
)

query = """query People($identified: Boolean!){
Expand Down
3 changes: 1 addition & 2 deletions tests/schema/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import strawberry
from strawberry.field import StrawberryField
from strawberry.printer import print_schema
from strawberry.schema.config import StrawberryConfig


def test_custom_field():
Expand Down Expand Up @@ -62,7 +61,7 @@ def user(self) -> User:

schema = strawberry.Schema(
query=Query,
config=StrawberryConfig(default_resolver=getitem),
config={"default_resolver": getitem},
)

query = "{ user { name } }"
Expand Down
3 changes: 1 addition & 2 deletions tests/schema/test_name_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from strawberry.directive import StrawberryDirective
from strawberry.enum import EnumDefinition, EnumValue
from strawberry.field import StrawberryField
from strawberry.schema.config import StrawberryConfig
from strawberry.schema.name_converter import NameConverter
from strawberry.schema_directive import Location, StrawberrySchemaDirective
from strawberry.type import StrawberryType
Expand Down Expand Up @@ -125,7 +124,7 @@ def print(self, enum: MyEnum) -> str:
schema = strawberry.Schema(
query=Query,
types=[MyScalar, Node],
config=StrawberryConfig(name_converter=AppendsNameConverter("X")),
config={"name_converter": AppendsNameConverter("X")},
)


Expand Down
Loading