Skip to content

Commit

Permalink
Add typing hints to CaselessDict and Headers. (#6097)
Browse files Browse the repository at this point in the history
  • Loading branch information
wRAR committed Oct 17, 2023
1 parent 064256b commit 5807970
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 60 deletions.
4 changes: 2 additions & 2 deletions scrapy/core/http2/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,8 @@ def _get_request_headers(self) -> List[Tuple[str, str]]:

content_length_name = self._request.headers.normkey(b"Content-Length")
for name, values in self._request.headers.items():
for value in values:
value = str(value, "utf-8")
for value_bytes in values:
value = str(value_bytes, "utf-8")
if name == content_length_name:
if value != content_length:
logger.warning(
Expand Down
93 changes: 64 additions & 29 deletions scrapy/http/headers.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,73 @@
from __future__ import annotations

from collections.abc import Mapping
from typing import (
TYPE_CHECKING,
Any,
AnyStr,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
cast,
)

from w3lib.http import headers_dict_to_raw

from scrapy.utils.datatypes import CaseInsensitiveDict, CaselessDict
from scrapy.utils.python import to_unicode

if TYPE_CHECKING:
# typing.Self requires Python 3.11
from typing_extensions import Self


_RawValueT = Union[bytes, str, int]


# isn't fully compatible typing-wise with either dict or CaselessDict,
# but it needs refactoring anyway, see also https://github.com/scrapy/scrapy/pull/5146
class Headers(CaselessDict):
"""Case insensitive http headers dictionary"""

def __init__(self, seq=None, encoding="utf-8"):
self.encoding = encoding
def __init__(
self,
seq: Union[Mapping[AnyStr, Any], Iterable[Tuple[AnyStr, Any]], None] = None,
encoding: str = "utf-8",
):
self.encoding: str = encoding
super().__init__(seq)

def update(self, seq):
def update( # type: ignore[override]
self, seq: Union[Mapping[AnyStr, Any], Iterable[Tuple[AnyStr, Any]]]
) -> None:
seq = seq.items() if isinstance(seq, Mapping) else seq
iseq = {}
iseq: Dict[bytes, List[bytes]] = {}
for k, v in seq:
iseq.setdefault(self.normkey(k), []).extend(self.normvalue(v))
super().update(iseq)

def normkey(self, key):
def normkey(self, key: AnyStr) -> bytes: # type: ignore[override]
"""Normalize key to bytes"""
return self._tobytes(key.title())

def normvalue(self, value):
def normvalue(self, value: Union[_RawValueT, Iterable[_RawValueT]]) -> List[bytes]:
"""Normalize values to bytes"""
_value: Iterable[_RawValueT]
if value is None:
value = []
_value = []
elif isinstance(value, (str, bytes)):
value = [value]
elif not hasattr(value, "__iter__"):
value = [value]
_value = [value]
elif hasattr(value, "__iter__"):
_value = value
else:
_value = [value]

return [self._tobytes(x) for x in value]
return [self._tobytes(x) for x in _value]

def _tobytes(self, x):
def _tobytes(self, x: _RawValueT) -> bytes:
if isinstance(x, bytes):
return x
if isinstance(x, str):
Expand All @@ -44,49 +76,52 @@ def _tobytes(self, x):
return str(x).encode(self.encoding)
raise TypeError(f"Unsupported value type: {type(x)}")

def __getitem__(self, key):
def __getitem__(self, key: AnyStr) -> Optional[bytes]:
try:
return super().__getitem__(key)[-1]
return cast(List[bytes], super().__getitem__(key))[-1]
except IndexError:
return None

def get(self, key, def_val=None):
def get(self, key: AnyStr, def_val: Any = None) -> Optional[bytes]:
try:
return super().get(key, def_val)[-1]
return cast(List[bytes], super().get(key, def_val))[-1]
except IndexError:
return None

def getlist(self, key, def_val=None):
def getlist(self, key: AnyStr, def_val: Any = None) -> List[bytes]:
try:
return super().__getitem__(key)
return cast(List[bytes], super().__getitem__(key))
except KeyError:
if def_val is not None:
return self.normvalue(def_val)
return []

def setlist(self, key, list_):
def setlist(self, key: AnyStr, list_: Iterable[_RawValueT]) -> None:
self[key] = list_

def setlistdefault(self, key, default_list=()):
def setlistdefault(
self, key: AnyStr, default_list: Iterable[_RawValueT] = ()
) -> Any:
return self.setdefault(key, default_list)

def appendlist(self, key, value):
def appendlist(self, key: AnyStr, value: Iterable[_RawValueT]) -> None:
lst = self.getlist(key)
lst.extend(self.normvalue(value))
self[key] = lst

def items(self):
def items(self) -> Iterable[Tuple[bytes, List[bytes]]]: # type: ignore[override]
return ((k, self.getlist(k)) for k in self.keys())

def values(self):
def values(self) -> List[Optional[bytes]]: # type: ignore[override]
return [self[k] for k in self.keys()]

def to_string(self):
return headers_dict_to_raw(self)
def to_string(self) -> bytes:
# cast() can be removed if the headers_dict_to_raw() hint is improved
return cast(bytes, headers_dict_to_raw(self))

def to_unicode_dict(self):
"""Return headers as a CaselessDict with unicode keys
and unicode values. Multiple values are joined with ','.
def to_unicode_dict(self) -> CaseInsensitiveDict:
"""Return headers as a CaseInsensitiveDict with str keys
and str values. Multiple values are joined with ','.
"""
return CaseInsensitiveDict(
(
Expand All @@ -96,7 +131,7 @@ def to_unicode_dict(self):
for key, value in self.items()
)

def __copy__(self):
def __copy__(self) -> Self:
return self.__class__(self)

copy = __copy__
16 changes: 14 additions & 2 deletions scrapy/http/request/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,19 @@
See documentation in docs/topics/request-response.rst
"""
import inspect
from typing import Callable, List, Optional, Tuple, Type, TypeVar, Union
from typing import (
Any,
AnyStr,
Callable,
Iterable,
List,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
)

from w3lib.url import safe_url_string

Expand Down Expand Up @@ -77,7 +89,7 @@ def __init__(
url: str,
callback: Optional[Callable] = None,
method: str = "GET",
headers: Optional[dict] = None,
headers: Union[Mapping[AnyStr, Any], Iterable[Tuple[AnyStr, Any]], None] = None,
body: Optional[Union[bytes, str]] = None,
cookies: Optional[Union[dict, List[dict]]] = None,
meta: Optional[dict] = None,
Expand Down
4 changes: 2 additions & 2 deletions scrapy/http/response/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
See documentation in docs/topics/request-response.rst
"""
from typing import Generator, Tuple
from typing import Any, AnyStr, Generator, Iterable, Mapping, Tuple, Union
from urllib.parse import urljoin

from scrapy.exceptions import NotSupported
Expand Down Expand Up @@ -42,7 +42,7 @@ def __init__(
self,
url: str,
status=200,
headers=None,
headers: Union[Mapping[AnyStr, Any], Iterable[Tuple[AnyStr, Any]], None] = None,
body=b"",
flags=None,
request=None,
Expand Down
8 changes: 4 additions & 4 deletions scrapy/http/response/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import json
from contextlib import suppress
from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple
from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple, cast
from urllib.parse import urljoin

import parsel
Expand Down Expand Up @@ -102,14 +102,14 @@ def urljoin(self, url):
return urljoin(get_base_url(self), url)

@memoizemethod_noargs
def _headers_encoding(self):
content_type = self.headers.get(b"Content-Type", b"")
def _headers_encoding(self) -> Optional[str]:
content_type = cast(bytes, self.headers.get(b"Content-Type", b""))
return http_content_type_encoding(to_unicode(content_type, encoding="latin-1"))

def _body_inferred_encoding(self):
if self._cached_benc is None:
content_type = to_unicode(
self.headers.get(b"Content-Type", b""), encoding="latin-1"
cast(bytes, self.headers.get(b"Content-Type", b"")), encoding="latin-1"
)
benc, ubody = html_to_unicode(
content_type,
Expand Down
56 changes: 39 additions & 17 deletions scrapy/utils/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,40 @@
This module must not depend on any module outside the Standard Library.
"""

from __future__ import annotations

import collections
import warnings
import weakref
from collections.abc import Mapping
from typing import Any, AnyStr, Optional, OrderedDict, Sequence, TypeVar
from typing import (
TYPE_CHECKING,
Any,
AnyStr,
Iterable,
Optional,
OrderedDict,
Sequence,
Tuple,
TypeVar,
Union,
)

from scrapy.exceptions import ScrapyDeprecationWarning

if TYPE_CHECKING:
# typing.Self requires Python 3.11
from typing_extensions import Self


_KT = TypeVar("_KT")
_VT = TypeVar("_VT")


class CaselessDict(dict):
__slots__ = ()

def __new__(cls, *args, **kwargs):
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
from scrapy.http.headers import Headers

if issubclass(cls, CaselessDict) and not issubclass(cls, Headers):
Expand All @@ -32,54 +50,58 @@ def __new__(cls, *args, **kwargs):
)
return super().__new__(cls, *args, **kwargs)

def __init__(self, seq=None):
def __init__(
self,
seq: Union[Mapping[AnyStr, Any], Iterable[Tuple[AnyStr, Any]], None] = None,
):
super().__init__()
if seq:
self.update(seq)

def __getitem__(self, key):
def __getitem__(self, key: AnyStr) -> Any:
return dict.__getitem__(self, self.normkey(key))

def __setitem__(self, key, value):
def __setitem__(self, key: AnyStr, value: Any) -> None:
dict.__setitem__(self, self.normkey(key), self.normvalue(value))

def __delitem__(self, key):
def __delitem__(self, key: AnyStr) -> None:
dict.__delitem__(self, self.normkey(key))

def __contains__(self, key):
def __contains__(self, key: AnyStr) -> bool: # type: ignore[override]
return dict.__contains__(self, self.normkey(key))

has_key = __contains__

def __copy__(self):
def __copy__(self) -> Self:
return self.__class__(self)

copy = __copy__

def normkey(self, key):
def normkey(self, key: AnyStr) -> AnyStr:
"""Method to normalize dictionary key access"""
return key.lower()

def normvalue(self, value):
def normvalue(self, value: Any) -> Any:
"""Method to normalize values prior to be set"""
return value

def get(self, key, def_val=None):
def get(self, key: AnyStr, def_val: Any = None) -> Any:
return dict.get(self, self.normkey(key), self.normvalue(def_val))

def setdefault(self, key, def_val=None):
return dict.setdefault(self, self.normkey(key), self.normvalue(def_val))
def setdefault(self, key: AnyStr, def_val: Any = None) -> Any:
return dict.setdefault(self, self.normkey(key), self.normvalue(def_val)) # type: ignore[arg-type]

def update(self, seq):
# doesn't fully implement MutableMapping.update()
def update(self, seq: Union[Mapping[AnyStr, Any], Iterable[Tuple[AnyStr, Any]]]) -> None: # type: ignore[override]
seq = seq.items() if isinstance(seq, Mapping) else seq
iseq = ((self.normkey(k), self.normvalue(v)) for k, v in seq)
super().update(iseq)

@classmethod
def fromkeys(cls, keys, value=None):
return cls((k, value) for k in keys)
def fromkeys(cls, keys: Iterable[AnyStr], value: Any = None) -> Self: # type: ignore[override]
return cls((k, value) for k in keys) # type: ignore[misc]

def pop(self, key, *args):
def pop(self, key: AnyStr, *args: Any) -> Any:
return dict.pop(self, self.normkey(key), *args)


Expand Down

0 comments on commit 5807970

Please sign in to comment.