Skip to content

Commit

Permalink
A better API
Browse files Browse the repository at this point in the history
  • Loading branch information
pythonspeed committed May 2, 2024
1 parent 239fedd commit cdd5e3d
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 23 deletions.
15 changes: 9 additions & 6 deletions src/twisted/web/_http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,7 +1059,7 @@ def flowControlBlocked(self):
self._producerProducing = False

# Methods called by the consumer (usually an IRequest).
def writeHeaders(self, version, code, reason, headers):
def writeHeadersObject(self, version, code, reason, headers):
"""
Called by the consumer to write headers to the stream.
Expand All @@ -1073,12 +1073,15 @@ def writeHeaders(self, version, code, reason, headers):
@type reason: L{bytes}
@param headers: The HTTP response headers.
@type headers: Any iterable of two-tuples of L{bytes}, representing header
names and header values.
@type headers: L{twisted.web.http_headers.Headers}
"""
self._conn.writeHeaders(version, code, reason, headers, self.streamID)

writeHeadersPresanitized = writeHeaders
self._conn.writeHeaders(
version,
code,
reason,
[(k, v) for (k, values) in headers.getAllRawHeaders() for v in values],
self.streamID,
)

def requestDone(self, request):
"""
Expand Down
30 changes: 14 additions & 16 deletions src/twisted/web/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,7 +1200,6 @@ def write(self, data):
version = self.clientproto
code = b"%d" % (self.code,)
reason = self.code_message
headers = []

# if we don't have a content length, we send data in
# chunked mode, so that we can support pipelining in
Expand All @@ -1211,7 +1210,7 @@ def write(self, data):
and self.method != b"HEAD"
and self.code not in NO_BODY_CODES
):
headers.append((b"Transfer-Encoding", b"chunked"))
self.responseHeaders.setRawHeaders("Transfer-Encoding", [b"chunked"])
self.chunked = 1

if self.lastModified is not None:
Expand All @@ -1231,11 +1230,7 @@ def write(self, data):
if self.cookies:
self.responseHeaders.setRawHeaders(b"Set-Cookie", self.cookies)

for name, values in self.responseHeaders.getAllRawHeaders():
for value in values:
headers.append((name, value))

self.channel.writeHeadersPresanitized(version, code, reason, headers)
self.channel.writeHeadersObject(version, code, reason, self.responseHeaders)

# if this is a "HEAD" request, we shouldn't return any data
if self.method == b"HEAD":
Expand Down Expand Up @@ -2667,11 +2662,14 @@ def writeHeaders(self, version, code, reason, headers):
headerSequence.append(b"\r\n")
self.transport.writeSequence(headerSequence)

def writeHeadersPresanitized(self, version, code, reason, headers):
def writeHeadersObject(self, version, code, reason, headers):
"""
Called by L{Request} objects to write a complete set of HTTP
headers to a transport that are already trusted to be sanitized and not
subject to injection attacks.
headers to a transport. Because they're given as a C{Headers} instance
we can make sure we're not subject to injection attacks.
This is faster than C{writeHeaders} if you already have a C{Headers}
instance.
@param version: The HTTP version in use.
@type version: L{bytes}
Expand All @@ -2682,14 +2680,14 @@ def writeHeadersPresanitized(self, version, code, reason, headers):
@param reason: The HTTP reason phrase to write.
@type reason: L{bytes}
@param headers: The headers to write to the transport, presumed to
already have been sanitized and deduplicated.
@type headers: Any iterable of two-tuples of L{bytes}, representing header
names and header values.
@param headers: The headers to write to the transport.
@type headers: L{twisted.web.http_headers.Headers}
"""
headerSequence = [version, b" ", code, b" ", reason, b"\r\n"]
for name, value in headers:
headerSequence.extend((name, b": ", value, b"\r\n"))
for name, values in headers.getAllRawHeaders():
for value in values:
headerSequence.extend((name, b": ", value, b"\r\n"))
headerSequence.append(b"\r\n")
self.transport.writeSequence(headerSequence)

Expand Down
8 changes: 7 additions & 1 deletion src/twisted/web/test/requesthelper.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,13 @@ def writeHeaders(self, version, code, reason, headers):
headerSequence.append(b"\r\n")
self.transport.writeSequence(headerSequence)

writeHeadersPresanitized = writeHeaders
def writeHeadersObject(self, version, code, reason, headers):
self.writeHeaders(
version,
code,
reason,
[(k, v) for (k, values) in headers.getAllRawHeaders() for v in values],
)

def getPeer(self):
return self.transport.getPeer()
Expand Down

0 comments on commit cdd5e3d

Please sign in to comment.