diff --git a/README.md b/README.md index 63a8343c..34ad8a0f 100644 --- a/README.md +++ b/README.md @@ -141,8 +141,8 @@ cache.store("What is the capital of France?", "Paris") cache.check("What is the capital of France?") ["Paris"] -# Cache will return the result if the query is similar enough -cache.get("What really is the capital of France?") +# Cache will still return the result if the query is similar enough +cache.check("What really is the capital of France?") ["Paris"] ``` diff --git a/redisvl/index.py b/redisvl/index.py index 6be401de..14c99bb9 100644 --- a/redisvl/index.py +++ b/redisvl/index.py @@ -255,15 +255,21 @@ def load(self, data: Iterable[Dict[str, Any]], **kwargs): raises: redis.exceptions.ResponseError: If the index does not exist """ - if not data: - return - if not isinstance(data, Iterable): - if not isinstance(data[0], dict): - raise TypeError("data must be an iterable of dictionaries") - - for record in data: - key = f"{self._prefix}:{self._get_key_field(record)}" - self._redis_conn.hset(key, mapping=record) # type: ignore + # TODO -- should we return a count of the upserts? or some kind of metadata? + if data: + if not isinstance(data, Iterable): + if not isinstance(data[0], dict): + raise TypeError("data must be an iterable of dictionaries") + + # Check if outer interface passes in TTL on load + ttl = kwargs.get("ttl") + pipe = self._redis_conn.pipeline(transaction=False) + for record in data: + key = f"{self._prefix}:{self._get_key_field(record)}" + pipe.hset(key, mapping=record) # type: ignore + if ttl: + pipe.expire(key, ttl) + pipe.execute() @check_connected("_redis_conn") def exists(self) -> bool: @@ -338,7 +344,7 @@ async def delete(self, drop: bool = True): await self._redis_conn.ft(self._name).dropindex(delete_documents=drop) # type: ignore @check_connected("_redis_conn") - async def load(self, data: Iterable[Dict[str, Any]], concurrency: int = 10): + async def load(self, data: Iterable[Dict[str, Any]], concurrency: int = 10, **kwargs): """Load data into Redis and index using this SearchIndex object Args: @@ -348,15 +354,18 @@ async def load(self, data: Iterable[Dict[str, Any]], concurrency: int = 10): raises: redis.exceptions.ResponseError: If the index does not exist """ + ttl = kwargs.get("ttl") semaphore = asyncio.Semaphore(concurrency) - async def load(d: dict): + async def _load(d: dict): async with semaphore: key = f"{self._prefix}:{self._get_key_field(d)}" await self._redis_conn.hset(key, mapping=d) # type: ignore + if ttl: + await self._redis_conn.expire(key, ttl) # gather with concurrency - await asyncio.gather(*[load(d) for d in data]) + await asyncio.gather(*[_load(d) for d in data]) @check_connected("_redis_conn") async def exists(self) -> bool: diff --git a/redisvl/llmcache/base.py b/redisvl/llmcache/base.py index c0223ddd..e0064ea7 100644 --- a/redisvl/llmcache/base.py +++ b/redisvl/llmcache/base.py @@ -5,6 +5,10 @@ class BaseLLMCache: verbose: bool = True + def clear(self): + """Clear the LLMCache and create a new underlying index.""" + raise NotImplementedError + def check(self, prompt: str) -> Optional[List[str]]: raise NotImplementedError diff --git a/redisvl/llmcache/semantic.py b/redisvl/llmcache/semantic.py index 5d20bf75..390ad7bb 100644 --- a/redisvl/llmcache/semantic.py +++ b/redisvl/llmcache/semantic.py @@ -114,6 +114,17 @@ def set_threshold(self, threshold: float): raise ValueError("Threshold must be between 0 and 1.") self._threshold = float(threshold) + def clear(self): + """Clear the LLMCache of all keys in the index""" + client = self._index.client + if client: + pipe = client.pipeline() + for key in client.scan_iter(match=f"{self._index._prefix}:*"): + pipe.delete(key) + pipe.execute() + else: + raise RuntimeError("LLMCache is not connected to a Redis instance.") + def check( self, prompt: Optional[str] = None, @@ -153,9 +164,9 @@ def check( cache_hits = [] for doc in results.docs: - self._refresh_ttl(doc.id) sim = similarity(doc.vector_distance) if sim > self.threshold: + self._refresh_ttl(doc.id) cache_hits.append(doc.response) return cache_hits @@ -179,18 +190,23 @@ def store( Raises: ValueError: If neither prompt nor vector is specified. """ + # Prepare LLMCache inputs if not key: key = self.hash_input(prompt) - if vector: - vector = array_to_buffer(vector) - else: + if not vector: vector = self._provider.embed(prompt) # type: ignore - payload = {"id": key, "prompt_vector": vector, "response": response} + payload = { + "id": key, + "prompt_vector": array_to_buffer(vector), + "response": response + } if metadata: payload.update(metadata) - self._index.load([payload]) + + # Load LLMCache entry with TTL + self._index.load([payload], ttl=self._ttl) def _refresh_ttl(self, key: str): """Refreshes the TTL for the specified key.""" @@ -201,6 +217,5 @@ def _refresh_ttl(self, key: str): else: raise RuntimeError("LLMCache is not connected to a Redis instance.") - def similarity(distance: Union[float, str]) -> float: return 1 - float(distance) diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 3f8f0820..f5656bc3 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -1,5 +1,6 @@ import pytest +from time import sleep from redisvl.llmcache.semantic import SemanticCache from redisvl.providers import HuggingfaceProvider @@ -8,18 +9,20 @@ def provider(): return HuggingfaceProvider("sentence-transformers/all-mpnet-base-v2") - @pytest.fixture def cache(provider): return SemanticCache(provider=provider, threshold=0.8) +@pytest.fixture +def cache_with_ttl(provider): + return SemanticCache(provider=provider, threshold=0.8, ttl=2) @pytest.fixture def vector(provider): return provider.embed("This is a test sentence.") -def test_store_and_check(cache, vector): +def test_store_and_check_and_clear(cache, vector): # Check that we can store and retrieve a response prompt = "This is a test prompt." response = "This is a test response." @@ -27,16 +30,27 @@ def test_store_and_check(cache, vector): check_result = cache.check(vector=vector) assert len(check_result) >= 1 assert response in check_result - cache.index.delete(drop=True) + cache.clear() + check_result = cache.check(vector=vector) + assert len(check_result) == 0 + cache._index.delete(True) +def test_ttl(cache_with_ttl, vector): + # Check that TTL expiration kicks in after 2 seconds + prompt = "This is a test prompt." + response = "This is a test response." + cache_with_ttl.store(prompt, response, vector=vector) + sleep(3) + check_result = cache_with_ttl.check(vector=vector) + assert len(check_result) == 0 + cache_with_ttl._index.delete(True) def test_check_no_match(cache, vector): # Check behavior when there is no match in the cache # In this case, we're using a vector, but the cache is empty check_result = cache.check(vector=vector) assert len(check_result) == 0 - cache.index.delete(drop=True) - + cache._index.delete(True) def test_store_with_vector_and_metadata(cache, vector): # Test storing a response with a vector and metadata @@ -44,12 +58,26 @@ def test_store_with_vector_and_metadata(cache, vector): response = "This is another test response." metadata = {"source": "test"} cache.store(prompt, response, vector=vector, metadata=metadata) - cache.index.delete(drop=True) - + check_result = cache.check(vector=vector) + assert len(check_result) >= 1 + assert response in check_result + cache._index.delete(True) def test_set_threshold(cache): # Test the getter and setter for the threshold assert cache.threshold == 0.8 cache.set_threshold(0.9) assert cache.threshold == 0.9 - cache.index.delete(drop=True) + cache._index.delete(True) + +def test_from_existing(cache, vector, provider): + prompt = "This is another test prompt." + response = "This is another test response." + metadata = {"source": "test"} + cache.store(prompt, response, vector=vector, metadata=metadata) + # connect from existing? + new_cache = SemanticCache(provider=provider, threshold=0.8) + check_result = new_cache.check(vector=vector) + assert len(check_result) >= 1 + assert response in check_result + new_cache._index.delete(True) \ No newline at end of file