From 15fcb705d4374cfa539fcd26c6fb48cab6d180c6 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Mon, 31 Jul 2023 12:48:38 -0400 Subject: [PATCH 01/11] refresh tll logic --- redisvl/llmcache/semantic.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/redisvl/llmcache/semantic.py b/redisvl/llmcache/semantic.py index 5d20bf75..9432200c 100644 --- a/redisvl/llmcache/semantic.py +++ b/redisvl/llmcache/semantic.py @@ -195,9 +195,8 @@ def store( def _refresh_ttl(self, key: str): """Refreshes the TTL for the specified key.""" client = self._index.client - if client: - if self.ttl: - client.expire(key, self.ttl) + if client and self.ttl: + client.expire(key, self.ttl) else: raise RuntimeError("LLMCache is not connected to a Redis instance.") From 7d72775d67f0da8fd172daa41b58891c62c9336a Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Mon, 31 Jul 2023 12:49:35 -0400 Subject: [PATCH 02/11] only refresh ttl if KNN matches are similar enough --- redisvl/llmcache/semantic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redisvl/llmcache/semantic.py b/redisvl/llmcache/semantic.py index 9432200c..e771f3d2 100644 --- a/redisvl/llmcache/semantic.py +++ b/redisvl/llmcache/semantic.py @@ -153,9 +153,9 @@ def check( cache_hits = [] for doc in results.docs: - self._refresh_ttl(doc.id) sim = similarity(doc.vector_distance) if sim > self.threshold: + self._refresh_ttl(doc.id) cache_hits.append(doc.response) return cache_hits From b202b93c94c7870acaa0e95722e3134b79f2ebb3 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Mon, 31 Jul 2023 12:54:55 -0400 Subject: [PATCH 03/11] convert vector to bytes string on store --- redisvl/llmcache/semantic.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/redisvl/llmcache/semantic.py b/redisvl/llmcache/semantic.py index e771f3d2..611dd482 100644 --- a/redisvl/llmcache/semantic.py +++ b/redisvl/llmcache/semantic.py @@ -182,12 +182,10 @@ def store( if not key: key = self.hash_input(prompt) - if vector: - vector = array_to_buffer(vector) - else: + if not vector: vector = self._provider.embed(prompt) # type: ignore - payload = {"id": key, "prompt_vector": vector, "response": response} + payload = {"id": key, "prompt_vector": array_to_buffer(vector), "response": response} if metadata: payload.update(metadata) self._index.load([payload]) From 9dd752d20655c276cfe6f970df0120bf537120d0 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Mon, 31 Jul 2023 13:00:26 -0400 Subject: [PATCH 04/11] adjust index load --- redisvl/index.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/redisvl/index.py b/redisvl/index.py index 6be401de..beb9cf16 100644 --- a/redisvl/index.py +++ b/redisvl/index.py @@ -254,16 +254,18 @@ def load(self, data: Iterable[Dict[str, Any]], **kwargs): containing the data to be indexed raises: redis.exceptions.ResponseError: If the index does not exist + # TODO -- should we return a count of the upserts? or some kind of metadata? """ - if not data: - return - if not isinstance(data, Iterable): - if not isinstance(data[0], dict): - raise TypeError("data must be an iterable of dictionaries") + if data: + if not isinstance(data, Iterable): + if not isinstance(data[0], dict): + raise TypeError("data must be an iterable of dictionaries") - for record in data: - key = f"{self._prefix}:{self._get_key_field(record)}" - self._redis_conn.hset(key, mapping=record) # type: ignore + pipe = self._redis_conn.pipeline(transaction=False) + for record in data: + key = f"{self._prefix}:{self._get_key_field(record)}" + pipe.hset(key, mapping=record) # type: ignore + pipe.execute() @check_connected("_redis_conn") def exists(self) -> bool: @@ -350,13 +352,13 @@ async def load(self, data: Iterable[Dict[str, Any]], concurrency: int = 10): """ semaphore = asyncio.Semaphore(concurrency) - async def load(d: dict): + async def _load(d: dict): async with semaphore: key = f"{self._prefix}:{self._get_key_field(d)}" await self._redis_conn.hset(key, mapping=d) # type: ignore # gather with concurrency - await asyncio.gather(*[load(d) for d in data]) + await asyncio.gather(*[_load(d) for d in data]) @check_connected("_redis_conn") async def exists(self) -> bool: From 5bd6706c07cab6efc1f40a7284ba83a26ea92c06 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Mon, 31 Jul 2023 13:05:45 -0400 Subject: [PATCH 05/11] set llmcache key ttl through load method --- redisvl/index.py | 4 ++++ redisvl/llmcache/semantic.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/redisvl/index.py b/redisvl/index.py index beb9cf16..fcbd5216 100644 --- a/redisvl/index.py +++ b/redisvl/index.py @@ -261,10 +261,14 @@ def load(self, data: Iterable[Dict[str, Any]], **kwargs): if not isinstance(data[0], dict): raise TypeError("data must be an iterable of dictionaries") + # Check if outer interface passes in TTL on load + ttl = kwargs.get("ttl") pipe = self._redis_conn.pipeline(transaction=False) for record in data: key = f"{self._prefix}:{self._get_key_field(record)}" pipe.hset(key, mapping=record) # type: ignore + if ttl: + pipe.expire(key, ttl) pipe.execute() @check_connected("_redis_conn") diff --git a/redisvl/llmcache/semantic.py b/redisvl/llmcache/semantic.py index 611dd482..d1382e8a 100644 --- a/redisvl/llmcache/semantic.py +++ b/redisvl/llmcache/semantic.py @@ -179,16 +179,23 @@ def store( Raises: ValueError: If neither prompt nor vector is specified. """ + # Prepare LLMCache inputs if not key: key = self.hash_input(prompt) if not vector: vector = self._provider.embed(prompt) # type: ignore - payload = {"id": key, "prompt_vector": array_to_buffer(vector), "response": response} + payload = { + "id": key, + "prompt_vector": array_to_buffer(vector), + "response": response + } if metadata: payload.update(metadata) - self._index.load([payload]) + + # Load LLMCache entry with TTL + self._index.load([payload], ttl=self._ttl) def _refresh_ttl(self, key: str): """Refreshes the TTL for the specified key.""" From 4f4b64d9b363100aecd6339bdb0a2fb0c9f813c2 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Mon, 31 Jul 2023 13:21:17 -0400 Subject: [PATCH 06/11] add method to clear LLMCache --- README.md | 4 ++-- redisvl/llmcache/base.py | 4 ++++ redisvl/llmcache/semantic.py | 4 ++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 63a8343c..34ad8a0f 100644 --- a/README.md +++ b/README.md @@ -141,8 +141,8 @@ cache.store("What is the capital of France?", "Paris") cache.check("What is the capital of France?") ["Paris"] -# Cache will return the result if the query is similar enough -cache.get("What really is the capital of France?") +# Cache will still return the result if the query is similar enough +cache.check("What really is the capital of France?") ["Paris"] ``` diff --git a/redisvl/llmcache/base.py b/redisvl/llmcache/base.py index c0223ddd..e0064ea7 100644 --- a/redisvl/llmcache/base.py +++ b/redisvl/llmcache/base.py @@ -5,6 +5,10 @@ class BaseLLMCache: verbose: bool = True + def clear(self): + """Clear the LLMCache and create a new underlying index.""" + raise NotImplementedError + def check(self, prompt: str) -> Optional[List[str]]: raise NotImplementedError diff --git a/redisvl/llmcache/semantic.py b/redisvl/llmcache/semantic.py index d1382e8a..7dfe1d52 100644 --- a/redisvl/llmcache/semantic.py +++ b/redisvl/llmcache/semantic.py @@ -114,6 +114,10 @@ def set_threshold(self, threshold: float): raise ValueError("Threshold must be between 0 and 1.") self._threshold = float(threshold) + def clear(self): + """Clear the LLMCache and create a new underlying index.""" + self._index.create(overwrite=True) + def check( self, prompt: Optional[str] = None, From 040f6002dfd910fb2fdcd975d85975a3a5361e51 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Mon, 31 Jul 2023 13:55:38 -0400 Subject: [PATCH 07/11] revert logic updated --- redisvl/llmcache/semantic.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/redisvl/llmcache/semantic.py b/redisvl/llmcache/semantic.py index 7dfe1d52..0e82c80b 100644 --- a/redisvl/llmcache/semantic.py +++ b/redisvl/llmcache/semantic.py @@ -204,8 +204,9 @@ def store( def _refresh_ttl(self, key: str): """Refreshes the TTL for the specified key.""" client = self._index.client - if client and self.ttl: - client.expire(key, self.ttl) + if client: + if self.ttl: + client.expire(key, self.ttl) else: raise RuntimeError("LLMCache is not connected to a Redis instance.") From 5730ea8c4943149d779d73a5a72bdfb2b61019f8 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Mon, 31 Jul 2023 14:30:28 -0400 Subject: [PATCH 08/11] adjust llmcache tests --- redisvl/llmcache/semantic.py | 2 +- tests/integration/test_llmcache.py | 39 ++++++++++++++++++++++++------ 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/redisvl/llmcache/semantic.py b/redisvl/llmcache/semantic.py index 0e82c80b..aaac44b7 100644 --- a/redisvl/llmcache/semantic.py +++ b/redisvl/llmcache/semantic.py @@ -116,7 +116,7 @@ def set_threshold(self, threshold: float): def clear(self): """Clear the LLMCache and create a new underlying index.""" - self._index.create(overwrite=True) + self._index.delete(drop=True) def check( self, diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 3f8f0820..4e55ace6 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -1,5 +1,6 @@ import pytest +from time import sleep from redisvl.llmcache.semantic import SemanticCache from redisvl.providers import HuggingfaceProvider @@ -8,11 +9,13 @@ def provider(): return HuggingfaceProvider("sentence-transformers/all-mpnet-base-v2") - @pytest.fixture def cache(provider): return SemanticCache(provider=provider, threshold=0.8) +@pytest.fixture +def cache_with_ttl(provider): + return SemanticCache(provider=provider, threshold=0.8, ttl=2) @pytest.fixture def vector(provider): @@ -27,16 +30,24 @@ def test_store_and_check(cache, vector): check_result = cache.check(vector=vector) assert len(check_result) >= 1 assert response in check_result - cache.index.delete(drop=True) + cache.clear() +def test_ttl(cache_with_ttl, vector): + # Check that TTL expiration kicks in after 2 seconds + prompt = "This is a test prompt." + response = "This is a test response." + cache_with_ttl.store(prompt, response, vector=vector) + sleep(3) + check_result = cache_with_ttl.check(vector=vector) + assert len(check_result) == 0 + cache_with_ttl.clear() def test_check_no_match(cache, vector): # Check behavior when there is no match in the cache # In this case, we're using a vector, but the cache is empty check_result = cache.check(vector=vector) assert len(check_result) == 0 - cache.index.delete(drop=True) - + cache.clear() def test_store_with_vector_and_metadata(cache, vector): # Test storing a response with a vector and metadata @@ -44,12 +55,26 @@ def test_store_with_vector_and_metadata(cache, vector): response = "This is another test response." metadata = {"source": "test"} cache.store(prompt, response, vector=vector, metadata=metadata) - cache.index.delete(drop=True) - + check_result = cache.check(vector=vector) + assert len(check_result) >= 1 + assert response in check_result + cache.clear() def test_set_threshold(cache): # Test the getter and setter for the threshold assert cache.threshold == 0.8 cache.set_threshold(0.9) assert cache.threshold == 0.9 - cache.index.delete(drop=True) + cache.clear() + +def test_from_existing(cache, vector, provider): + prompt = "This is another test prompt." + response = "This is another test response." + metadata = {"source": "test"} + cache.store(prompt, response, vector=vector, metadata=metadata) + # connect from existing? + new_cache = SemanticCache(provider=provider, threshold=0.8) + check_result = new_cache.check(vector=vector) + assert len(check_result) >= 1 + assert response in check_result + new_cache.clear() \ No newline at end of file From 634167cddd198e6640278c77653c3165c7e5238f Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Tue, 1 Aug 2023 14:13:49 -0400 Subject: [PATCH 09/11] update clear method and tests --- redisvl/llmcache/semantic.py | 12 +++++++++--- tests/integration/test_llmcache.py | 15 +++++++++------ 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/redisvl/llmcache/semantic.py b/redisvl/llmcache/semantic.py index aaac44b7..390ad7bb 100644 --- a/redisvl/llmcache/semantic.py +++ b/redisvl/llmcache/semantic.py @@ -115,8 +115,15 @@ def set_threshold(self, threshold: float): self._threshold = float(threshold) def clear(self): - """Clear the LLMCache and create a new underlying index.""" - self._index.delete(drop=True) + """Clear the LLMCache of all keys in the index""" + client = self._index.client + if client: + pipe = client.pipeline() + for key in client.scan_iter(match=f"{self._index._prefix}:*"): + pipe.delete(key) + pipe.execute() + else: + raise RuntimeError("LLMCache is not connected to a Redis instance.") def check( self, @@ -210,6 +217,5 @@ def _refresh_ttl(self, key: str): else: raise RuntimeError("LLMCache is not connected to a Redis instance.") - def similarity(distance: Union[float, str]) -> float: return 1 - float(distance) diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 4e55ace6..f5656bc3 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -22,7 +22,7 @@ def vector(provider): return provider.embed("This is a test sentence.") -def test_store_and_check(cache, vector): +def test_store_and_check_and_clear(cache, vector): # Check that we can store and retrieve a response prompt = "This is a test prompt." response = "This is a test response." @@ -31,6 +31,9 @@ def test_store_and_check(cache, vector): assert len(check_result) >= 1 assert response in check_result cache.clear() + check_result = cache.check(vector=vector) + assert len(check_result) == 0 + cache._index.delete(True) def test_ttl(cache_with_ttl, vector): # Check that TTL expiration kicks in after 2 seconds @@ -40,14 +43,14 @@ def test_ttl(cache_with_ttl, vector): sleep(3) check_result = cache_with_ttl.check(vector=vector) assert len(check_result) == 0 - cache_with_ttl.clear() + cache_with_ttl._index.delete(True) def test_check_no_match(cache, vector): # Check behavior when there is no match in the cache # In this case, we're using a vector, but the cache is empty check_result = cache.check(vector=vector) assert len(check_result) == 0 - cache.clear() + cache._index.delete(True) def test_store_with_vector_and_metadata(cache, vector): # Test storing a response with a vector and metadata @@ -58,14 +61,14 @@ def test_store_with_vector_and_metadata(cache, vector): check_result = cache.check(vector=vector) assert len(check_result) >= 1 assert response in check_result - cache.clear() + cache._index.delete(True) def test_set_threshold(cache): # Test the getter and setter for the threshold assert cache.threshold == 0.8 cache.set_threshold(0.9) assert cache.threshold == 0.9 - cache.clear() + cache._index.delete(True) def test_from_existing(cache, vector, provider): prompt = "This is another test prompt." @@ -77,4 +80,4 @@ def test_from_existing(cache, vector, provider): check_result = new_cache.check(vector=vector) assert len(check_result) >= 1 assert response in check_result - new_cache.clear() \ No newline at end of file + new_cache._index.delete(True) \ No newline at end of file From 04f0869f62dbe012fd79b7b2c80218175cafae5f Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Tue, 1 Aug 2023 14:17:41 -0400 Subject: [PATCH 10/11] add ttl load for async index too: --- redisvl/index.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/redisvl/index.py b/redisvl/index.py index fcbd5216..eb7482aa 100644 --- a/redisvl/index.py +++ b/redisvl/index.py @@ -254,6 +254,7 @@ def load(self, data: Iterable[Dict[str, Any]], **kwargs): containing the data to be indexed raises: redis.exceptions.ResponseError: If the index does not exist + # TODO -- should we return a count of the upserts? or some kind of metadata? """ if data: @@ -344,7 +345,7 @@ async def delete(self, drop: bool = True): await self._redis_conn.ft(self._name).dropindex(delete_documents=drop) # type: ignore @check_connected("_redis_conn") - async def load(self, data: Iterable[Dict[str, Any]], concurrency: int = 10): + async def load(self, data: Iterable[Dict[str, Any]], concurrency: int = 10, **kwargs): """Load data into Redis and index using this SearchIndex object Args: @@ -354,12 +355,15 @@ async def load(self, data: Iterable[Dict[str, Any]], concurrency: int = 10): raises: redis.exceptions.ResponseError: If the index does not exist """ + ttl = kwargs.get("ttl") semaphore = asyncio.Semaphore(concurrency) async def _load(d: dict): async with semaphore: key = f"{self._prefix}:{self._get_key_field(d)}" await self._redis_conn.hset(key, mapping=d) # type: ignore + if ttl: + await self._redis_conn.expire(key, ttl) # gather with concurrency await asyncio.gather(*[_load(d) for d in data]) From b15086a16d5914b1e33a9fa82b21b4109b51ef48 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Tue, 1 Aug 2023 16:54:36 -0400 Subject: [PATCH 11/11] fix todo in docstring --- redisvl/index.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/redisvl/index.py b/redisvl/index.py index eb7482aa..14c99bb9 100644 --- a/redisvl/index.py +++ b/redisvl/index.py @@ -254,9 +254,8 @@ def load(self, data: Iterable[Dict[str, Any]], **kwargs): containing the data to be indexed raises: redis.exceptions.ResponseError: If the index does not exist - - # TODO -- should we return a count of the upserts? or some kind of metadata? """ + # TODO -- should we return a count of the upserts? or some kind of metadata? if data: if not isinstance(data, Iterable): if not isinstance(data[0], dict):