Skip to content

Commit

Permalink
Touch up Decimal validator
Browse files Browse the repository at this point in the history
This improves performance slightly by not always going through Python, but error messages actually get less intelligle
  • Loading branch information
adriangb committed Jul 8, 2023
1 parent c5b28b8 commit d1eee58
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 99 deletions.
195 changes: 113 additions & 82 deletions pydantic/_internal/_std_types_schema.py
Expand Up @@ -136,19 +136,24 @@ def get_json_schema(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue:

@slots_dataclass
class DecimalValidator:
gt: int | decimal.Decimal | None = None
ge: int | decimal.Decimal | None = None
lt: int | decimal.Decimal | None = None
le: int | decimal.Decimal | None = None
gt: decimal.Decimal | None = None
ge: decimal.Decimal | None = None
lt: decimal.Decimal | None = None
le: decimal.Decimal | None = None
max_digits: int | None = None
decimal_places: int | None = None
multiple_of: int | decimal.Decimal | None = None
multiple_of: decimal.Decimal | None = None
allow_inf_nan: bool = False
check_digits: bool = False
strict: bool = False

def __post_init__(self) -> None:
self.check_digits = self.max_digits is not None or self.decimal_places is not None
self.gt = decimal.Decimal(self.gt) if self.gt is not None else None
self.ge = decimal.Decimal(self.ge) if self.ge is not None else None
self.lt = decimal.Decimal(self.lt) if self.lt is not None else None
self.le = decimal.Decimal(self.le) if self.le is not None else None
self.multiple_of = decimal.Decimal(self.multiple_of) if self.multiple_of is not None else None
if self.check_digits and self.allow_inf_nan:
raise ValueError('allow_inf_nan=True cannot be used with max_digits or decimal_places')

Expand All @@ -171,94 +176,120 @@ def __get_pydantic_json_schema__(self, _schema: CoreSchema, handler: GetJsonSche
return string_schema

def __get_pydantic_core_schema__(self, _source_type: Any, _handler: GetCoreSchemaHandler) -> CoreSchema:
return core_schema.general_after_validator_function(self.validate, core_schema.any_schema())

def validate( # noqa: C901 (ignore complexity)
self, input_value: Any, info: core_schema.ValidationInfo
) -> decimal.Decimal:
if isinstance(input_value, decimal.Decimal):
value = input_value
elif self.strict or (info.config or {}).get('strict', False) and info.mode == 'python':
raise PydanticCustomError(
'decimal_type', 'Input should be a valid Decimal instance or decimal string in JSON'
)
else:
Decimal = decimal.Decimal

def to_decimal(v: Any) -> decimal.Decimal:
try:
value = decimal.Decimal(str(input_value))
except decimal.DecimalException:
raise PydanticCustomError('decimal_parsing', 'Input should be a valid decimal')
return Decimal(v)
except decimal.DecimalException as e:
raise PydanticCustomError('decimal_parsing', 'Input should be a valid decimal') from e

primitive_schema = core_schema.union_schema(
[
core_schema.float_schema(strict=True),
core_schema.int_schema(strict=True),
core_schema.str_schema(strict=True, strip_whitespace=True),
core_schema.no_info_plain_validator_function(lambda x: str(x).strip()),
],
)
json_schema = core_schema.no_info_after_validator_function(to_decimal, primitive_schema)
schema = core_schema.json_or_python_schema(
json_schema=json_schema,
python_schema=core_schema.lax_or_strict_schema(
lax_schema=core_schema.union_schema([core_schema.is_instance_schema(decimal.Decimal), json_schema]),
strict_schema=core_schema.is_instance_schema(decimal.Decimal),
),
)

if not self.allow_inf_nan or self.check_digits:
try:
normalized_value = value.normalize()
except decimal.InvalidOperation:
normalized_value = value
_1, digit_tuple, exponent = normalized_value.as_tuple()
if not self.allow_inf_nan and exponent in {'F', 'n', 'N'}:
raise PydanticKnownError('finite_number')
schema = core_schema.no_info_after_validator_function(
self.check_digits_validator,
schema,
)

if self.check_digits:
if isinstance(exponent, str):
raise PydanticKnownError('finite_number')
elif exponent >= 0:
# A positive exponent adds that many trailing zeros.
digits = len(digit_tuple) + exponent
decimals = 0
else:
# If the absolute value of the negative exponent is larger than the
# number of digits, then it's the same as the number of digits,
# because it'll consume all the digits in digit_tuple and then
# add abs(exponent) - len(digit_tuple) leading zeros after the
# decimal point.
if abs(exponent) > len(digit_tuple):
digits = decimals = abs(exponent)
else:
digits = len(digit_tuple)
decimals = abs(exponent)

if self.max_digits is not None and digits > self.max_digits:
raise PydanticCustomError(
'decimal_max_digits',
'ensure that there are no more than {max_digits} digits in total',
{'max_digits': self.max_digits},
)
if self.multiple_of is not None:
schema = core_schema.no_info_after_validator_function(
partial(_validators.multiple_of_validator, multiple_of=self.multiple_of),
schema,
)

if self.decimal_places is not None and decimals > self.decimal_places:
raise PydanticCustomError(
'decimal_max_places',
'ensure that there are no more than {decimal_places} decimal places',
{'decimal_places': self.decimal_places},
)
if self.gt is not None:
schema = core_schema.no_info_after_validator_function(
partial(_validators.greater_than_validator, gt=self.gt),
schema,
)

if self.max_digits is not None and self.decimal_places is not None:
whole_digits = digits - decimals
expected = self.max_digits - self.decimal_places
if whole_digits > expected:
raise PydanticCustomError(
'decimal_whole_digits',
'ensure that there are no more than {whole_digits} digits before the decimal point',
{'whole_digits': expected},
)
if self.ge is not None:
schema = core_schema.no_info_after_validator_function(
partial(_validators.greater_than_or_equal_validator, ge=self.ge),
schema,
)

if self.multiple_of is not None:
mod = value / self.multiple_of % 1
if mod != 0:
if self.lt is not None:
schema = core_schema.no_info_after_validator_function(
partial(_validators.less_than_validator, lt=self.lt),
schema,
)

if self.le is not None:
schema = core_schema.no_info_after_validator_function(
partial(_validators.less_than_or_equal_validator, le=self.le),
schema,
)

return schema

def check_digits_validator(self, value: decimal.Decimal) -> decimal.Decimal:
try:
normalized_value = value.normalize()
except decimal.InvalidOperation:
normalized_value = value
_1, digit_tuple, exponent = normalized_value.as_tuple()
if not self.allow_inf_nan and exponent in {'F', 'n', 'N'}:
raise PydanticKnownError('finite_number')

if self.check_digits:
if isinstance(exponent, str):
raise PydanticKnownError('finite_number')
elif exponent >= 0:
# A positive exponent adds that many trailing zeros.
digits = len(digit_tuple) + exponent
decimals = 0
else:
# If the absolute value of the negative exponent is larger than the
# number of digits, then it's the same as the number of digits,
# because it'll consume all the digits in digit_tuple and then
# add abs(exponent) - len(digit_tuple) leading zeros after the
# decimal point.
if abs(exponent) > len(digit_tuple):
digits = decimals = abs(exponent)
else:
digits = len(digit_tuple)
decimals = abs(exponent)

if self.max_digits is not None and digits > self.max_digits:
raise PydanticCustomError(
'decimal_multiple_of',
'Input should be a multiple of {multiple_of}',
{'multiple_of': self.multiple_of},
'decimal_max_digits',
'ensure that there are no more than {max_digits} digits in total',
{'max_digits': self.max_digits},
)

if self.gt is not None and not value > self.gt: # type: ignore
raise PydanticKnownError('greater_than', {'gt': self.gt})
elif self.ge is not None and not value >= self.ge: # type: ignore
raise PydanticKnownError('greater_than_equal', {'ge': self.ge})

if self.lt is not None and not value < self.lt: # type: ignore
raise PydanticKnownError('less_than', {'lt': self.lt})
if self.le is not None and not value <= self.le: # type: ignore
raise PydanticKnownError('less_than_equal', {'le': self.le})
if self.decimal_places is not None and decimals > self.decimal_places:
raise PydanticCustomError(
'decimal_max_places',
'ensure that there are no more than {decimal_places} decimal places',
{'decimal_places': self.decimal_places},
)

if self.max_digits is not None and self.decimal_places is not None:
whole_digits = digits - decimals
expected = self.max_digits - self.decimal_places
if whole_digits > expected:
raise PydanticCustomError(
'decimal_whole_digits',
'ensure that there are no more than {whole_digits} digits before the decimal point',
{'whole_digits': expected},
)
return value


Expand Down
46 changes: 39 additions & 7 deletions tests/test_edge_cases.py
Expand Up @@ -423,20 +423,37 @@ class Model(BaseModel):

with pytest.raises(ValidationError) as exc_info:
Model(v=['x', 'y', 'x'])
# insert_assert(exc_info.value.errors(include_url=False))
assert exc_info.value.errors(include_url=False) == [
{
'input': 'x',
'loc': ('v', 0),
'msg': 'Input should be a valid integer, unable to parse string as an ' 'integer',
'type': 'int_parsing',
'loc': ('v', 0),
'msg': 'Input should be a valid integer, unable to parse string as an integer',
'input': 'x',
},
{
'input': 'y',
'type': 'float_parsing',
'loc': ('v', 1),
'msg': 'Input should be a valid number, unable to parse string as a number',
'type': 'float_parsing',
'input': 'y',
},
{
'type': 'is_instance_of',
'loc': ('v', 2, 'is-instance[Decimal]'),
'msg': 'Input should be an instance of Decimal',
'input': 'x',
'ctx': {'class': 'Decimal'},
},
{
'type': 'decimal_parsing',
'loc': (
'v',
2,
'function-after[to_decimal(), union[float,int,constrained-str,function-plain[<lambda>()]]]',
),
'msg': 'Input should be a valid decimal',
'input': 'x',
},
{'input': 'x', 'loc': ('v', 2), 'msg': 'Input should be a valid decimal', 'type': 'decimal_parsing'},
]


Expand Down Expand Up @@ -1220,9 +1237,24 @@ class Model(BaseModel):
'msg': 'Input should be a valid number, unable to parse string as a number',
'input': 'foobar',
},
{
'type': 'is_instance_of',
'loc': (
'a',
'function-after[check_digits_validator(), json-or-python[json=function-after[to_decimal(), union[float,int,constrained-str,function-plain[<lambda>()]]],python=lax-or-strict[lax=union[is-instance[Decimal],function-after[to_decimal(), union[float,int,constrained-str,function-plain[<lambda>()]]]],strict=is-instance[Decimal]]]]', # noqa: E501
'is-instance[Decimal]',
),
'msg': 'Input should be an instance of Decimal',
'input': 'foobar',
'ctx': {'class': 'Decimal'},
},
{
'type': 'decimal_parsing',
'loc': ('a', 'function-after[validate(), any]'),
'loc': (
'a',
'function-after[check_digits_validator(), json-or-python[json=function-after[to_decimal(), union[float,int,constrained-str,function-plain[<lambda>()]]],python=lax-or-strict[lax=union[is-instance[Decimal],function-after[to_decimal(), union[float,int,constrained-str,function-plain[<lambda>()]]]],strict=is-instance[Decimal]]]]', # noqa: E501
'function-after[to_decimal(), union[float,int,constrained-str,function-plain[<lambda>()]]]',
),
'msg': 'Input should be a valid decimal',
'input': 'foobar',
},
Expand Down
2 changes: 1 addition & 1 deletion tests/test_json.py
Expand Up @@ -119,7 +119,7 @@ class Model(BaseModel):
c: Decimal
d: ModelA

m = Model(a=10.2, b='foobar', c=10.2, d={'x': 123, 'y': '123'})
m = Model(a=10.2, b='foobar', c='10.2', d={'x': 123, 'y': '123'})
assert m.model_dump() == {'a': 10.2, 'b': b'foobar', 'c': Decimal('10.2'), 'd': {'x': 123, 'y': '123'}}
assert m.model_dump_json() == '{"a":10.2,"b":"foobar","c":"10.2","d":{"x":123,"y":"123"}}'
assert m.model_dump_json(exclude={'b'}) == '{"a":10.2,"c":"10.2","d":{"x":123,"y":"123"}}'
Expand Down
28 changes: 19 additions & 9 deletions tests/test_types.py
Expand Up @@ -1026,12 +1026,15 @@ class Model(BaseModel):

with pytest.raises(ValidationError) as exc_info:
Model(v=1.23)

# insert_assert(exc_info.value.errors(include_url=False))
assert exc_info.value.errors(include_url=False) == [
{
'type': 'decimal_type',
'type': 'is_instance_of',
'loc': ('v',),
'msg': 'Input should be a valid Decimal instance or decimal string in JSON',
'msg': 'Input should be an instance of Decimal',
'input': 1.23,
'ctx': {'class': 'Decimal'},
}
]

Expand Down Expand Up @@ -1275,7 +1278,7 @@ def __bool__(self) -> bool:
('uuid_check', b'\x12\x34\x56\x78' * 4, UUID('12345678-1234-5678-1234-567812345678')),
('uuid_check', 'ebcdab58-6eb8-46fb-a190-', ValidationError),
('uuid_check', 123, ValidationError),
('decimal_check', 42.24, Decimal('42.24')),
('decimal_check', 42.24, Decimal(42.24)),
('decimal_check', '42.24', Decimal('42.24')),
('decimal_check', b'42.24', ValidationError),
('decimal_check', ' 42.24 ', Decimal('42.24')),
Expand Down Expand Up @@ -2889,7 +2892,7 @@ class Model(BaseModel):
'loc': ('foo',),
'msg': 'Input should be greater than 42.24',
'input': Decimal('42'),
'ctx': {'gt': 42.24},
'ctx': {'gt': '42.24'},
}
],
),
Expand All @@ -2904,7 +2907,7 @@ class Model(BaseModel):
'msg': 'Input should be less than 42.24',
'input': Decimal('43'),
'ctx': {
'lt': 42.24,
'lt': '42.24',
},
},
],
Expand All @@ -2921,7 +2924,7 @@ class Model(BaseModel):
'msg': 'Input should be greater than or equal to 42.24',
'input': Decimal('42'),
'ctx': {
'ge': 42.24,
'ge': '42.24',
},
}
],
Expand All @@ -2938,7 +2941,7 @@ class Model(BaseModel):
'msg': 'Input should be less than or equal to 42.24',
'input': Decimal('43'),
'ctx': {
'le': 42.24,
'le': '42.24',
},
}
],
Expand Down Expand Up @@ -3067,7 +3070,7 @@ class Model(BaseModel):
Decimal('42'),
[
{
'type': 'decimal_multiple_of',
'type': 'multiple_of',
'loc': ('foo',),
'msg': 'Input should be a multiple of 5',
'input': Decimal('42'),
Expand Down Expand Up @@ -5069,8 +5072,15 @@ class Model(BaseModel):
assert isinstance(Model(x=1).x, PdDecimal)
with pytest.raises(ValidationError) as exc_info:
Model(x=-1)
# insert_assert(exc_info.value.errors(include_url=False))
assert exc_info.value.errors(include_url=False) == [
{'type': 'greater_than', 'loc': ('x',), 'msg': 'Input should be greater than 0', 'input': -1, 'ctx': {'gt': 0}}
{
'type': 'greater_than',
'loc': ('x',),
'msg': 'Input should be greater than 0',
'input': -1,
'ctx': {'gt': '0'},
}
]


Expand Down

0 comments on commit d1eee58

Please sign in to comment.