From 26f9d24d6d8f8dca0852dbc08718da7765e57ebf Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Thu, 27 Mar 2025 17:03:40 -0700 Subject: [PATCH 1/8] Add batch_search to sync Index --- redisvl/index/index.py | 78 +++++++++++++++++++++++++- tests/integration/test_search_index.py | 49 ++++++++++++++++ 2 files changed, 126 insertions(+), 1 deletion(-) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index f52de03d..ee2b6e25 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -1,6 +1,7 @@ import asyncio import json import threading +import time import warnings import weakref from typing import ( @@ -26,6 +27,8 @@ import redis import redis.asyncio as aredis +from redis.client import NEVER_DECODE +from redis.commands.helpers import get_protocol_version # type: ignore from redis.commands.search.indexDefinition import IndexDefinition from redisvl.exceptions import RedisModuleVersionError, RedisSearchError @@ -349,7 +352,7 @@ def client(self) -> Optional[redis.Redis]: return self.__redis_client @property - def _redis_client(self) -> Optional[redis.Redis]: + def _redis_client(self) -> redis.Redis: """ Get a Redis client instance. @@ -652,6 +655,79 @@ def aggregate(self, *args, **kwargs) -> "AggregateResult": except Exception as e: raise RedisSearchError(f"Error while aggregating: {str(e)}") from e + def batch_search( + self, queries: List[str], batch_size: int = 100, **query_params + ) -> List[List[Dict[str, Any]]]: + """Perform a search against the index for multiple queries. + + This method takes a list of queries and returns a list of search results. + The results are returned in the same order as the queries. + + Args: + queries (List[str]): The queries to search for. + batch_size (int, optional): The number of queries to search for at a time. + Defaults to 100. + query_params (dict, optional): The query parameters to pass to the search + for each query. + + Returns: + List[List[Dict[str, Any]]]: The search results. + """ + all_parsed = [] + search = self._redis_client.ft(self.schema.index.name) + options = {} + if get_protocol_version(self._redis_client) not in ["3", 3]: + options[NEVER_DECODE] = True + + for i in range(0, len(queries), batch_size): + batch_queries = queries[i : i + batch_size] + print("batch queries", batch_queries) + + # redis-py doesn't support calling `search` in a pipeline, + # so we need to manually execute each command in a pipeline + # and parse the results + with self._redis_client.pipeline(transaction=False) as pipe: + batch_built_queries = [] + for query in batch_queries: + query_args, q = search._mk_query_args( # type: ignore + query, query_params=query_params + ) + batch_built_queries.append(q) + print("query", query_args, options) + pipe.execute_command( + "FT.SEARCH", + *query_args, + **options, + ) + + st = time.time() + # One list of results per query + print("query stack", pipe.command_stack) + results = pipe.execute() + print("SUCCESS") + + # We don't know how long each query took, so we'll use the total time + # for all queries in the batch as the duration for each query + duration = (time.time() - st) * 1000.0 + + for i, query_results in enumerate(results): + _built_query = batch_built_queries[i] + parsed_raw = search._parse_search( # type: ignore + query_results, + query=_built_query, + duration=duration, + ) + parsed = process_results( + parsed_raw, + query=_built_query, + storage_type=self.schema.index.storage_type, + ) + # Create separate lists of parsed results for each query + # passed in to the batch_search method, so that callers can + # access the results for each query individually + all_parsed.append(parsed) + return all_parsed + def search(self, *args, **kwargs) -> "Result": """Perform a search against the index. diff --git a/tests/integration/test_search_index.py b/tests/integration/test_search_index.py index 368c048a..fc270336 100644 --- a/tests/integration/test_search_index.py +++ b/tests/integration/test_search_index.py @@ -389,3 +389,52 @@ def test_search_index_validates_redis_modules(redis_url): index.create(overwrite=True, drop=True) mock_validate_sync_redis.assert_called_once() + + +def test_batch_search(index): + index.create(overwrite=True, drop=True) + data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}] + index.load(data, id_field="id") + + results = index.batch_search(["@test:{foo}", "@test:{bar}"]) + assert len(results) == 2 + assert results[0][0]["id"] == "rvl:1" + assert results[1][0]["id"] == "rvl:2" + + +def test_batch_search_with_multiple_batches(index): + index.create(overwrite=True, drop=True) + data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}] + index.load(data, id_field="id") + + results = index.batch_search(["@test:{foo}", "@test:{bar}"]) + assert len(results) == 2 + assert len(results[0]) == 1 + assert len(results[1]) == 1 + + results = index.batch_search( + [ + "@test:{foo}", + "@test:{bar}", + "@test:{baz}", + "@test:{foo}", + "@test:{bar}", + "@test:{baz}", + ], + batch_size=2, + ) + assert len(results) == 6 + + # First (and only) result for the first query + assert results[0][0]["id"] == "rvl:1" + + # Second (and only) result for the second query + assert results[1][0]["id"] == "rvl:2" + + # Third query has no results + assert len(results[2]) == 0 + + # Then the pattern repeats + assert results[3][0]["id"] == "rvl:1" + assert results[4][0]["id"] == "rvl:2" + assert len(results[5]) == 0 From 9b88f04551e3c3eb9ff7c520a76e2f294c37147b Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Thu, 27 Mar 2025 17:06:53 -0700 Subject: [PATCH 2/8] Remove debug prints --- redisvl/index/index.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index ee2b6e25..0d20ff68 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -681,7 +681,6 @@ def batch_search( for i in range(0, len(queries), batch_size): batch_queries = queries[i : i + batch_size] - print("batch queries", batch_queries) # redis-py doesn't support calling `search` in a pipeline, # so we need to manually execute each command in a pipeline @@ -693,7 +692,6 @@ def batch_search( query, query_params=query_params ) batch_built_queries.append(q) - print("query", query_args, options) pipe.execute_command( "FT.SEARCH", *query_args, @@ -701,10 +699,7 @@ def batch_search( ) st = time.time() - # One list of results per query - print("query stack", pipe.command_stack) results = pipe.execute() - print("SUCCESS") # We don't know how long each query took, so we'll use the total time # for all queries in the batch as the duration for each query From 980f51a3a8886e1396baeef636917f552d57e897 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Thu, 27 Mar 2025 17:14:15 -0700 Subject: [PATCH 3/8] Add async batch_search --- redisvl/index/index.py | 69 ++++++++++++++++++++ tests/integration/test_async_search_index.py | 51 +++++++++++++++ tests/integration/test_search_index.py | 2 +- 3 files changed, 121 insertions(+), 1 deletion(-) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 0d20ff68..5179435c 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -1282,6 +1282,75 @@ async def aggregate(self, *args, **kwargs) -> "AggregateResult": except Exception as e: raise RedisSearchError(f"Error while aggregating: {str(e)}") from e + async def batch_search( + self, queries: List[str], batch_size: int = 100, **query_params + ) -> List[List[Dict[str, Any]]]: + """Perform a search against the index for multiple queries. + + This method takes a list of queries and returns a list of search results. + The results are returned in the same order as the queries. + + Args: + queries (List[str]): The queries to search for. + batch_size (int, optional): The number of queries to search for at a time. + Defaults to 100. + query_params (dict, optional): The query parameters to pass to the search + for each query. + + Returns: + List[List[Dict[str, Any]]]: The search results. + """ + all_parsed = [] + client = await self._get_client() + search = client.ft(self.schema.index.name) + options = {} + if get_protocol_version(client) not in ["3", 3]: + options[NEVER_DECODE] = True + + for i in range(0, len(queries), batch_size): + batch_queries = queries[i : i + batch_size] + + # redis-py doesn't support calling `search` in a pipeline, + # so we need to manually execute each command in a pipeline + # and parse the results + async with client.pipeline(transaction=False) as pipe: + batch_built_queries = [] + for query in batch_queries: + query_args, q = search._mk_query_args( # type: ignore + query, query_params=query_params + ) + batch_built_queries.append(q) + pipe.execute_command( + "FT.SEARCH", + *query_args, + **options, + ) + + st = time.time() + results = await pipe.execute() + + # We don't know how long each query took, so we'll use the total time + # for all queries in the batch as the duration for each query + duration = (time.time() - st) * 1000.0 + + for i, query_results in enumerate(results): + _built_query = batch_built_queries[i] + parsed_raw = search._parse_search( # type: ignore + query_results, + query=_built_query, + duration=duration, + ) + parsed = process_results( + parsed_raw, + query=_built_query, + storage_type=self.schema.index.storage_type, + ) + # Create separate lists of parsed results for each query + # passed in to the batch_search method, so that callers can + # access the results for each query individually + all_parsed.append(parsed) + return all_parsed + async def search(self, *args, **kwargs) -> "Result": """Perform a search on this index. diff --git a/tests/integration/test_async_search_index.py b/tests/integration/test_async_search_index.py index d1b42235..0ed931d6 100644 --- a/tests/integration/test_async_search_index.py +++ b/tests/integration/test_async_search_index.py @@ -436,3 +436,54 @@ async def test_async_search_index_validates_redis_modules(redis_url): await index.create(overwrite=True, drop=True) mock_validate_async_redis.assert_called_once() + + +@pytest.mark.asyncio +async def test_batch_search(async_index): + await async_index.create(overwrite=True, drop=True) + data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}] + await async_index.load(data, id_field="id") + + results = await async_index.batch_search(["@test:{foo}", "@test:{bar}"]) + assert len(results) == 2 + assert results[0][0]["id"] == "rvl:1" + assert results[1][0]["id"] == "rvl:2" + + +@pytest.mark.asyncio +async def test_batch_search_with_multiple_batches(async_index): + await async_index.create(overwrite=True, drop=True) + data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}] + await async_index.load(data, id_field="id") + + results = await async_index.batch_search(["@test:{foo}", "@test:{bar}"]) + assert len(results) == 2 + assert len(results[0]) == 1 + assert len(results[1]) == 1 + + results = await async_index.batch_search( + [ + "@test:{foo}", + "@test:{bar}", + "@test:{baz}", + "@test:{foo}", + "@test:{bar}", + "@test:{baz}", + ], + batch_size=2, + ) + assert len(results) == 6 + + # First (and only) result for the first query + assert results[0][0]["id"] == "rvl:1" + + # Second (and only) result for the second query + assert results[1][0]["id"] == "rvl:2" + + # Third query should have zero results because there is no baz + assert len(results[2]) == 0 + + # Then the pattern repeats + assert results[3][0]["id"] == "rvl:1" + assert results[4][0]["id"] == "rvl:2" + assert len(results[5]) == 0 diff --git a/tests/integration/test_search_index.py b/tests/integration/test_search_index.py index fc270336..1b210b1d 100644 --- a/tests/integration/test_search_index.py +++ b/tests/integration/test_search_index.py @@ -431,7 +431,7 @@ def test_batch_search_with_multiple_batches(index): # Second (and only) result for the second query assert results[1][0]["id"] == "rvl:2" - # Third query has no results + # Third query should have zero results because there is no baz assert len(results[2]) == 0 # Then the pattern repeats From 64a1ba9bfdfab31fa63e5998c220b1eb2e384bed Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Thu, 27 Mar 2025 17:31:51 -0700 Subject: [PATCH 4/8] WIP on async batch_query --- redisvl/index/index.py | 48 ++++++++++++++------ tests/integration/test_async_search_index.py | 13 ++++++ 2 files changed, 47 insertions(+), 14 deletions(-) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 5179435c..de024fd3 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -1283,8 +1283,8 @@ async def aggregate(self, *args, **kwargs) -> "AggregateResult": raise RedisSearchError(f"Error while aggregating: {str(e)}") from e async def batch_search( - self, queries: List[str], batch_size: int = 100, **query_params - ) -> List[List[Dict[str, Any]]]: + self, queries: List[BaseQuery], batch_size: int = 100, **query_params + ) -> List["Result"]: """Perform a search against the index for multiple queries. This method takes a list of queries and returns a list of search results. @@ -1298,9 +1298,9 @@ async def batch_search( for each query. Returns: - List[List[Dict[str, Any]]]: The search results. + List[Result]: The search results. """ - all_parsed = [] + all_results = [] client = await self._get_client() search = client.ft(self.schema.index.name) options = {} @@ -1340,16 +1340,8 @@ async def batch_search( query=_built_query, duration=duration, ) - parsed = process_results( - parsed_raw, - query=_built_query, - storage_type=self.schema.index.storage_type, - ) - # Create separate lists of parsed results for each query - # passed in to the batch_search method, so that callers can - # access the results for each query individually - all_parsed.append(parsed) - return all_parsed + all_results.append(parsed_raw) + return all_results async def search(self, *args, **kwargs) -> "Result": """Perform a search on this index. @@ -1367,6 +1359,34 @@ async def search(self, *args, **kwargs) -> "Result": except Exception as e: raise RedisSearchError(f"Error while searching: {str(e)}") from e + async def _batch_query( + self, queries: List[BaseQuery], batch_size: int = 100 + ) -> List[List[Dict[str, Any]]]: + """Asynchronously execute a batch of queries and process results.""" + results = await self.batch_search(queries, batch_size=batch_size) + all_parsed = [] + for query, batch_results in zip(queries, results): + parsed = process_results( + batch_results, + query=query, + storage_type=self.schema.index.storage_type, + ) + # Create separate lists of parsed results for each query + # passed in to the batch_search method, so that callers can + # access the results for each query individually + all_parsed.append(parsed) + + return all_parsed + + async def batch_query( + self, queries: List[BaseQuery], batch_size: int = 100 + ) -> List[List[Dict[str, Any]]]: + """Asynchronously execute a batch of queries and process results.""" + return await self._batch_query( + [query.query for query in queries], + batch_size=batch_size, + ) + async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]: """Asynchronously execute a query and process results.""" results = await self.search(query.query, query_params=query.params) diff --git a/tests/integration/test_async_search_index.py b/tests/integration/test_async_search_index.py index 0ed931d6..d1d00e61 100644 --- a/tests/integration/test_async_search_index.py +++ b/tests/integration/test_async_search_index.py @@ -8,6 +8,7 @@ from redisvl.exceptions import RedisModuleVersionError, RedisSearchError from redisvl.index import AsyncSearchIndex from redisvl.query import VectorQuery +from redisvl.query.query import FilterQuery from redisvl.redis.utils import convert_bytes from redisvl.schema import IndexSchema, StorageType @@ -487,3 +488,15 @@ async def test_batch_search_with_multiple_batches(async_index): assert results[3][0]["id"] == "rvl:1" assert results[4][0]["id"] == "rvl:2" assert len(results[5]) == 0 + + +@pytest.mark.asyncio +async def test_batch_query(async_index): + await async_index.create(overwrite=True, drop=True) + data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}] + await async_index.load(data, id_field="id") + + query = FilterQuery(filter_expression="@test:{foo}") + results = await async_index.batch_query([query]) + + assert len(results) == 1 From 039f259bab2032aea5bc30580ddfe421f0b028d1 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Fri, 28 Mar 2025 17:07:38 -0700 Subject: [PATCH 5/8] Refactor batch search and add batch_query --- redisvl/index/index.py | 123 +++++++++++-------- tests/integration/test_async_search_index.py | 91 +++++++++++--- tests/integration/test_search_index.py | 98 ++++++++++++--- 3 files changed, 223 insertions(+), 89 deletions(-) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index de024fd3..c4e5de62 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -14,6 +14,7 @@ Iterable, List, Optional, + Tuple, Union, ) @@ -51,6 +52,14 @@ {"name": "searchlight", "ver": 20810}, ] +SearchParams = Union[ + Tuple[ + Union[str, BaseQuery], + Union[Dict[str, Union[str, int, float, bytes]], None], + ], + Union[str, BaseQuery], +] + def process_results( results: "Result", query: BaseQuery, storage_type: StorageType @@ -656,22 +665,23 @@ def aggregate(self, *args, **kwargs) -> "AggregateResult": raise RedisSearchError(f"Error while aggregating: {str(e)}") from e def batch_search( - self, queries: List[str], batch_size: int = 100, **query_params - ) -> List[List[Dict[str, Any]]]: + self, + queries: List[SearchParams], + batch_size: int = 10, + ) -> List["Result"]: """Perform a search against the index for multiple queries. - This method takes a list of queries and returns a list of search results. - The results are returned in the same order as the queries. + This method takes a list of queries and optionally query params and + returns a list of Result objects for each query. Results are + returned in the same order as the queries. Args: - queries (List[str]): The queries to search for. - batch_size (int, optional): The number of queries to search for at a time. - Defaults to 100. - query_params (dict, optional): The query parameters to pass to the search - for each query. + queries (List[SearchParams]): The queries to search for. batch_size + (int, optional): The number of queries to search for at a time. + Defaults to 10. Returns: - List[List[Dict[str, Any]]]: The search results. + List[Result]: The search results for each query. """ all_parsed = [] search = self._redis_client.ft(self.schema.index.name) @@ -688,9 +698,14 @@ def batch_search( with self._redis_client.pipeline(transaction=False) as pipe: batch_built_queries = [] for query in batch_queries: - query_args, q = search._mk_query_args( # type: ignore - query, query_params=query_params - ) + if isinstance(query, tuple): + query_args, q = search._mk_query_args( # type: ignore + query[0], query_params=query[1] + ) + else: + query_args, q = search._mk_query_args( # type: ignore + query, query_params=None + ) batch_built_queries.append(q) pipe.execute_command( "FT.SEARCH", @@ -707,20 +722,13 @@ def batch_search( for i, query_results in enumerate(results): _built_query = batch_built_queries[i] - parsed_raw = search._parse_search( # type: ignore + parsed_result = search._parse_search( # type: ignore query_results, query=_built_query, duration=duration, ) - parsed = process_results( - parsed_raw, - query=_built_query, - storage_type=self.schema.index.storage_type, - ) - # Create separate lists of parsed results for each query - # passed in to the batch_search method, so that callers can - # access the results for each query individually - all_parsed.append(parsed) + # Return a parsed Result object for each query + all_parsed.append(parsed_result) return all_parsed def search(self, *args, **kwargs) -> "Result": @@ -740,6 +748,26 @@ def search(self, *args, **kwargs) -> "Result": except Exception as e: raise RedisSearchError(f"Error while searching: {str(e)}") from e + def batch_query( + self, queries: List[BaseQuery], batch_size: int = 10 + ) -> List[List[Dict[str, Any]]]: + """Execute a batch of queries and process results.""" + results = self.batch_search( + [(query.query, query.params) for query in queries], batch_size=batch_size + ) + all_parsed = [] + for query, batch_results in zip(queries, results): + parsed = process_results( + batch_results, + query=query, + storage_type=self.schema.index.storage_type, + ) + # Create separate lists of parsed results for each query + # passed in to the batch_search method, so that callers can + # access the results for each query individually + all_parsed.append(parsed) + return all_parsed + def _query(self, query: BaseQuery) -> List[Dict[str, Any]]: """Execute a query and process results.""" results = self.search(query.query, query_params=query.params) @@ -1283,22 +1311,20 @@ async def aggregate(self, *args, **kwargs) -> "AggregateResult": raise RedisSearchError(f"Error while aggregating: {str(e)}") from e async def batch_search( - self, queries: List[BaseQuery], batch_size: int = 100, **query_params + self, queries: List[SearchParams], batch_size: int = 10 ) -> List["Result"]: """Perform a search against the index for multiple queries. - This method takes a list of queries and returns a list of search results. - The results are returned in the same order as the queries. + This method takes a list of queries and returns a list of Result objects + for each query. Results are returned in the same order as the queries. Args: - queries (List[str]): The queries to search for. - batch_size (int, optional): The number of queries to search for at a time. - Defaults to 100. - query_params (dict, optional): The query parameters to pass to the search - for each query. + queries (List[SearchParams]): The queries to search for. batch_size + (int, optional): The number of queries to search for at a time. + Defaults to 10. Returns: - List[Result]: The search results. + List[Result]: The search results for each query. """ all_results = [] client = await self._get_client() @@ -1316,9 +1342,14 @@ async def batch_search( async with client.pipeline(transaction=False) as pipe: batch_built_queries = [] for query in batch_queries: - query_args, q = search._mk_query_args( # type: ignore - query, query_params=query_params - ) + if isinstance(query, tuple): + query_args, q = search._mk_query_args( # type: ignore + query[0], query_params=query[1] + ) + else: + query_args, q = search._mk_query_args( # type: ignore + query, query_params=None + ) batch_built_queries.append(q) pipe.execute_command( "FT.SEARCH", @@ -1335,12 +1366,13 @@ async def batch_search( for i, query_results in enumerate(results): _built_query = batch_built_queries[i] - parsed_raw = search._parse_search( # type: ignore + parsed_result = search._parse_search( # type: ignore query_results, query=_built_query, duration=duration, ) - all_results.append(parsed_raw) + # Return a parsed Result object for each query + all_results.append(parsed_result) return all_results async def search(self, *args, **kwargs) -> "Result": @@ -1359,11 +1391,13 @@ async def search(self, *args, **kwargs) -> "Result": except Exception as e: raise RedisSearchError(f"Error while searching: {str(e)}") from e - async def _batch_query( - self, queries: List[BaseQuery], batch_size: int = 100 + async def batch_query( + self, queries: List[BaseQuery], batch_size: int = 10 ) -> List[List[Dict[str, Any]]]: """Asynchronously execute a batch of queries and process results.""" - results = await self.batch_search(queries, batch_size=batch_size) + results = await self.batch_search( + [(query.query, query.params) for query in queries], batch_size=batch_size + ) all_parsed = [] for query, batch_results in zip(queries, results): parsed = process_results( @@ -1378,15 +1412,6 @@ async def _batch_query( return all_parsed - async def batch_query( - self, queries: List[BaseQuery], batch_size: int = 100 - ) -> List[List[Dict[str, Any]]]: - """Asynchronously execute a batch of queries and process results.""" - return await self._batch_query( - [query.query for query in queries], - batch_size=batch_size, - ) - async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]: """Asynchronously execute a query and process results.""" results = await self.search(query.query, query_params=query.params) diff --git a/tests/integration/test_async_search_index.py b/tests/integration/test_async_search_index.py index d1d00e61..edc6c01a 100644 --- a/tests/integration/test_async_search_index.py +++ b/tests/integration/test_async_search_index.py @@ -447,47 +447,81 @@ async def test_batch_search(async_index): results = await async_index.batch_search(["@test:{foo}", "@test:{bar}"]) assert len(results) == 2 - assert results[0][0]["id"] == "rvl:1" - assert results[1][0]["id"] == "rvl:2" + assert results[0].total == 1 + assert results[0].docs[0]["id"] == "rvl:1" + assert results[1].total == 1 + assert results[1].docs[0]["id"] == "rvl:2" +@pytest.mark.parametrize( + "queries", + [ + [ + [ + FilterQuery(filter_expression="@test:{foo}"), + FilterQuery(filter_expression="@test:{bar}"), + ], + [ + FilterQuery(filter_expression="@test:{foo}"), + FilterQuery(filter_expression="@test:{bar}"), + FilterQuery(filter_expression="@test:{baz}"), + FilterQuery(filter_expression="@test:{foo}"), + FilterQuery(filter_expression="@test:{bar}"), + FilterQuery(filter_expression="@test:{baz}"), + ], + ], + [ + [ + "@test:{foo}", + "@test:{bar}", + ], + [ + "@test:{foo}", + "@test:{bar}", + "@test:{baz}", + "@test:{foo}", + "@test:{bar}", + "@test:{baz}", + ], + ], + ], +) @pytest.mark.asyncio -async def test_batch_search_with_multiple_batches(async_index): +async def test_batch_search_with_multiple_batches(async_index, queries): await async_index.create(overwrite=True, drop=True) data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}] await async_index.load(data, id_field="id") - results = await async_index.batch_search(["@test:{foo}", "@test:{bar}"]) + results = await async_index.batch_search(queries[0]) assert len(results) == 2 - assert len(results[0]) == 1 - assert len(results[1]) == 1 + assert results[0].total == 1 + assert results[0].docs[0]["id"] == "rvl:1" + assert results[1].total == 1 + assert results[1].docs[0]["id"] == "rvl:2" results = await async_index.batch_search( - [ - "@test:{foo}", - "@test:{bar}", - "@test:{baz}", - "@test:{foo}", - "@test:{bar}", - "@test:{baz}", - ], + queries[1], batch_size=2, ) assert len(results) == 6 # First (and only) result for the first query - assert results[0][0]["id"] == "rvl:1" + assert results[0].total == 1 + assert results[0].docs[0]["id"] == "rvl:1" # Second (and only) result for the second query - assert results[1][0]["id"] == "rvl:2" + assert results[1].total == 1 + assert results[1].docs[0]["id"] == "rvl:2" # Third query should have zero results because there is no baz - assert len(results[2]) == 0 + assert results[2].total == 0 # Then the pattern repeats - assert results[3][0]["id"] == "rvl:1" - assert results[4][0]["id"] == "rvl:2" - assert len(results[5]) == 0 + assert results[3].total == 1 + assert results[3].docs[0]["id"] == "rvl:1" + assert results[4].total == 1 + assert results[4].docs[0]["id"] == "rvl:2" + assert results[5].total == 0 @pytest.mark.asyncio @@ -500,3 +534,20 @@ async def test_batch_query(async_index): results = await async_index.batch_query([query]) assert len(results) == 1 + assert results[0][0]["id"] == "rvl:1" + + +@pytest.mark.asyncio +async def test_batch_query_with_multiple_batches(async_index): + await async_index.create(overwrite=True, drop=True) + data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}] + await async_index.load(data, id_field="id") + + queries = [ + FilterQuery(filter_expression="@test:{foo}"), + FilterQuery(filter_expression="@test:{bar}"), + ] + results = await async_index.batch_query(queries, batch_size=1) + assert len(results) == 2 + assert results[0][0]["id"] == "rvl:1" + assert results[1][0]["id"] == "rvl:2" diff --git a/tests/integration/test_search_index.py b/tests/integration/test_search_index.py index 1b210b1d..800f6a06 100644 --- a/tests/integration/test_search_index.py +++ b/tests/integration/test_search_index.py @@ -7,6 +7,7 @@ from redisvl.exceptions import RedisModuleVersionError, RedisSearchError from redisvl.index import SearchIndex from redisvl.query import VectorQuery +from redisvl.query.query import FilterQuery from redisvl.redis.utils import convert_bytes from redisvl.schema import IndexSchema, StorageType @@ -398,43 +399,100 @@ def test_batch_search(index): results = index.batch_search(["@test:{foo}", "@test:{bar}"]) assert len(results) == 2 - assert results[0][0]["id"] == "rvl:1" - assert results[1][0]["id"] == "rvl:2" + assert results[0].total == 1 + assert results[0].docs[0]["id"] == "rvl:1" + assert results[1].total == 1 + assert results[1].docs[0]["id"] == "rvl:2" -def test_batch_search_with_multiple_batches(index): +@pytest.mark.parametrize( + "queries", + [ + [ + [ + FilterQuery(filter_expression="@test:{foo}"), + FilterQuery(filter_expression="@test:{bar}"), + ], + [ + FilterQuery(filter_expression="@test:{foo}"), + FilterQuery(filter_expression="@test:{bar}"), + FilterQuery(filter_expression="@test:{baz}"), + FilterQuery(filter_expression="@test:{foo}"), + FilterQuery(filter_expression="@test:{bar}"), + FilterQuery(filter_expression="@test:{baz}"), + ], + ], + [ + [ + "@test:{foo}", + "@test:{bar}", + ], + [ + "@test:{foo}", + "@test:{bar}", + "@test:{baz}", + "@test:{foo}", + "@test:{bar}", + "@test:{baz}", + ], + ], + ], +) +def test_batch_search_with_multiple_batches(index, queries): index.create(overwrite=True, drop=True) data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}] index.load(data, id_field="id") - results = index.batch_search(["@test:{foo}", "@test:{bar}"]) + results = index.batch_search(queries[0]) assert len(results) == 2 - assert len(results[0]) == 1 - assert len(results[1]) == 1 + assert results[0].total == 1 + assert results[0].docs[0]["id"] == "rvl:1" + assert results[1].total == 1 + assert results[1].docs[0]["id"] == "rvl:2" results = index.batch_search( - [ - "@test:{foo}", - "@test:{bar}", - "@test:{baz}", - "@test:{foo}", - "@test:{bar}", - "@test:{baz}", - ], + queries[1], batch_size=2, ) assert len(results) == 6 # First (and only) result for the first query - assert results[0][0]["id"] == "rvl:1" + assert results[0].docs[0]["id"] == "rvl:1" # Second (and only) result for the second query - assert results[1][0]["id"] == "rvl:2" + assert results[1].docs[0]["id"] == "rvl:2" # Third query should have zero results because there is no baz - assert len(results[2]) == 0 + assert results[2].total == 0 # Then the pattern repeats - assert results[3][0]["id"] == "rvl:1" - assert results[4][0]["id"] == "rvl:2" - assert len(results[5]) == 0 + assert results[3].docs[0]["id"] == "rvl:1" + assert results[4].docs[0]["id"] == "rvl:2" + assert results[5].total == 0 + + +def test_batch_query(index): + index.create(overwrite=True, drop=True) + data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}] + index.load(data, id_field="id") + + query = FilterQuery(filter_expression="@test:{foo}") + results = index.batch_query([query]) + + assert len(results) == 1 + assert results[0][0]["id"] == "rvl:1" + + +def test_batch_query_with_multiple_batches(index): + index.create(overwrite=True, drop=True) + data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}] + index.load(data, id_field="id") + + queries = [ + FilterQuery(filter_expression="@test:{foo}"), + FilterQuery(filter_expression="@test:{bar}"), + ] + results = index.batch_query(queries, batch_size=1) + assert len(results) == 2 + assert results[0][0]["id"] == "rvl:1" + assert results[1][0]["id"] == "rvl:2" From d8b1c190b895e789d36e854ce07e6679d03a5888 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Fri, 28 Mar 2025 18:01:12 -0700 Subject: [PATCH 6/8] Try caching huggingface models with a primer --- .github/workflows/test.yml | 77 ++++++++++++++++++++++++++++++-------- 1 file changed, 61 insertions(+), 16 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f2349633..b2ad2df3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -17,14 +17,53 @@ env: POETRY_VERSION: "1.8.3" jobs: + prime-cache: + name: Prime HuggingFace Model Cache + runs-on: ubuntu-latest + env: + HF_HOME: ${{ github.workspace }}/hf_cache + steps: + - name: Check out repository + uses: actions/checkout@v3 + + - name: Cache HuggingFace Models + id: hf-cache + uses: actions/cache@v3 + with: + path: hf_cache + key: ${{ runner.os }}-hf-cache + + - name: Set up Python 3.9 + uses: actions/setup-python@v4 + with: + python-version: 3.9 + cache: pip + + - name: Install Poetry + uses: snok/install-poetry@v1 + with: + version: ${{ env.POETRY_VERSION }} + + - name: Install dependencies + run: | + poetry install --all-extras + + - name: Run full test suite to prime cache + env: + HF_HOME: ${{ github.workspace }}/hf_cache + run: | + make test-all + test: name: Python ${{ matrix.python-version }} - ${{ matrix.connection }} [redis ${{ matrix.redis-version }}] runs-on: ubuntu-latest - + needs: prime-cache + env: + HF_HOME: ${{ github.workspace }}/hf_cache strategy: fail-fast: false matrix: - python-version: [3.9, '3.10', 3.11, 3.12, 3.13] + python-version: ['3.10', '3.11', 3.12, 3.13] connection: ['hiredis', 'plain'] redis-version: ['6.2.6-v9', 'latest', '8.0-M03'] @@ -32,11 +71,17 @@ jobs: - name: Check out repository uses: actions/checkout@v3 + - name: Cache HuggingFace Models + uses: actions/cache@v3 + with: + path: hf_cache + key: ${{ runner.os }}-hf-cache + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - cache: 'pip' + cache: pip - name: Install Poetry uses: snok/install-poetry@v1 @@ -74,16 +119,16 @@ jobs: COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }} - AZURE_OPENAI_API_KEY: ${{secrets.AZURE_OPENAI_API_KEY}} - AZURE_OPENAI_ENDPOINT: ${{secrets.AZURE_OPENAI_ENDPOINT}} - AZURE_OPENAI_DEPLOYMENT_NAME: ${{secrets.AZURE_OPENAI_DEPLOYMENT_NAME}} - OPENAI_API_VERSION: ${{secrets.OPENAI_API_VERSION}} + AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }} + AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_OPENAI_ENDPOINT }} + AZURE_OPENAI_DEPLOYMENT_NAME: ${{ secrets.AZURE_OPENAI_DEPLOYMENT_NAME }} + OPENAI_API_VERSION: ${{ secrets.OPENAI_API_VERSION }} AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} run: | make test-all - - name: Run tests + - name: Run tests (alternate) if: matrix.connection != 'plain' || matrix.redis-version != 'latest' run: | make test @@ -97,15 +142,15 @@ jobs: COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }} - AZURE_OPENAI_API_KEY: ${{secrets.AZURE_OPENAI_API_KEY}} - AZURE_OPENAI_ENDPOINT: ${{secrets.AZURE_OPENAI_ENDPOINT}} - AZURE_OPENAI_DEPLOYMENT_NAME: ${{secrets.AZURE_OPENAI_DEPLOYMENT_NAME}} - OPENAI_API_VERSION: ${{secrets.OPENAI_API_VERSION}} + AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }} + AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_OPENAI_ENDPOINT }} + AZURE_OPENAI_DEPLOYMENT_NAME: ${{ secrets.AZURE_OPENAI_DEPLOYMENT_NAME }} + OPENAI_API_VERSION: ${{ secrets.OPENAI_API_VERSION }} AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} run: | docker run -d --name redis -p 6379:6379 redis/redis-stack-server:latest - make test-notebooks + make test-notebooks docs: runs-on: ubuntu-latest @@ -117,17 +162,17 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ env.PYTHON_VERSION }} - cache: 'pip' + cache: pip - name: Install Poetry uses: snok/install-poetry@v1 with: version: ${{ env.POETRY_VERSION }} - + - name: Install dependencies run: | poetry install --all-extras - name: Build docs run: | - make docs-build \ No newline at end of file + make docs-build From f6b08b1a3022479a6c5920224d3cc409bb85152f Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Fri, 28 Mar 2025 18:08:22 -0700 Subject: [PATCH 7/8] WIP --- .github/workflows/test.yml | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b2ad2df3..58d6f09a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -51,6 +51,18 @@ jobs: - name: Run full test suite to prime cache env: HF_HOME: ${{ github.workspace }}/hf_cache + OPENAI_API_KEY: ${{ secrets.OPENAI_KEY }} + GCP_LOCATION: ${{ secrets.GCP_LOCATION }} + GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }} + COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} + MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} + VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }} + AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }} + AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_OPENAI_ENDPOINT }} + AZURE_OPENAI_DEPLOYMENT_NAME: ${{ secrets.AZURE_OPENAI_DEPLOYMENT_NAME }} + OPENAI_API_VERSION: ${{ secrets.OPENAI_API_VERSION }} + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} run: | make test-all @@ -113,6 +125,7 @@ jobs: - name: Run tests if: matrix.connection == 'plain' && matrix.redis-version == 'latest' env: + HF_HOME: ${{ github.workspace }}/hf_cache OPENAI_API_KEY: ${{ secrets.OPENAI_KEY }} GCP_LOCATION: ${{ secrets.GCP_LOCATION }} GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }} @@ -136,6 +149,7 @@ jobs: - name: Run notebooks if: matrix.connection == 'plain' && matrix.redis-version == 'latest' env: + HF_HOME: ${{ github.workspace }}/hf_cache OPENAI_API_KEY: ${{ secrets.OPENAI_KEY }} GCP_LOCATION: ${{ secrets.GCP_LOCATION }} GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }} From e36342542e82d21d1d38d7ea3810ea4c8bf758b7 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Fri, 28 Mar 2025 18:19:22 -0700 Subject: [PATCH 8/8] WIP --- .github/workflows/test.yml | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 58d6f09a..ce226199 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -33,6 +33,11 @@ jobs: path: hf_cache key: ${{ runner.os }}-hf-cache + - name: Set HuggingFace token + run: | + mkdir -p ~/.huggingface + echo '{"token":"${{ secrets.HF_TOKEN }}"}' > ~/.huggingface/token + - name: Set up Python 3.9 uses: actions/setup-python@v4 with: @@ -48,8 +53,14 @@ jobs: run: | poetry install --all-extras + - name: Authenticate to Google Cloud + uses: google-github-actions/auth@v1 + with: + credentials_json: ${{ secrets.GOOGLE_CREDENTIALS }} + - name: Run full test suite to prime cache env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_HOME: ${{ github.workspace }}/hf_cache OPENAI_API_KEY: ${{ secrets.OPENAI_KEY }} GCP_LOCATION: ${{ secrets.GCP_LOCATION }}