Skip to content

Commit

Permalink
add support for default values with a field that is a pydantic model (#…
Browse files Browse the repository at this point in the history
…3499)

* add support for default values

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove print schema

* just modify code in pydantic

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix url that seems to be updated

* remove redundant cast

* avoid pydantic imports

* add release md

* fix import

* try and remove tag requiring pydantic v2

* remove additional import check

---------

Co-authored-by: Patrick Pease <ppease@thezebra.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Jun 10, 2024
1 parent b7f2881 commit a22a383
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 14 deletions.
4 changes: 4 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Release type: patch

Fixes a bug where pydantic models as the default value for an input did not print the proper schema.
See [this issue](https://github.com/strawberry-graphql/strawberry/issues/3285).
6 changes: 6 additions & 0 deletions strawberry/experimental/pydantic/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ def get_basic_type(self, type_: Any) -> Type[Any]:

return type_

def model_dump(self, model_instance: BaseModel) -> Dict[Any, Any]:
return model_instance.model_dump()


class PydanticV1Compat:
@property
Expand Down Expand Up @@ -235,6 +238,9 @@ def get_basic_type(self, type_: Any) -> Type[Any]:

return type_

def model_dump(self, model_instance: BaseModel) -> Dict[Any, Any]:
return model_instance.dict()


class PydanticCompat:
def __init__(self, is_v2: bool) -> None:
Expand Down
2 changes: 1 addition & 1 deletion strawberry/experimental/pydantic/object_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _build_dataclass_creation_fields(
graphql_name=graphql_name,
# always unset because we use default_factory instead
default=dataclasses.MISSING,
default_factory=get_default_factory_for_field(field),
default_factory=get_default_factory_for_field(field, compat=compat),
type_annotation=StrawberryAnnotation.from_annotation(field_type),
description=field.description,
deprecation_reason=(
Expand Down
13 changes: 10 additions & 3 deletions strawberry/experimental/pydantic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
cast,
)

from pydantic import BaseModel

from strawberry.experimental.pydantic._compat import (
CompatModelField,
PydanticCompat,
Expand All @@ -33,7 +35,6 @@
)

if TYPE_CHECKING:
from pydantic import BaseModel
from pydantic.typing import NoArgAnyCallable


Expand Down Expand Up @@ -72,6 +73,7 @@ def to_tuple(self) -> Tuple[str, Type, dataclasses.Field]:

def get_default_factory_for_field(
field: CompatModelField,
compat: PydanticCompat,
) -> Union[NoArgAnyCallable, dataclasses._MISSING_TYPE]:
"""
Gets the default factory for a pydantic field.
Expand Down Expand Up @@ -105,9 +107,14 @@ def get_default_factory_for_field(
return default_factory

# if we have a default, we should return it

if has_default:
return lambda: smart_deepcopy(default)
# if the default value is a pydantic base model
# we should return the serialized version of that default for
# printing the value.
if isinstance(default, BaseModel):
return lambda: compat.model_dump(default)
else:
return lambda: smart_deepcopy(default)

# if we don't have default or default_factory, but the field is not required,
# we should return a factory that returns None
Expand Down
80 changes: 80 additions & 0 deletions tests/experimental/pydantic/schema/test_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,83 @@ def a(self) -> PydanticUser:
"""

assert print_schema(schema) == textwrap.dedent(expected).strip()


def test_v2_input_with_nonscalar_default():
class NonScalarType(pydantic.BaseModel):
id: int = 10
nullable_field: Optional[int] = None

class Owning(pydantic.BaseModel):
non_scalar_type: NonScalarType = NonScalarType()
id: int = 10

@strawberry.experimental.pydantic.type(
model=NonScalarType, all_fields=True, is_input=True
)
class NonScalarTypeInput: ...

@strawberry.experimental.pydantic.type(model=Owning, all_fields=True, is_input=True)
class OwningInput: ...

@strawberry.type
class ExampleOutput:
owning_id: int
non_scalar_id: int
non_scalar_nullable_field: Optional[int]

@strawberry.type
class Query:
@strawberry.field()
def test(self, x: OwningInput) -> ExampleOutput:
return ExampleOutput(
owning_id=x.id,
non_scalar_id=x.non_scalar_type.id,
non_scalar_nullable_field=x.non_scalar_type.nullable_field,
)

schema = strawberry.Schema(Query)

expected = """
type ExampleOutput {
owningId: Int!
nonScalarId: Int!
nonScalarNullableField: Int
}
input NonScalarTypeInput {
id: Int! = 10
nullableField: Int = null
}
input OwningInput {
nonScalarType: NonScalarTypeInput! = {id: 10}
id: Int! = 10
}
type Query {
test(x: OwningInput!): ExampleOutput!
}
"""

assert print_schema(schema) == textwrap.dedent(expected).strip()

query = """
query($input_data: OwningInput!)
{
test(x: $input_data) {
owningId nonScalarId nonScalarNullableField
}
}
"""
result = schema.execute_sync(
query, variable_values=dict(input_data=dict(nonScalarType={}))
)

assert not result.errors
expected_result = {
"owningId": 10,
"nonScalarId": 10,
"nonScalarNullableField": None,
}
assert result.data["test"] == expected_result
20 changes: 10 additions & 10 deletions tests/experimental/pydantic/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from strawberry.experimental.pydantic._compat import (
IS_PYDANTIC_V2,
CompatModelField,
PydanticV1Compat,
PydanticV2Compat,
PydanticCompat,
)
from strawberry.experimental.pydantic.exceptions import (
AutoFieldsNotInBaseModelError,
Expand Down Expand Up @@ -841,10 +840,11 @@ class UserType:


def test_get_default_factory_for_field():
if IS_PYDANTIC_V2:
MISSING_TYPE = PydanticV2Compat().PYDANTIC_MISSING_TYPE
else:
MISSING_TYPE = PydanticV1Compat().PYDANTIC_MISSING_TYPE
class User(BaseModel):
pass

compat = PydanticCompat.from_model(User)
MISSING_TYPE = compat.PYDANTIC_MISSING_TYPE

def _get_field(
default: Any = MISSING_TYPE,
Expand All @@ -867,21 +867,21 @@ def _get_field(

field = _get_field()

assert get_default_factory_for_field(field) is dataclasses.MISSING
assert get_default_factory_for_field(field, compat) is dataclasses.MISSING

def factory_func():
return "strawberry"

field = _get_field(default_factory=factory_func)

# should return the default_factory unchanged
assert get_default_factory_for_field(field) is factory_func
assert get_default_factory_for_field(field, compat) is factory_func

mutable_default = [123, "strawberry"]

field = _get_field(mutable_default)

created_factory = get_default_factory_for_field(field)
created_factory = get_default_factory_for_field(field, compat)

# should return a factory that copies the default parameter
assert created_factory() == mutable_default
Expand All @@ -893,7 +893,7 @@ def factory_func():
BothDefaultAndDefaultFactoryDefinedError,
match=("Not allowed to specify both default and default_factory."),
):
get_default_factory_for_field(field)
get_default_factory_for_field(field, compat)


def test_convert_input_types_to_pydantic_default_and_default_factory():
Expand Down
73 changes: 73 additions & 0 deletions tests/schema/test_input.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import textwrap
from typing import Optional

import strawberry
from strawberry.printer import print_schema


def test_renaming_input_fields():
Expand All @@ -27,3 +29,74 @@ def filter(self, input: FilterInput) -> str:
assert not result.errors
assert result.data
assert result.data["filter"] == "Hello nope"


def test_input_with_nonscalar_field_default():
@strawberry.input
class NonScalarField:
id: int = 10
nullable_field: Optional[int] = None

@strawberry.input
class Input:
non_scalar_field: NonScalarField = strawberry.field(
default_factory=lambda: NonScalarField()
)
id: int = 10

@strawberry.type
class ExampleOutput:
input_id: int
non_scalar_id: int
non_scalar_nullable_field: Optional[int]

@strawberry.type
class Query:
@strawberry.field
def example(self, data: Input) -> ExampleOutput:
return ExampleOutput(
input_id=data.id,
non_scalar_id=data.non_scalar_field.id,
non_scalar_nullable_field=data.non_scalar_field.nullable_field,
)

schema = strawberry.Schema(query=Query)

expected = """
type ExampleOutput {
inputId: Int!
nonScalarId: Int!
nonScalarNullableField: Int
}
input Input {
nonScalarField: NonScalarField! = {id: 10}
id: Int! = 10
}
input NonScalarField {
id: Int! = 10
nullableField: Int = null
}
type Query {
example(data: Input!): ExampleOutput!
}
"""
assert print_schema(schema) == textwrap.dedent(expected).strip()

query = """
query($input_data: Input!)
{
example(data: $input_data) {
inputId nonScalarId nonScalarNullableField
}
}
"""
result = schema.execute_sync(
query, variable_values=dict(input_data=dict(nonScalarField={}))
)

assert not result.errors
expected_result = {"inputId": 10, "nonScalarId": 10, "nonScalarNullableField": None}
assert result.data["example"] == expected_result

0 comments on commit a22a383

Please sign in to comment.