Skip to content

Commit

Permalink
Store canonical values internally.
Browse files Browse the repository at this point in the history
  • Loading branch information
pythonspeed committed Mar 15, 2024
1 parent 26645c4 commit d837488
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 53 deletions.
2 changes: 1 addition & 1 deletion src/twisted/web/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,7 +1530,7 @@ def _handleResponse(self, response):
return response


_canonicalHeaderName = Headers()._canonicalNameCaps
_canonicalHeaderName = Headers()._encodeName
_defaultSensitiveHeaders = frozenset(
[
b"Authorization",
Expand Down
73 changes: 37 additions & 36 deletions src/twisted/web/http_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,11 @@ class Headers:
and values as opaque byte strings.
@cvar _caseMappings: A L{dict} that maps lowercase header names
to their canonicalized representation. Additional entries may be added,
as it is used as a cache.
to their canonicalized representation, for headers with unconventional
capitalization.
@cvar _canonicalHeaderCache: A L{dict} that maps header names to their
canonicalized representation.
@ivar _rawHeaders: A L{dict} mapping header names as L{bytes} to L{list}s of
header values as L{bytes}.
Expand All @@ -71,6 +74,8 @@ class Headers:
b"x-xss-protection": b"X-XSS-Protection",
}

_canonicalHeaderCache: Dict[Union[bytes, str], bytes] = {}

__slots__ = ["_rawHeaders"]

def __init__(
Expand Down Expand Up @@ -104,16 +109,39 @@ def __cmp__(self, other):

def _encodeName(self, name: Union[str, bytes]) -> bytes:
"""
Encode the name of a header (eg 'Content-Type') to an ISO-8859-1 encoded
bytestring if required.
Encode the name of a header (eg 'Content-Type') to an ISO-8859-1
encoded bytestring if required. It will be canonicalized and
whitespace-sanitized.
@param name: A HTTP header name
@return: C{name}, encoded if required, lowercased
"""
if canonicalName := self._canonicalHeaderCache.get(name, None):
return canonicalName

if isinstance(name, str):
return name.lower().encode("iso-8859-1")
return name.lower()
bytes_name = name.encode("iso-8859-1")
else:
bytes_name = name

if bytes_name.lower() in self._caseMappings:
# Some headers have special capitalization:
result = self._caseMappings[bytes_name.lower()]
else:
result = _sanitizeLinearWhitespace(
b"-".join([word.capitalize() for word in bytes_name.split(b"-")])
)

# In general, we should only see a very small number of header
# variations in the real world, so caching them is fine. However, an
# attacker could generate infinite header variations to fill up RAM, so
# we cap how many we cache. The performance degradation from lack of
# caching won't be that bad, and legit traffic won't hit it.
if len(self._canonicalHeaderCache) < 10_000:
self._canonicalHeaderCache[name] = result

return result

def copy(self):
"""
Expand Down Expand Up @@ -171,7 +199,7 @@ def setRawHeaders(self, name: Union[str, bytes], values: object) -> None:
@return: L{None}
"""
_name = _sanitizeLinearWhitespace(self._encodeName(name))
_name = self._encodeName(name)
encodedValues: List[bytes] = []
for v in values:
if isinstance(v, str):
Expand All @@ -190,9 +218,7 @@ def addRawHeader(self, name: Union[str, bytes], value: Union[str, bytes]) -> Non
@param value: The value to set for the named header.
"""
self._rawHeaders.setdefault(
_sanitizeLinearWhitespace(self._encodeName(name)), []
).append(
self._rawHeaders.setdefault(self._encodeName(name), []).append(
_sanitizeLinearWhitespace(
value.encode("utf8") if isinstance(value, str) else value
)
Expand Down Expand Up @@ -236,32 +262,7 @@ def getAllRawHeaders(self) -> Iterator[Tuple[bytes, Sequence[bytes]]]:
object, as L{bytes}. The keys are capitalized in canonical
capitalization.
"""
for k, v in self._rawHeaders.items():
yield self._canonicalNameCaps(k), v

def _canonicalNameCaps(self, name: bytes) -> bytes:
"""
Return the canonical name for the given header.
@param name: The all-lowercase header name to capitalize in its
canonical form.
@return: The canonical name of the header.
"""
if canonicalName := self._caseMappings.get(name, None):
return canonicalName

result = b"-".join([word.capitalize() for word in name.split(b"-")])

# In general, we should only see a very small number of header
# variations in the real world, so caching them is fine. However, an
# attacker could generate infinite header variations to fill up RAM, so
# we cap how many we cache. The performance degradation from lack of
# caching won't be that bad, and legit traffic won't hit it.
if len(self._caseMappings) < 10_000:
self._caseMappings[name] = result

return result
return iter(self._rawHeaders.items())


__all__ = ["Headers"]
34 changes: 18 additions & 16 deletions src/twisted/web/test/test_http_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,21 +176,23 @@ def test_removeHeaderDoesntExist(self) -> None:
h.removeHeader(b"test")
self.assertEqual(list(h.getAllRawHeaders()), [])

def test_canonicalNameCaps(self) -> None:
def test_encodeName(self) -> None:
"""
L{Headers._canonicalNameCaps} returns the canonical capitalization for
L{Headers._encodeName} returns the canonical capitalization for
the given header.
"""
h = Headers()
self.assertEqual(h._canonicalNameCaps(b"test"), b"Test")
self.assertEqual(h._canonicalNameCaps(b"test-stuff"), b"Test-Stuff")
self.assertEqual(h._canonicalNameCaps(b"content-md5"), b"Content-MD5")
self.assertEqual(h._canonicalNameCaps(b"dnt"), b"DNT")
self.assertEqual(h._canonicalNameCaps(b"etag"), b"ETag")
self.assertEqual(h._canonicalNameCaps(b"p3p"), b"P3P")
self.assertEqual(h._canonicalNameCaps(b"te"), b"TE")
self.assertEqual(h._canonicalNameCaps(b"www-authenticate"), b"WWW-Authenticate")
self.assertEqual(h._canonicalNameCaps(b"x-xss-protection"), b"X-XSS-Protection")
self.assertEqual(h._encodeName(b"test"), b"Test")
self.assertEqual(h._encodeName(b"test-stuff"), b"Test-Stuff")
self.assertEqual(h._encodeName(b"content-md5"), b"Content-MD5")
self.assertEqual(h._encodeName(b"dnt"), b"DNT")
self.assertEqual(h._encodeName(b"etag"), b"ETag")
self.assertEqual(h._encodeName(b"p3p"), b"P3P")
self.assertEqual(h._encodeName(b"te"), b"TE")
self.assertEqual(h._encodeName(b"www-authenticate"), b"WWW-Authenticate")
self.assertEqual(h._encodeName(b"WWW-authenticate"), b"WWW-Authenticate")
self.assertEqual(h._encodeName(b"Www-Authenticate"), b"WWW-Authenticate")
self.assertEqual(h._encodeName(b"x-xss-protection"), b"X-XSS-Protection")

def test_getAllRawHeaders(self) -> None:
"""
Expand Down Expand Up @@ -244,7 +246,7 @@ def test_repr(self) -> None:
baz = b"baz"
self.assertEqual(
repr(Headers({foo: [bar, baz]})),
f"Headers({{{foo!r}: [{bar!r}, {baz!r}]}})",
f"Headers({{{foo.capitalize()!r}: [{bar!r}, {baz!r}]}})",
)

def test_reprWithRawBytes(self) -> None:
Expand All @@ -261,7 +263,7 @@ def test_reprWithRawBytes(self) -> None:
baz = b"baz\xe1"
self.assertEqual(
repr(Headers({foo: [bar, baz]})),
f"Headers({{{foo!r}: [{bar!r}, {baz!r}]}})",
f"Headers({{{foo.capitalize()!r}: [{bar!r}, {baz!r}]}})",
)

def test_subclassRepr(self) -> None:
Expand All @@ -278,7 +280,7 @@ class FunnyHeaders(Headers):

self.assertEqual(
repr(FunnyHeaders({foo: [bar, baz]})),
f"FunnyHeaders({{{foo!r}: [{bar!r}, {baz!r}]}})",
f"FunnyHeaders({{{foo.capitalize()!r}: [{bar!r}, {baz!r}]}})",
)

def test_copy(self) -> None:
Expand Down Expand Up @@ -551,7 +553,7 @@ def test_repr(self) -> None:
foo = "foo\u00E1"
bar = "bar\u2603"
baz = "baz"
fooEncoded = "'foo\\xe1'"
fooEncoded = "'Foo\\xe1'"
barEncoded = "'bar\\xe2\\x98\\x83'"
fooEncoded = "b" + fooEncoded
barEncoded = "b" + barEncoded
Expand All @@ -570,7 +572,7 @@ def test_subclassRepr(self) -> None:
foo = "foo\u00E1"
bar = "bar\u2603"
baz = "baz"
fooEncoded = "b'foo\\xe1'"
fooEncoded = "b'Foo\\xe1'"
barEncoded = "b'bar\\xe2\\x98\\x83'"

class FunnyHeaders(Headers):
Expand Down

0 comments on commit d837488

Please sign in to comment.