From c31ba8eec8933d54da50802076d264baa081795e Mon Sep 17 00:00:00 2001 From: Sam Partee Date: Tue, 15 Aug 2023 22:30:02 -0700 Subject: [PATCH] Add FilterQuery --- redisvl/cli/index.py | 9 +-- redisvl/cli/main.py | 4 +- redisvl/cli/stats.py | 15 ++-- redisvl/cli/utils.py | 15 ++-- redisvl/query/__init__.py | 7 +- redisvl/query/filter.py | 7 +- redisvl/query/query.py | 82 ++++++++++++++++++++++ redisvl/schema.py | 2 +- redisvl/vectorize/text/huggingface.py | 6 +- redisvl/vectorize/text/openai.py | 18 ++--- tests/integration/test_query.py | 96 ++++++++++++++++---------- tests/integration/test_simple_async.py | 2 +- 12 files changed, 184 insertions(+), 79 deletions(-) diff --git a/redisvl/cli/index.py b/redisvl/cli/index.py index 22ad782a..f2ec2805 100644 --- a/redisvl/cli/index.py +++ b/redisvl/cli/index.py @@ -1,11 +1,11 @@ import argparse import sys -from tabulate import tabulate from argparse import Namespace +from tabulate import tabulate from redisvl.cli.log import get_logger -from redisvl.cli.utils import create_redis_url, add_index_parsing_options +from redisvl.cli.utils import add_index_parsing_options, create_redis_url from redisvl.index import SearchIndex from redisvl.utils.connection import get_redis_connection from redisvl.utils.utils import convert_bytes, make_dict @@ -36,7 +36,7 @@ def __init__(self): "--format", help="Output format for info command", type=str, - default="rounded_outline" + default="rounded_outline", ) parser = add_index_parsing_options(parser) @@ -126,6 +126,7 @@ def _connect_to_index(self, args: Namespace) -> SearchIndex: return index + def _display_in_table(index_info, output_format="rounded_outline"): print("\n") attributes = index_info.get("attributes", []) @@ -183,4 +184,4 @@ def _display_in_table(index_info, output_format="rounded_outline"): headers=headers, tablefmt=output_format, ) - ) \ No newline at end of file + ) diff --git a/redisvl/cli/main.py b/redisvl/cli/main.py index 065e66c9..2dcd16af 100644 --- a/redisvl/cli/main.py +++ b/redisvl/cli/main.py @@ -3,9 +3,8 @@ from redisvl.cli.index import Index from redisvl.cli.log import get_logger -from redisvl.cli.version import Version from redisvl.cli.stats import Stats - +from redisvl.cli.version import Version logger = get_logger(__name__) @@ -50,4 +49,3 @@ def version(self): def stats(self): Stats() exit(0) - diff --git a/redisvl/cli/stats.py b/redisvl/cli/stats.py index aca2847c..9861f9e3 100644 --- a/redisvl/cli/stats.py +++ b/redisvl/cli/stats.py @@ -1,13 +1,14 @@ import argparse import sys -from tabulate import tabulate from argparse import Namespace -from redisvl.cli.utils import create_redis_url, add_index_parsing_options +from tabulate import tabulate + +from redisvl.cli.log import get_logger +from redisvl.cli.utils import add_index_parsing_options, create_redis_url from redisvl.index import SearchIndex from redisvl.utils.connection import get_redis_connection -from redisvl.cli.log import get_logger logger = get_logger("[RedisVL]") STATS_KEYS = [ @@ -32,6 +33,7 @@ "vector_index_sz_mb", ] + class Stats: usage = "\n".join( [ @@ -43,11 +45,7 @@ def __init__(self): parser = argparse.ArgumentParser(usage=self.usage) parser.add_argument( - "-f", - "--format", - help="Output format", - type=str, - default="rounded_outline" + "-f", "--format", help="Output format", type=str, default="rounded_outline" ) parser = add_index_parsing_options(parser) args = parser.parse_args(sys.argv[2:]) @@ -57,7 +55,6 @@ def __init__(self): logger.error(e) exit(0) - def stats(self, args: Namespace): """Obtain stats about an index diff --git a/redisvl/cli/utils.py b/redisvl/cli/utils.py index 3799e3c1..7bfbefda 100644 --- a/redisvl/cli/utils.py +++ b/redisvl/cli/utils.py @@ -1,5 +1,5 @@ import os -from argparse import Namespace, ArgumentParser +from argparse import ArgumentParser, Namespace def create_redis_url(args: Namespace) -> str: @@ -18,20 +18,15 @@ def create_redis_url(args: Namespace) -> str: url += args.host + ":" + str(args.port) return url + def add_index_parsing_options(parser: ArgumentParser) -> ArgumentParser: - parser.add_argument( - "-i", "--index", help="Index name", type=str, required=False - ) + parser.add_argument("-i", "--index", help="Index name", type=str, required=False) parser.add_argument( "-s", "--schema", help="Path to schema file", type=str, required=False ) parser.add_argument("--host", help="Redis host", type=str, default="localhost") parser.add_argument("-p", "--port", help="Redis port", type=int, default=6379) - parser.add_argument( - "--user", help="Redis username", type=str, default="default" - ) + parser.add_argument("--user", help="Redis username", type=str, default="default") parser.add_argument("--ssl", help="Use SSL", action="store_true") - parser.add_argument( - "-a", "--password", help="Redis password", type=str, default="" - ) + parser.add_argument("-a", "--password", help="Redis password", type=str, default="") return parser diff --git a/redisvl/query/__init__.py b/redisvl/query/__init__.py index 1f9835c2..2a6cce6f 100644 --- a/redisvl/query/__init__.py +++ b/redisvl/query/__init__.py @@ -1,3 +1,6 @@ -from redisvl.query.query import VectorQuery +from redisvl.query.query import FilterQuery, VectorQuery -__all__ = ["VectorQuery"] +__all__ = [ + "VectorQuery", + "FilterQuery", +] diff --git a/redisvl/query/filter.py b/redisvl/query/filter.py index 1d378c59..50ec52bb 100644 --- a/redisvl/query/filter.py +++ b/redisvl/query/filter.py @@ -117,6 +117,7 @@ class Geo(FilterField): field in a Redis index. """ + OPERATORS = { FilterOperator.EQ: "==", FilterOperator.NE: "!=", @@ -174,12 +175,13 @@ def __init__(self, longitude: float, latitude: float, unit: str = "km"): class GeoRadius(GeoSpec): """A GeoRadius is a GeoSpec representing a geographic radius""" + def __init__( self, longitude: float, latitude: float, radius: Optional[int] = 1, - unit: Optional[str] = "km" + unit: Optional[str] = "km", ): """Create a GeoRadius specification (GeoSpec) @@ -202,6 +204,7 @@ def get_args(self) -> List[Union[float, int, str]]: class Num(FilterField): """A Num is a FilterField representing a numeric field in a Redis index.""" + OPERATORS = { FilterOperator.EQ: "==", FilterOperator.NE: "!=", @@ -311,6 +314,7 @@ def __le__(self, other: str) -> "FilterExpression": class Text(FilterField): """A Text is a FilterField representing a text field in a Redis index.""" + OPERATORS = { FilterOperator.EQ: "==", FilterOperator.NE: "!=", @@ -399,6 +403,7 @@ class FilterExpression: ... filter_expression=filter, ... ) """ + def __init__( self, _filter: str = None, diff --git a/redisvl/query/query.py b/redisvl/query/query.py index 6a962c1f..03163f26 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -23,6 +23,88 @@ def params(self) -> Dict[str, Any]: pass +class FilterQuery(BaseQuery): + def __init__( + self, + return_fields: List[str], + filter_expression: FilterExpression, + num_results: Optional[int] = 10, + params: Optional[Dict[str, Any]] = None, + ): + """Query for a filter expression. + + Args: + return_fields (List[str]): The fields to return. + filter_expression (FilterExpression): The filter expression to query for. + num_results (Optional[int], optional): The number of results to return. Defaults to 10. + params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None. + + Raises: + TypeError: If filter_expression is not of type redisvl.query.FilterExpression + + Examples: + >>> from redisvl.query import FilterQuery + >>> from redisvl.query.filter import Tag + >>> t = Tag("brand") == "Nike" + >>> q = FilterQuery(return_fields=["brand", "price"], filter_expression=t) + """ + + super().__init__(return_fields, num_results) + self.set_filter(filter_expression) + self._params = params + + def __str__(self) -> str: + return " ".join([str(x) for x in self.query.get_args()]) + + def set_filter(self, filter_expression: FilterExpression): + """Set the filter for the query. + + Args: + filter_expression (FilterExpression): The filter to apply to the query. + + Raises: + TypeError: If filter_expression is not of type redisvl.query.FilterExpression + """ + if not isinstance(filter_expression, FilterExpression): + raise TypeError( + "filter_expression must be of type redisvl.query.FilterExpression" + ) + self._filter = str(filter_expression) + + def get_filter(self) -> FilterExpression: + """Get the filter for the query. + + Returns: + FilterExpression: The filter for the query. + """ + return self._filter + + @property + def query(self) -> Query: + """Return a Redis-Py Query object representing the query. + + Returns: + redis.commands.search.query.Query: The query object. + """ + base_query = str(self._filter) + query = ( + Query(base_query) + .return_fields(*self._return_fields) + .paging(0, self._num_results) + .dialect(2) + ) + return query + + @property + def params(self) -> Dict[str, Any]: + """Return the parameters for the query. + + Returns: + Dict[str, Any]: The parameters for the query. + """ + return self._params + + class VectorQuery(BaseQuery): dtypes = { "float32": np.float32, diff --git a/redisvl/schema.py b/redisvl/schema.py index 986c0ca3..bc186de6 100644 --- a/redisvl/schema.py +++ b/redisvl/schema.py @@ -156,4 +156,4 @@ def read_schema(file_path: str): with open(fp, "r") as f: schema = yaml.safe_load(f) - return SchemaModel(**schema) \ No newline at end of file + return SchemaModel(**schema) diff --git a/redisvl/vectorize/text/huggingface.py b/redisvl/vectorize/text/huggingface.py index debf7c7f..ff762ade 100644 --- a/redisvl/vectorize/text/huggingface.py +++ b/redisvl/vectorize/text/huggingface.py @@ -80,9 +80,9 @@ def embed_many( TypeError: If the wrong input type is passed in for the test. """ if not isinstance(texts, list): - raise TypeError("Must pass in a list of str values to embed.") - if len(texts) > 0 and not isinstance(texts[0], str): - raise TypeError("Must pass in a list of str values to embed.") + raise TypeError("Must pass in a list of str values to embed.") + if len(texts) > 0 and not isinstance(texts[0], str): + raise TypeError("Must pass in a list of str values to embed.") embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): diff --git a/redisvl/vectorize/text/openai.py b/redisvl/vectorize/text/openai.py index 91c50fef..64cc3f67 100644 --- a/redisvl/vectorize/text/openai.py +++ b/redisvl/vectorize/text/openai.py @@ -13,6 +13,7 @@ class OpenAITextVectorizer(BaseVectorizer): API key to be passed in the api_config dictionary. The API key can be obtained from https://api.openai.com/. """ + def __init__(self, model: str, api_config: Optional[Dict] = None): """Initialize the OpenAI vectorizer. @@ -45,14 +46,13 @@ def __init__(self, model: str, api_config: Optional[Dict] = None): def _set_model_dims(self) -> int: try: embedding = self._model_client.create( - input=["dimension test"], - engine=self._model + input=["dimension test"], engine=self._model )["data"][0]["embedding"] except (KeyError, IndexError) as ke: raise ValueError(f"Unexpected response from the OpenAI API: {str(ke)}") except openai.error.AuthenticationError as ae: raise ValueError(f"Error authenticating with the OpenAI API: {str(ae)}") - except Exception as e: # pylint: disable=broad-except + except Exception as e: # pylint: disable=broad-except # fall back (TODO get more specific) raise ValueError(f"Error setting embedding model dimensions: {str(e)}") return len(embedding) @@ -87,9 +87,9 @@ def embed_many( TypeError: If the wrong input type is passed in for the test. """ if not isinstance(texts, list): - raise TypeError("Must pass in a list of str values to embed.") - if len(texts) > 0 and not isinstance(texts[0], str): - raise TypeError("Must pass in a list of str values to embed.") + raise TypeError("Must pass in a list of str values to embed.") + if len(texts) > 0 and not isinstance(texts[0], str): + raise TypeError("Must pass in a list of str values to embed.") embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): @@ -164,9 +164,9 @@ async def aembed_many( TypeError: If the wrong input type is passed in for the test. """ if not isinstance(texts, list): - raise TypeError("Must pass in a list of str values to embed.") - if len(texts) > 0 and not isinstance(texts[0], str): - raise TypeError("Must pass in a list of str values to embed.") + raise TypeError("Must pass in a list of str values to embed.") + if len(texts) > 0 and not isinstance(texts[0], str): + raise TypeError("Must pass in a list of str values to embed.") embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py index b9795b35..4ea04de5 100644 --- a/tests/integration/test_query.py +++ b/tests/integration/test_query.py @@ -5,7 +5,7 @@ from redis.commands.search.result import Result from redisvl.index import SearchIndex -from redisvl.query import VectorQuery +from redisvl.query import FilterQuery, VectorQuery from redisvl.query.filter import Geo, GeoRadius, Num, Tag, Text data = [ @@ -91,6 +91,19 @@ } +vector_query = VectorQuery( + [0.1, 0.1, 0.5], + "user_embedding", + return_fields=["user", "credit_score", "age", "job", "location"], +) + +filter_query = FilterQuery( + return_fields=["user", "credit_score", "age", "job", "location"], + # this will get set everytime + filter_expression=Tag("credit_score") == "high", +) + + @pytest.fixture(scope="module") def index(): # construct a search index from the schema @@ -163,101 +176,112 @@ def test_simple_tag_filter(index): assert len(results.docs) == 4 +@pytest.fixture(params=[vector_query, filter_query], ids=["VectorQuery", "FilterQuery"]) +def query(request): + return request.param + + def filter_test( - index, _filter, expected_count, credit_check=None, age_range=None, location=None + query, + index, + _filter, + expected_count, + credit_check=None, + age_range=None, + location=None, ): """Utility function to test filters""" - v = VectorQuery( - [0.1, 0.1, 0.5], - "user_embedding", - return_fields=["user", "credit_score", "age", "job", "location"], - filter_expression=_filter, - ) + + # set the new filter + query.set_filter(_filter) + # print(str(v) + "\n") # to print the query - results = index.query(v) + results = index.search(query.query, query_params=query.params) if credit_check: - for doc in results: - assert doc["credit_score"] == credit_check + for doc in results.docs: + assert doc.credit_score == credit_check if age_range: - for doc in results: + for doc in results.docs: if len(age_range) == 3: - assert int(doc["age"]) != age_range[2] + assert int(doc.age) != age_range[2] elif age_range[1] < age_range[0]: - assert (int(doc["age"]) <= age_range[0]) or (int(doc["age"]) >= age_range[1]) + assert (int(doc.age) <= age_range[0]) or (int(doc.age) >= age_range[1]) else: - assert age_range[0] <= int(doc["age"]) <= age_range[1] + assert age_range[0] <= int(doc.age) <= age_range[1] if location: - for doc in results: - assert doc["location"] == location - assert len(results) == expected_count + for doc in results.docs: + assert doc.location == location + assert len(results.docs) == expected_count -def test_filters(index): +def test_filters(index, query): # Simple Tag Filter t = Tag("credit_score") == "high" - filter_test(index, t, 4, credit_check="high") + filter_test(query, index, t, 4, credit_check="high") # Simple Numeric Filter n1 = Num("age") >= 18 - filter_test(index, n1, 4, age_range=(18, 100)) + filter_test(query, index, n1, 4, age_range=(18, 100)) # intersection of rules n2 = (Num("age") >= 18) & (Num("age") < 100) - filter_test(index, n2, 3, age_range=(18, 99)) + filter_test(query, index, n2, 3, age_range=(18, 99)) # union n3 = (Num("age") < 18) | (Num("age") > 94) - filter_test(index, n3, 4, age_range=(95, 17)) + filter_test(query, index, n3, 4, age_range=(95, 17)) n4 = Num("age") != 18 - filter_test(index, n4, 6, age_range=(0, 0, 18)) + filter_test(query, index, n4, 6, age_range=(0, 0, 18)) # Geographic filters g = Geo("location") == GeoRadius(-122.4194, 37.7749, 1, unit="m") - filter_test(index, g, 3, location="-122.4194,37.7749") + filter_test(query, index, g, 3, location="-122.4194,37.7749") g = Geo("location") != GeoRadius(-122.4194, 37.7749, 1, unit="m") - filter_test(index, g, 4, location="-110.0839,37.3861") + filter_test(query, index, g, 4, location="-110.0839,37.3861") # Text filters t = Text("job") == "engineer" - filter_test(index, t, 2) + filter_test(query, index, t, 2) t = Text("job") != "engineer" - filter_test(index, t, 5) + filter_test(query, index, t, 5) t = Text("job") % "enginee*" - filter_test(index, t, 2) + filter_test(query, index, t, 2) -def test_filter_combinations(index): +def test_filter_combinations(index, query): # test combinations # intersection t = Tag("credit_score") == "high" text = Text("job") == "engineer" - filter_test(index, t & text, 2, credit_check="high") + filter_test(query, index, t & text, 2, credit_check="high") # union t = Tag("credit_score") == "high" text = Text("job") == "engineer" - filter_test(index, t | text, 4, credit_check="high") + filter_test(query, index, t | text, 4, credit_check="high") # union of negated expressions _filter = (Tag("credit_score") != "high") & (Text("job") != "engineer") - filter_test(index, _filter, 3) + filter_test(query, index, _filter, 3) # geo + text g = Geo("location") == GeoRadius(-122.4194, 37.7749, 1, unit="m") text = Text("job") == "engineer" - filter_test(index, g & text, 1, location="-122.4194,37.7749") + filter_test(query, index, g & text, 1, location="-122.4194,37.7749") # geo + text g = Geo("location") != GeoRadius(-122.4194, 37.7749, 1, unit="m") text = Text("job") == "engineer" - filter_test(index, g & text, 1, location="-110.0839,37.3861") + filter_test(query, index, g & text, 1, location="-110.0839,37.3861") # num + text + geo n = (Num("age") >= 18) & (Num("age") < 100) t = Text("job") != "engineer" g = Geo("location") == GeoRadius(-122.4194, 37.7749, 1, unit="m") - filter_test(index, n & t & g, 1, age_range=(18, 99), location="-122.4194,37.7749") + filter_test( + query, index, n & t & g, 1, age_range=(18, 99), location="-122.4194,37.7749" + ) diff --git a/tests/integration/test_simple_async.py b/tests/integration/test_simple_async.py index 1a2b4867..d97a293d 100644 --- a/tests/integration/test_simple_async.py +++ b/tests/integration/test_simple_async.py @@ -102,4 +102,4 @@ async def test_simple(async_client): print("Score:", doc.vector_distance) pprint(doc) - await index.delete() \ No newline at end of file + await index.delete()