From a1cd907b18a6de5be5dcd48cadc62d78ee2475b9 Mon Sep 17 00:00:00 2001 From: Bruce Merry Date: Fri, 19 Jan 2018 15:04:01 +0200 Subject: [PATCH] Make empty sets/lists/hashes disappear This fixes jamesls/fakeredis#155. --- fakenewsredis.py | 42 ++++++++++++++++++++++++++++++++++++++++-- test_fakenewsredis.py | 27 +++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/fakenewsredis.py b/fakenewsredis.py index 79930f1..56651e4 100644 --- a/fakenewsredis.py +++ b/fakenewsredis.py @@ -13,6 +13,7 @@ import time import types import re +import functools import redis from redis.exceptions import ResponseError @@ -218,6 +219,16 @@ def _patch_responses(obj): setattr(obj, attr_name, func) +def _remove_empty(func): + @functools.wraps(func) + def wrapper(self, key, *args, **kwargs): + ret = func(self, key, *args, **kwargs) + self._remove_if_empty(key) + return ret + + return wrapper + + class _Lock(object): def __init__(self, redis, name, timeout): self.redis = redis @@ -274,6 +285,15 @@ def flushall(self): del self._pubsubs[:] + def _remove_if_empty(self, key): + try: + value = self._db[key] + except KeyError: + pass + else: + if not value: + del self._db[key] + def _get_string(self, name, default=b''): value = self._db.get(name, default) # Allow None so that default can be set as None @@ -767,6 +787,7 @@ def lrange(self, name, start, end): def llen(self, name): return len(self._get_list(name)) + @_remove_empty def lrem(self, name, count, value): value = to_bytes(value) a_list = self._get_list(name) @@ -790,6 +811,7 @@ def rpush(self, name, *values): self._setdefault_list(name).extend([to_bytes(x) for x in values]) return len(self._db[name]) + @_remove_empty def lpop(self, name): try: return self._get_list(name).pop(0) @@ -827,6 +849,7 @@ def lindex(self, name, index): def lpushx(self, name, value): self._get_list(name).insert(0, to_bytes(value)) + @_remove_empty def rpop(self, name): try: return self._get_list(name).pop() @@ -875,7 +898,9 @@ def blpop(self, keys, timeout=0): for key in keys: lst = self._get_list(key) if lst: - return (key, lst.pop(0)) + ret = (key, lst.pop(0)) + self._remove_if_empty(key) + return ret def brpop(self, keys, timeout=0): if isinstance(keys, string_types): @@ -885,7 +910,9 @@ def brpop(self, keys, timeout=0): for key in keys: lst = self._get_list(key) if lst: - return (key, lst.pop()) + ret = (key, lst.pop()) + self._remove_if_empty(key) + return ret def brpoplpush(self, src, dst, timeout=0): return self.rpoplpush(src, dst) @@ -902,6 +929,7 @@ def _setdefault_hash(self, name): raise redis.ResponseError(_WRONGTYPE_MSG) return value + @_remove_empty def hdel(self, name, *keys): h = self._get_hash(name) rem = 0 @@ -1030,6 +1058,7 @@ def sdiff(self, keys, *args): diff -= self._get_set(key) return diff + @_remove_empty def sdiffstore(self, dest, keys, *args): """ Store the difference of sets specified by ``keys`` into a new @@ -1047,6 +1076,7 @@ def sinter(self, keys, *args): intersect.intersection_update(self._get_set(key)) return intersect + @_remove_empty def sinterstore(self, dest, keys, *args): """ Store the intersection of sets specified by ``keys`` into a new @@ -1064,6 +1094,7 @@ def smembers(self, name): "Return all members of the set ``name``" return self._get_set(name) + @_remove_empty def smove(self, src, dst, value): value = to_bytes(value) src_set = self._get_set(src) @@ -1075,6 +1106,7 @@ def smove(self, src, dst, value): except KeyError: return False + @_remove_empty def spop(self, name): "Remove and return a random member of set ``name``" try: @@ -1111,6 +1143,7 @@ def srandmember(self, name, number=None): in sorted(random.sample(range(len(members)), number)) ] + @_remove_empty def srem(self, name, *values): "Remove ``value`` from set ``name``" a_set = self._setdefault_set(name) @@ -1287,6 +1320,7 @@ def zincrby(self, name, value, amount=1): d[value] = score return score + @_remove_empty def zinterstore(self, dest, keys, aggregate=None): """ Intersect multiple sorted sets specified by ``keys`` into @@ -1422,6 +1456,7 @@ def zrank(self, name, value): except ValueError: return None + @_remove_empty def zrem(self, name, *values): "Remove member ``value`` from sorted set ``name``" z = self._get_zset(name) @@ -1432,6 +1467,7 @@ def zrem(self, name, *values): rem += 1 return rem + @_remove_empty def zremrangebyrank(self, name, min, max): """ Remove all elements in the sorted set ``name`` with ranks between @@ -1451,6 +1487,7 @@ def zremrangebyrank(self, name, min, max): num_deleted += 1 return num_deleted + @_remove_empty def zremrangebyscore(self, name, min, max): """ Remove all elements in the sorted set ``name`` with scores @@ -1465,6 +1502,7 @@ def zremrangebyscore(self, name, min, max): removed += 1 return removed + @_remove_empty def zremrangebylex(self, name, min, max): """ Remove all elements in the sorted set ``name`` diff --git a/test_fakenewsredis.py b/test_fakenewsredis.py index 064d24f..183591e 100644 --- a/test_fakenewsredis.py +++ b/test_fakenewsredis.py @@ -873,6 +873,11 @@ def test_rpoplpush_expiry(self): self.redis.rpoplpush('foo', 'bar') self.assertGreater(self.redis.ttl('bar'), 0) + def test_rpoplpush_one_to_self(self): + self.redis.rpush('list', 'element') + self.assertEqual(self.redis.brpoplpush('list', 'list'), b'element') + self.assertEqual(self.redis.lrange('list', 0, -1), [b'element']) + def test_rpoplpush_wrong_type(self): self.redis.set('foo', 'bar') self.redis.rpush('list', 'element') @@ -896,6 +901,7 @@ def test_blpop_test_multiple_lists(self): self.redis.rpush('baz', 'zero') self.assertEqual(self.redis.blpop(['foo', 'baz'], timeout=1), (b'baz', b'zero')) + self.assertFalse(self.redis.exists('baz')) self.redis.rpush('foo', 'one') self.redis.rpush('foo', 'two') @@ -925,6 +931,7 @@ def test_brpop_test_multiple_lists(self): self.redis.rpush('baz', 'zero') self.assertEqual(self.redis.brpop(['foo', 'baz'], timeout=1), (b'baz', b'zero')) + self.assertFalse(self.redis.exists('baz')) self.redis.rpush('foo', 'one') self.redis.rpush('foo', 'two') @@ -977,6 +984,11 @@ def test_blocking_operations_when_empty(self): self.assertEqual(self.redis.brpoplpush('foo', 'bar', timeout=1), None) + def test_empty_list(self): + self.redis.rpush('foo', 'bar') + self.redis.rpop('foo') + self.assertFalse(self.redis.exists('foo')) + # Tests for the hash type. def test_hset_then_hget(self): @@ -1183,6 +1195,11 @@ def test_hmset_wrong_type(self): with self.assertRaises(redis.ResponseError): self.redis.hmset('foo', {'key': 'value'}) + def test_empty_hash(self): + self.redis.hset('foo', 'bar', 'baz') + self.redis.hdel('foo', 'bar') + self.assertFalse(self.redis.exists('foo')) + def test_sadd(self): self.assertEqual(self.redis.sadd('foo', 'member1'), 1) self.assertEqual(self.redis.sadd('foo', 'member1'), 0) @@ -1485,6 +1502,11 @@ def test_sunionstore(self): self.redis.sadd('baz', 'member3') self.assertEqual(self.redis.scard('baz'), 3) + def test_empty_set(self): + self.redis.sadd('foo', 'bar') + self.redis.srem('foo', 'bar') + self.assertFalse(self.redis.exists('foo')) + def test_zadd(self): self.redis.zadd('foo', four=4) self.redis.zadd('foo', three=3) @@ -2251,6 +2273,11 @@ def test_zinterstore_wrong_type(self): with self.assertRaises(redis.ResponseError): self.redis.zinterstore('baz', ['foo', 'bar']) + def test_empty_zset(self): + self.redis.zadd('foo', one=1) + self.redis.zrem('foo', 'one') + self.assertFalse(self.redis.exists('foo')) + def test_multidb(self): r1 = self.create_redis(db=0) r2 = self.create_redis(db=1)