Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,16 @@


class Encoder(object):
"Encode strings to bytes and decode bytes to strings"
"Encode strings to bytes-like and decode bytes-like to strings"

def __init__(self, encoding, encoding_errors, decode_responses):
self.encoding = encoding
self.encoding_errors = encoding_errors
self.decode_responses = decode_responses

def encode(self, value):
"Return a bytestring representation of the value"
if isinstance(value, bytes):
"Return a bytestring or bytes-like representation of the value"
if isinstance(value, (bytes, memoryview)):
return value
elif isinstance(value, bool):
# special case bool since it is a subclass of int
Expand All @@ -124,9 +124,12 @@ def encode(self, value):
return value

def decode(self, value, force=False):
"Return a unicode string from the byte representation"
if (self.decode_responses or force) and isinstance(value, bytes):
value = value.decode(self.encoding, self.encoding_errors)
"Return a unicode string from the bytes-like representation"
if self.decode_responses or force:
if isinstance(value, memoryview):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you run into a case where you needed this? Are we ever actually decoding memoryview instances?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I ran into a case in the tests where a memoryview gets decoded. I think I got a bit spooked by lines like this and this, but on a closer look those lines shouldn't cause a problem unless the user passes a memoryview kwarg -- and the docs instruct them to pass strings. Maybe it just felt right to make sure that decode(encode(x)) always works, but it seems it's not essential, happy to remove this if you prefer.

value = value.tobytes()
if isinstance(value, bytes):
value = value.decode(self.encoding, self.encoding_errors)
return value


Expand Down Expand Up @@ -770,9 +773,10 @@ def pack_command(self, *args):
buffer_cutoff = self._buffer_cutoff
for arg in imap(self.encoder.encode, args):
# to avoid large string mallocs, chunk the command into the
# output list if we're sending large values
# output list if we're sending large values or memoryviews
arg_length = len(arg)
if len(buff) > buffer_cutoff or arg_length > buffer_cutoff:
if (len(buff) > buffer_cutoff or arg_length > buffer_cutoff
or isinstance(arg, memoryview)):
buff = SYM_EMPTY.join(
(buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF))
output.append(buff)
Expand All @@ -795,12 +799,13 @@ def pack_commands(self, commands):
for cmd in commands:
for chunk in self.pack_command(*cmd):
chunklen = len(chunk)
if buffer_length > buffer_cutoff or chunklen > buffer_cutoff:
if (buffer_length > buffer_cutoff or chunklen > buffer_cutoff
or isinstance(chunk, memoryview)):
output.append(SYM_EMPTY.join(pieces))
buffer_length = 0
pieces = []

if chunklen > self._buffer_cutoff:
if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
output.append(chunk)
else:
pieces.append(chunk)
Expand Down
46 changes: 45 additions & 1 deletion tests/test_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import redis

from redis._compat import unichr, unicode
from redis.connection import Connection
from .conftest import _get_client


Expand All @@ -11,13 +12,45 @@ class TestEncoding(object):
def r(self, request):
return _get_client(redis.Redis, request=request, decode_responses=True)

def test_simple_encoding(self, r):
@pytest.fixture()
def r_no_decode(self, request):
return _get_client(
redis.Redis,
request=request,
decode_responses=False,
)

def test_simple_encoding(self, r_no_decode):
unicode_string = unichr(3456) + 'abcd' + unichr(3421)
r_no_decode['unicode-string'] = unicode_string.encode('utf-8')
cached_val = r_no_decode['unicode-string']
assert isinstance(cached_val, bytes)
assert unicode_string == cached_val.decode('utf-8')

def test_simple_encoding_decoding(self, r):
unicode_string = unichr(3456) + 'abcd' + unichr(3421)
r['unicode-string'] = unicode_string
cached_val = r['unicode-string']
assert isinstance(cached_val, unicode)
assert unicode_string == cached_val

def test_memoryview_encoding(self, r_no_decode):
unicode_string = unichr(3456) + 'abcd' + unichr(3421)
unicode_string_view = memoryview(unicode_string.encode('utf-8'))
r_no_decode['unicode-string-memoryview'] = unicode_string_view
cached_val = r_no_decode['unicode-string-memoryview']
# The cached value won't be a memoryview because it's a copy from Redis
assert isinstance(cached_val, bytes)
assert unicode_string == cached_val.decode('utf-8')

def test_memoryview_encoding_decoding(self, r):
unicode_string = unichr(3456) + 'abcd' + unichr(3421)
unicode_string_view = memoryview(unicode_string.encode('utf-8'))
r['unicode-string-memoryview'] = unicode_string_view
cached_val = r['unicode-string-memoryview']
assert isinstance(cached_val, unicode)
assert unicode_string == cached_val

def test_list_encoding(self, r):
unicode_string = unichr(3456) + 'abcd' + unichr(3421)
result = [unicode_string, unicode_string, unicode_string]
Expand All @@ -39,6 +72,17 @@ def test_replace(self, request):
assert r.get('a') == 'foo\ufffd'


class TestMemoryviewsAreNotPacked(object):
c = Connection()
arg = memoryview(b'some_arg')
arg_list = ['SOME_COMMAND', arg]
cmd = c.pack_command(*arg_list)
assert cmd[1] is arg
cmds = c.pack_commands([arg_list, arg_list])
assert cmds[1] is arg
assert cmds[3] is arg


class TestCommandsAreNotEncoded(object):
@pytest.fixture()
def r(self, request):
Expand Down
10 changes: 10 additions & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ def test_pipeline(self, r):
[(b'z1', 2.0), (b'z2', 4)],
]

def test_pipeline_memoryview(self, r):
with r.pipeline() as pipe:
(pipe.set('a', memoryview(b'a1'))
.get('a'))
assert pipe.execute() == \
[
True,
b'a1',
]

def test_pipeline_length(self, r):
with r.pipeline() as pipe:
# Initially empty.
Expand Down