Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix typing of model_validator #6514

Merged
merged 3 commits into from Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/usage/dataclasses.md
Expand Up @@ -347,10 +347,10 @@ class User:
return values

@model_validator(mode='after')
def post_root(cls, values: Dict[str, Any]) -> Dict[str, Any]:
print(values)
def post_root(self) -> 'User':
print(self)
#> User(birth=Birth(year=1995, month=3, day=2))
return values
return self

def __post_init__(self):
print(self.birth)
Expand Down
8 changes: 4 additions & 4 deletions docs/usage/validators.md
Expand Up @@ -415,12 +415,12 @@ class UserModel(BaseModel):
return data

@model_validator(mode='after')
def check_passwords_match(cls, m: 'UserModel'):
pw1 = m.password1
pw2 = m.password2
def check_passwords_match(self) -> 'UserModel':
pw1 = self.password1
pw2 = self.password2
if pw1 is not None and pw2 is not None and pw1 != pw2:
raise ValueError('passwords do not match')
return m
return self


print(UserModel(username='scolvin', password1='zxcvbn', password2='zxcvbn'))
Expand Down
33 changes: 10 additions & 23 deletions pydantic/functional_validators.py
Expand Up @@ -409,32 +409,17 @@ def __call__( # noqa: D102
...


class ModelAfterValidatorWithoutInfo(Protocol):
"""A `@model_validator` decorated function signature. This is used when `mode='after'` and the function does not
have info argument.
"""

@staticmethod
def __call__( # noqa: D102
self: _ModelType, # type: ignore
) -> _ModelType:
...


class ModelAfterValidator(Protocol):
"""A `@model_validator` decorated function signature. This is used when `mode='after'`."""

@staticmethod
def __call__( # noqa: D102
self: _ModelType, # type: ignore
__info: _core_schema.ValidationInfo,
) -> _ModelType:
...
ModelAfterValidatorWithoutInfo = Callable[[_ModelType], _ModelType]
"""A `@model_validator` decorated function signature. This is used when `mode='after'` and the function does not
have info argument.
"""
Comment on lines +412 to +415
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, mypy didn't recognize the the previously annotated Protocol object as callable? (Even though it had call method? By switching to explicitly annotating with Callable, the annotations are correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that we're annotating callable functions as callable though. (Though Protocol is supposed to solve this for us, no?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory you can use either, but it's a little bit weird because here we want to require the self argument, in particular because it helps distinguish from the other alternative signatures. But my hack to make that work does not seem to have been compatible with mypy so this is a simpler albeit less pedantic version that hopefully will be.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we want to require the self argument, in particular because it helps distinguish from the other alternative signatures

I don't know if I'm understanding your thought correctly.
But if those validators work with other signatures as well (e.g. with classmethod signatures) it would be misleading imo. The implementation as far as I saw supports static-, class- and instance-signatures and even decorators. Therefore, I don't think you have to enforce an instance-signature here (i.e. with self argument).

Actually, even the documentation suggests a classmethod-signature for model_validator with mode="after". Does the solution of this PR really work with classmethod-signatures (and maybe with a @classmethod decorator) though, did you test it? If not, is anything wrong with doing it similar to field_validator? It seems to work there.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, even the documentation suggests a classmethod-signature for model_validator with mode="after"

This is precisely what we wanted to avoid, the docs are wrong.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I tried around by myself but I really couldn't get a Protocol typing for instance methods to work together with the decorator. At least with mypy. Very sad, maybe it's possible in the future.

Btw, why do you even want to avoid classmethod signatures? Is there anywhere a discussion about this?

Copy link
Member Author

@adriangb adriangb Jul 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, why do you even want to avoid classmethod signatures? Is there anywhere a discussion about this?

You're getting an instance of the model. Of course (cls: Type['Model'], v: 'Model') and add an @classmethod works but why do that instead of (self) which works great with type checkers, etc.


ModelAfterValidator = Callable[[_ModelType, _core_schema.ValidationInfo], _ModelType]
"""A `@model_validator` decorated function signature. This is used when `mode='after'`."""

_AnyModelWrapValidator = Union[ModelWrapValidator, ModelWrapValidatorWithoutInfo]
_AnyModeBeforeValidator = Union[ModelBeforeValidator, ModelBeforeValidatorWithoutInfo]
_AnyModeAfterValidator = Union[ModelAfterValidator, ModelAfterValidatorWithoutInfo]
_AnyModelAfterValidator = Union[ModelAfterValidator[_ModelType], ModelAfterValidatorWithoutInfo[_ModelType]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this reparametrization do anything? It looks like the same typevar is used in the original definition of this thing. Is this for pseudo-documentation reasons or does it type check differently? Should it make use of a TypeAlias?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because generic aliases require the type var reparametrization



@overload
Expand All @@ -457,7 +442,9 @@ def model_validator(
def model_validator(
*,
mode: Literal['after'],
) -> Callable[[_AnyModeAfterValidator], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo]]:
) -> Callable[
[_AnyModelAfterValidator[_ModelType]], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo]
]:
...


Expand Down
6 changes: 3 additions & 3 deletions tests/test_dataclasses.py
Expand Up @@ -1748,9 +1748,9 @@ class MyDataclass:
b: float

@model_validator(mode='after')
def double_b(cls, dc: 'MyDataclass'):
dc.b *= 2
return dc
def double_b(self) -> 'MyDataclass':
self.b *= 2
return self

d = MyDataclass('1', b='2')
assert d.a == 1
Expand Down
9 changes: 4 additions & 5 deletions tests/test_generics.py
Expand Up @@ -94,7 +94,7 @@ class Result(BaseModel, Generic[data_type]):


def test_value_validation():
T = TypeVar('T')
T = TypeVar('T', bound=Dict[Any, Any])

class Response(BaseModel, Generic[T]):
data: T
Expand All @@ -107,12 +107,11 @@ def validate_value_nonzero(cls, v: Any):
return v

@model_validator(mode='after')
@classmethod
def validate_sum(cls, m):
data = m.data
def validate_sum(self) -> 'Response[T]':
data = self.data
if sum(data.values()) > 5:
raise ValueError('sum too large')
return m
return self

assert Response[Dict[int, int]](data={1: '4'}).model_dump() == {'data': {1: 4}}
with pytest.raises(ValidationError) as exc_info:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_root_model.py
Expand Up @@ -244,9 +244,9 @@ def words(cls, v):
def test_model_validator_after():
class Model(RootModel[int]):
@model_validator(mode='after')
def double(cls, v):
v.root *= 2
return v
def double(self) -> 'Model':
self.root *= 2
return self

assert Model('1').root == 2
assert Model('21').root == 42
Expand Down
18 changes: 9 additions & 9 deletions tests/test_validators.py
Expand Up @@ -1481,8 +1481,8 @@ def test_model_validator_returns_ignore():
class Model(BaseModel):
a: int = 1

@model_validator(mode='after')
def model_validator_return_none(cls, m):
@model_validator(mode='after') # type: ignore
def model_validator_return_none(self) -> None:
return None

m = Model(a=2)
Expand Down Expand Up @@ -1727,9 +1727,9 @@ def pre_root(cls, values: Dict[str, Any]) -> Dict[str, Any]:
return values

@model_validator(mode='after')
def post_root(cls, values: Dict[str, Any]) -> Dict[str, Any]:
def post_root(self) -> 'A':
validate_stub('A', 'post')
return values
return self

class B(A):
@model_validator(mode='before')
Expand All @@ -1738,9 +1738,9 @@ def pre_root(cls, values: Dict[str, Any]) -> Dict[str, Any]:
return values

@model_validator(mode='after')
def post_root(cls, values: Dict[str, Any]) -> Dict[str, Any]:
def post_root(self) -> 'B':
validate_stub('B', 'post')
return values
return self

A(x='pika')
assert validate_stub.call_args_list == [[('A', 'pre'), {}], [('A', 'post'), {}]]
Expand Down Expand Up @@ -1858,9 +1858,9 @@ class Rectangle(BaseModel):
model_config = ConfigDict(validate_assignment=True)

@model_validator(mode='after')
def set_area(cls, m: 'Rectangle') -> 'Rectangle':
m.__dict__['area'] = m.width * m.height
return m
def set_area(self) -> 'Rectangle':
self.__dict__['area'] = self.width * self.height
return self

r = Rectangle(width=1, height=1)
assert r.area == 1
Expand Down
14 changes: 7 additions & 7 deletions tests/test_validators_dataclass.py
Expand Up @@ -150,7 +150,7 @@ def add_to_a(cls, v):


def test_model_validator():
root_val_values = []
root_val_values: list[Any] = []

@dataclass
class MyDataclass:
Expand All @@ -159,16 +159,16 @@ class MyDataclass:

@field_validator('b')
@classmethod
def repeat_b(cls, v):
def repeat_b(cls, v: str) -> str:
return v * 2

@model_validator(mode='after')
def root_validator(cls, m):
root_val_values.append(asdict(m))
if 'snap' in m.b:
def root_validator(self) -> 'MyDataclass':
root_val_values.append(asdict(self))
if 'snap' in self.b:
raise ValueError('foobar')
m.b = 'changed'
return m
self.b = 'changed'
return self

assert asdict(MyDataclass(a='123', b='bar')) == {'a': 123, 'b': 'changed'}

Expand Down