Skip to content
This repository has been archived by the owner on Sep 11, 2019. It is now read-only.

Commit

Permalink
Add eval
Browse files Browse the repository at this point in the history
  • Loading branch information
blfoster committed Feb 6, 2018
2 parents 9f881a8 + 0dd661f commit 507a125
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 1 deletion.
52 changes: 52 additions & 0 deletions fakenewsredis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import types
import re
import functools
from itertools import count

from lupa import LuaRuntime, lua_type

import redis
from redis.exceptions import ResponseError
Expand Down Expand Up @@ -701,6 +704,55 @@ def sort(self, name, start=None, num=None, by=None, get=None, desc=False,
except KeyError:
return []

def eval(self, script, numkeys, *keys_and_args):
"""
Execute the Lua ``script``, specifying the ``numkeys`` the script
will touch and the key names and argument values in ``keys_and_args``.
Returns the result of the script.
In practice, use the object returned by ``register_script``. This
function exists purely for Redis API completion.
"""
lua_runtime = LuaRuntime(unpack_returned_tuples=True)

raw_lua = """
function(KEYS, ARGV, callback)
redis = {{}}
redis.call = callback
{body}
end
""".format(body=script)
keys = (None,) + keys_and_args[:numkeys]
args = (None,) + keys_and_args[numkeys:]

lua_func = lua_runtime.eval(raw_lua)
result = lua_func(
keys,
args,
self._lua_callback
)
if lua_type(result) == 'table':
# Convert Lua tables into lists, starting from index 1, mimicking the behavior of StrictRedis.
result_list = []
for index in count(1):
if index not in result:
break
item = result[index]
result_list.append(
item.encode() if isinstance(item, str) and not isinstance(item, bytes) else item
)
return result_list
return result

def _lua_callback(self, op, *args):
special_cases = {
'del': self.delete,
'decrby': self.decr,
'incrby': self.incr
}
op = op.lower()
func = special_cases[op] if op in special_cases else getattr(self, op)
return func(*args)

def _retrive_data_from_sort(self, data, get):
if get is not None:
if isinstance(get, string_types):
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
flake8<3.0.0
nose==1.3.4
redis==2.10.6
lupa==1.6
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='fakenewsredis',
version='0.9.4',
version='0.9.5',
description="Fake implementation of redis API for testing purposes.",
long_description=open(os.path.join(os.path.dirname(__file__),
'README.rst')).read(),
Expand Down
53 changes: 53 additions & 0 deletions test_fakenewsredis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3046,6 +3046,59 @@ def test_set_existing_key_persists(self):
self.redis.set('foo', 'foo')
self.assertEqual(self.redis.ttl('foo'), -1)

def test_eval_delete(self):
self.redis.set('foo', 'bar')
val = self.redis.get('foo')
self.assertEqual(val, b'bar')
val = self.redis.eval('redis.call("DEL", KEYS[1])', 1, 'foo')
self.assertIsNone(val)

def test_eval_set_value_to_arg(self):
self.redis.eval('redis.call("SET", KEYS[1], ARGV[1])', 1, 'foo', 'bar')
val = self.redis.get('foo')
self.assertEqual(val, b'bar')

def test_eval_conditional(self):
lua = """
local val = redis.call("GET", KEYS[1])
if val == ARGV[1] then
redis.call("SET", KEYS[1], ARGV[2])
else
redis.call("SET", KEYS[1], ARGV[1])
end
"""
self.redis.eval(lua, 1, 'foo', 'bar', 'baz')
val = self.redis.get('foo')
self.assertEqual(val, b'bar')
self.redis.eval(lua, 1, 'foo', 'bar', 'baz')
val = self.redis.get('foo')
self.assertEqual(val, b'baz')

def test_eval_lrange(self):
self.redis.lpush("foo", "bar")
val = self.redis.eval('return redis.call("LRANGE", KEYS[1], 0, 1)', 1, 'foo')
self.assertEqual(val, [b'bar'])

def test_eval_table(self):
lua = """
local a = {}
a[1] = "foo"
a[2] = "bar"
a[17] = "baz"
return a
"""
val = self.redis.eval(lua, 0)
self.assertEqual(val, [b'foo', b'bar'])

def test_eval_table_with_numbers(self):
lua = """
local a = {}
a[1] = 42
return a
"""
val = self.redis.eval(lua, 0)
self.assertEqual(val, [42])


class TestFakeRedis(unittest.TestCase):
decode_responses = False
Expand Down

0 comments on commit 507a125

Please sign in to comment.