Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,6 @@ def sample_data():
},
]

@pytest.fixture(scope="session")
def event_loop():
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
yield loop
loop.close()

@pytest.fixture
def clear_db(redis):
redis.flushall()
Expand Down
6 changes: 3 additions & 3 deletions redisvl/redis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ def convert_bytes(data: Any) -> Any:
except:
return data
if isinstance(data, dict):
return dict(map(convert_bytes, data.items()))
return {convert_bytes(key): convert_bytes(value) for key, value in data.items()}
if isinstance(data, list):
return list(map(convert_bytes, data))
return [convert_bytes(item) for item in data]
if isinstance(data, tuple):
return map(convert_bytes, data)
return tuple(convert_bytes(item) for item in data)
return data


Expand Down
55 changes: 54 additions & 1 deletion tests/integration/test_llmcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from redisvl.extensions.llmcache import SemanticCache
from redisvl.utils.vectorize import HFTextVectorizer

from redisvl.index.index import SearchIndex
from collections import namedtuple

@pytest.fixture
def vectorizer():
Expand All @@ -18,6 +19,10 @@ def cache(vectorizer):
cache_instance.clear() # Clear cache after each test
cache_instance._index.delete(True) # Clean up index

@pytest.fixture
def cache_no_cleanup(vectorizer):
cache_instance = SemanticCache(vectorizer=vectorizer, distance_threshold=0.2)
yield cache_instance

@pytest.fixture
def cache_with_ttl(vectorizer):
Expand All @@ -26,6 +31,12 @@ def cache_with_ttl(vectorizer):
cache_instance.clear() # Clear cache after each test
cache_instance._index.delete(True) # Clean up index

@pytest.fixture
def cache_with_redis_client(vectorizer, client):
cache_instance = SemanticCache(vectorizer=vectorizer, redis_client=client, distance_threshold=0.2)
yield cache_instance
cache_instance.clear() # Clear cache after each test
cache_instance._index.delete(True) # Clean up index

# Test basic store and check functionality
def test_store_and_check(cache, vectorizer):
Expand Down Expand Up @@ -83,6 +94,10 @@ def test_check_invalid_input(cache):
with pytest.raises(TypeError):
cache.check(prompt="test", return_fields="bad value")

# Test handling invalid input for check method
def test_bad_ttl(cache):
with pytest.raises(ValueError):
cache.set_ttl(2.5)

# Test storing with metadata
def test_store_with_metadata(cache, vectorizer):
Expand All @@ -100,6 +115,16 @@ def test_store_with_metadata(cache, vectorizer):
assert check_result[0]["metadata"] == metadata
assert check_result[0]["prompt"] == prompt

# Test storing with invalid metadata
def test_store_with_invalid_metadata(cache, vectorizer):
prompt = "This is another test prompt."
response = "This is another test response."
metadata = namedtuple('metadata', 'source')(**{'source': 'test'})

vector = vectorizer.embed(prompt)

with pytest.raises(TypeError, match=r"If specified, cached metadata must be a dictionary."):
cache.store(prompt, response, vector=vector, metadata=metadata)

# Test setting and getting the distance threshold
def test_distance_threshold(cache):
Expand All @@ -110,6 +135,11 @@ def test_distance_threshold(cache):
assert cache.distance_threshold == new_threshold
assert cache.distance_threshold != initial_threshold

# Test out of range distance threshold
def test_distance_threshold_out_of_range(cache):
out_of_range_threshold = -1
with pytest.raises(ValueError):
cache.set_threshold(out_of_range_threshold)

# Test storing and retrieving multiple items
def test_multiple_items(cache, vectorizer):
Expand All @@ -130,3 +160,26 @@ def test_multiple_items(cache, vectorizer):
print(check_result, flush=True)
assert check_result[0]["response"] == expected_response
assert "metadata" not in check_result[0]

# Test retrieving underlying SearchIndex for the cache.
def test_get_index(cache):
assert isinstance(cache.index, SearchIndex)

# Test basic functionality with cache created with user-provided Redis client
def test_store_and_check_with_provided_client(cache_with_redis_client, vectorizer):
prompt = "This is a test prompt."
response = "This is a test response."
vector = vectorizer.embed(prompt)

cache_with_redis_client.store(prompt, response, vector=vector)
check_result = cache_with_redis_client.check(vector=vector)

assert len(check_result) == 1
print(check_result, flush=True)
assert response == check_result[0]["response"]
assert "metadata" not in check_result[0]

# Test deleting the cache
def test_delete(cache_no_cleanup, vectorizer):
cache_no_cleanup.delete()
assert not cache_no_cleanup.index.exists()
62 changes: 62 additions & 0 deletions tests/integration/test_search_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest

from redisvl.index import SearchIndex
from redisvl.query import FilterQuery
from redisvl.query.filter import Tag

@pytest.fixture
def filter_query():
return FilterQuery(
return_fields=None,
filter_expression=Tag("credit_score") == "high",
)

@pytest.fixture
def index(sample_data):
fields_spec = [
{"name": "credit_score", "type": "tag"},
{"name": "user", "type": "tag"},
{"name": "job", "type": "text"},
{"name": "age", "type": "numeric"},
{
"name": "user_embedding",
"type": "vector",
"attrs": {
"dims": 3,
"distance_metric": "cosine",
"algorithm": "flat",
"datatype": "float32",
},
},
]

json_schema = {
"index": {
"name": "user_index_json",
"prefix": "users_json",
"storage_type": "json",
},
"fields": fields_spec,
}

# construct a search index from the schema
index = SearchIndex.from_dict(json_schema)

# connect to local redis instance
index.connect("redis://localhost:6379")

# create the index (no data yet)
index.create(overwrite=True)

# Prepare and load the data
index.load(sample_data)

# run the test
yield index

# clean up
index.delete(drop=True)

def test_process_results_unpacks_json_properly(index, filter_query):
results = index.query(filter_query)
assert len(results) == 4
34 changes: 34 additions & 0 deletions tests/unit/test_async_search_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from redisvl.index import AsyncSearchIndex
from redisvl.redis.utils import convert_bytes
from redisvl.schema import IndexSchema, StorageType
from redisvl.query import VectorQuery

fields = [{"name": "test", "type": "tag"}]

Expand Down Expand Up @@ -137,3 +138,36 @@ async def test_no_id_field(async_client, async_index):
# catch missing / invalid id_field
with pytest.raises(ValueError):
await async_index.load(bad_data, id_field="key")


@pytest.mark.asyncio
async def test_check_index_exists_before_delete(async_client, async_index):
async_index.set_client(async_client)
await async_index.create(overwrite=True, drop=True)
await async_index.delete(drop=True)
with pytest.raises(ValueError):
await async_index.delete()

@pytest.mark.asyncio
async def test_check_index_exists_before_search(async_client, async_index):
async_index.set_client(async_client)
await async_index.create(overwrite=True, drop=True)
await async_index.delete(drop=True)

query = VectorQuery(
[0.1, 0.1, 0.5],
"user_embedding",
return_fields=["user", "credit_score", "age", "job", "location"],
num_results=7,
)
with pytest.raises(ValueError):
await async_index.search(query.query, query_params=query.params)

@pytest.mark.asyncio
async def test_check_index_exists_before_info(async_client, async_index):
async_index.set_client(async_client)
await async_index.create(overwrite=True, drop=True)
await async_index.delete(drop=True)

with pytest.raises(ValueError):
await async_index.info()
45 changes: 44 additions & 1 deletion tests/unit/test_search_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from redisvl.index import SearchIndex
from redisvl.redis.utils import convert_bytes
from redisvl.schema import IndexSchema, StorageType
from redisvl.query import VectorQuery

fields = [{"name": "test", "type": "tag"}]

Expand All @@ -11,11 +12,13 @@
def index_schema():
return IndexSchema.from_dict({"index": {"name": "my_index"}, "fields": fields})


@pytest.fixture
def index(index_schema):
return SearchIndex(schema=index_schema)

@pytest.fixture
def index_from_yaml():
return SearchIndex.from_yaml("schemas/test_json_schema.yaml")

def test_search_index_properties(index_schema, index):
assert index.schema == index_schema
Expand All @@ -28,6 +31,13 @@ def test_search_index_properties(index_schema, index):
assert index.storage_type == index_schema.index.storage_type == StorageType.HASH
assert index.key("foo").startswith(index.prefix)

def test_search_index_from_yaml(index_from_yaml):
assert index_from_yaml.name == "json-test"
assert index_from_yaml.client == None
assert index_from_yaml.prefix == "json"
assert index_from_yaml.key_separator == ":"
assert index_from_yaml.storage_type == StorageType.JSON
assert index_from_yaml.key("foo").startswith(index_from_yaml.prefix)

def test_search_index_no_prefix(index_schema):
# specify an explicitly empty prefix...
Expand Down Expand Up @@ -118,3 +128,36 @@ def test_no_id_field(client, index):
# catch missing / invalid id_field
with pytest.raises(ValueError):
index.load(bad_data, id_field="key")

def test_check_index_exists_before_delete(client, index):
index.set_client(client)
index.create(overwrite=True, drop=True)
index.delete(drop=True)
with pytest.raises(ValueError):
index.delete()

def test_check_index_exists_before_search(client, index):
index.set_client(client)
index.create(overwrite=True, drop=True)
index.delete(drop=True)

query = VectorQuery(
[0.1, 0.1, 0.5],
"user_embedding",
return_fields=["user", "credit_score", "age", "job", "location"],
num_results=7,
)
with pytest.raises(ValueError):
index.search(query.query, query_params=query.params)

def test_check_index_exists_before_info(client, index):
index.set_client(client)
index.create(overwrite=True, drop=True)
index.delete(drop=True)

with pytest.raises(ValueError):
index.info()

def test_index_needs_valid_schema():
with pytest.raises(ValueError, match=r"Must provide a valid IndexSchema object"):
index = SearchIndex(schema="Not A Valid Schema")
Loading