Skip to content

Commit

Permalink
Fix float -> Decimal coercion precision loss (#6810)
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Jul 23, 2023
1 parent ebc6019 commit 5320764
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 11 deletions.
8 changes: 6 additions & 2 deletions pydantic/_internal/_std_types_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,17 +180,21 @@ def __get_pydantic_json_schema__(self, _schema: CoreSchema, handler: GetJsonSche
def __get_pydantic_core_schema__(self, _source_type: Any, _handler: GetCoreSchemaHandler) -> CoreSchema:
Decimal = decimal.Decimal

def to_decimal(v: Any) -> decimal.Decimal:
def to_decimal(v: str) -> decimal.Decimal:
try:
return Decimal(v)
except decimal.DecimalException as e:
raise PydanticCustomError('decimal_parsing', 'Input should be a valid decimal') from e

primitive_schema = core_schema.union_schema(
[
core_schema.float_schema(strict=True),
# if it's an int keep it like that and pass it straight to Decimal
# but if it's not make it a string
# we don't use JSON -> float because parsing to any float will cause
# loss of precision
core_schema.int_schema(strict=True),
core_schema.str_schema(strict=True, strip_whitespace=True),
core_schema.no_info_plain_validator_function(str),
],
)
json_schema = core_schema.no_info_after_validator_function(to_decimal, primitive_schema)
Expand Down
12 changes: 4 additions & 8 deletions tests/test_edge_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ class Model(BaseModel):
},
{
'type': 'decimal_parsing',
'loc': ('v', 2, 'function-after[to_decimal(), union[float,int,constrained-str]]'),
'loc': ('v', 2, 'function-after[to_decimal(), union[int,constrained-str,function-plain[str()]]]'),
'msg': 'Input should be a valid decimal',
'input': 'x',
},
Expand Down Expand Up @@ -1237,9 +1237,7 @@ class Model(BaseModel):
'type': 'is_instance_of',
'loc': (
'a',
'function-after[check_digits_validator(), json-or-python[json=function-after[to_decimal(), '
'union[float,int,constrained-str]],python=lax-or-strict[lax=union[is-instance[Decimal],'
'function-after[to_decimal(), union[float,int,constrained-str]]],strict=is-instance[Decimal]]]]',
'function-after[check_digits_validator(), json-or-python[json=function-after[to_decimal(), union[int,constrained-str,function-plain[str()]]],python=lax-or-strict[lax=union[is-instance[Decimal],function-after[to_decimal(), union[int,constrained-str,function-plain[str()]]]],strict=is-instance[Decimal]]]]',
'is-instance[Decimal]',
),
'msg': 'Input should be an instance of Decimal',
Expand All @@ -1250,10 +1248,8 @@ class Model(BaseModel):
'type': 'decimal_parsing',
'loc': (
'a',
'function-after[check_digits_validator(), json-or-python[json=function-after[to_decimal(), '
'union[float,int,constrained-str]],python=lax-or-strict[lax=union[is-instance[Decimal],'
'function-after[to_decimal(), union[float,int,constrained-str]]],strict=is-instance[Decimal]]]]',
'function-after[to_decimal(), union[float,int,constrained-str]]',
'function-after[check_digits_validator(), json-or-python[json=function-after[to_decimal(), union[int,constrained-str,function-plain[str()]]],python=lax-or-strict[lax=union[is-instance[Decimal],function-after[to_decimal(), union[int,constrained-str,function-plain[str()]]]],strict=is-instance[Decimal]]]]',
'function-after[to_decimal(), union[int,constrained-str,function-plain[str()]]]',
),
'msg': 'Input should be a valid decimal',
'input': 'foobar',
Expand Down
13 changes: 12 additions & 1 deletion tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,7 +1292,7 @@ def __bool__(self) -> bool:
('uuid_check', b'\x12\x34\x56\x78' * 4, UUID('12345678-1234-5678-1234-567812345678')),
('uuid_check', 'ebcdab58-6eb8-46fb-a190-', ValidationError),
('uuid_check', 123, ValidationError),
('decimal_check', 42.24, Decimal(42.24)),
('decimal_check', 42.24, Decimal('42.24')),
('decimal_check', '42.24', Decimal('42.24')),
('decimal_check', b'42.24', ValidationError),
('decimal_check', ' 42.24 ', Decimal('42.24')),
Expand Down Expand Up @@ -5615,3 +5615,14 @@ def test_string_constraints() -> None:
Annotated[str, StringConstraints(strip_whitespace=True, to_lower=True), AfterValidator(lambda x: x * 2)]
)
assert ta.validate_python(' ABC ') == 'abcabc'


def test_decimal_float_precision() -> None:
"""https://github.com/pydantic/pydantic/issues/6807"""
ta = TypeAdapter(Decimal)
assert ta.validate_json('1.1') == Decimal('1.1')
assert ta.validate_python(1.1) == Decimal('1.1')
assert ta.validate_json('"1.1"') == Decimal('1.1')
assert ta.validate_python('1.1') == Decimal('1.1')
assert ta.validate_json('1') == Decimal('1')
assert ta.validate_python(1) == Decimal('1')

0 comments on commit 5320764

Please sign in to comment.