Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WiP] Rewrite zstd decoder to use an API that supports multiple frames (fix issue #3008) #3021

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions src/urllib3/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,18 +164,37 @@ def flush(self) -> bytes:

class ZstdDecoder(ContentDecoder):
def __init__(self) -> None:
self._obj = zstd.ZstdDecompressor().decompressobj()
self._zstd_stream = io.BytesIO()
self._sr = zstd.ZstdDecompressor().stream_reader(
self._zstd_stream, read_across_frames=True)

def decompress(self, data: bytes) -> bytes:
if not data:
return b""
return self._obj.decompress(data) # type: ignore[no-any-return]

# Push more data to the end, then go back to the previous position.
current_position = self._zstd_stream.tell()
self._zstd_stream.seek(0, io.SEEK_END)
self._zstd_stream.write(data)
self._zstd_stream.seek(current_position)
bytes_available = (len(self._zstd_stream.getvalue()) -
current_position)

if bytes_available >= zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE:
return self._sr.read(zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE)

# Not enough data available for a decompress operation.
return b""

def _is_stream_at_end(self) -> bool:
return (len(self._zstd_stream.getvalue()) ==
self._zstd_stream.tell())

def flush(self) -> bytes:
ret = self._obj.flush()
if not self._obj.eof:
raise DecodeError("Zstandard data is incomplete")
return ret # type: ignore[no-any-return]
if self._is_stream_at_end():
return b""

return self._sr.readall()


class MultiDecoder(ContentDecoder):
Expand Down Expand Up @@ -439,15 +458,16 @@ def _decode(
if self._decoder:
data = self._decoder.decompress(data)
self._has_decoded_content = True

if flush_decoder:
data += self._flush_decoder()
except self.DECODER_ERROR_CLASSES as e:
content_encoding = self.headers.get("content-encoding", "").lower()
raise DecodeError(
"Received response with content-encoding: %s, but "
"failed to decode it." % content_encoding,
e,
) from e
if flush_decoder:
data += self._flush_decoder()

return data

Expand Down
12 changes: 12 additions & 0 deletions test/test_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import contextlib
import http.client as httplib
import os
import socket
import ssl
import sys
Expand Down Expand Up @@ -332,6 +333,17 @@ def test_decode_zstd(self) -> None:
r = HTTPResponse(fp, headers={"content-encoding": "zstd"})
assert r.data == b"foo"

@onlyZstd()
def test_decode_zstd_multiple_frames(self) -> None:
# TODO: Can we dynamically generate data that caused the bug too?
with open(os.path.join(os.path.dirname(__file__), "text.txt.zstd"), "rb") as f:
data = f.read()

fp = BytesIO(data)
r = HTTPResponse(fp, headers={"content-encoding": "zstd"})
# Each frame size is 1048576 so correct data should be longer.
assert len(r.data) > 1048576

@onlyZstd()
def test_chunked_decoding_zstd(self) -> None:
data = zstd.compress(b"foobarbaz")
Expand Down
Binary file added test/text.txt.zstd
Binary file not shown.