Skip to content
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ cache.store("What is the capital of France?", "Paris")
cache.check("What is the capital of France?")
["Paris"]

# Cache will return the result if the query is similar enough
cache.get("What really is the capital of France?")
# Cache will still return the result if the query is similar enough
cache.check("What really is the capital of France?")
["Paris"]
```

Expand Down
33 changes: 21 additions & 12 deletions redisvl/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,15 +255,21 @@ def load(self, data: Iterable[Dict[str, Any]], **kwargs):
raises:
redis.exceptions.ResponseError: If the index does not exist
"""
if not data:
return
if not isinstance(data, Iterable):
if not isinstance(data[0], dict):
raise TypeError("data must be an iterable of dictionaries")

for record in data:
key = f"{self._prefix}:{self._get_key_field(record)}"
self._redis_conn.hset(key, mapping=record) # type: ignore
# TODO -- should we return a count of the upserts? or some kind of metadata?
if data:
if not isinstance(data, Iterable):
if not isinstance(data[0], dict):
raise TypeError("data must be an iterable of dictionaries")

# Check if outer interface passes in TTL on load
ttl = kwargs.get("ttl")
pipe = self._redis_conn.pipeline(transaction=False)
for record in data:
key = f"{self._prefix}:{self._get_key_field(record)}"
pipe.hset(key, mapping=record) # type: ignore
if ttl:
pipe.expire(key, ttl)
pipe.execute()

@check_connected("_redis_conn")
def exists(self) -> bool:
Expand Down Expand Up @@ -338,7 +344,7 @@ async def delete(self, drop: bool = True):
await self._redis_conn.ft(self._name).dropindex(delete_documents=drop) # type: ignore

@check_connected("_redis_conn")
async def load(self, data: Iterable[Dict[str, Any]], concurrency: int = 10):
async def load(self, data: Iterable[Dict[str, Any]], concurrency: int = 10, **kwargs):
"""Load data into Redis and index using this SearchIndex object

Args:
Expand All @@ -348,15 +354,18 @@ async def load(self, data: Iterable[Dict[str, Any]], concurrency: int = 10):
raises:
redis.exceptions.ResponseError: If the index does not exist
"""
ttl = kwargs.get("ttl")
semaphore = asyncio.Semaphore(concurrency)

async def load(d: dict):
async def _load(d: dict):
async with semaphore:
key = f"{self._prefix}:{self._get_key_field(d)}"
await self._redis_conn.hset(key, mapping=d) # type: ignore
if ttl:
await self._redis_conn.expire(key, ttl)

# gather with concurrency
await asyncio.gather(*[load(d) for d in data])
await asyncio.gather(*[_load(d) for d in data])

@check_connected("_redis_conn")
async def exists(self) -> bool:
Expand Down
4 changes: 4 additions & 0 deletions redisvl/llmcache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
class BaseLLMCache:
verbose: bool = True

def clear(self):
"""Clear the LLMCache and create a new underlying index."""
raise NotImplementedError

def check(self, prompt: str) -> Optional[List[str]]:
raise NotImplementedError

Expand Down
29 changes: 22 additions & 7 deletions redisvl/llmcache/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,17 @@ def set_threshold(self, threshold: float):
raise ValueError("Threshold must be between 0 and 1.")
self._threshold = float(threshold)

def clear(self):
"""Clear the LLMCache of all keys in the index"""
client = self._index.client
if client:
pipe = client.pipeline()
for key in client.scan_iter(match=f"{self._index._prefix}:*"):
pipe.delete(key)
pipe.execute()
else:
raise RuntimeError("LLMCache is not connected to a Redis instance.")

def check(
self,
prompt: Optional[str] = None,
Expand Down Expand Up @@ -153,9 +164,9 @@ def check(

cache_hits = []
for doc in results.docs:
self._refresh_ttl(doc.id)
sim = similarity(doc.vector_distance)
if sim > self.threshold:
self._refresh_ttl(doc.id)
cache_hits.append(doc.response)
return cache_hits

Expand All @@ -179,18 +190,23 @@ def store(
Raises:
ValueError: If neither prompt nor vector is specified.
"""
# Prepare LLMCache inputs
if not key:
key = self.hash_input(prompt)

if vector:
vector = array_to_buffer(vector)
else:
if not vector:
vector = self._provider.embed(prompt) # type: ignore

payload = {"id": key, "prompt_vector": vector, "response": response}
payload = {
"id": key,
"prompt_vector": array_to_buffer(vector),
"response": response
}
if metadata:
payload.update(metadata)
self._index.load([payload])

# Load LLMCache entry with TTL
self._index.load([payload], ttl=self._ttl)

def _refresh_ttl(self, key: str):
"""Refreshes the TTL for the specified key."""
Expand All @@ -201,6 +217,5 @@ def _refresh_ttl(self, key: str):
else:
raise RuntimeError("LLMCache is not connected to a Redis instance.")


def similarity(distance: Union[float, str]) -> float:
return 1 - float(distance)
44 changes: 36 additions & 8 deletions tests/integration/test_llmcache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from time import sleep
from redisvl.llmcache.semantic import SemanticCache
from redisvl.providers import HuggingfaceProvider

Expand All @@ -8,48 +9,75 @@
def provider():
return HuggingfaceProvider("sentence-transformers/all-mpnet-base-v2")


@pytest.fixture
def cache(provider):
return SemanticCache(provider=provider, threshold=0.8)

@pytest.fixture
def cache_with_ttl(provider):
return SemanticCache(provider=provider, threshold=0.8, ttl=2)

@pytest.fixture
def vector(provider):
return provider.embed("This is a test sentence.")


def test_store_and_check(cache, vector):
def test_store_and_check_and_clear(cache, vector):
# Check that we can store and retrieve a response
prompt = "This is a test prompt."
response = "This is a test response."
cache.store(prompt, response, vector=vector)
check_result = cache.check(vector=vector)
assert len(check_result) >= 1
assert response in check_result
cache.index.delete(drop=True)
cache.clear()
check_result = cache.check(vector=vector)
assert len(check_result) == 0
cache._index.delete(True)

def test_ttl(cache_with_ttl, vector):
# Check that TTL expiration kicks in after 2 seconds
prompt = "This is a test prompt."
response = "This is a test response."
cache_with_ttl.store(prompt, response, vector=vector)
sleep(3)
check_result = cache_with_ttl.check(vector=vector)
assert len(check_result) == 0
cache_with_ttl._index.delete(True)

def test_check_no_match(cache, vector):
# Check behavior when there is no match in the cache
# In this case, we're using a vector, but the cache is empty
check_result = cache.check(vector=vector)
assert len(check_result) == 0
cache.index.delete(drop=True)

cache._index.delete(True)

def test_store_with_vector_and_metadata(cache, vector):
# Test storing a response with a vector and metadata
prompt = "This is another test prompt."
response = "This is another test response."
metadata = {"source": "test"}
cache.store(prompt, response, vector=vector, metadata=metadata)
cache.index.delete(drop=True)

check_result = cache.check(vector=vector)
assert len(check_result) >= 1
assert response in check_result
cache._index.delete(True)

def test_set_threshold(cache):
# Test the getter and setter for the threshold
assert cache.threshold == 0.8
cache.set_threshold(0.9)
assert cache.threshold == 0.9
cache.index.delete(drop=True)
cache._index.delete(True)

def test_from_existing(cache, vector, provider):
prompt = "This is another test prompt."
response = "This is another test response."
metadata = {"source": "test"}
cache.store(prompt, response, vector=vector, metadata=metadata)
# connect from existing?
new_cache = SemanticCache(provider=provider, threshold=0.8)
check_result = new_cache.check(vector=vector)
assert len(check_result) >= 1
assert response in check_result
new_cache._index.delete(True)