Skip to content

Commit

Permalink
type Serializer as generic
Browse files Browse the repository at this point in the history
  • Loading branch information
davidism committed Apr 13, 2024
1 parent bc88e94 commit 01001c6
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 14 deletions.
59 changes: 50 additions & 9 deletions src/itsdangerous/serializer.py
Expand Up @@ -10,13 +10,20 @@
from .signer import _make_keys_list
from .signer import Signer

_TAnyStr = t.TypeVar("_TAnyStr", str, bytes, covariant=True)

def is_text_serializer(serializer: t.Any) -> bool:

class _PDataSerializer(t.Protocol[_TAnyStr]):
def loads(self, payload: str | bytes) -> t.Any: ...
def dumps(self, obj: t.Any, **kwargs: t.Any) -> _TAnyStr: ...


def is_text_serializer(serializer: _PDataSerializer[t.Any]) -> bool:
"""Checks whether a serializer generates text or binary."""
return isinstance(serializer.dumps({}), str)


class Serializer:
class Serializer(t.Generic[_TAnyStr]):
"""A serializer wraps a :class:`~itsdangerous.signer.Signer` to
enable serializing and securely signing data other than bytes. It
can unsign to verify that the data hasn't been changed.
Expand Down Expand Up @@ -71,7 +78,7 @@ class Serializer:
#: The default serialization module to use to serialize data to a
#: string internally. The default is :mod:`json`, but can be changed
#: to any object that provides ``dumps`` and ``loads`` methods.
default_serializer: t.Any = json
default_serializer: _PDataSerializer[_TAnyStr] = json # type: ignore[assignment]

#: The default ``Signer`` class to instantiate when signing data.
#: The default is :class:`itsdangerous.signer.Signer`.
Expand All @@ -82,11 +89,43 @@ class Serializer:
dict[str, t.Any] | tuple[type[Signer], dict[str, t.Any]] | type[Signer]
] = []

# Tell type checkers that the default type is Serializer[str] if no
# data serializer is provided.
@t.overload
def __init__(
self: Serializer[str],
secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes],
salt: str | bytes | None = b"itsdangerous",
serializer: None = None,
serializer_kwargs: dict[str, t.Any] | None = None,
signer: type[Signer] | None = None,
signer_kwargs: dict[str, t.Any] | None = None,
fallback_signers: list[
dict[str, t.Any] | tuple[type[Signer], dict[str, t.Any]] | type[Signer]
]
| None = None,
): ...

@t.overload
def __init__(
self: Serializer[_TAnyStr],
secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes],
salt: str | bytes | None = b"itsdangerous",
serializer: _PDataSerializer[_TAnyStr] = ...,
serializer_kwargs: dict[str, t.Any] | None = None,
signer: type[Signer] | None = None,
signer_kwargs: dict[str, t.Any] | None = None,
fallback_signers: list[
dict[str, t.Any] | tuple[type[Signer], dict[str, t.Any]] | type[Signer]
]
| None = None,
): ...

def __init__(
self,
secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes],
salt: str | bytes | None = b"itsdangerous",
serializer: t.Any = None,
serializer: _PDataSerializer[_TAnyStr] | None = None,
serializer_kwargs: dict[str, t.Any] | None = None,
signer: type[Signer] | None = None,
signer_kwargs: dict[str, t.Any] | None = None,
Expand All @@ -111,7 +150,7 @@ def __init__(
if serializer is None:
serializer = self.default_serializer

self.serializer: t.Any = serializer
self.serializer: _PDataSerializer[_TAnyStr] = serializer
self.is_text_serializer: bool = is_text_serializer(serializer)

if signer is None:
Expand All @@ -135,7 +174,9 @@ def secret_key(self) -> bytes:
"""
return self.secret_keys[-1]

def load_payload(self, payload: bytes, serializer: t.Any | None = None) -> t.Any:
def load_payload(
self, payload: bytes, serializer: _PDataSerializer[_TAnyStr] | None = None
) -> t.Any:
"""Loads the encoded object. This function raises
:class:`.BadPayload` if the payload is not valid. The
``serializer`` parameter can be used to override the serializer
Expand Down Expand Up @@ -199,7 +240,7 @@ def iter_unsigners(self, salt: str | bytes | None = None) -> cabc.Iterator[Signe
for secret_key in self.secret_keys:
yield fallback(secret_key, salt=salt, **kwargs)

def dumps(self, obj: t.Any, salt: str | bytes | None = None) -> str | bytes:
def dumps(self, obj: t.Any, salt: str | bytes | None = None) -> _TAnyStr:
"""Returns a signed string serialized with the internal
serializer. The return value can be either a byte or unicode
string depending on the format of the internal serializer.
Expand All @@ -208,9 +249,9 @@ def dumps(self, obj: t.Any, salt: str | bytes | None = None) -> str | bytes:
rv = self.make_signer(salt).sign(payload)

if self.is_text_serializer:
return rv.decode("utf-8")
return rv.decode("utf-8") # type: ignore[return-value]

return rv
return rv # type: ignore[return-value]

def dump(self, obj: t.Any, f: t.IO[t.Any], salt: str | bytes | None = None) -> None:
"""Like :meth:`dumps` but dumps into a file. The file handle has
Expand Down
4 changes: 3 additions & 1 deletion src/itsdangerous/timed.py
Expand Up @@ -17,6 +17,8 @@
from .serializer import Serializer
from .signer import Signer

_TAnyStr = t.TypeVar("_TAnyStr", str, bytes, covariant=True)


class TimestampSigner(Signer):
"""Works like the regular :class:`.Signer` but also records the time
Expand Down Expand Up @@ -166,7 +168,7 @@ def validate(self, signed_value: str | bytes, max_age: int | None = None) -> boo
return False


class TimedSerializer(Serializer):
class TimedSerializer(Serializer[_TAnyStr]):
"""Uses :class:`TimestampSigner` instead of the default
:class:`.Signer`.
"""
Expand Down
9 changes: 5 additions & 4 deletions src/itsdangerous/url_safe.py
Expand Up @@ -7,17 +7,18 @@
from .encoding import base64_decode
from .encoding import base64_encode
from .exc import BadPayload
from .serializer import _PDataSerializer
from .serializer import Serializer
from .timed import TimedSerializer


class URLSafeSerializerMixin(Serializer):
class URLSafeSerializerMixin(Serializer[str]):
"""Mixed in with a regular serializer it will attempt to zlib
compress the string to make it shorter if necessary. It will also
base64 encode the string so that it can safely be placed in a URL.
"""

default_serializer = _CompactJSON
default_serializer: _PDataSerializer[str] = _CompactJSON

def load_payload(
self,
Expand Down Expand Up @@ -68,14 +69,14 @@ def dump_payload(self, obj: t.Any) -> bytes:
return base64d


class URLSafeSerializer(URLSafeSerializerMixin, Serializer):
class URLSafeSerializer(URLSafeSerializerMixin, Serializer[str]):
"""Works like :class:`.Serializer` but dumps and loads into a URL
safe string consisting of the upper and lowercase character of the
alphabet as well as ``'_'``, ``'-'`` and ``'.'``.
"""


class URLSafeTimedSerializer(URLSafeSerializerMixin, TimedSerializer):
class URLSafeTimedSerializer(URLSafeSerializerMixin, TimedSerializer[str]):
"""Works like :class:`.TimedSerializer` but dumps and loads into a
URL safe string consisting of the upper and lowercase character of
the alphabet as well as ``'_'``, ``'-'`` and ``'.'``.
Expand Down

0 comments on commit 01001c6

Please sign in to comment.