diff --git a/fakenewsredis.py b/fakenewsredis.py index 6e8fdcc..79930f1 100644 --- a/fakenewsredis.py +++ b/fakenewsredis.py @@ -142,6 +142,18 @@ def expire(self, key, timestamp): value = self._dict[to_bytes(key)][0] self._dict[to_bytes(key)] = (value, timestamp) + def setx(self, key, value, src=None): + """Set a value, keeping the existing expiry time if any. If + `src` is specified, it is used as the source of the expiry + """ + if src is None: + src = key + try: + _, expiration = self._dict[to_bytes(src)] + except KeyError: + expiration = None + self._dict[to_bytes(key)] = (value, expiration) + def persist(self, key): try: value, _ = self._dict[to_bytes(key)] @@ -294,11 +306,12 @@ def bitcount(self, name, start=0, end=-1): def decr(self, name, amount=1): try: - self._db[name] = to_bytes(int(self._get_string(name, b'0')) - amount) + value = int(self._get_string(name, b'0')) - amount + self._db.setx(name, to_bytes(value)) except (TypeError, ValueError): raise redis.ResponseError("value is not an integer or out of " "range.") - return int(self._db[name]) + return value def exists(self, name): return name in self._db @@ -381,11 +394,12 @@ def incr(self, name, amount=1): if not isinstance(amount, int): raise redis.ResponseError("value is not an integer or out " "of range.") - self._db[name] = to_bytes(int(self._get_string(name, b'0')) + amount) + value = int(self._get_string(name, b'0')) + amount + self._db.setx(name, to_bytes(value)) except (TypeError, ValueError): raise redis.ResponseError("value is not an integer or out of " "range.") - return int(self._db[name]) + return value def incrby(self, name, amount=1): """ @@ -395,10 +409,11 @@ def incrby(self, name, amount=1): def incrbyfloat(self, name, amount=1.0): try: - self._db[name] = to_bytes(float(self._get_string(name, b'0')) + amount) + value = float(self._get_string(name, b'0')) + amount + self._db.setx(name, to_bytes(value)) except (TypeError, ValueError): raise redis.ResponseError("value is not a valid float.") - return float(self._db[name]) + return value def keys(self, pattern=None): return [key for key in self._db @@ -457,7 +472,7 @@ def rename(self, src, dst): value = self._db[src] except KeyError: raise redis.ResponseError("No such key: %s" % src) - self._db[dst] = value + self._db.setx(dst, value, src=src) del self._db[src] return True @@ -512,7 +527,7 @@ def setbit(self, name, offset, value): new_byte = byte_to_int(val[byte]) ^ (1 << actual_bitoffset) reconstructed = bytearray(val) reconstructed[byte] = new_byte - self._db[name] = bytes(reconstructed) + self._db.setx(name, bytes(reconstructed)) def setex(self, name, time, value): if isinstance(time, timedelta): @@ -541,7 +556,7 @@ def setrange(self, name, offset, value): if len(val) < offset: val += b'\x00' * (offset - len(val)) val = val[0:offset] + to_bytes(value) + val[offset+len(value):] - self.set(name, val) + self._db.setx(name, val) return len(val) def strlen(self, name): @@ -800,7 +815,7 @@ def ltrim(self, name, start, end): end = None else: end += 1 - self._db[name] = val[start:end] + self._db.setx(name, val[start:end]) return True def lindex(self, name, index): @@ -843,7 +858,7 @@ def rpoplpush(self, src, dst): if el is not None: el = to_bytes(el) dst_list.insert(0, el) - self._db[dst] = dst_list + self._db.setx(dst, dst_list) return el def blpop(self, keys, timeout=0): diff --git a/test_fakenewsredis.py b/test_fakenewsredis.py index d60751f..064d24f 100644 --- a/test_fakenewsredis.py +++ b/test_fakenewsredis.py @@ -201,6 +201,11 @@ def test_setbit_wrong_type(self): with self.assertRaises(redis.ResponseError): self.redis.setbit('foo', 0, 1) + def test_setbit_expiry(self): + self.redis.set('foo', b'0x00', ex=10) + self.redis.setbit('foo', 1, 1) + self.assertGreater(self.redis.ttl('foo'), 0) + def test_bitcount(self): self.redis.delete('foo') self.assertEqual(self.redis.bitcount('foo'), 0) @@ -296,6 +301,11 @@ def test_incr_preexisting_key(self): self.assertEqual(self.redis.incr('foo', 5), 20) self.assertEqual(self.redis.get('foo'), b'20') + def test_incr_expiry(self): + self.redis.set('foo', 15, ex=10) + self.redis.incr('foo', 5) + self.assertGreater(self.redis.ttl('foo'), 0) + def test_incr_bad_type(self): self.redis.set('foo', 'bar') with self.assertRaises(redis.ResponseError): @@ -326,6 +336,11 @@ def test_incrbyfloat_with_noexist(self): self.assertEqual(self.redis.incrbyfloat('foo', 1.0), 1.0) self.assertEqual(self.redis.incrbyfloat('foo', 1.0), 2.0) + def test_incrbyfloat_expiry(self): + self.redis.set('foo', 1.5, ex=10) + self.redis.incrbyfloat('foo', 2.5) + self.assertGreater(self.redis.ttl('foo'), 0) + def test_incrbyfloat_bad_type(self): self.redis.set('foo', 'bar') with self.assertRaisesRegexp(redis.ResponseError, 'not a valid float'): @@ -348,6 +363,11 @@ def test_decr_newkey(self): self.redis.decr('foo') self.assertEqual(self.redis.get('foo'), b'-1') + def test_decr_expiry(self): + self.redis.set('foo', 10, ex=10) + self.redis.decr('foo', 5) + self.assertGreater(self.redis.ttl('foo'), 0) + def test_decr_badtype(self): self.redis.set('foo', 'bar') with self.assertRaises(redis.ResponseError): @@ -389,6 +409,12 @@ def test_rename_does_exist(self): self.assertEqual(self.redis.get('foo'), b'unique value') self.assertEqual(self.redis.get('bar'), b'unique value2') + def test_rename_expiry(self): + self.redis.set('foo', 'value1', ex=10) + self.redis.set('bar', 'value2') + self.redis.rename('foo', 'bar') + self.assertGreater(self.redis.ttl('bar'), 0) + def test_mget(self): self.redis.set('foo', 'one') self.redis.set('bar', 'two') @@ -743,6 +769,12 @@ def test_ltrim(self): def test_ltrim_with_non_existent_key(self): self.assertTrue(self.redis.ltrim('foo', 0, -1)) + def test_ltrim_expiry(self): + self.redis.rpush('foo', 'one', 'two', 'three') + self.redis.expire('foo', 10) + self.redis.ltrim('foo', 1, 2) + self.assertGreater(self.redis.ttl('foo'), 0) + def test_ltrim_wrong_type(self): self.redis.set('foo', 'bar') with self.assertRaises(redis.ResponseError): @@ -834,6 +866,13 @@ def test_rpoplpush_to_nonexistent_destination(self): self.assertEqual(self.redis.rpoplpush('foo', 'bar'), b'one') self.assertEqual(self.redis.rpop('bar'), b'one') + def test_rpoplpush_expiry(self): + self.redis.rpush('foo', 'one') + self.redis.rpush('bar', 'two') + self.redis.expire('bar', 10) + self.redis.rpoplpush('foo', 'bar') + self.assertGreater(self.redis.ttl('bar'), 0) + def test_rpoplpush_wrong_type(self): self.redis.set('foo', 'bar') self.redis.rpush('list', 'element') @@ -1273,6 +1312,11 @@ def test_setrange(self): self.assertEqual(self.redis.setrange('bar', 2, 'test'), 6) self.assertEqual(self.redis.get('bar'), b'\x00\x00test') + def test_setrange_expiry(self): + self.redis.set('foo', 'test', ex=10) + self.redis.setrange('foo', 1, 'aste') + self.assertGreater(self.redis.ttl('foo'), 0) + def test_sinter(self): self.redis.sadd('foo', 'member1') self.redis.sadd('foo', 'member2')