diff --git a/pydantic/types.py b/pydantic/types.py index 5d6b3a167c..b47b917291 100644 --- a/pydantic/types.py +++ b/pydantic/types.py @@ -1710,37 +1710,6 @@ def validate_brand(card_number: str) -> PaymentCardBrand: # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BYTE SIZE TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -BYTE_SIZES = { - 'b': 1, - 'kb': 10**3, - 'mb': 10**6, - 'gb': 10**9, - 'tb': 10**12, - 'pb': 10**15, - 'eb': 10**18, - 'kib': 2**10, - 'mib': 2**20, - 'gib': 2**30, - 'tib': 2**40, - 'pib': 2**50, - 'eib': 2**60, - 'bit': 1 / 8, - 'kbit': 10**3 / 8, - 'mbit': 10**6 / 8, - 'gbit': 10**9 / 8, - 'tbit': 10**12 / 8, - 'pbit': 10**15 / 8, - 'ebit': 10**18 / 8, - 'kibit': 2**10 / 8, - 'mibit': 2**20 / 8, - 'gibit': 2**30 / 8, - 'tibit': 2**40 / 8, - 'pibit': 2**50 / 8, - 'eibit': 2**60 / 8, -} -BYTE_SIZES.update({k.lower()[0]: v for k, v in BYTE_SIZES.items() if 'i' not in k}) -byte_string_re = re.compile(r'^\s*(\d*\.?\d+)\s*(\w+)?', re.IGNORECASE) - class ByteSize(int): """Converts a string representing a number of bytes with units (such as `'1KB'` or `'11.5MiB'`) into an integer. @@ -1777,9 +1746,53 @@ class MyModel(BaseModel): ``` """ + byte_sizes = { + 'b': 1, + 'kb': 10**3, + 'mb': 10**6, + 'gb': 10**9, + 'tb': 10**12, + 'pb': 10**15, + 'eb': 10**18, + 'kib': 2**10, + 'mib': 2**20, + 'gib': 2**30, + 'tib': 2**40, + 'pib': 2**50, + 'eib': 2**60, + 'bit': 1 / 8, + 'kbit': 10**3 / 8, + 'mbit': 10**6 / 8, + 'gbit': 10**9 / 8, + 'tbit': 10**12 / 8, + 'pbit': 10**15 / 8, + 'ebit': 10**18 / 8, + 'kibit': 2**10 / 8, + 'mibit': 2**20 / 8, + 'gibit': 2**30 / 8, + 'tibit': 2**40 / 8, + 'pibit': 2**50 / 8, + 'eibit': 2**60 / 8, + } + byte_sizes.update({k.lower()[0]: v for k, v in byte_sizes.items() if 'i' not in k}) + + byte_string_pattern = r'^\s*(\d*\.?\d+)\s*(\w+)?' + byte_string_re = re.compile(byte_string_pattern, re.IGNORECASE) + @classmethod def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: - return core_schema.with_info_plain_validator_function(cls._validate) + return core_schema.with_info_after_validator_function( + function=cls._validate, + schema=core_schema.union_schema( + [ + core_schema.str_schema(pattern=cls.byte_string_pattern), + core_schema.int_schema(ge=0), + ] + ), + serialization=core_schema.plain_serializer_function_ser_schema( + int, return_schema=core_schema.int_schema(ge=0) + ), + ) @classmethod def _validate(cls, __input_value: Any, _: core_schema.ValidationInfo) -> ByteSize: @@ -1788,7 +1801,7 @@ def _validate(cls, __input_value: Any, _: core_schema.ValidationInfo) -> ByteSiz except ValueError: pass - str_match = byte_string_re.match(str(__input_value)) + str_match = cls.byte_string_re.match(str(__input_value)) if str_match is None: raise PydanticCustomError('byte_size', 'could not parse value and unit from byte string') @@ -1797,7 +1810,7 @@ def _validate(cls, __input_value: Any, _: core_schema.ValidationInfo) -> ByteSiz unit = 'b' try: - unit_mult = BYTE_SIZES[unit.lower()] + unit_mult = cls.byte_sizes[unit.lower()] except KeyError: raise PydanticCustomError('byte_size_unit', 'could not interpret byte unit: {unit}', {'unit': unit}) @@ -1846,7 +1859,7 @@ def to(self, unit: str) -> float: The byte size in the new unit. """ try: - unit_div = BYTE_SIZES[unit.lower()] + unit_div = self.byte_sizes[unit.lower()] except KeyError: raise PydanticCustomError('byte_size_unit', 'Could not interpret byte unit: {unit}', {'unit': unit}) diff --git a/tests/test_json_schema.py b/tests/test_json_schema.py index 3e61715648..7a6488ad18 100644 --- a/tests/test_json_schema.py +++ b/tests/test_json_schema.py @@ -6,7 +6,14 @@ from datetime import date, datetime, time, timedelta from decimal import Decimal from enum import Enum, IntEnum -from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network +from ipaddress import ( + IPv4Address, + IPv4Interface, + IPv4Network, + IPv6Address, + IPv6Interface, + IPv6Network, +) from pathlib import Path from typing import ( Any, @@ -68,13 +75,22 @@ model_json_schema, models_json_schema, ) -from pydantic.networks import AnyUrl, EmailStr, IPvAnyAddress, IPvAnyInterface, IPvAnyNetwork, MultiHostUrl, NameEmail +from pydantic.networks import ( + AnyUrl, + EmailStr, + IPvAnyAddress, + IPvAnyInterface, + IPvAnyNetwork, + MultiHostUrl, + NameEmail, +) from pydantic.type_adapter import TypeAdapter from pydantic.types import ( UUID1, UUID3, UUID4, UUID5, + ByteSize, DirectoryPath, FilePath, Json, @@ -1292,6 +1308,50 @@ class MyGenerator(GenerateJsonSchema): assert model_schema['properties'] == properties +def test_byte_size_type(): + class Model(BaseModel): + a: ByteSize + b: ByteSize = Field('1MB', validate_default=True) + + model_json_schema_validation = Model.model_json_schema(mode='validation') + model_json_schema_serialization = Model.model_json_schema(mode='serialization') + + print(model_json_schema_serialization) + + assert model_json_schema_validation == { + 'properties': { + 'a': { + 'anyOf': [ + {'pattern': '^\\s*(\\d*\\.?\\d+)\\s*(\\w+)?', 'type': 'string'}, + {'minimum': 0, 'type': 'integer'}, + ], + 'title': 'A', + }, + 'b': { + 'anyOf': [ + {'pattern': '^\\s*(\\d*\\.?\\d+)\\s*(\\w+)?', 'type': 'string'}, + {'minimum': 0, 'type': 'integer'}, + ], + 'default': '1MB', + 'title': 'B', + }, + }, + 'required': ['a'], + 'title': 'Model', + 'type': 'object', + } + + assert model_json_schema_serialization == { + 'properties': { + 'a': {'minimum': 0, 'title': 'A', 'type': 'integer'}, + 'b': {'default': '1MB', 'minimum': 0, 'title': 'B', 'type': 'integer'}, + }, + 'required': ['a'], + 'title': 'Model', + 'type': 'object', + } + + @pytest.mark.parametrize( 'type_,default_value,properties', ( diff --git a/tests/test_types.py b/tests/test_types.py index e2ea29e082..d62f10e414 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -4434,6 +4434,7 @@ class FrozenSetModel(BaseModel): @pytest.mark.parametrize( 'input_value,output,human_bin,human_dec', ( + (1, 1, '1B', '1B'), ('1', 1, '1B', '1B'), ('1.0', 1, '1B', '1B'), ('1b', 1, '1B', '1B'), @@ -4476,7 +4477,7 @@ def test_bytesize_raises(): class Model(BaseModel): size: ByteSize - with pytest.raises(ValidationError, match='parse value'): + with pytest.raises(ValidationError, match='should match'): Model(size='d1MB') with pytest.raises(ValidationError, match='byte unit'):