Skip to content
This repository has been archived by the owner on Sep 11, 2019. It is now read-only.

Commit

Permalink
Make empty sets/lists/hashes disappear
Browse files Browse the repository at this point in the history
This fixes jamesls#155.
  • Loading branch information
bmerry committed Jan 19, 2018
1 parent b09b021 commit a1cd907
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 2 deletions.
42 changes: 40 additions & 2 deletions fakenewsredis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import time
import types
import re
import functools

import redis
from redis.exceptions import ResponseError
Expand Down Expand Up @@ -218,6 +219,16 @@ def _patch_responses(obj):
setattr(obj, attr_name, func)


def _remove_empty(func):
@functools.wraps(func)
def wrapper(self, key, *args, **kwargs):
ret = func(self, key, *args, **kwargs)
self._remove_if_empty(key)
return ret

return wrapper


class _Lock(object):
def __init__(self, redis, name, timeout):
self.redis = redis
Expand Down Expand Up @@ -274,6 +285,15 @@ def flushall(self):

del self._pubsubs[:]

def _remove_if_empty(self, key):
try:
value = self._db[key]
except KeyError:
pass
else:
if not value:
del self._db[key]

def _get_string(self, name, default=b''):
value = self._db.get(name, default)
# Allow None so that default can be set as None
Expand Down Expand Up @@ -767,6 +787,7 @@ def lrange(self, name, start, end):
def llen(self, name):
return len(self._get_list(name))

@_remove_empty
def lrem(self, name, count, value):
value = to_bytes(value)
a_list = self._get_list(name)
Expand All @@ -790,6 +811,7 @@ def rpush(self, name, *values):
self._setdefault_list(name).extend([to_bytes(x) for x in values])
return len(self._db[name])

@_remove_empty
def lpop(self, name):
try:
return self._get_list(name).pop(0)
Expand Down Expand Up @@ -827,6 +849,7 @@ def lindex(self, name, index):
def lpushx(self, name, value):
self._get_list(name).insert(0, to_bytes(value))

@_remove_empty
def rpop(self, name):
try:
return self._get_list(name).pop()
Expand Down Expand Up @@ -875,7 +898,9 @@ def blpop(self, keys, timeout=0):
for key in keys:
lst = self._get_list(key)
if lst:
return (key, lst.pop(0))
ret = (key, lst.pop(0))
self._remove_if_empty(key)
return ret

def brpop(self, keys, timeout=0):
if isinstance(keys, string_types):
Expand All @@ -885,7 +910,9 @@ def brpop(self, keys, timeout=0):
for key in keys:
lst = self._get_list(key)
if lst:
return (key, lst.pop())
ret = (key, lst.pop())
self._remove_if_empty(key)
return ret

def brpoplpush(self, src, dst, timeout=0):
return self.rpoplpush(src, dst)
Expand All @@ -902,6 +929,7 @@ def _setdefault_hash(self, name):
raise redis.ResponseError(_WRONGTYPE_MSG)
return value

@_remove_empty
def hdel(self, name, *keys):
h = self._get_hash(name)
rem = 0
Expand Down Expand Up @@ -1030,6 +1058,7 @@ def sdiff(self, keys, *args):
diff -= self._get_set(key)
return diff

@_remove_empty
def sdiffstore(self, dest, keys, *args):
"""
Store the difference of sets specified by ``keys`` into a new
Expand All @@ -1047,6 +1076,7 @@ def sinter(self, keys, *args):
intersect.intersection_update(self._get_set(key))
return intersect

@_remove_empty
def sinterstore(self, dest, keys, *args):
"""
Store the intersection of sets specified by ``keys`` into a new
Expand All @@ -1064,6 +1094,7 @@ def smembers(self, name):
"Return all members of the set ``name``"
return self._get_set(name)

@_remove_empty
def smove(self, src, dst, value):
value = to_bytes(value)
src_set = self._get_set(src)
Expand All @@ -1075,6 +1106,7 @@ def smove(self, src, dst, value):
except KeyError:
return False

@_remove_empty
def spop(self, name):
"Remove and return a random member of set ``name``"
try:
Expand Down Expand Up @@ -1111,6 +1143,7 @@ def srandmember(self, name, number=None):
in sorted(random.sample(range(len(members)), number))
]

@_remove_empty
def srem(self, name, *values):
"Remove ``value`` from set ``name``"
a_set = self._setdefault_set(name)
Expand Down Expand Up @@ -1287,6 +1320,7 @@ def zincrby(self, name, value, amount=1):
d[value] = score
return score

@_remove_empty
def zinterstore(self, dest, keys, aggregate=None):
"""
Intersect multiple sorted sets specified by ``keys`` into
Expand Down Expand Up @@ -1422,6 +1456,7 @@ def zrank(self, name, value):
except ValueError:
return None

@_remove_empty
def zrem(self, name, *values):
"Remove member ``value`` from sorted set ``name``"
z = self._get_zset(name)
Expand All @@ -1432,6 +1467,7 @@ def zrem(self, name, *values):
rem += 1
return rem

@_remove_empty
def zremrangebyrank(self, name, min, max):
"""
Remove all elements in the sorted set ``name`` with ranks between
Expand All @@ -1451,6 +1487,7 @@ def zremrangebyrank(self, name, min, max):
num_deleted += 1
return num_deleted

@_remove_empty
def zremrangebyscore(self, name, min, max):
"""
Remove all elements in the sorted set ``name`` with scores
Expand All @@ -1465,6 +1502,7 @@ def zremrangebyscore(self, name, min, max):
removed += 1
return removed

@_remove_empty
def zremrangebylex(self, name, min, max):
"""
Remove all elements in the sorted set ``name``
Expand Down
27 changes: 27 additions & 0 deletions test_fakenewsredis.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,11 @@ def test_rpoplpush_expiry(self):
self.redis.rpoplpush('foo', 'bar')
self.assertGreater(self.redis.ttl('bar'), 0)

def test_rpoplpush_one_to_self(self):
self.redis.rpush('list', 'element')
self.assertEqual(self.redis.brpoplpush('list', 'list'), b'element')
self.assertEqual(self.redis.lrange('list', 0, -1), [b'element'])

def test_rpoplpush_wrong_type(self):
self.redis.set('foo', 'bar')
self.redis.rpush('list', 'element')
Expand All @@ -896,6 +901,7 @@ def test_blpop_test_multiple_lists(self):
self.redis.rpush('baz', 'zero')
self.assertEqual(self.redis.blpop(['foo', 'baz'], timeout=1),
(b'baz', b'zero'))
self.assertFalse(self.redis.exists('baz'))

self.redis.rpush('foo', 'one')
self.redis.rpush('foo', 'two')
Expand Down Expand Up @@ -925,6 +931,7 @@ def test_brpop_test_multiple_lists(self):
self.redis.rpush('baz', 'zero')
self.assertEqual(self.redis.brpop(['foo', 'baz'], timeout=1),
(b'baz', b'zero'))
self.assertFalse(self.redis.exists('baz'))

self.redis.rpush('foo', 'one')
self.redis.rpush('foo', 'two')
Expand Down Expand Up @@ -977,6 +984,11 @@ def test_blocking_operations_when_empty(self):
self.assertEqual(self.redis.brpoplpush('foo', 'bar', timeout=1),
None)

def test_empty_list(self):
self.redis.rpush('foo', 'bar')
self.redis.rpop('foo')
self.assertFalse(self.redis.exists('foo'))

# Tests for the hash type.

def test_hset_then_hget(self):
Expand Down Expand Up @@ -1183,6 +1195,11 @@ def test_hmset_wrong_type(self):
with self.assertRaises(redis.ResponseError):
self.redis.hmset('foo', {'key': 'value'})

def test_empty_hash(self):
self.redis.hset('foo', 'bar', 'baz')
self.redis.hdel('foo', 'bar')
self.assertFalse(self.redis.exists('foo'))

def test_sadd(self):
self.assertEqual(self.redis.sadd('foo', 'member1'), 1)
self.assertEqual(self.redis.sadd('foo', 'member1'), 0)
Expand Down Expand Up @@ -1485,6 +1502,11 @@ def test_sunionstore(self):
self.redis.sadd('baz', 'member3')
self.assertEqual(self.redis.scard('baz'), 3)

def test_empty_set(self):
self.redis.sadd('foo', 'bar')
self.redis.srem('foo', 'bar')
self.assertFalse(self.redis.exists('foo'))

def test_zadd(self):
self.redis.zadd('foo', four=4)
self.redis.zadd('foo', three=3)
Expand Down Expand Up @@ -2251,6 +2273,11 @@ def test_zinterstore_wrong_type(self):
with self.assertRaises(redis.ResponseError):
self.redis.zinterstore('baz', ['foo', 'bar'])

def test_empty_zset(self):
self.redis.zadd('foo', one=1)
self.redis.zrem('foo', 'one')
self.assertFalse(self.redis.exists('foo'))

def test_multidb(self):
r1 = self.create_redis(db=0)
r2 = self.create_redis(db=1)
Expand Down

0 comments on commit a1cd907

Please sign in to comment.