From b4df1f2eeece2ba7106adaa2b1ad2db25dc7cc51 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 29 Aug 2023 08:49:46 -0600 Subject: [PATCH] Add Base64Url types --- docs/usage/types/encoded.md | 45 +++++++++++++++- pydantic/__init__.py | 2 + pydantic/types.py | 52 ++++++++++++++++-- tests/test_types.py | 102 ++++++++++++++++++++++++++++++++++++ 4 files changed, 197 insertions(+), 4 deletions(-) diff --git a/docs/usage/types/encoded.md b/docs/usage/types/encoded.md index dcd0d4864c..19205f7989 100644 --- a/docs/usage/types/encoded.md +++ b/docs/usage/types/encoded.md @@ -86,7 +86,8 @@ except ValidationError as e: Internally, Pydantic uses the [`EncodedBytes`][pydantic.types.EncodedBytes] and [`EncodedStr`][pydantic.types.EncodedStr] annotations with [`Base64Encoder`][pydantic.types.Base64Encoder] to implement base64 encoding/decoding in the -[`Base64Bytes`][pydantic.types.Base64Bytes] and [`Base64Str`][pydantic.types.Base64Str] types, respectively. +[`Base64Bytes`][pydantic.types.Base64Bytes], [`Base64UrlBytes`][pydantic.types.Base64UrlBytes], +[`Base64Str`][pydantic.types.Base64Str], and [`Base64UrlStr`][pydantic.types.Base64Str] types. ```py from typing import Optional @@ -131,3 +132,45 @@ except ValidationError as e: Base64 decoding error: 'Incorrect padding' [type=base64_decode, input_value=b'undecodable', input_type=bytes] """ ``` + +If you need url-safe base64 encoding, you can use the `Base64UrlBytes` and `Base64UrlStr` types. The following snippet +demonstrates the difference in alphabets used by the url-safe and non-url-safe encodings: + +```py +from pydantic import ( + Base64Bytes, + Base64Str, + Base64UrlBytes, + Base64UrlStr, + BaseModel, +) + + +class Model(BaseModel): + base64_bytes: Base64Bytes + base64_str: Base64Str + base64url_bytes: Base64UrlBytes + base64url_str: Base64UrlStr + + +# Initialize the model with base64 data +m = Model( + base64_bytes=b'SHc/dHc+TXc==', + base64_str='SHc/dHc+TXc==', + base64url_bytes=b'SHc_dHc-TXc==', + base64url_str='SHc_dHc-TXc==', +) +print(m) +""" +base64_bytes=b'Hw?tw>Mw' base64_str='Hw?tw>Mw' base64url_bytes=b'Hw?tw>Mw' base64url_str='Hw?tw>Mw' +""" +``` + +!!! note + Under the hood, `Base64Bytes` and `Base64Str` use the standard library `base64.encodebytes` and `base64.decodebytes` + functions, while `Base64UrlBytes` and `Base64UrlStr` use the `base64.urlsafe_b64encode` and + `base64.urlsafe_b64decode` functions. + + As a result, the `Base64UrlBytes` and `Base64UrlStr` types can be used to faithfully decode "vanilla" base64 data + (using `'+'` and `'/'`), but the reverse is not true — attempting to decode url-safe base64 data using the + `Base64Bytes` and `Base64Str` types may fail or produce an incorrect decoding. diff --git a/pydantic/__init__.py b/pydantic/__init__.py index 24590b0de4..9aa865cb80 100644 --- a/pydantic/__init__.py +++ b/pydantic/__init__.py @@ -179,6 +179,8 @@ 'Base64Encoder', 'Base64Bytes', 'Base64Str', + 'Base64UrlBytes', + 'Base64UrlStr', 'SkipValidation', 'InstanceOf', 'WithJsonSchema', diff --git a/pydantic/types.py b/pydantic/types.py index 3c6a227f78..d3df0d8d00 100644 --- a/pydantic/types.py +++ b/pydantic/types.py @@ -96,6 +96,8 @@ 'Base64Encoder', 'Base64Bytes', 'Base64Str', + 'Base64UrlBytes', + 'Base64UrlStr', 'GetPydanticSchema', 'StringConstraints', ) @@ -1233,7 +1235,7 @@ def get_json_format(cls) -> str: class Base64Encoder(EncoderProtocol): - """Base64 encoder.""" + """Standard (non-URL-safe) Base64 encoder.""" @classmethod def decode(cls, data: bytes) -> bytes: @@ -1272,6 +1274,46 @@ def get_json_format(cls) -> Literal['base64']: return 'base64' +class Base64UrlEncoder(EncoderProtocol): + """URL-safe Base64 encoder.""" + + @classmethod + def decode(cls, data: bytes) -> bytes: + """Decode the data from base64 encoded bytes to original bytes data. + + Args: + data: The data to decode. + + Returns: + The decoded data. + """ + try: + return base64.urlsafe_b64decode(data) + except ValueError as e: + raise PydanticCustomError('base64_decode', "Base64 decoding error: '{error}'", {'error': str(e)}) + + @classmethod + def encode(cls, value: bytes) -> bytes: + """Encode the data from bytes to a base64 encoded bytes. + + Args: + value: The data to encode. + + Returns: + The encoded data. + """ + return base64.urlsafe_b64encode(value) + + @classmethod + def get_json_format(cls) -> Literal['base64url']: + """Get the JSON format for the encoded data. + + Returns: + The JSON format for the encoded data. + """ + return 'base64url' + + @_dataclasses.dataclass(**_internal_dataclass.slots_true) class EncodedBytes: """A bytes type that is encoded and decoded using the specified encoder.""" @@ -1356,9 +1398,13 @@ def encode_str(self, value: str) -> str: Base64Bytes = Annotated[bytes, EncodedBytes(encoder=Base64Encoder)] -"""A bytes type that is encoded and decoded using the base64 encoder.""" +"""A bytes type that is encoded and decoded using the standard (non-URL-safe) base64 encoder.""" Base64Str = Annotated[str, EncodedStr(encoder=Base64Encoder)] -"""A str type that is encoded and decoded using the base64 encoder.""" +"""A str type that is encoded and decoded using the standard (non-URL-safe) base64 encoder.""" +Base64UrlBytes = Annotated[bytes, EncodedBytes(encoder=Base64UrlEncoder)] +"""A bytes type that is encoded and decoded using the URL-safe base64 encoder.""" +Base64UrlStr = Annotated[str, EncodedStr(encoder=Base64UrlEncoder)] +"""A str type that is encoded and decoded using the URL-safe base64 encoder.""" __getattr__ = getattr_migration(__name__) diff --git a/tests/test_types.py b/tests/test_types.py index 4d06c8c666..d7dd44da50 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -50,6 +50,8 @@ AwareDatetime, Base64Bytes, Base64Str, + Base64UrlBytes, + Base64UrlStr, BaseModel, ByteSize, ConfigDict, @@ -4882,6 +4884,13 @@ class Model(BaseModel): pytest.param( Base64Str, bytearray(b'Zm9vIGJhcg=='), 'foo bar', 'Zm9vIGJhcg==\n', id='Base64Str-bytearray-input' ), + pytest.param( + Base64Bytes, + b'BCq+6+1/Paun/Q==', + b'\x04*\xbe\xeb\xed\x7f=\xab\xa7\xfd', + b'BCq+6+1/Paun/Q==\n', + id='Base64Bytes-bytes-alphabet-vanilla', + ), ], ) def test_base64(field_type, input_data, expected_value, serialized_data): @@ -4946,6 +4955,99 @@ class Model(BaseModel): ] +@pytest.mark.parametrize( + ('field_type', 'input_data', 'expected_value', 'serialized_data'), + [ + pytest.param(Base64UrlBytes, b'Zm9vIGJhcg==\n', b'foo bar', b'Zm9vIGJhcg==', id='Base64UrlBytes-reversible'), + pytest.param(Base64UrlStr, 'Zm9vIGJhcg==\n', 'foo bar', 'Zm9vIGJhcg==', id='Base64UrlStr-reversible'), + pytest.param(Base64UrlBytes, b'Zm9vIGJhcg==', b'foo bar', b'Zm9vIGJhcg==', id='Base64UrlBytes-bytes-input'), + pytest.param(Base64UrlBytes, 'Zm9vIGJhcg==', b'foo bar', b'Zm9vIGJhcg==', id='Base64UrlBytes-str-input'), + pytest.param( + Base64UrlBytes, bytearray(b'Zm9vIGJhcg=='), b'foo bar', b'Zm9vIGJhcg==', id='Base64UrlBytes-bytearray-input' + ), + pytest.param(Base64UrlStr, b'Zm9vIGJhcg==', 'foo bar', 'Zm9vIGJhcg==', id='Base64UrlStr-bytes-input'), + pytest.param(Base64UrlStr, 'Zm9vIGJhcg==', 'foo bar', 'Zm9vIGJhcg==', id='Base64UrlStr-str-input'), + pytest.param( + Base64UrlStr, bytearray(b'Zm9vIGJhcg=='), 'foo bar', 'Zm9vIGJhcg==', id='Base64UrlStr-bytearray-input' + ), + pytest.param( + Base64UrlBytes, + b'BCq-6-1_Paun_Q==', + b'\x04*\xbe\xeb\xed\x7f=\xab\xa7\xfd', + b'BCq-6-1_Paun_Q==', + id='Base64UrlBytes-bytes-alphabet-url', + ), + pytest.param( + Base64UrlBytes, + b'BCq+6+1/Paun/Q==', + b'\x04*\xbe\xeb\xed\x7f=\xab\xa7\xfd', + b'BCq-6-1_Paun_Q==', + id='Base64UrlBytes-bytes-alphabet-vanilla', + ), + ], +) +def test_base64url(field_type, input_data, expected_value, serialized_data): + class Model(BaseModel): + base64url_value: field_type + base64url_value_or_none: Optional[field_type] = None + + m = Model(base64url_value=input_data) + assert m.base64url_value == expected_value + + m = Model.model_construct(base64url_value=expected_value) + assert m.base64url_value == expected_value + + assert m.model_dump() == { + 'base64url_value': serialized_data, + 'base64url_value_or_none': None, + } + + assert Model.model_json_schema() == { + 'properties': { + 'base64url_value': { + 'format': 'base64url', + 'title': 'Base64Url Value', + 'type': 'string', + }, + 'base64url_value_or_none': { + 'anyOf': [{'type': 'string', 'format': 'base64url'}, {'type': 'null'}], + 'default': None, + 'title': 'Base64Url Value Or None', + }, + }, + 'required': ['base64url_value'], + 'title': 'Model', + 'type': 'object', + } + + +@pytest.mark.parametrize( + ('field_type', 'input_data'), + [ + pytest.param(Base64UrlBytes, b'Zm9vIGJhcg', id='Base64UrlBytes-invalid-base64-bytes'), + pytest.param(Base64UrlBytes, 'Zm9vIGJhcg', id='Base64UrlBytes-invalid-base64-str'), + pytest.param(Base64UrlStr, b'Zm9vIGJhcg', id='Base64UrlStr-invalid-base64-bytes'), + pytest.param(Base64UrlStr, 'Zm9vIGJhcg', id='Base64UrlStr-invalid-base64-str'), + ], +) +def test_base64url_invalid(field_type, input_data): + class Model(BaseModel): + base64url_value: field_type + + with pytest.raises(ValidationError) as e: + Model(base64url_value=input_data) + + assert e.value.errors(include_url=False) == [ + { + 'ctx': {'error': 'Incorrect padding'}, + 'input': input_data, + 'loc': ('base64url_value',), + 'msg': "Base64 decoding error: 'Incorrect padding'", + 'type': 'base64_decode', + }, + ] + + def test_sequence_subclass_without_core_schema() -> None: class MyList(List[int]): # The point of this is that subclasses can do arbitrary things