diff --git a/redis_cache/backends/base.py b/redis_cache/backends/base.py index fd146c02..ebcec8e6 100644 --- a/redis_cache/backends/base.py +++ b/redis_cache/backends/base.py @@ -4,7 +4,7 @@ from django.utils.functional import cached_property from django.utils.importlib import import_module -from redis_cache.compat import smart_bytes, DEFAULT_TIMEOUT +from redis_cache.compat import bytes_type, smart_bytes, DEFAULT_TIMEOUT try: import cPickle as pickle @@ -22,6 +22,26 @@ from redis_cache.utils import CacheKey +from functools import wraps + + +def get_client(write=False): + + def wrapper(method): + + @wraps(method) + def wrapped(self, key, *args, **kwargs): + version = kwargs.pop('version', None) + client = self.get_client(key, write=write) + key = self.make_key(key, version=version) + + return method(self, client, key, *args, **kwargs) + + return wrapped + + return wrapper + + class BaseRedisCache(BaseCache): def __init__(self, server, params): @@ -30,15 +50,20 @@ def __init__(self, server, params): """ super(BaseRedisCache, self).__init__(params) self.server = server + self.servers = self.get_servers(server) self.params = params or {} self.options = params.get('OPTIONS', {}) + self.clients = {} + self.client_list = [] self.db = self.get_db() self.password = self.get_password() self.parser_class = self.get_parser_class() self.pickle_version = self.get_pickle_version() self.connection_pool_class = self.get_connection_pool_class() - self.connection_pool_class_kwargs = self.get_connection_pool_class_kwargs() + self.connection_pool_class_kwargs = ( + self.get_connection_pool_class_kwargs() + ) def __getstate__(self): return {'params': self.params, 'server': self.server} @@ -46,6 +71,20 @@ def __getstate__(self): def __setstate__(self, state): self.__init__(**state) + def get_servers(self, server): + """returns a list of servers given the server argument passed in + from Django. + """ + if isinstance(server, bytes_type): + servers = server.split(',') + elif hasattr(server, '__iter__'): + servers = server + else: + raise ImproperlyConfigured( + '"server" must be an iterable or string' + ) + return servers + def get_db(self): _db = self.params.get('db', self.options.get('DB', 1)) try: @@ -92,6 +131,13 @@ def get_connection_pool_class(self): def get_connection_pool_class_kwargs(self): return self.options.get('CONNECTION_POOL_CLASS_KWARGS', {}) + def get_master_client(self): + """ + Get the write server:port of the master cache + """ + cache = self.options.get('MASTER_CACHE', None) + return self.client_list[0] if cache is None else self.create_client(cache) + def create_client(self, server): kwargs = { 'db': self.db, @@ -172,33 +218,27 @@ def make_keys(self, keys, version=None): # Django cache api # #################### - def _add(self, client, key, value, timeout): - return self._set(key, value, timeout, client, _add_only=True) - - def add(self, key, value, timeout=None, version=None): - """ - Add a value to the cache, failing if the key already exists. + @get_client(write=True) + def add(self, client, key, value, timeout=None): + """Add a value to the cache, failing if the key already exists. Returns ``True`` if the object was added, ``False`` if not. """ - raise NotImplementedError + return self._set(client, key, self.prep_value(value), timeout, _add_only=True) + + @get_client() + def get(self, client, key, default=None): + """Retrieve a value from the cache. - def _get(self, client, key, default=None): + Returns deserialized value if key is found, the default if not. + """ value = client.get(key) if value is None: return default value = self.get_value(value) return value - def get(self, key, default=None, version=None): - """ - Retrieve a value from the cache. - - Returns unpickled value if key is found, the default if not. - """ - raise NotImplementedError - - def __set(self, client, key, value, timeout, _add_only=False): + def _set(self, client, key, value, timeout, _add_only=False): if timeout is None or timeout == 0: if _add_only: return client.setnx(key, value) @@ -213,36 +253,24 @@ def __set(self, client, key, value, timeout, _add_only=False): else: return False - def _set(self, key, value, timeout=DEFAULT_TIMEOUT, client=None, _add_only=False): - """ - Persist a value to the cache, and set an optional expiration time. + @get_client(write=True) + def set(self, client, key, value, timeout=DEFAULT_TIMEOUT): + """Persist a value to the cache, and set an optional expiration time. """ if timeout is DEFAULT_TIMEOUT: timeout = self.default_timeout + if timeout is not None: timeout = int(timeout) - # If ``value`` is not an int, then pickle it - if not isinstance(value, int) or isinstance(value, bool): - result = self.__set(client, key, pickle.dumps(value), timeout, _add_only) - else: - result = self.__set(client, key, value, timeout, _add_only) - # result is a boolean - return result - def set(self, key, value, timeout=None, version=None, client=None): - """ - Persist a value to the cache, and set an optional expiration time. - """ - raise NotImplementedError() + result = self._set(client, key, self.prep_value(value), timeout, _add_only=False) - def _delete(self, client, key): - return client.delete(key) + return result - def delete(self, key, version=None): - """ - Remove a key from the cache. - """ - raise NotImplementedError + @get_client(write=True) + def delete(self, client, key): + """Remove a key from the cache.""" + return client.delete(key) def _delete_many(self, client, keys): return client.delete(*keys) @@ -257,8 +285,7 @@ def _clear(self, client): return client.flushdb() def clear(self, version=None): - """ - Flush cache keys. + """Flush cache keys. If version is specified, all keys belonging the version's key namespace will be deleted. Otherwise, all keys will be deleted. @@ -266,9 +293,6 @@ def clear(self, version=None): raise NotImplementedError def _get_many(self, client, original_keys, versioned_keys): - """ - Retrieve many keys. - """ recovered_data = {} map_keys = dict(zip(versioned_keys, original_keys)) @@ -282,18 +306,14 @@ def _get_many(self, client, original_keys, versioned_keys): return recovered_data def get_many(self, keys, version=None): + """Retrieve many keys.""" raise NotImplementedError def _set_many(self, client, data): - new_data = {} - for key, value in data.items(): - new_data[key] = self.prep_value(value) - - return client.mset(new_data) + return client.mset(data) def set_many(self, data, timeout=None, version=None): - """ - Set a bunch of values in the cache at once from a dict of key/value + """Set a bunch of values in the cache at once from a dict of key/value pairs. This is much more efficient than calling set() multiple times. If timeout is given, that timeout will be used for the key; otherwise @@ -301,37 +321,32 @@ def set_many(self, data, timeout=None, version=None): """ raise NotImplementedError - def _incr(self, client, key, delta=1): + @get_client(write=True) + def incr(self, client, key, delta=1): + """Add delta to value in the cache. If the key does not exist, raise a + `ValueError` exception. + """ exists = client.exists(key) if not exists: raise ValueError("Key '%s' not found" % key) try: value = client.incr(key, delta) except redis.ResponseError: - value = self._get(client, key) + delta - self._set(client, key, value, timeout=None) + key = key._original_key + value = self.get(key) + delta + self.set(key, value, timeout=None) return value - def incr(self, key, delta=1, version=None): - """ - Add delta to value in the cache. If the key does not exist, raise a - ValueError exception. - """ - raise NotImplementedError - def _incr_version(self, client, old, new, delta, version): try: client.rename(old, new) except redis.ResponseError: raise ValueError("Key '%s' not found" % old._original_key) - return version + delta def incr_version(self, key, delta=1, version=None): - """ - Adds delta to the cache version for the supplied key. Returns the + """Adds delta to the cache version for the supplied key. Returns the new version. - """ raise NotImplementedError @@ -339,28 +354,22 @@ def incr_version(self, key, delta=1, version=None): # Extra api methods # ##################### - def _has_key(self, client, key, version=None): + @get_client() + def has_key(self, client, key): """Returns True if the key is in the cache and has not expired.""" - key = self.make_key(key, version=version) return client.exists(key) - def has_key(self, key, version=None): - raise NotImplementedError - - def _ttl(self, client, key): - """ - Returns the 'time-to-live' of a key. If the key is not volitile, i.e. - it has not set expiration, then the value returned is None. Otherwise, - the value is the number of seconds remaining. If the key does not exist, - 0 is returned. + @get_client() + def ttl(self, client, key): + """Returns the 'time-to-live' of a key. If the key is not volitile, + i.e. it has not set expiration, then the value returned is None. + Otherwise, the value is the number of seconds remaining. If the key + does not exist, 0 is returned. """ if client.exists(key): return client.ttl(key) return 0 - def ttl(self, key, version=None): - raise NotImplementedError - def _delete_pattern(self, client, pattern): keys = client.keys(pattern) if len(keys): @@ -369,26 +378,24 @@ def _delete_pattern(self, client, pattern): def delete_pattern(self, pattern, version=None): raise NotImplementedError - def _get_or_set(self, client, key, func, timeout=None): + @get_client(write=True) + def get_or_set(self, client, key, func, timeout=None): if not callable(func): - raise Exception("func must be a callable") + raise Exception("Must pass in a callable") dogpile_lock_key = "_lock" + key._versioned_key dogpile_lock = client.get(dogpile_lock_key) if dogpile_lock is None: - self._set(dogpile_lock_key, 0, None, client) + self.set(dogpile_lock_key, 0, None) value = func() - self.__set(client, key, self.prep_value(value), None) - self.__set(client, dogpile_lock_key, 0, timeout) + self._set(client, key, self.prep_value(value), None) + self._set(client, dogpile_lock_key, 0, timeout) else: - value = self._get(client, key) + value = self.get(key._original_key) return value - def get_or_set(self, key, func, timeout=None, version=None): - raise NotImplementedError - def _reinsert_keys(self, client): keys = client.keys('*') for key in keys: @@ -404,25 +411,21 @@ def reinsert_keys(self): """ raise NotImplementedError - def _persist(self, client, key, version=None): - if client.exists(key): - client.persist(key) + @get_client(write=True) + def persist(self, client, key): + """Remove the timeout on a key. - def persist(self, key): - """ - Remove the timeout on a key. Equivalent to setting a timeout - of None in a set command. - """ - raise NotImplementedError + Equivalent to setting a timeout of None in a set command. - def _expire(self, client, key, timeout, version=None): - if client.exists(key): - client.expire(key, timeout) + Returns True if successful and False if not. + """ + return client.persist(key) - def expire(self, key, timeout): + @get_client(write=True) + def expire(self, client, key, timeout): """ Set the expire time on a key - Will raise an exception if the key does not exist + returns True if successful and False if not. """ - raise NotImplementedError + return client.expire(key, timeout) diff --git a/redis_cache/backends/multiple.py b/redis_cache/backends/multiple.py index bc40eb13..d8bb21d1 100644 --- a/redis_cache/backends/multiple.py +++ b/redis_cache/backends/multiple.py @@ -11,98 +11,38 @@ class ShardedRedisCache(BaseRedisCache): def __init__(self, server, params): super(ShardedRedisCache, self).__init__(server, params) - self._params = params - self._server = server - self._pickle_version = None - self.__master_client = None - self.clients = {} self.sharder = HashRing() - if not isinstance(server, (list, tuple)): - servers = [server] - else: - servers = server - - for server in servers: + for server in self.servers: client = self.create_client(server) self.clients[client.connection_pool.connection_identifier] = client self.sharder.add(client.connection_pool.connection_identifier) - @property - def master_client(self): - """ - Get the write server:port of the master cache - """ - if not hasattr(self, '_master_client') and self.__master_client is None: - cache = self.options.get('MASTER_CACHE', None) - if cache is None: - self._master_client = None - else: - self._master_client = self.create_client(cache) - return self._master_client - - def get_client(self, key, for_write=False): - if for_write and self.master_client is not None: - return self.master_client + self.client_list = self.clients.values() + + + def get_client(self, key, write=False): node = self.sharder.get_node(unicode(key)) return self.clients[node] - def shard(self, keys, for_write=False, version=None): + def shard(self, keys, write=False, version=None): """ Returns a dict of keys that belong to a cache's keyspace. """ clients = defaultdict(list) for key in keys: - clients[self.get_client(key, for_write)].append(self.make_key(key, version)) + clients[self.get_client(key, write)].append(self.make_key(key, version)) return clients #################### # Django cache api # #################### - def add(self, key, value, timeout=None, version=None): - """ - Add a value to the cache, failing if the key already exists. - - Returns ``True`` if the object was added, ``False`` if not. - """ - client = self.get_client(key) - key = self.make_key(key, version=version) - return self._add(client, key, value, timeout) - - def get(self, key, default=None, version=None): - """ - Retrieve a value from the cache. - - Returns unpickled value if key is found, the default if not. - """ - client = self.get_client(key) - key = self.make_key(key, version=version) - - return self._get(client, key, default) - - def set(self, key, value, timeout=None, version=None, client=None): - """ - Persist a value to the cache, and set an optional expiration time. - """ - if client is None: - client = self.get_client(key, for_write=True) - key = self.make_key(key, version=version) - return self._set(key, value, timeout, client=client) - - def delete(self, key, version=None): - """ - Remove a key from the cache. - """ - client = self.get_client(key, for_write=True) - key = self.make_key(key, version=version) - return self._delete(client, key) - def delete_many(self, keys, version=None): """ Remove multiple keys at once. """ - clients = self.shard(keys, for_write=True, version=version) + clients = self.shard(keys, write=True, version=version) for client, keys in clients.items(): self._delete_many(client, keys) @@ -114,11 +54,8 @@ def clear(self, version=None): namespace will be deleted. Otherwise, all keys will be deleted. """ if version is None: - if self.master_client is None: - for client in self.clients.itervalues(): - self._clear(client) - else: - self._clear(self.master_client) + for client in self.clients.itervalues(): + self._clear(client) else: self.delete_pattern('*', version=version) @@ -138,31 +75,23 @@ def set_many(self, data, timeout=None, version=None): If timeout is given, that timeout will be used for the key; otherwise the default cache timeout will be used. """ - clients = self.shard(data.keys(), for_write=True, version=version) + clients = self.shard(data.keys(), write=True, version=version) if timeout is None: for client, keys in clients.items(): subset = {} for key in keys: - subset[key] = data[key._original_key] + subset[key] = self.prep_value(data[key._original_key]) self._set_many(client, subset) return for client, keys in clients.items(): pipeline = client.pipeline() for key in keys: - self._set(key, data[key._original_key], timeout, client=pipeline) + value = self.prep_value(data[key._original_key]) + self._set(pipeline, key, value, timeout) pipeline.execute() - def incr(self, key, delta=1, version=None): - """ - Add delta to value in the cache. If the key does not exist, raise a - ValueError exception. - """ - client = self.get_client(key, for_write=True) - key = self.make_key(key, version=version) - return self._incr(client, key, delta=delta) - def incr_version(self, key, delta=1, version=None): """ Adds delta to the cache version for the supplied key. Returns the @@ -172,7 +101,7 @@ def incr_version(self, key, delta=1, version=None): if version is None: version = self.version - client = self.get_client(key, for_write=True) + client = self.get_client(key, write=True) old = self.make_key(key, version=version) new = self.make_key(key, version=version + delta) @@ -182,27 +111,10 @@ def incr_version(self, key, delta=1, version=None): # Extra api methods # ##################### - def has_key(self, key, version=None): - client = self.get_client(key, for_write=False) - return self._has_key(client, key, version) - - def ttl(self, key, version=None): - client = self.get_client(key, for_write=False) - key = self.make_key(key, version=version) - return self._ttl(client, key) - def delete_pattern(self, pattern, version=None): pattern = self.make_key(pattern, version=version) - if self.master_client is None: - for client in self.clients.itervalues(): - self._delete_pattern(client, pattern) - else: - self._delete_pattern(self.master_client, pattern) - - def get_or_set(self, key, func, timeout=None, version=None): - client = self.get_client(key, for_write=True) - key = self.make_key(key, version=version) - return self._get_or_set(client, key, func, timeout) + for client in self.clients.itervalues(): + self._delete_pattern(client, pattern) def reinsert_keys(self): """ @@ -210,14 +122,3 @@ def reinsert_keys(self): """ for client in self.clients.itervalues(): self._reinsert_keys(client) - print - - def persist(self, key, version=None): - client = self.get_client(key, for_write=True) - key = self.make_key(key, version=version) - self._persist(client, key, version) - - def expire(self, key, timeout, version=None): - client = self.get_client(key, for_write=True) - key = self.make_key(key, version=version) - self._expire(client, key, timeout, version) diff --git a/redis_cache/backends/single.py b/redis_cache/backends/single.py index cb16aa35..fee1c837 100644 --- a/redis_cache/backends/single.py +++ b/redis_cache/backends/single.py @@ -2,6 +2,7 @@ import cPickle as pickle except ImportError: import pickle +import random from redis_cache.backends.base import BaseRedisCache from redis_cache.compat import bytes_type, DEFAULT_TIMEOUT @@ -15,75 +16,41 @@ def __init__(self, server, params): """ super(RedisCache, self).__init__(server, params) - if not isinstance(server, bytes_type): - self._server, = server + for server in self.servers: + client = self.create_client(server) + self.clients[client.connection_pool.connection_identifier] = client - self.client = self.create_client(server) - self.clients = { - self.client.connection_pool.connection_identifier: self.client - } + self.client_list = self.clients.values() + self.master_client = self.get_master_client() - def get_client(self, *args): - return self.client + def get_client(self, key, write=False): + if write and self.master_client is not None: + return self.master_client + return random.choice(self.client_list) #################### # Django cache api # #################### - def add(self, key, value, timeout=None, version=None): - """ - Add a value to the cache, failing if the key already exists. - - Returns ``True`` if the object was added, ``False`` if not. - """ - key = self.make_key(key, version=version) - return self._add(self.client, key, value, timeout) - - def get(self, key, default=None, version=None): - """ - Retrieve a value from the cache. - - Returns unpickled value if key is found, the default if not. - """ - key = self.make_key(key, version=version) - return self._get(self.client, key, default) - - def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None, client=None): - """ - Persist a value to the cache, and set an optional expiration time. - """ - key = self.make_key(key, version=version) - return self._set(key, value, timeout, client=self.client) - - def delete(self, key, version=None): - """ - Remove a key from the cache. - """ - key = self.make_key(key, version=version) - return self._delete(self.client, key) - def delete_many(self, keys, version=None): - """ - Remove multiple keys at once. - """ + """Remove multiple keys at once.""" versioned_keys = self.make_keys(keys, version=version) - self._delete_many(self.client, versioned_keys) + self._delete_many(self.master_client, versioned_keys) def clear(self, version=None): - """ - Flush cache keys. + """Flush cache keys. If version is specified, all keys belonging the version's key namespace will be deleted. Otherwise, all keys will be deleted. """ if version is None: - self._clear(self.client) + self._clear(self.master_client) else: self.delete_pattern('*', version=version) def get_many(self, keys, version=None): versioned_keys = self.make_keys(keys, version=version) - return self._get_many(self.client, keys, versioned_keys=versioned_keys) + return self._get_many(self.master_client, keys, versioned_keys=versioned_keys) def set_many(self, data, timeout=None, version=None): """ @@ -97,22 +64,15 @@ def set_many(self, data, timeout=None, version=None): if timeout is None: new_data = {} for key in versioned_keys: - new_data[key] = data[key._original_key] - return self._set_many(self.client, new_data) + new_data[key] = self.prep_value(data[key._original_key]) + return self._set_many(self.master_client, new_data) - pipeline = self.client.pipeline() + pipeline = self.master_client.pipeline() for key in versioned_keys: - self._set(key, data[key._original_key], timeout, client=pipeline) + value = self.prep_value(data[key._original_key]) + self._set(pipeline, key, value, timeout) pipeline.execute() - def incr(self, key, delta=1, version=None): - """ - Add delta to value in the cache. If the key does not exist, raise a - ValueError exception. - """ - key = self.make_key(key, version=version) - return self._incr(self.client, key, delta=delta) - def incr_version(self, key, delta=1, version=None): """ Adds delta to the cache version for the supplied key. Returns the @@ -125,37 +85,18 @@ def incr_version(self, key, delta=1, version=None): old = self.make_key(key, version) new = self.make_key(key, version=version + delta) - return self._incr_version(self.client, old, new, delta, version) + return self._incr_version(self.master_client, old, new, delta, version) ##################### # Extra api methods # ##################### - def has_key(self, key, version=None): - return self._has_key(self.client, key, version) - - def ttl(self, key, version=None): - key = self.make_key(key, version=version) - return self._ttl(self.client, key) - def delete_pattern(self, pattern, version=None): pattern = self.make_key(pattern, version=version) - self._delete_pattern(self.client, pattern) - - def get_or_set(self, key, func, timeout=None, version=None): - key = self.make_key(key, version=version) - return self._get_or_set(self.client, key, func, timeout) + self._delete_pattern(self.master_client, pattern) def reinsert_keys(self): """ Reinsert cache entries using the current pickle protocol version. """ - self._reinsert_keys(self.client) - - def persist(self, key, version=None): - key = self.make_key(key, version=version) - self._persist(self.client, key, version) - - def expire(self, key, timeout, version=None): - key = self.make_key(key, version=version) - self._expire(self.client, key, timeout, version) + self._reinsert_keys(self.master_client) diff --git a/setup.py b/setup.py index 81b3cf88..8e6ff989 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ url="http://github.com/sebleier/django-redis-cache/", author="Sean Bleier", author_email="sebleier@gmail.com", - version="1.0.1", + version="1.1.0", packages=["redis_cache", "redis_cache.backends"], description="Redis Cache Backend for Django", install_requires=['redis>=2.4.5'], diff --git a/tests/testapp/tests/master_slave_tests.py b/tests/testapp/tests/master_slave_tests.py index e25af27e..e9e264d6 100644 --- a/tests/testapp/tests/master_slave_tests.py +++ b/tests/testapp/tests/master_slave_tests.py @@ -21,7 +21,7 @@ @override_settings(CACHES={ 'default': { - 'BACKEND': 'redis_cache.ShardedRedisCache', + 'BACKEND': 'redis_cache.RedisCache', 'LOCATION': LOCATIONS, 'OPTIONS': { 'DB': 1,