diff --git a/pymemcache/test/test_client.py b/pymemcache/test/test_client.py index 24d6b9d6..530f4c13 100644 --- a/pymemcache/test/test_client.py +++ b/pymemcache/test/test_client.py @@ -723,6 +723,41 @@ def make_client(self, mock_socket_values, **kwargs): client.sock = MockSocket(list(mock_socket_values)) return client + def test_get_found(self): + client = self.make_client([ + b'STORED\r\n', + b'VALUE key 0 5\r\nvalue\r\nEND\r\n', + ]) + result = client.set(b'key', b'value', noreply=False) + result = client.get(b'key') + assert result == b'value' + + def test_deserialization(self): + def _serializer(key, value): + if isinstance(value, dict): + return json.dumps(value).encode('UTF-8'), 1 + return value, 0 + + def _deserializer(key, value, flags): + if flags == 1: + return json.loads(value.decode('UTF-8')) + return value + + client = self.make_client([ + b'STORED\r\n', + b'VALUE key1 0 5\r\nhello\r\nEND\r\n', + b'STORED\r\n', + b'VALUE key2 0 18\r\n{"hello": "world"}\r\nEND\r\n', + ], serializer=_serializer, deserializer=_deserializer) + + result = client.set(b'key1', b'hello', noreply=False) + result = client.get(b'key1') + assert result == b'hello' + + result = client.set(b'key2', dict(hello='world'), noreply=False) + result = client.get(b'key2') + assert result == dict(hello='world') + class TestPrefixedClient(ClientTestMixin, unittest.TestCase): def make_client(self, mock_socket_values, **kwargs): diff --git a/pymemcache/test/utils.py b/pymemcache/test/utils.py index 997a58e9..f2c5a80b 100644 --- a/pymemcache/test/utils.py +++ b/pymemcache/test/utils.py @@ -47,13 +47,13 @@ def get(self, key, default=None): if key not in self._contents: return default - expire, value, was_serialized = self._contents[key] + expire, value, flags = self._contents[key] if expire and expire < time.time(): del self._contents[key] return default if self.deserializer: - return self.deserializer(key, value, 2 if was_serialized else 1) + return self.deserializer(key, value, flags) return value def get_many(self, keys): @@ -72,14 +72,14 @@ def set(self, key, value, expire=0, noreply=True): if isinstance(value, six.text_type): raise MemcacheIllegalInputError(value) - was_serialized = False + flags = 0 if self.serializer: - value = self.serializer(key, value) + value, flags = self.serializer(key, value) if expire: expire += time.time() - self._contents[key] = expire, value, was_serialized + self._contents[key] = expire, value, flags return True def set_many(self, values, expire=None, noreply=True):