Skip to content

Commit

Permalink
Added optional for root_validator to be skipped if values validation …
Browse files Browse the repository at this point in the history
…fails (#1050)

* Added optional for root_validator to be skipped if values validation fails

* cleaner usage of skip_on_failure

* skip_on_failure: documentation update
  • Loading branch information
aviramha authored and samuelcolvin committed Dec 16, 2019
1 parent f37789c commit 5510a13
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 11 deletions.
1 change: 1 addition & 0 deletions changes/1049-aviramha.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added optional for `root_validator` to be skipped if values validation fails using keyword `skip_on_failure=True`
6 changes: 4 additions & 2 deletions docs/usage/validators.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,10 @@ validation occurs (and are provided with the raw input data), or `pre=False` (th
they're called after field validation.

Field validation will not occur if `pre=True` root validators raise an error. As with field validators,
"post" (i.e. `pre=False`) root validators will be called even if field validation fails; the `values` argument will
be a dict containing the values which passed field validation and field defaults where applicable.
"post" (i.e. `pre=False`) root validators by default will be called even if field validation fails; this
behaviour can be changed by setting the `skip_on_failure=True` keyword argument to the validator.
The `values` argument will be a dict containing the values which passed field validation and
field defaults where applicable.

## Field Checks

Expand Down
20 changes: 13 additions & 7 deletions pydantic/class_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


class Validator:
__slots__ = 'func', 'pre', 'each_item', 'always', 'check_fields'
__slots__ = 'func', 'pre', 'each_item', 'always', 'check_fields', 'skip_on_failure'

def __init__(
self,
Expand All @@ -21,12 +21,14 @@ def __init__(
each_item: bool = False,
always: bool = False,
check_fields: bool = False,
skip_on_failure: bool = False,
):
self.func = func
self.pre = pre
self.each_item = each_item
self.always = always
self.check_fields = check_fields
self.skip_on_failure = skip_on_failure


if TYPE_CHECKING:
Expand Down Expand Up @@ -105,20 +107,24 @@ def root_validator(*, pre: bool = False) -> Callable[[AnyCallable], classmethod]


def root_validator(
_func: Optional[AnyCallable] = None, *, pre: bool = False, allow_reuse: bool = False
_func: Optional[AnyCallable] = None, *, pre: bool = False, allow_reuse: bool = False, skip_on_failure: bool = False
) -> Union[classmethod, Callable[[AnyCallable], classmethod]]:
"""
Decorate methods on a model indicating that they should be used to validate (and perhaps modify) data either
before or after standard model parsing/validation is performed.
"""
if _func:
f_cls = _prepare_validator(_func, allow_reuse)
setattr(f_cls, ROOT_VALIDATOR_CONFIG_KEY, Validator(func=f_cls.__func__, pre=pre))
setattr(
f_cls, ROOT_VALIDATOR_CONFIG_KEY, Validator(func=f_cls.__func__, pre=pre, skip_on_failure=skip_on_failure)
)
return f_cls

def dec(f: AnyCallable) -> classmethod:
f_cls = _prepare_validator(f, allow_reuse)
setattr(f_cls, ROOT_VALIDATOR_CONFIG_KEY, Validator(func=f_cls.__func__, pre=pre))
setattr(
f_cls, ROOT_VALIDATOR_CONFIG_KEY, Validator(func=f_cls.__func__, pre=pre, skip_on_failure=skip_on_failure)
)
return f_cls

return dec
Expand Down Expand Up @@ -184,9 +190,9 @@ def extract_validators(namespace: Dict[str, Any]) -> Dict[str, List[Validator]]:
return validators


def extract_root_validators(namespace: Dict[str, Any]) -> Tuple[List[AnyCallable], List[AnyCallable]]:
def extract_root_validators(namespace: Dict[str, Any]) -> Tuple[List[AnyCallable], List[Tuple[bool, AnyCallable]]]:
pre_validators: List[AnyCallable] = []
post_validators: List[AnyCallable] = []
post_validators: List[Tuple[bool, AnyCallable]] = []
for name, value in namespace.items():
validator_config: Optional[Validator] = getattr(value, ROOT_VALIDATOR_CONFIG_KEY, None)
if validator_config:
Expand All @@ -203,7 +209,7 @@ def extract_root_validators(namespace: Dict[str, Any]) -> Tuple[List[AnyCallable
if validator_config.pre:
pre_validators.append(validator_config.func)
else:
post_validators.append(validator_config.func)
post_validators.append((validator_config.skip_on_failure, validator_config.func))
return pre_validators, post_validators


Expand Down
6 changes: 4 additions & 2 deletions pydantic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ class BaseModel(metaclass=ModelMetaclass):
__field_defaults__: Dict[str, Any] = {}
__validators__: Dict[str, AnyCallable] = {}
__pre_root_validators__: List[AnyCallable]
__post_root_validators__: List[AnyCallable]
__post_root_validators__: List[Tuple[bool, AnyCallable]]
__config__: Type[BaseConfig] = BaseConfig
__root__: Any = None
__json_encoder__: Callable[[Any], Any] = lambda x: x
Expand Down Expand Up @@ -859,7 +859,9 @@ def validate_model( # noqa: C901 (ignore complexity)
for f in sorted(extra):
errors.append(ErrorWrapper(ExtraError(), loc=f))

for validator in model.__post_root_validators__:
for skip_on_failure, validator in model.__post_root_validators__:
if skip_on_failure and errors:
continue
try:
values = validator(cls_, values)
except (ValueError, TypeError, AssertionError) as exc:
Expand Down
29 changes: 29 additions & 0 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,3 +988,32 @@ def example_root_validator(cls, values):
]

assert root_val_values == [{'a': 123, 'b': 'barbar'}, {'a': 1, 'b': 'snap dragonsnap dragon'}, {'b': 'barbar'}]


def test_root_validator_skip_on_failure():
a_called = False

class ModelA(BaseModel):
a: int

@root_validator
def example_root_validator(cls, values):
nonlocal a_called
a_called = True

with pytest.raises(ValidationError):
ModelA(a='a')
assert a_called
b_called = False

class ModelB(BaseModel):
a: int

@root_validator(skip_on_failure=True)
def example_root_validator(cls, values):
nonlocal b_called
b_called = True

with pytest.raises(ValidationError):
ModelB(a='a')
assert not b_called

0 comments on commit 5510a13

Please sign in to comment.