Skip to content
This repository has been archived by the owner on Sep 11, 2019. It is now read-only.

Improve emulation of redis -> Lua returns #13

Merged
merged 3 commits into from
Feb 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions fakenewsredis.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,19 @@ def _patch_responses(obj):
setattr(obj, attr_name, func)


def _lua_bool_ok(lua_runtime, value):
# Inverse of bool_ok wrapper from redis-py
return lua_runtime.table(ok='OK')


def _lua_reply(converter):
def decorator(func):
func._lua_reply = converter
return func

return decorator


def _remove_empty(func):
@functools.wraps(func)
def wrapper(self, key, *args, **kwargs):
Expand Down Expand Up @@ -281,15 +294,18 @@ def __init__(self, db=0, charset='utf-8', errors='strict',
if decode_responses:
_patch_responses(self)

@_lua_reply(_lua_bool_ok)
def flushdb(self):
DATABASES[self._db_num].clear()
return True

@_lua_reply(_lua_bool_ok)
def flushall(self):
for db in DATABASES:
DATABASES[db].clear()

del self._pubsubs[:]
return True

def _remove_if_empty(self, key):
try:
Expand Down Expand Up @@ -460,6 +476,7 @@ def mget(self, keys, *args):
found.append(value)
return found

@_lua_reply(_lua_bool_ok)
def mset(self, *args, **kwargs):
if args:
if len(args) != 1 or not isinstance(args[0], dict):
Expand Down Expand Up @@ -493,6 +510,7 @@ def ping(self):
def randomkey(self):
pass

@_lua_reply(_lua_bool_ok)
def rename(self, src, dst):
try:
value = self._db[src]
Expand Down Expand Up @@ -636,9 +654,11 @@ def type(self, name):
assert key is None
return b'none'

@_lua_reply(_lua_bool_ok)
def watch(self, *names):
pass

@_lua_reply(_lua_bool_ok)
def unwatch(self):
pass

Expand Down Expand Up @@ -754,14 +774,33 @@ def eval(self, script, numkeys, *keys_and_args):

return self._convert_lua_result(result, nested=False)

def _convert_redis_result(self, result):
def _convert_redis_result(self, lua_runtime, result):
if isinstance(result, dict):
return [
i
for item in result.items()
for i in item
]
return result
elif isinstance(result, set):
converted = sorted(
self._convert_redis_result(lua_runtime, item)
for item in result
)
return lua_runtime.table_from(converted)
elif isinstance(result, (list, set, tuple)):
converted = [
self._convert_redis_result(lua_runtime, item)
for item in result
]
return lua_runtime.table_from(converted)
elif isinstance(result, bool):
return int(result)
elif isinstance(result, float):
return to_bytes(result)
elif result is None:
return False
else:
return result

def _convert_lua_result(self, result, nested=True):
from lupa import lua_type
Expand Down Expand Up @@ -842,7 +881,9 @@ def _lua_redis_call(self, lua_runtime, expected_globals, op, *args):
'incrby': FakeStrictRedis.incr
}
func = special_cases[op] if op in special_cases else getattr(FakeStrictRedis, op)
return self._convert_redis_result(func(self, *args))
result = func(self, *args)
converter = getattr(func, '_lua_reply', self._convert_redis_result)
return converter(lua_runtime, result)

def _retrieve_data_from_sort(self, data, get):
if get is not None:
Expand Down Expand Up @@ -965,6 +1006,7 @@ def lpop(self, name):
except IndexError:
return None

@_lua_reply(_lua_bool_ok)
def lset(self, name, index, value):
try:
lst = self._get_list_or_none(name)
Expand All @@ -973,10 +1015,12 @@ def lset(self, name, index, value):
lst[index] = to_bytes(value)
except IndexError:
raise redis.ResponseError("index out of range")
return True

def rpushx(self, name, value):
self._get_list(name).append(to_bytes(value))

@_lua_reply(_lua_bool_ok)
def ltrim(self, name, start, end):
val = self._get_list_or_none(name)
if val is not None:
Expand Down Expand Up @@ -1882,6 +1926,7 @@ def pfcount(self, *sources):
"""
return len(self.sunion(*sources))

@_lua_reply(_lua_bool_ok)
def pfmerge(self, dest, *sources):
"Merge N different HyperLogLogs into a single one."
self.sunionstore(dest, sources)
Expand Down
123 changes: 102 additions & 21 deletions test_fakenewsredis.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,6 @@ def test_rpush_then_lrange_with_nested_list1(self):
self.assertEqual(self.redis.lrange(
'foo', 0, -1), ['[12345L, 6789L]', '[54321L, 9876L]'] if PY2 else
[b'[12345, 6789]', b'[54321, 9876]'])
self.redis.flushall()

def test_rpush_then_lrange_with_nested_list2(self):
self.assertEqual(self.redis.rpush('foo', [long(12345), 'banana']), 1)
Expand All @@ -597,7 +596,6 @@ def test_rpush_then_lrange_with_nested_list2(self):
'foo', 0, -1),
['[12345L, \'banana\']', '[54321L, \'elephant\']'] if PY2 else
[b'[12345, \'banana\']', b'[54321, \'elephant\']'])
self.redis.flushall()

def test_rpush_then_lrange_with_nested_list3(self):
self.assertEqual(self.redis.rpush('foo', [long(12345), []]), 1)
Expand All @@ -606,7 +604,6 @@ def test_rpush_then_lrange_with_nested_list3(self):
self.assertEqual(self.redis.lrange(
'foo', 0, -1), ['[12345L, []]', '[54321L, []]'] if PY2 else
[b'[12345, []]', b'[54321, []]'])
self.redis.flushall()

def test_lpush_then_lrange_all(self):
self.assertEqual(self.redis.lpush('foo', 'bar'), 1)
Expand Down Expand Up @@ -2309,7 +2306,7 @@ def test_multidb(self):
self.assertEqual(r1['r1'], b'r1')
self.assertEqual(r2['r2'], b'r2')

r1.flushall()
self.assertEqual(r1.flushall(), True)

self.assertTrue('r1' not in r1)
self.assertTrue('r2' not in r2)
Expand Down Expand Up @@ -3046,13 +3043,6 @@ def test_set_existing_key_persists(self):
self.redis.set('foo', 'foo')
self.assertEqual(self.redis.ttl('foo'), -1)

def test_eval_delete(self):
self.redis.set('foo', 'bar')
val = self.redis.get('foo')
self.assertEqual(val, b'bar')
val = self.redis.eval('redis.call("DEL", KEYS[1])', 1, 'foo')
self.assertIsNone(val)

def test_eval_set_value_to_arg(self):
self.redis.eval('redis.call("SET", KEYS[1], ARGV[1])', 1, 'foo', 'bar')
val = self.redis.get('foo')
Expand All @@ -3074,11 +3064,6 @@ def test_eval_conditional(self):
val = self.redis.get('foo')
self.assertEqual(val, b'baz')

def test_eval_lrange(self):
self.redis.lpush("foo", "bar")
val = self.redis.eval('return redis.call("LRANGE", KEYS[1], 0, 1)', 1, 'foo')
self.assertEqual(val, [b'bar'])

def test_eval_table(self):
lua = """
local a = {}
Expand Down Expand Up @@ -3168,23 +3153,23 @@ def test_eval_runtime_error(self):
with self.assertRaises(ResponseError):
self.redis.eval('error("CRASH")', 0)

def test_more_keys_than_args(self):
def test_eval_more_keys_than_args(self):
with self.assertRaises(ResponseError):
self.redis.eval('return 1', 42)

def test_numkeys_float_string(self):
def test_eval_numkeys_float_string(self):
with self.assertRaises(ResponseError):
self.redis.eval('return KEYS[1]', '0.7', 'foo')

def test_numkeys_integer_string(self):
def test_eval_numkeys_integer_string(self):
val = self.redis.eval('return KEYS[1]', "1", "foo")
self.assertEqual(val, b'foo')

def test_numkeys_negative(self):
def test_eval_numkeys_negative(self):
with self.assertRaises(ResponseError):
self.redis.eval('return KEYS[1]', -1, "foo")

def test_numkeys_float(self):
def test_eval_numkeys_float(self):
with self.assertRaises(ResponseError):
self.redis.eval('return KEYS[1]', 0.7, "foo")

Expand Down Expand Up @@ -3271,6 +3256,102 @@ def test_eval_pcall_return_value(self):
with self.assertRaises(ResponseError):
self.redis.eval('return redis.pcall("foo")', 0)

def test_eval_delete(self):
self.redis.set('foo', 'bar')
val = self.redis.get('foo')
self.assertEqual(val, b'bar')
val = self.redis.eval('redis.call("DEL", KEYS[1])', 1, 'foo')
self.assertIsNone(val)

def test_eval_exists(self):
val = self.redis.eval('return redis.call("exists", KEYS[1]) == 0', 1, 'foo')
self.assertEqual(val, 1)

def test_eval_flushdb(self):
self.redis.set('foo', 'bar')
val = self.redis.eval(
'''
local value = redis.call("FLUSHDB");
return type(value) == "table" and value.ok == "OK";
''', 0
)
self.assertEqual(val, 1)

def test_eval_flushall(self):
r1 = self.create_redis(db=0)
r2 = self.create_redis(db=1)

r1['r1'] = 'r1'
r2['r2'] = 'r2'

val = self.redis.eval(
'''
local value = redis.call("FLUSHALL");
return type(value) == "table" and value.ok == "OK";
''', 0
)

self.assertEqual(val, 1)
self.assertNotIn('r1', r1)
self.assertNotIn('r2', r2)

def test_eval_incrbyfloat(self):
self.redis.set('foo', 0.5)
val = self.redis.eval(
'''
local value = redis.call("INCRBYFLOAT", KEYS[1], 2.0);
return type(value) == "string" and tonumber(value) == 2.5;
''', 1, 'foo'
)
self.assertEqual(val, 1)

def test_eval_lrange(self):
self.redis.rpush('foo', 'a', 'b')
val = self.redis.eval(
'''
local value = redis.call("LRANGE", KEYS[1], 0, -1);
return type(value) == "table" and value[1] == "a" and value[2] == "b";
''', 1, 'foo'
)
self.assertEqual(val, 1)

def test_eval_ltrim(self):
self.redis.rpush('foo', 'a', 'b', 'c', 'd')
val = self.redis.eval(
'''
local value = redis.call("LTRIM", KEYS[1], 1, 2);
return type(value) == "table" and value.ok == "OK";
''', 1, 'foo'
)
self.assertEqual(val, 1)
self.assertEqual(self.redis.lrange('foo', 0, -1), [b'b', b'c'])

def test_eval_lset(self):
self.redis.rpush('foo', 'a', 'b')
val = self.redis.eval(
'''
local value = redis.call("LSET", KEYS[1], 0, "z");
return type(value) == "table" and value.ok == "OK";
''', 1, 'foo'
)
self.assertEqual(val, 1)
self.assertEqual(self.redis.lrange('foo', 0, -1), [b'z', b'b'])

def test_eval_sdiff(self):
self.redis.sadd('foo', 'a', 'b', 'c', 'f', 'e', 'd')
self.redis.sadd('bar', 'b')
val = self.redis.eval(
'''
local value = redis.call("SDIFF", KEYS[1], KEYS[2]);
if type(value) ~= "table" then
return redis.error_reply(type(value) .. ", should be table");
else
return value;
end
''', 2, 'foo', 'bar')
# Lua must receive the set *sorted*
self.assertEqual(val, [b'a', b'c', b'd', b'e', b'f'])


class TestFakeRedis(unittest.TestCase):
decode_responses = False
Expand Down