diff --git a/README.rst b/README.rst index 29db4d7..18c894b 100644 --- a/README.rst +++ b/README.rst @@ -203,7 +203,6 @@ scripting * script kill * script load * evalsha - * eval * script exists @@ -266,6 +265,10 @@ they have all been tagged as 'slow' so you can skip them by running:: Revision history ================ +Development version +----- +- `#9 `_ Add support for StrictRedis.eval for Lua scripts + 0.9.4 ----- This is a minor bugfix and optimization release: diff --git a/fakenewsredis.py b/fakenewsredis.py index a1506cd..101b816 100644 --- a/fakenewsredis.py +++ b/fakenewsredis.py @@ -14,6 +14,7 @@ import types import re import functools +from itertools import count import redis from redis.exceptions import ResponseError @@ -701,6 +702,148 @@ def sort(self, name, start=None, num=None, by=None, get=None, desc=False, except KeyError: return [] + def eval(self, script, numkeys, *keys_and_args): + from lupa import LuaRuntime, LuaError + + if any( + isinstance(numkeys, t) for t in (text_type, str, bytes) + ): + try: + numkeys = int(numkeys) + except ValueError: + # Non-numeric string will be handled below. + pass + if not(isinstance(numkeys, int)): + raise ResponseError("value is not an integer or out of range") + elif numkeys > len(keys_and_args): + raise ResponseError("Number of keys can't be greater than number of args") + elif numkeys < 0: + raise ResponseError("Number of keys can't be negative") + + keys_and_args = [to_bytes(v) for v in keys_and_args] + lua_runtime = LuaRuntime(unpack_returned_tuples=True) + + set_globals = lua_runtime.eval( + """ + function(keys, argv, redis_call, redis_pcall) + redis = {} + redis.call = redis_call + redis.pcall = redis_pcall + redis.error_reply = function(msg) return {err=msg} end + redis.status_reply = function(msg) return {ok=msg} end + KEYS = keys + ARGV = argv + end + """ + ) + expected_globals = set() + set_globals( + [None] + keys_and_args[:numkeys], + [None] + keys_and_args[numkeys:], + functools.partial(self._lua_redis_call, lua_runtime, expected_globals), + functools.partial(self._lua_redis_pcall, lua_runtime, expected_globals) + ) + expected_globals.update(lua_runtime.globals().keys()) + + try: + result = lua_runtime.execute(script) + except LuaError as ex: + raise ResponseError(ex) + + self._check_for_lua_globals(lua_runtime, expected_globals) + + return self._convert_lua_result(result, nested=False) + + def _convert_redis_result(self, result): + if isinstance(result, dict): + return [ + i + for item in result.items() + for i in item + ] + return result + + def _convert_lua_result(self, result, nested=True): + from lupa import lua_type + if lua_type(result) == 'table': + for key in ('ok', 'err'): + if key in result: + msg = self._convert_lua_result(result[key]) + if not isinstance(msg, bytes): + raise ResponseError("wrong number or type of arguments") + if key == 'ok': + return msg + elif nested: + return ResponseError(msg) + else: + raise ResponseError(msg) + # Convert Lua tables into lists, starting from index 1, mimicking the behavior of StrictRedis. + result_list = [] + for index in count(1): + if index not in result: + break + item = result[index] + result_list.append(self._convert_lua_result(item)) + return result_list + elif isinstance(result, text_type): + return to_bytes(result) + elif isinstance(result, float): + return int(result) + elif isinstance(result, bool): + return 1 if result else None + return result + + def _check_for_lua_globals(self, lua_runtime, expected_globals): + actual_globals = set(lua_runtime.globals().keys()) + if actual_globals != expected_globals: + raise ResponseError( + "Script attempted to set a global variables: %s" % ", ".join( + actual_globals - expected_globals + ) + ) + + def _lua_redis_pcall(self, lua_runtime, expected_globals, op, *args): + try: + return self._lua_redis_call(lua_runtime, expected_globals, op, *args) + except Exception as ex: + return lua_runtime.table_from({"err": str(ex)}) + + def _lua_redis_call(self, lua_runtime, expected_globals, op, *args): + # Check if we've set any global variables before making any change. + self._check_for_lua_globals(lua_runtime, expected_globals) + # These commands aren't necessarily all implemented, but if op is not one of these commands, we expect + # a ResponseError for consistency with Redis + commands = [ + 'append', 'auth', 'bitcount', 'bitfield', 'bitop', 'bitpos', 'blpop', 'brpop', 'brpoplpush', + 'decr', 'decrby', 'del', 'dump', 'echo', 'eval', 'evalsha', 'exists', 'expire', 'expireat', + 'flushall', 'flushdb', 'geoadd', 'geodist', 'geohash', 'geopos', 'georadius', 'georadiusbymember', + 'get', 'getbit', 'getrange', 'getset', 'hdel', 'hexists', 'hget', 'hgetall', 'hincrby', + 'hincrbyfloat', 'hkeys', 'hlen', 'hmget', 'hmset', 'hscan', 'hset', 'hsetnx', 'hstrlen', 'hvals', + 'incr', 'incrby', 'incrbyfloat', 'info', 'keys', 'lindex', 'linsert', 'llen', 'lpop', 'lpush', + 'lpushx', 'lrange', 'lrem', 'lset', 'ltrim', 'mget', 'migrate', 'move', 'mset', 'msetnx', + 'object', 'persist', 'pexpire', 'pexpireat', 'pfadd', 'pfcount', 'pfmerge', 'ping', 'psetex', + 'psubscribe', 'pttl', 'publish', 'pubsub', 'punsubscribe', 'rename', 'renamenx', 'restore', + 'rpop', 'rpoplpush', 'rpush', 'rpushx', 'sadd', 'scan', 'scard', 'sdiff', 'sdiffstore', 'select', + 'set', 'setbit', 'setex', 'setnx', 'setrange', 'shutdown', 'sinter', 'sinterstore', 'sismember', + 'slaveof', 'slowlog', 'smembers', 'smove', 'sort', 'spop', 'srandmember', 'srem', 'sscan', + 'strlen', 'subscribe', 'sunion', 'sunionstore', 'swapdb', 'touch', 'ttl', 'type', 'unlink', + 'unsubscribe', 'wait', 'watch', 'zadd', 'zcard', 'zcount', 'zincrby', 'zinterstore', 'zlexcount', + 'zrange', 'zrangebylex', 'zrangebyscore', 'zrank', 'zrem', 'zremrangebylex', 'zremrangebyrank', + 'zremrangebyscore', 'zrevrange', 'zrevrangebylex', 'zrevrangebyscore', 'zrevrank', 'zscan', + 'zscore', 'zunionstore' + ] + + op = op.lower() + if op not in commands: + raise ResponseError("Unknown Redis command called from Lua script") + special_cases = { + 'del': FakeStrictRedis.delete, + 'decrby': FakeStrictRedis.decr, + 'incrby': FakeStrictRedis.incr + } + func = special_cases[op] if op in special_cases else getattr(FakeStrictRedis, op) + return self._convert_redis_result(func(self, *args)) + def _retrive_data_from_sort(self, data, get): if get is not None: if isinstance(get, string_types): diff --git a/requirements.txt b/requirements.txt index 12cc5c9..00fab91 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ flake8<3.0.0 nose==1.3.4 redis==2.10.6 +lupa==1.6 \ No newline at end of file diff --git a/setup.py b/setup.py index 3f98e43..36b76c4 100644 --- a/setup.py +++ b/setup.py @@ -30,5 +30,8 @@ ], install_requires=[ 'redis', - ] + ], + extras_require={ + "lua": ['lupa'] + } ) diff --git a/test_fakenewsredis.py b/test_fakenewsredis.py index 403d885..e6ede44 100644 --- a/test_fakenewsredis.py +++ b/test_fakenewsredis.py @@ -3046,6 +3046,231 @@ def test_set_existing_key_persists(self): self.redis.set('foo', 'foo') self.assertEqual(self.redis.ttl('foo'), -1) + def test_eval_delete(self): + self.redis.set('foo', 'bar') + val = self.redis.get('foo') + self.assertEqual(val, b'bar') + val = self.redis.eval('redis.call("DEL", KEYS[1])', 1, 'foo') + self.assertIsNone(val) + + def test_eval_set_value_to_arg(self): + self.redis.eval('redis.call("SET", KEYS[1], ARGV[1])', 1, 'foo', 'bar') + val = self.redis.get('foo') + self.assertEqual(val, b'bar') + + def test_eval_conditional(self): + lua = """ + local val = redis.call("GET", KEYS[1]) + if val == ARGV[1] then + redis.call("SET", KEYS[1], ARGV[2]) + else + redis.call("SET", KEYS[1], ARGV[1]) + end + """ + self.redis.eval(lua, 1, 'foo', 'bar', 'baz') + val = self.redis.get('foo') + self.assertEqual(val, b'bar') + self.redis.eval(lua, 1, 'foo', 'bar', 'baz') + val = self.redis.get('foo') + self.assertEqual(val, b'baz') + + def test_eval_lrange(self): + self.redis.lpush("foo", "bar") + val = self.redis.eval('return redis.call("LRANGE", KEYS[1], 0, 1)', 1, 'foo') + self.assertEqual(val, [b'bar']) + + def test_eval_table(self): + lua = """ + local a = {} + a[1] = "foo" + a[2] = "bar" + a[17] = "baz" + return a + """ + val = self.redis.eval(lua, 0) + self.assertEqual(val, [b'foo', b'bar']) + + def test_eval_table_with_nil(self): + lua = """ + local a = {} + a[1] = "foo" + a[2] = nil + a[3] = "bar" + return a + """ + val = self.redis.eval(lua, 0) + self.assertEqual(val, [b'foo']) + + def test_eval_table_with_numbers(self): + lua = """ + local a = {} + a[1] = 42 + return a + """ + val = self.redis.eval(lua, 0) + self.assertEqual(val, [42]) + + def test_eval_nested_table(self): + lua = """ + local a = {} + a[1] = {} + a[1][1] = "foo" + return a + """ + val = self.redis.eval(lua, 0) + self.assertEqual(val, [[b'foo']]) + + def test_eval_mget(self): + self.redis.set('foo1', 'bar1') + self.redis.set('foo2', 'bar2') + val = self.redis.eval('return redis.call("mget", "foo1", "foo2")', 2, 'foo1', 'foo2') + self.assertEqual(val, [b'bar1', b'bar2']) + + def test_eval_mget_none(self): + self.redis.set('foo1', None) + self.redis.set('foo2', None) + val = self.redis.eval('return redis.call("mget", "foo1", "foo2")', 2, 'foo1', 'foo2') + self.assertEqual(val, [b'None', b'None']) + + def test_eval_mget_not_set(self): + val = self.redis.eval('return redis.call("mget", "foo1", "foo2")', 2, 'foo1', 'foo2') + self.assertEqual(val, [None, None]) + + def test_eval_hgetall(self): + self.redis.hset('foo', 'k1', 'bar') + self.redis.hset('foo', 'k2', 'baz') + val = self.redis.eval('return redis.call("hgetall", "foo")', 1, 'foo') + sorted_val = sorted([val[:2], val[2:]]) + self.assertEqual( + sorted_val, + [[b'k1', b'bar'], [b'k2', b'baz']] + ) + + def test_eval_list_with_nil(self): + self.redis.lpush('foo', 'bar') + self.redis.lpush('foo', None) + self.redis.lpush('foo', 'baz') + val = self.redis.eval('return redis.call("lrange", KEYS[1], 0, 2)', 1, 'foo') + self.assertEqual(val, [b'baz', b'None', b'bar']) + + def test_eval_invalid_command(self): + with self.assertRaises(ResponseError): + self.redis.eval( + 'return redis.call("FOO")', + 0 + ) + + def test_eval_syntax_error(self): + with self.assertRaises(ResponseError): + self.redis.eval('return "', 0) + + def test_eval_runtime_error(self): + with self.assertRaises(ResponseError): + self.redis.eval('error("CRASH")', 0) + + def test_more_keys_than_args(self): + with self.assertRaises(ResponseError): + self.redis.eval('return 1', 42) + + def test_numkeys_float_string(self): + with self.assertRaises(ResponseError): + self.redis.eval('return KEYS[1]', '0.7', 'foo') + + def test_numkeys_integer_string(self): + val = self.redis.eval('return KEYS[1]', "1", "foo") + self.assertEqual(val, b'foo') + + def test_numkeys_negative(self): + with self.assertRaises(ResponseError): + self.redis.eval('return KEYS[1]', -1, "foo") + + def test_numkeys_float(self): + with self.assertRaises(ResponseError): + self.redis.eval('return KEYS[1]', 0.7, "foo") + + def test_eval_global_variable(self): + # Redis doesn't allow script to define global variables + with self.assertRaises(ResponseError): + self.redis.eval('a=10', 0) + + def test_eval_global_and_return_ok(self): + # Redis doesn't allow script to define global variables + with self.assertRaises(ResponseError): + self.redis.eval( + ''' + a=10 + return redis.status_reply("Everything is awesome") + ''', + 0 + ) + + def test_eval_convert_number(self): + # Redis forces all Lua numbers to integer + val = self.redis.eval('return 3.2', 0) + self.assertEqual(val, 3) + val = self.redis.eval('return 3.8', 0) + self.assertEqual(val, 3) + val = self.redis.eval('return -3.8', 0) + self.assertEqual(val, -3) + + def test_eval_convert_bool(self): + # Redis converts true to 1 and false to nil (which redis-py converts to None) + val = self.redis.eval('return false', 0) + self.assertIsNone(val) + val = self.redis.eval('return true', 0) + self.assertEqual(val, 1) + self.assertNotIsInstance(val, bool) + + def test_eval_none_arg(self): + val = self.redis.eval('return ARGV[1] == "None"', 0, None) + self.assertTrue(val) + + def test_eval_return_error(self): + with self.assertRaises(redis.ResponseError) as cm: + self.redis.eval('return {err="Testing"}', 0) + self.assertIn('Testing', str(cm.exception)) + with self.assertRaises(redis.ResponseError) as cm: + self.redis.eval('return redis.error_reply("Testing")', 0) + self.assertIn('Testing', str(cm.exception)) + + def test_eval_return_ok(self): + val = self.redis.eval('return {ok="Testing"}', 0) + self.assertEqual(val, b'Testing') + val = self.redis.eval('return redis.status_reply("Testing")', 0) + self.assertEqual(val, b'Testing') + + def test_eval_return_ok_nested(self): + val = self.redis.eval( + ''' + local a = {} + a[1] = {ok="Testing"} + return a + ''', + 0 + ) + self.assertEqual(val, [b'Testing']) + + def test_eval_return_ok_wrong_type(self): + with self.assertRaises(redis.ResponseError): + self.redis.eval('return redis.status_reply(123)', 0) + + def test_eval_pcall(self): + val = self.redis.eval( + ''' + local a = {} + a[1] = redis.pcall("foo") + return a + ''', + 0 + ) + self.assertIsInstance(val, list) + self.assertEqual(len(val), 1) + self.assertIsInstance(val[0], ResponseError) + + def test_eval_pcall_return_value(self): + with self.assertRaises(ResponseError): + self.redis.eval('return redis.pcall("foo")', 0) + class TestFakeRedis(unittest.TestCase): decode_responses = False