From 69a3bca7582859bc8ce83ccc28b0cf33f7909388 Mon Sep 17 00:00:00 2001 From: David Lord Date: Sat, 13 Apr 2024 11:53:08 -0700 Subject: [PATCH] improve typing * use deferred annotations * use collections.abc * check with pyright * verify exported API with pyright --- src/itsdangerous/_json.py | 8 ++- src/itsdangerous/encoding.py | 14 ++--- src/itsdangerous/exc.py | 29 +++++---- src/itsdangerous/serializer.py | 106 ++++++++++++++++----------------- src/itsdangerous/signer.py | 55 ++++++++--------- src/itsdangerous/timed.py | 77 +++++++++++------------- src/itsdangerous/url_safe.py | 14 +++-- tox.ini | 5 +- 8 files changed, 154 insertions(+), 154 deletions(-) diff --git a/src/itsdangerous/_json.py b/src/itsdangerous/_json.py index c70d37a..fc23fea 100644 --- a/src/itsdangerous/_json.py +++ b/src/itsdangerous/_json.py @@ -1,16 +1,18 @@ +from __future__ import annotations + import json as _json -import typing as _t +import typing as t class _CompactJSON: """Wrapper around json module that strips whitespace.""" @staticmethod - def loads(payload: _t.Union[str, bytes]) -> _t.Any: + def loads(payload: str | bytes) -> t.Any: return _json.loads(payload) @staticmethod - def dumps(obj: _t.Any, **kwargs: _t.Any) -> str: + def dumps(obj: t.Any, **kwargs: t.Any) -> str: kwargs.setdefault("ensure_ascii", False) kwargs.setdefault("separators", (",", ":")) return _json.dumps(obj, **kwargs) diff --git a/src/itsdangerous/encoding.py b/src/itsdangerous/encoding.py index edb04d1..f5ca80f 100644 --- a/src/itsdangerous/encoding.py +++ b/src/itsdangerous/encoding.py @@ -1,15 +1,15 @@ +from __future__ import annotations + import base64 import string import struct -import typing as _t +import typing as t from .exc import BadData -_t_str_bytes = _t.Union[str, bytes] - def want_bytes( - s: _t_str_bytes, encoding: str = "utf-8", errors: str = "strict" + s: str | bytes, encoding: str = "utf-8", errors: str = "strict" ) -> bytes: if isinstance(s, str): s = s.encode(encoding, errors) @@ -17,7 +17,7 @@ def want_bytes( return s -def base64_encode(string: _t_str_bytes) -> bytes: +def base64_encode(string: str | bytes) -> bytes: """Base64 encode a string of bytes or text. The resulting bytes are safe to use in URLs. """ @@ -25,7 +25,7 @@ def base64_encode(string: _t_str_bytes) -> bytes: return base64.urlsafe_b64encode(string).rstrip(b"=") -def base64_decode(string: _t_str_bytes) -> bytes: +def base64_decode(string: str | bytes) -> bytes: """Base64 decode a URL-safe string of bytes or text. The result is bytes. """ @@ -43,7 +43,7 @@ def base64_decode(string: _t_str_bytes) -> bytes: _int64_struct = struct.Struct(">Q") _int_to_bytes = _int64_struct.pack -_bytes_to_int = _t.cast("_t.Callable[[bytes], _t.Tuple[int]]", _int64_struct.unpack) +_bytes_to_int = t.cast("t.Callable[[bytes], tuple[int]]", _int64_struct.unpack) def int_to_bytes(num: int) -> bytes: diff --git a/src/itsdangerous/exc.py b/src/itsdangerous/exc.py index c38a6af..a75adcd 100644 --- a/src/itsdangerous/exc.py +++ b/src/itsdangerous/exc.py @@ -1,8 +1,7 @@ -import typing as _t -from datetime import datetime +from __future__ import annotations -_t_opt_any = _t.Optional[_t.Any] -_t_opt_exc = _t.Optional[Exception] +import typing as t +from datetime import datetime class BadData(Exception): @@ -23,7 +22,7 @@ def __str__(self) -> str: class BadSignature(BadData): """Raised if a signature does not match.""" - def __init__(self, message: str, payload: _t_opt_any = None): + def __init__(self, message: str, payload: t.Any | None = None): super().__init__(message) #: The payload that failed the signature test. In some @@ -31,7 +30,7 @@ def __init__(self, message: str, payload: _t_opt_any = None): #: you know it was tampered with. #: #: .. versionadded:: 0.14 - self.payload: _t_opt_any = payload + self.payload: t.Any | None = payload class BadTimeSignature(BadSignature): @@ -42,8 +41,8 @@ class BadTimeSignature(BadSignature): def __init__( self, message: str, - payload: _t_opt_any = None, - date_signed: _t.Optional[datetime] = None, + payload: t.Any | None = None, + date_signed: datetime | None = None, ): super().__init__(message, payload) @@ -75,19 +74,19 @@ class BadHeader(BadSignature): def __init__( self, message: str, - payload: _t_opt_any = None, - header: _t_opt_any = None, - original_error: _t_opt_exc = None, + payload: t.Any | None = None, + header: t.Any | None = None, + original_error: Exception | None = None, ): super().__init__(message, payload) #: If the header is actually available but just malformed it #: might be stored here. - self.header: _t_opt_any = header + self.header: t.Any | None = header #: If available, the error that indicates why the payload was #: not valid. This might be ``None``. - self.original_error: _t_opt_exc = original_error + self.original_error: Exception | None = original_error class BadPayload(BadData): @@ -99,9 +98,9 @@ class BadPayload(BadData): .. versionadded:: 0.15 """ - def __init__(self, message: str, original_error: _t_opt_exc = None): + def __init__(self, message: str, original_error: Exception | None = None): super().__init__(message) #: If available, the error that indicates why the payload was #: not valid. This might be ``None``. - self.original_error: _t_opt_exc = original_error + self.original_error: Exception | None = original_error diff --git a/src/itsdangerous/serializer.py b/src/itsdangerous/serializer.py index 1a66789..362bf79 100644 --- a/src/itsdangerous/serializer.py +++ b/src/itsdangerous/serializer.py @@ -1,5 +1,8 @@ +from __future__ import annotations + +import collections.abc as cabc import json -import typing as _t +import typing as t from .encoding import want_bytes from .exc import BadPayload @@ -7,17 +10,8 @@ from .signer import _make_keys_list from .signer import Signer -_t_str_bytes = _t.Union[str, bytes] -_t_opt_str_bytes = _t.Optional[_t_str_bytes] -_t_kwargs = _t.Dict[str, _t.Any] -_t_opt_kwargs = _t.Optional[_t_kwargs] -_t_signer = _t.Type[Signer] -_t_fallbacks = _t.List[_t.Union[_t_kwargs, _t.Tuple[_t_signer, _t_kwargs], _t_signer]] -_t_load_unsafe = _t.Tuple[bool, _t.Any] -_t_secret_key = _t.Union[_t.Iterable[_t_str_bytes], _t_str_bytes] - -def is_text_serializer(serializer: _t.Any) -> bool: +def is_text_serializer(serializer: t.Any) -> bool: """Checks whether a serializer generates text or binary.""" return isinstance(serializer.dumps({}), str) @@ -77,31 +71,36 @@ class Serializer: #: The default serialization module to use to serialize data to a #: string internally. The default is :mod:`json`, but can be changed #: to any object that provides ``dumps`` and ``loads`` methods. - default_serializer: _t.Any = json + default_serializer: t.Any = json #: The default ``Signer`` class to instantiate when signing data. #: The default is :class:`itsdangerous.signer.Signer`. - default_signer: _t_signer = Signer + default_signer: type[Signer] = Signer #: The default fallback signers to try when unsigning fails. - default_fallback_signers: _t_fallbacks = [] + default_fallback_signers: list[ + dict[str, t.Any] | tuple[type[Signer], dict[str, t.Any]] | type[Signer] + ] = [] def __init__( self, - secret_key: _t_secret_key, - salt: _t_opt_str_bytes = b"itsdangerous", - serializer: _t.Any = None, - serializer_kwargs: _t_opt_kwargs = None, - signer: _t.Optional[_t_signer] = None, - signer_kwargs: _t_opt_kwargs = None, - fallback_signers: _t.Optional[_t_fallbacks] = None, + secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes], + salt: str | bytes | None = b"itsdangerous", + serializer: t.Any = None, + serializer_kwargs: dict[str, t.Any] | None = None, + signer: type[Signer] | None = None, + signer_kwargs: dict[str, t.Any] | None = None, + fallback_signers: list[ + dict[str, t.Any] | tuple[type[Signer], dict[str, t.Any]] | type[Signer] + ] + | None = None, ): #: The list of secret keys to try for verifying signatures, from #: oldest to newest. The newest (last) key is used for signing. #: #: This allows a key rotation system to keep a list of allowed #: keys and remove expired ones. - self.secret_keys: _t.List[bytes] = _make_keys_list(secret_key) + self.secret_keys: list[bytes] = _make_keys_list(secret_key) if salt is not None: salt = want_bytes(salt) @@ -112,20 +111,22 @@ def __init__( if serializer is None: serializer = self.default_serializer - self.serializer: _t.Any = serializer + self.serializer: t.Any = serializer self.is_text_serializer: bool = is_text_serializer(serializer) if signer is None: signer = self.default_signer - self.signer: _t_signer = signer - self.signer_kwargs: _t_kwargs = signer_kwargs or {} + self.signer: type[Signer] = signer + self.signer_kwargs: dict[str, t.Any] = signer_kwargs or {} if fallback_signers is None: - fallback_signers = list(self.default_fallback_signers or ()) + fallback_signers = list(self.default_fallback_signers) - self.fallback_signers: _t_fallbacks = fallback_signers - self.serializer_kwargs: _t_kwargs = serializer_kwargs or {} + self.fallback_signers: list[ + dict[str, t.Any] | tuple[type[Signer], dict[str, t.Any]] | type[Signer] + ] = fallback_signers + self.serializer_kwargs: dict[str, t.Any] = serializer_kwargs or {} @property def secret_key(self) -> bytes: @@ -134,9 +135,7 @@ def secret_key(self) -> bytes: """ return self.secret_keys[-1] - def load_payload( - self, payload: bytes, serializer: _t.Optional[_t.Any] = None - ) -> _t.Any: + def load_payload(self, payload: bytes, serializer: t.Any | None = None) -> t.Any: """Loads the encoded object. This function raises :class:`.BadPayload` if the payload is not valid. The ``serializer`` parameter can be used to override the serializer @@ -144,16 +143,17 @@ def load_payload( bytes. """ if serializer is None: - serializer = self.serializer + use_serializer = self.serializer is_text = self.is_text_serializer else: + use_serializer = serializer is_text = is_text_serializer(serializer) try: if is_text: - return serializer.loads(payload.decode("utf-8")) + return use_serializer.loads(payload.decode("utf-8")) - return serializer.loads(payload) + return use_serializer.loads(payload) except Exception as e: raise BadPayload( "Could not load the payload because an exception" @@ -161,14 +161,14 @@ def load_payload( original_error=e, ) from e - def dump_payload(self, obj: _t.Any) -> bytes: + def dump_payload(self, obj: t.Any) -> bytes: """Dumps the encoded object. The return value is always bytes. If the internal serializer returns text, the value will be encoded as UTF-8. """ return want_bytes(self.serializer.dumps(obj, **self.serializer_kwargs)) - def make_signer(self, salt: _t_opt_str_bytes = None) -> Signer: + def make_signer(self, salt: str | bytes | None = None) -> Signer: """Creates a new instance of the signer to be used. The default implementation uses the :class:`.Signer` base class. """ @@ -177,7 +177,7 @@ def make_signer(self, salt: _t_opt_str_bytes = None) -> Signer: return self.signer(self.secret_keys, salt=salt, **self.signer_kwargs) - def iter_unsigners(self, salt: _t_opt_str_bytes = None) -> _t.Iterator[Signer]: + def iter_unsigners(self, salt: str | bytes | None = None) -> cabc.Iterator[Signer]: """Iterates over all signers to be tried for unsigning. Starts with the configured signer, then constructs each signer specified in ``fallback_signers``. @@ -199,7 +199,7 @@ def iter_unsigners(self, salt: _t_opt_str_bytes = None) -> _t.Iterator[Signer]: for secret_key in self.secret_keys: yield fallback(secret_key, salt=salt, **kwargs) - def dumps(self, obj: _t.Any, salt: _t_opt_str_bytes = None) -> _t_str_bytes: + def dumps(self, obj: t.Any, salt: str | bytes | None = None) -> str | bytes: """Returns a signed string serialized with the internal serializer. The return value can be either a byte or unicode string depending on the format of the internal serializer. @@ -212,17 +212,15 @@ def dumps(self, obj: _t.Any, salt: _t_opt_str_bytes = None) -> _t_str_bytes: return rv - def dump( - self, obj: _t.Any, f: _t.IO[_t.Any], salt: _t_opt_str_bytes = None - ) -> None: + def dump(self, obj: t.Any, f: t.IO[t.Any], salt: str | bytes | None = None) -> None: """Like :meth:`dumps` but dumps into a file. The file handle has to be compatible with what the internal serializer expects. """ f.write(self.dumps(obj, salt)) def loads( - self, s: _t_str_bytes, salt: _t_opt_str_bytes = None, **kwargs: _t.Any - ) -> _t.Any: + self, s: str | bytes, salt: str | bytes | None = None, **kwargs: t.Any + ) -> t.Any: """Reverse of :meth:`dumps`. Raises :exc:`.BadSignature` if the signature validation fails. """ @@ -235,15 +233,15 @@ def loads( except BadSignature as err: last_exception = err - raise _t.cast(BadSignature, last_exception) + raise t.cast(BadSignature, last_exception) - def load(self, f: _t.IO[_t.Any], salt: _t_opt_str_bytes = None) -> _t.Any: + def load(self, f: t.IO[t.Any], salt: str | bytes | None = None) -> t.Any: """Like :meth:`loads` but loads from a file.""" return self.loads(f.read(), salt) def loads_unsafe( - self, s: _t_str_bytes, salt: _t_opt_str_bytes = None - ) -> _t_load_unsafe: + self, s: str | bytes, salt: str | bytes | None = None + ) -> tuple[bool, t.Any]: """Like :meth:`loads` but without verifying the signature. This is potentially very dangerous to use depending on how your serializer works. The return value is ``(signature_valid, @@ -261,11 +259,11 @@ def loads_unsafe( def _loads_unsafe_impl( self, - s: _t_str_bytes, - salt: _t_opt_str_bytes, - load_kwargs: _t_opt_kwargs = None, - load_payload_kwargs: _t_opt_kwargs = None, - ) -> _t_load_unsafe: + s: str | bytes, + salt: str | bytes | None, + load_kwargs: dict[str, t.Any] | None = None, + load_payload_kwargs: dict[str, t.Any] | None = None, + ) -> tuple[bool, t.Any]: """Low level helper function to implement :meth:`loads_unsafe` in serializer subclasses. """ @@ -290,8 +288,8 @@ def _loads_unsafe_impl( return False, None def load_unsafe( - self, f: _t.IO[_t.Any], salt: _t_opt_str_bytes = None - ) -> _t_load_unsafe: + self, f: t.IO[t.Any], salt: str | bytes | None = None + ) -> tuple[bool, t.Any]: """Like :meth:`loads_unsafe` but loads from a file. .. versionadded:: 0.15 diff --git a/src/itsdangerous/signer.py b/src/itsdangerous/signer.py index aa12005..6410e65 100644 --- a/src/itsdangerous/signer.py +++ b/src/itsdangerous/signer.py @@ -1,6 +1,9 @@ +from __future__ import annotations + +import collections.abc as cabc import hashlib import hmac -import typing as _t +import typing as t from .encoding import _base64_alphabet from .encoding import base64_decode @@ -8,10 +11,6 @@ from .encoding import want_bytes from .exc import BadSignature -_t_str_bytes = _t.Union[str, bytes] -_t_opt_str_bytes = _t.Optional[_t_str_bytes] -_t_secret_key = _t.Union[_t.Iterable[_t_str_bytes], _t_str_bytes] - class SigningAlgorithm: """Subclasses must implement :meth:`get_signature` to provide @@ -44,24 +43,26 @@ class HMACAlgorithm(SigningAlgorithm): #: The digest method to use with the MAC algorithm. This defaults to #: SHA1, but can be changed to any other function in the hashlib #: module. - default_digest_method: _t.Any = staticmethod(hashlib.sha1) + default_digest_method: t.Any = staticmethod(hashlib.sha1) - def __init__(self, digest_method: _t.Any = None): + def __init__(self, digest_method: t.Any = None): if digest_method is None: digest_method = self.default_digest_method - self.digest_method: _t.Any = digest_method + self.digest_method: t.Any = digest_method def get_signature(self, key: bytes, value: bytes) -> bytes: mac = hmac.new(key, msg=value, digestmod=self.digest_method) return mac.digest() -def _make_keys_list(secret_key: _t_secret_key) -> _t.List[bytes]: +def _make_keys_list( + secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes], +) -> list[bytes]: if isinstance(secret_key, (str, bytes)): return [want_bytes(secret_key)] - return [want_bytes(s) for s in secret_key] + return [want_bytes(s) for s in secret_key] # pyright: ignore class Signer: @@ -108,7 +109,7 @@ class Signer: #: doesn't apply when used intermediately in HMAC. #: #: .. versionadded:: 0.14 - default_digest_method: _t.Any = staticmethod(hashlib.sha1) + default_digest_method: t.Any = staticmethod(hashlib.sha1) #: The default scheme to use to derive the signing key from the #: secret key and salt. The default is ``django-concat``. Possible @@ -119,19 +120,19 @@ class Signer: def __init__( self, - secret_key: _t_secret_key, - salt: _t_opt_str_bytes = b"itsdangerous.Signer", - sep: _t_str_bytes = b".", - key_derivation: _t.Optional[str] = None, - digest_method: _t.Optional[_t.Any] = None, - algorithm: _t.Optional[SigningAlgorithm] = None, + secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes], + salt: str | bytes | None = b"itsdangerous.Signer", + sep: str | bytes = b".", + key_derivation: str | None = None, + digest_method: t.Any | None = None, + algorithm: SigningAlgorithm | None = None, ): #: The list of secret keys to try for verifying signatures, from #: oldest to newest. The newest (last) key is used for signing. #: #: This allows a key rotation system to keep a list of allowed #: keys and remove expired ones. - self.secret_keys: _t.List[bytes] = _make_keys_list(secret_key) + self.secret_keys: list[bytes] = _make_keys_list(secret_key) self.sep: bytes = want_bytes(sep) if self.sep in _base64_alphabet: @@ -156,7 +157,7 @@ def __init__( if digest_method is None: digest_method = self.default_digest_method - self.digest_method: _t.Any = digest_method + self.digest_method: t.Any = digest_method if algorithm is None: algorithm = HMACAlgorithm(self.digest_method) @@ -170,7 +171,7 @@ def secret_key(self) -> bytes: """ return self.secret_keys[-1] - def derive_key(self, secret_key: _t_opt_str_bytes = None) -> bytes: + def derive_key(self, secret_key: str | bytes | None = None) -> bytes: """This method is called to derive the key. The default key derivation choices can be overridden here. Key derivation is not intended to be used as a security method to make a complex key @@ -189,9 +190,9 @@ def derive_key(self, secret_key: _t_opt_str_bytes = None) -> bytes: secret_key = want_bytes(secret_key) if self.key_derivation == "concat": - return _t.cast(bytes, self.digest_method(self.salt + secret_key).digest()) + return t.cast(bytes, self.digest_method(self.salt + secret_key).digest()) elif self.key_derivation == "django-concat": - return _t.cast( + return t.cast( bytes, self.digest_method(self.salt + b"signer" + secret_key).digest() ) elif self.key_derivation == "hmac": @@ -203,19 +204,19 @@ def derive_key(self, secret_key: _t_opt_str_bytes = None) -> bytes: else: raise TypeError("Unknown key derivation method") - def get_signature(self, value: _t_str_bytes) -> bytes: + def get_signature(self, value: str | bytes) -> bytes: """Returns the signature for the given value.""" value = want_bytes(value) key = self.derive_key() sig = self.algorithm.get_signature(key, value) return base64_encode(sig) - def sign(self, value: _t_str_bytes) -> bytes: + def sign(self, value: str | bytes) -> bytes: """Signs the given string.""" value = want_bytes(value) return value + self.sep + self.get_signature(value) - def verify_signature(self, value: _t_str_bytes, sig: _t_str_bytes) -> bool: + def verify_signature(self, value: str | bytes, sig: str | bytes) -> bool: """Verifies the signature for the given value.""" try: sig = base64_decode(sig) @@ -232,7 +233,7 @@ def verify_signature(self, value: _t_str_bytes, sig: _t_str_bytes) -> bool: return False - def unsign(self, signed_value: _t_str_bytes) -> bytes: + def unsign(self, signed_value: str | bytes) -> bytes: """Unsigns the given string.""" signed_value = want_bytes(signed_value) @@ -246,7 +247,7 @@ def unsign(self, signed_value: _t_str_bytes) -> bytes: raise BadSignature(f"Signature {sig!r} does not match", payload=value) - def validate(self, signed_value: _t_str_bytes) -> bool: + def validate(self, signed_value: str | bytes) -> bool: """Only validates the given signed value. Returns ``True`` if the signature exists and is valid. """ diff --git a/src/itsdangerous/timed.py b/src/itsdangerous/timed.py index e4d88cd..8188b97 100644 --- a/src/itsdangerous/timed.py +++ b/src/itsdangerous/timed.py @@ -1,6 +1,8 @@ +from __future__ import annotations + +import collections.abc as cabc import time -import typing -import typing as _t +import typing as t from datetime import datetime from datetime import timezone @@ -15,13 +17,6 @@ from .serializer import Serializer from .signer import Signer -_t_str_bytes = _t.Union[str, bytes] -_t_opt_str_bytes = _t.Optional[_t_str_bytes] -_t_opt_int = _t.Optional[int] - -if _t.TYPE_CHECKING: - import typing_extensions as _te - class TimestampSigner(Signer): """Works like the regular :class:`.Signer` but also records the time @@ -46,7 +41,7 @@ def timestamp_to_datetime(self, ts: int) -> datetime: """ return datetime.fromtimestamp(ts, tz=timezone.utc) - def sign(self, value: _t_str_bytes) -> bytes: + def sign(self, value: str | bytes) -> bytes: """Signs the given string and also attaches time information.""" value = want_bytes(value) timestamp = base64_encode(int_to_bytes(self.get_timestamp())) @@ -57,28 +52,28 @@ def sign(self, value: _t_str_bytes) -> bytes: # Ignore overlapping signatures check, return_timestamp is the only # parameter that affects the return type. - @typing.overload - def unsign( # type: ignore + @t.overload + def unsign( # type: ignore[overload-overlap] self, - signed_value: _t_str_bytes, - max_age: _t_opt_int = None, - return_timestamp: "_te.Literal[False]" = False, + signed_value: str | bytes, + max_age: int | None = None, + return_timestamp: t.Literal[False] = False, ) -> bytes: ... - @typing.overload + @t.overload def unsign( self, - signed_value: _t_str_bytes, - max_age: _t_opt_int = None, - return_timestamp: "_te.Literal[True]" = True, - ) -> _t.Tuple[bytes, datetime]: ... + signed_value: str | bytes, + max_age: int | None = None, + return_timestamp: t.Literal[True] = True, + ) -> tuple[bytes, datetime]: ... def unsign( self, - signed_value: _t_str_bytes, - max_age: _t_opt_int = None, + signed_value: str | bytes, + max_age: int | None = None, return_timestamp: bool = False, - ) -> _t.Union[_t.Tuple[bytes, datetime], bytes]: + ) -> tuple[bytes, datetime] | bytes: """Works like the regular :meth:`.Signer.unsign` but can also validate the time. See the base docstring of the class for the general behavior. If ``return_timestamp`` is ``True`` the @@ -110,8 +105,8 @@ def unsign( raise BadTimeSignature("timestamp missing", payload=result) value, ts_bytes = result.rsplit(sep, 1) - ts_int: _t_opt_int = None - ts_dt: _t.Optional[datetime] = None + ts_int: int | None = None + ts_dt: datetime | None = None try: ts_int = bytes_to_int(base64_decode(ts_bytes)) @@ -161,7 +156,7 @@ def unsign( return value - def validate(self, signed_value: _t_str_bytes, max_age: _t_opt_int = None) -> bool: + def validate(self, signed_value: str | bytes, max_age: int | None = None) -> bool: """Only validates the given signed value. Returns ``True`` if the signature exists and is valid.""" try: @@ -176,23 +171,23 @@ class TimedSerializer(Serializer): :class:`.Signer`. """ - default_signer: _t.Type[TimestampSigner] = TimestampSigner + default_signer: type[TimestampSigner] = TimestampSigner def iter_unsigners( - self, salt: _t_opt_str_bytes = None - ) -> _t.Iterator[TimestampSigner]: - return _t.cast("_t.Iterator[TimestampSigner]", super().iter_unsigners(salt)) + self, salt: str | bytes | None = None + ) -> cabc.Iterator[TimestampSigner]: + return t.cast("cabc.Iterator[TimestampSigner]", super().iter_unsigners(salt)) # TODO: Signature is incompatible because parameters were added # before salt. - def loads( # type: ignore + def loads( # type: ignore[override] self, - s: _t_str_bytes, - max_age: _t_opt_int = None, + s: str | bytes, + max_age: int | None = None, return_timestamp: bool = False, - salt: _t_opt_str_bytes = None, - ) -> _t.Any: + salt: str | bytes | None = None, + ) -> t.Any: """Reverse of :meth:`dumps`, raises :exc:`.BadSignature` if the signature validation fails. If a ``max_age`` is provided it will ensure the signature is not older than that time in seconds. In @@ -221,12 +216,12 @@ def loads( # type: ignore except BadSignature as err: last_exception = err - raise _t.cast(BadSignature, last_exception) + raise t.cast(BadSignature, last_exception) - def loads_unsafe( # type: ignore + def loads_unsafe( # type: ignore[override] self, - s: _t_str_bytes, - max_age: _t_opt_int = None, - salt: _t_opt_str_bytes = None, - ) -> _t.Tuple[bool, _t.Any]: + s: str | bytes, + max_age: int | None = None, + salt: str | bytes | None = None, + ) -> tuple[bool, t.Any]: return self._loads_unsafe_impl(s, salt, load_kwargs={"max_age": max_age}) diff --git a/src/itsdangerous/url_safe.py b/src/itsdangerous/url_safe.py index d5a9b0c..e33b241 100644 --- a/src/itsdangerous/url_safe.py +++ b/src/itsdangerous/url_safe.py @@ -1,4 +1,6 @@ -import typing as _t +from __future__ import annotations + +import typing as t import zlib from ._json import _CompactJSON @@ -20,10 +22,10 @@ class URLSafeSerializerMixin(Serializer): def load_payload( self, payload: bytes, - *args: _t.Any, - serializer: _t.Optional[_t.Any] = None, - **kwargs: _t.Any, - ) -> _t.Any: + *args: t.Any, + serializer: t.Any | None = None, + **kwargs: t.Any, + ) -> t.Any: decompress = False if payload.startswith(b"."): @@ -49,7 +51,7 @@ def load_payload( return super().load_payload(json, *args, **kwargs) - def dump_payload(self, obj: _t.Any) -> bytes: + def dump_payload(self, obj: t.Any) -> bytes: json = super().dump_payload(obj) is_compressed = False compressed = zlib.compress(json) diff --git a/tox.ini b/tox.ini index f7bc0b3..a385dc7 100644 --- a/tox.ini +++ b/tox.ini @@ -22,7 +22,10 @@ commands = pre-commit run --all-files [testenv:typing] deps = -r requirements/typing.txt -commands = mypy +commands = + mypy + pyright + pyright --verifytypes itsdangerous --ignoreexternal [testenv:docs] deps = -r requirements/docs.txt