Skip to content

Commit

Permalink
Fix mypy error on free before validator (classmethod) (#8285)
Browse files Browse the repository at this point in the history
  • Loading branch information
sydney-runkle committed Dec 4, 2023
1 parent 8d2e6f2 commit 20c0c6d
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 5 deletions.
38 changes: 33 additions & 5 deletions pydantic/functional_validators.py
Expand Up @@ -417,6 +417,21 @@ def __call__( # noqa: D102
...


class FreeModelBeforeValidatorWithoutInfo(Protocol):
"""A @model_validator decorated function signature.
This is used when `mode='before'` and the function does not have info argument.
"""

def __call__( # noqa: D102
self,
# this can be a dict, a model instance
# or anything else that gets passed to validate_python
# thus validators _must_ handle all cases
__value: Any,
) -> Any:
...


class ModelBeforeValidatorWithoutInfo(Protocol):
"""A @model_validator decorated function signature.
This is used when `mode='before'` and the function does not have info argument.
Expand All @@ -433,6 +448,20 @@ def __call__( # noqa: D102
...


class FreeModelBeforeValidator(Protocol):
"""A `@model_validator` decorated function signature. This is used when `mode='before'`."""

def __call__( # noqa: D102
self,
# this can be a dict, a model instance
# or anything else that gets passed to validate_python
# thus validators _must_ handle all cases
__value: Any,
__info: _core_schema.ValidationInfo,
) -> Any:
...


class ModelBeforeValidator(Protocol):
"""A `@model_validator` decorated function signature. This is used when `mode='before'`."""

Expand All @@ -457,7 +486,9 @@ def __call__( # noqa: D102
"""A `@model_validator` decorated function signature. This is used when `mode='after'`."""

_AnyModelWrapValidator = Union[ModelWrapValidator[_ModelType], ModelWrapValidatorWithoutInfo[_ModelType]]
_AnyModeBeforeValidator = Union[ModelBeforeValidator, ModelBeforeValidatorWithoutInfo]
_AnyModeBeforeValidator = Union[
FreeModelBeforeValidator, ModelBeforeValidator, FreeModelBeforeValidatorWithoutInfo, ModelBeforeValidatorWithoutInfo
]
_AnyModelAfterValidator = Union[ModelAfterValidator[_ModelType], ModelAfterValidatorWithoutInfo[_ModelType]]


Expand Down Expand Up @@ -499,8 +530,6 @@ def model_validator(
Example usage:
```py
from typing import Optional
from typing_extensions import Self
from pydantic import BaseModel, ValidationError, model_validator
Expand All @@ -525,8 +554,7 @@ def verify_square(self) -> Self:
print(e)
'''
1 validation error for Square
__root__
width and height do not match (type=value_error)
Value error, width and height do not match [type=value_error, input_value={'width': 1, 'height': 2}, input_type=dict]
'''
```
Expand Down
13 changes: 13 additions & 0 deletions tests/mypy/modules/success.py
Expand Up @@ -42,6 +42,7 @@
WrapValidator,
create_model,
field_validator,
model_validator,
root_validator,
validate_call,
)
Expand Down Expand Up @@ -308,3 +309,15 @@ class Abstract(BaseModel):

class Concrete(Abstract):
class_id = 1


def two_dim_shape_validator(v: Dict[str, Any]) -> Dict[str, Any]:
assert 'volume' not in v, 'shape is 2d, cannot have volume'
return v


class Square(BaseModel):
width: float
height: float

free_validator = model_validator(mode='before')(two_dim_shape_validator)
13 changes: 13 additions & 0 deletions tests/mypy/outputs/1.0.1/mypy-default_ini/success.py
Expand Up @@ -42,6 +42,7 @@
WrapValidator,
create_model,
field_validator,
model_validator,
root_validator,
validate_call,
)
Expand Down Expand Up @@ -314,3 +315,15 @@ class Abstract(BaseModel):

class Concrete(Abstract):
class_id = 1


def two_dim_shape_validator(v: Dict[str, Any]) -> Dict[str, Any]:
assert 'volume' not in v, 'shape is 2d, cannot have volume'
return v


class Square(BaseModel):
width: float
height: float

free_validator = model_validator(mode='before')(two_dim_shape_validator)
13 changes: 13 additions & 0 deletions tests/mypy/outputs/1.0.1/pyproject-default_toml/success.py
Expand Up @@ -42,6 +42,7 @@
WrapValidator,
create_model,
field_validator,
model_validator,
root_validator,
validate_call,
)
Expand Down Expand Up @@ -314,3 +315,15 @@ class Abstract(BaseModel):

class Concrete(Abstract):
class_id = 1


def two_dim_shape_validator(v: Dict[str, Any]) -> Dict[str, Any]:
assert 'volume' not in v, 'shape is 2d, cannot have volume'
return v


class Square(BaseModel):
width: float
height: float

free_validator = model_validator(mode='before')(two_dim_shape_validator)

0 comments on commit 20c0c6d

Please sign in to comment.