Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Strawberry config to a typed dict #3365

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: major

Make schema config parameter a TypedDict
4 changes: 1 addition & 3 deletions docs/types/schema-configurations.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@ example below:
```python
import strawberry

from strawberry.schema.config import StrawberryConfig


@strawberry.type
class Query:
example_field: str


schema = strawberry.Schema(query=Query, config=StrawberryConfig(auto_camel_case=False))
schema = strawberry.Schema(query=Query, config={"auto_camel_case": False})
```

In this case we are disabling the auto camel casing feature, so your output schema
Expand Down
4 changes: 2 additions & 2 deletions strawberry/federation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from strawberry.enum import EnumDefinition
from strawberry.extensions import SchemaExtension
from strawberry.federation.schema_directives import ComposeDirective
from strawberry.schema.config import StrawberryConfig
from strawberry.schema.config import StrawberryConfigDict
from strawberry.schema.types.concrete_type import TypeMap
from strawberry.schema_directive import StrawberrySchemaDirective
from strawberry.union import StrawberryUnion
Expand All @@ -60,7 +60,7 @@ def __init__(
types: Iterable[Type] = (),
extensions: Iterable[Union[Type["SchemaExtension"], "SchemaExtension"]] = (),
execution_context_class: Optional[Type["GraphQLExecutionContext"]] = None,
config: Optional["StrawberryConfig"] = None,
config: Optional["StrawberryConfigDict"] = None,
scalar_overrides: Optional[
Dict[object, Union[Type, "ScalarWrapper", "ScalarDefinition"]]
] = None,
Expand Down
15 changes: 14 additions & 1 deletion strawberry/schema/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from __future__ import annotations

from dataclasses import InitVar, dataclass, field
from typing import Any, Callable
from typing import Any, Callable, TypedDict

from .name_converter import NameConverter


class StrawberryConfigDict(TypedDict, total=False):
auto_camel_case: bool
name_converter: NameConverter
default_resolver: Callable[[Any, str], object]
relay_max_results: int


@dataclass
class StrawberryConfig:
auto_camel_case: InitVar[bool] = None # pyright: reportGeneralTypeIssues=false
Expand All @@ -19,3 +26,9 @@ def __post_init__(
):
if auto_camel_case is not None:
self.name_converter.auto_camel_case = auto_camel_case

@classmethod
def from_dict(cls, data: StrawberryConfigDict | None) -> StrawberryConfig:
if not data:
return cls()
return cls(**data)
5 changes: 3 additions & 2 deletions strawberry/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from strawberry.enum import EnumDefinition
from strawberry.extensions import SchemaExtension
from strawberry.field import StrawberryField
from strawberry.schema.config import StrawberryConfigDict
from strawberry.type import StrawberryType
from strawberry.types import ExecutionResult
from strawberry.union import StrawberryUnion
Expand All @@ -77,7 +78,7 @@ def __init__(
types: Iterable[Union[Type, StrawberryType]] = (),
extensions: Iterable[Union[Type[SchemaExtension], SchemaExtension]] = (),
execution_context_class: Optional[Type[GraphQLExecutionContext]] = None,
config: Optional[StrawberryConfig] = None,
config: Optional[StrawberryConfigDict] = None,
scalar_overrides: Optional[
Dict[object, Union[Type, ScalarWrapper, ScalarDefinition]]
] = None,
Expand All @@ -89,7 +90,7 @@ def __init__(

self.extensions = extensions
self.execution_context_class = execution_context_class
self.config = config or StrawberryConfig()
self.config = StrawberryConfig.from_dict(config)

SCALAR_OVERRIDES_DICT_TYPE = Dict[
object, Union["ScalarWrapper", "ScalarDefinition"]
Expand Down
25 changes: 6 additions & 19 deletions tests/schema/test_camel_casing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import strawberry
from strawberry.schema.config import StrawberryConfig


def test_camel_case_is_on_by_default():
Expand Down Expand Up @@ -30,9 +29,7 @@ def test_can_set_camel_casing():
class Query:
example_field: str = "Example"

schema = strawberry.Schema(
query=Query, config=StrawberryConfig(auto_camel_case=True)
)
schema = strawberry.Schema(query=Query, config={"auto_camel_case": True})

query = """
{
Expand All @@ -55,9 +52,7 @@ def test_can_set_camel_casing_to_false():
class Query:
example_field: str = "Example"

schema = strawberry.Schema(
query=Query, config=StrawberryConfig(auto_camel_case=False)
)
schema = strawberry.Schema(query=Query, config={"auto_camel_case": False})

query = """
{
Expand All @@ -80,9 +75,7 @@ def test_can_set_camel_casing_to_false_uses_name():
class Query:
example_field: str = strawberry.field(name="exampleField")

schema = strawberry.Schema(
query=Query, config=StrawberryConfig(auto_camel_case=False)
)
schema = strawberry.Schema(query=Query, config={"auto_camel_case": False})

query = """
{
Expand All @@ -107,9 +100,7 @@ class Query:
def example_field(self) -> str:
return "ABC"

schema = strawberry.Schema(
query=Query, config=StrawberryConfig(auto_camel_case=False)
)
schema = strawberry.Schema(query=Query, config={"auto_camel_case": False})

query = """
{
Expand Down Expand Up @@ -162,9 +153,7 @@ class Query:
def example_field(self, example_input: str) -> str:
return example_input

schema = strawberry.Schema(
query=Query, config=StrawberryConfig(auto_camel_case=False)
)
schema = strawberry.Schema(query=Query, config={"auto_camel_case": False})

query = """
{
Expand Down Expand Up @@ -192,9 +181,7 @@ class Query:
def example_field(self, example_input: str) -> str:
return example_input

schema = strawberry.Schema(
query=Query, config=StrawberryConfig(auto_camel_case=False)
)
schema = strawberry.Schema(query=Query, config={"auto_camel_case": False})

query = """
{
Expand Down
3 changes: 1 addition & 2 deletions tests/schema/test_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import strawberry
from strawberry.directive import DirectiveLocation, DirectiveValue
from strawberry.extensions import SchemaExtension
from strawberry.schema.config import StrawberryConfig
from strawberry.type import get_object_definition
from strawberry.types.info import Info
from strawberry.utils.await_maybe import await_maybe
Expand Down Expand Up @@ -224,7 +223,7 @@ def replace(value: str, old: str, new: str):
schema = strawberry.Schema(
query=Query,
directives=[turn_uppercase, replace],
config=StrawberryConfig(auto_camel_case=False),
config={"auto_camel_case": False},
)

query = """query People($identified: Boolean!){
Expand Down
3 changes: 1 addition & 2 deletions tests/schema/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import strawberry
from strawberry.field import StrawberryField
from strawberry.printer import print_schema
from strawberry.schema.config import StrawberryConfig


def test_custom_field():
Expand Down Expand Up @@ -62,7 +61,7 @@ def user(self) -> User:

schema = strawberry.Schema(
query=Query,
config=StrawberryConfig(default_resolver=getitem),
config={"default_resolver": getitem},
)

query = "{ user { name } }"
Expand Down
3 changes: 1 addition & 2 deletions tests/schema/test_name_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from strawberry.directive import StrawberryDirective
from strawberry.enum import EnumDefinition, EnumValue
from strawberry.field import StrawberryField
from strawberry.schema.config import StrawberryConfig
from strawberry.schema.name_converter import NameConverter
from strawberry.schema_directive import Location, StrawberrySchemaDirective
from strawberry.type import StrawberryType
Expand Down Expand Up @@ -125,7 +124,7 @@ def print(self, enum: MyEnum) -> str:
schema = strawberry.Schema(
query=Query,
types=[MyScalar, Node],
config=StrawberryConfig(name_converter=AppendsNameConverter("X")),
config={"name_converter": AppendsNameConverter("X")},
)


Expand Down
Loading