diff --git a/redis/connection.py b/redis/connection.py index 2323365210..c13c7fc025 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -94,7 +94,7 @@ class Encoder(object): - "Encode strings to bytes and decode bytes to strings" + "Encode strings to bytes-like and decode bytes-like to strings" def __init__(self, encoding, encoding_errors, decode_responses): self.encoding = encoding @@ -102,8 +102,8 @@ def __init__(self, encoding, encoding_errors, decode_responses): self.decode_responses = decode_responses def encode(self, value): - "Return a bytestring representation of the value" - if isinstance(value, bytes): + "Return a bytestring or bytes-like representation of the value" + if isinstance(value, (bytes, memoryview)): return value elif isinstance(value, bool): # special case bool since it is a subclass of int @@ -124,9 +124,12 @@ def encode(self, value): return value def decode(self, value, force=False): - "Return a unicode string from the byte representation" - if (self.decode_responses or force) and isinstance(value, bytes): - value = value.decode(self.encoding, self.encoding_errors) + "Return a unicode string from the bytes-like representation" + if self.decode_responses or force: + if isinstance(value, memoryview): + value = value.tobytes() + if isinstance(value, bytes): + value = value.decode(self.encoding, self.encoding_errors) return value @@ -770,9 +773,10 @@ def pack_command(self, *args): buffer_cutoff = self._buffer_cutoff for arg in imap(self.encoder.encode, args): # to avoid large string mallocs, chunk the command into the - # output list if we're sending large values + # output list if we're sending large values or memoryviews arg_length = len(arg) - if len(buff) > buffer_cutoff or arg_length > buffer_cutoff: + if (len(buff) > buffer_cutoff or arg_length > buffer_cutoff + or isinstance(arg, memoryview)): buff = SYM_EMPTY.join( (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)) output.append(buff) @@ -795,12 +799,13 @@ def pack_commands(self, commands): for cmd in commands: for chunk in self.pack_command(*cmd): chunklen = len(chunk) - if buffer_length > buffer_cutoff or chunklen > buffer_cutoff: + if (buffer_length > buffer_cutoff or chunklen > buffer_cutoff + or isinstance(chunk, memoryview)): output.append(SYM_EMPTY.join(pieces)) buffer_length = 0 pieces = [] - if chunklen > self._buffer_cutoff: + if chunklen > buffer_cutoff or isinstance(chunk, memoryview): output.append(chunk) else: pieces.append(chunk) diff --git a/tests/test_encoding.py b/tests/test_encoding.py index 3f43006447..9a687c5380 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -3,6 +3,7 @@ import redis from redis._compat import unichr, unicode +from redis.connection import Connection from .conftest import _get_client @@ -11,13 +12,45 @@ class TestEncoding(object): def r(self, request): return _get_client(redis.Redis, request=request, decode_responses=True) - def test_simple_encoding(self, r): + @pytest.fixture() + def r_no_decode(self, request): + return _get_client( + redis.Redis, + request=request, + decode_responses=False, + ) + + def test_simple_encoding(self, r_no_decode): + unicode_string = unichr(3456) + 'abcd' + unichr(3421) + r_no_decode['unicode-string'] = unicode_string.encode('utf-8') + cached_val = r_no_decode['unicode-string'] + assert isinstance(cached_val, bytes) + assert unicode_string == cached_val.decode('utf-8') + + def test_simple_encoding_decoding(self, r): unicode_string = unichr(3456) + 'abcd' + unichr(3421) r['unicode-string'] = unicode_string cached_val = r['unicode-string'] assert isinstance(cached_val, unicode) assert unicode_string == cached_val + def test_memoryview_encoding(self, r_no_decode): + unicode_string = unichr(3456) + 'abcd' + unichr(3421) + unicode_string_view = memoryview(unicode_string.encode('utf-8')) + r_no_decode['unicode-string-memoryview'] = unicode_string_view + cached_val = r_no_decode['unicode-string-memoryview'] + # The cached value won't be a memoryview because it's a copy from Redis + assert isinstance(cached_val, bytes) + assert unicode_string == cached_val.decode('utf-8') + + def test_memoryview_encoding_decoding(self, r): + unicode_string = unichr(3456) + 'abcd' + unichr(3421) + unicode_string_view = memoryview(unicode_string.encode('utf-8')) + r['unicode-string-memoryview'] = unicode_string_view + cached_val = r['unicode-string-memoryview'] + assert isinstance(cached_val, unicode) + assert unicode_string == cached_val + def test_list_encoding(self, r): unicode_string = unichr(3456) + 'abcd' + unichr(3421) result = [unicode_string, unicode_string, unicode_string] @@ -39,6 +72,17 @@ def test_replace(self, request): assert r.get('a') == 'foo\ufffd' +class TestMemoryviewsAreNotPacked(object): + c = Connection() + arg = memoryview(b'some_arg') + arg_list = ['SOME_COMMAND', arg] + cmd = c.pack_command(*arg_list) + assert cmd[1] is arg + cmds = c.pack_commands([arg_list, arg_list]) + assert cmds[1] is arg + assert cmds[3] is arg + + class TestCommandsAreNotEncoded(object): @pytest.fixture() def r(self, request): diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 867666b523..828b9898e4 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -29,6 +29,16 @@ def test_pipeline(self, r): [(b'z1', 2.0), (b'z2', 4)], ] + def test_pipeline_memoryview(self, r): + with r.pipeline() as pipe: + (pipe.set('a', memoryview(b'a1')) + .get('a')) + assert pipe.execute() == \ + [ + True, + b'a1', + ] + def test_pipeline_length(self, r): with r.pipeline() as pipe: # Initially empty.