Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adds optional sorting parameter to FilterQuery #148

Merged
merged 3 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions redisvl/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(
return_fields: Optional[List[str]] = None,
num_results: int = 10,
dialect: int = 2,
sort_by: Optional[str] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's also add this support to vector queries too

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added to Vector and Range queries also

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

params: Optional[Dict[str, Any]] = None,
):
"""A query for a running a filtered search with a filter expression.
Expand All @@ -146,6 +147,8 @@ def __init__(
return_fields (Optional[List[str]], optional): The fields to return.
num_results (Optional[int], optional): The number of results to
return. Defaults to 10.
sort_by (Optional[str]): The field to order the results by. Defaults
to None. Results will be ordered by vector distance.
params (Optional[Dict[str, Any]], optional): The parameters for the
query. Defaults to None.

Expand All @@ -164,6 +167,7 @@ def __init__(
"""
super().__init__(return_fields, num_results, dialect)
self.set_filter(filter_expression)
self._sort_by = sort_by
self._params = params or {}

@property
Expand All @@ -180,6 +184,8 @@ def query(self) -> Query:
.paging(self._first, self._limit)
.dialect(self._dialect)
)
if self._sort_by:
query = query.sort_by(self._sort_by)
return query


Expand All @@ -201,12 +207,14 @@ def __init__(
num_results: int = 10,
return_score: bool = True,
dialect: int = 2,
sort_by: Optional[str] = None,
):
super().__init__(return_fields, num_results, dialect)
self.set_filter(filter_expression)
self._vector = vector
self._field = vector_field_name
self._dtype = dtype.lower()
self._sort_by = sort_by

if return_score:
self._return_fields.append(self.DISTANCE_ID)
Expand All @@ -223,6 +231,7 @@ def __init__(
num_results: int = 10,
return_score: bool = True,
dialect: int = 2,
sort_by: Optional[str] = None,
):
"""A query for running a vector search along with an optional filter
expression.
Expand All @@ -243,6 +252,8 @@ def __init__(
distance. Defaults to True.
dialect (int, optional): The RediSearch query dialect.
Defaults to 2.
sort_by (Optional[str]): The field to order the results by. Defaults
to None. Results will be ordered by vector distance.

Raises:
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
Expand All @@ -259,6 +270,7 @@ def __init__(
num_results,
return_score,
dialect,
sort_by,
)

@property
Expand All @@ -272,10 +284,13 @@ def query(self) -> Query:
query = (
Query(base_query)
.return_fields(*self._return_fields)
.sort_by(self.DISTANCE_ID)
.paging(self._first, self._limit)
.dialect(self._dialect)
)
if self._sort_by:
query = query.sort_by(self._sort_by)
else:
query = query.sort_by(self.DISTANCE_ID)
return query

@property
Expand Down Expand Up @@ -307,6 +322,7 @@ def __init__(
num_results: int = 10,
return_score: bool = True,
dialect: int = 2,
sort_by: Optional[str] = None,
):
"""A query for running a filtered vector search based on semantic
distance threshold.
Expand All @@ -330,7 +346,8 @@ def __init__(
distance. Defaults to True.
dialect (int, optional): The RediSearch query dialect.
Defaults to 2.

sort_by (Optional[str]): The field to order the results by. Defaults
to None. Results will be ordered by vector distance.
Raises:
TypeError: If filter_expression is not of type redisvl.query.FilterExpression

Expand All @@ -347,6 +364,7 @@ def __init__(
num_results,
return_score,
dialect,
sort_by,
)
self.set_distance_threshold(distance_threshold)

Expand Down Expand Up @@ -390,10 +408,13 @@ def query(self) -> Query:
query = (
Query(base_query)
.return_fields(*self._return_fields)
.sort_by(self.DISTANCE_ID)
.paging(self._first, self._limit)
.dialect(self._dialect)
)
if self._sort_by:
query = query.sort_by(self._sort_by)
else:
query = query.sort_by(self.DISTANCE_ID)
return query

@property
Expand Down
61 changes: 61 additions & 0 deletions tests/integration/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ def vector_query():
)


@pytest.fixture
def sorted_vector_query():
return VectorQuery(
vector=[0.1, 0.1, 0.5],
vector_field_name="user_embedding",
return_fields=["user", "credit_score", "age", "job", "location"],
sort_by="age",
)


@pytest.fixture
def filter_query():
return FilterQuery(
Expand All @@ -26,6 +36,15 @@ def filter_query():
)


@pytest.fixture
def sorted_filter_query():
return FilterQuery(
return_fields=["user", "credit_score", "age", "job", "location"],
filter_expression=Tag("credit_score") == "high",
sort_by="age",
)


@pytest.fixture
def range_query():
return RangeQuery(
Expand All @@ -36,6 +55,17 @@ def range_query():
)


@pytest.fixture
def sorted_range_query():
return RangeQuery(
vector=[0.1, 0.1, 0.5],
vector_field_name="user_embedding",
return_fields=["user", "credit_score", "age", "job", "location"],
distance_threshold=0.2,
sort_by="age",
)


@pytest.fixture
def index(sample_data, redis_url):
# construct a search index from the schema
Expand Down Expand Up @@ -160,6 +190,7 @@ def search(
age_range=None,
location=None,
distance_threshold=0.2,
sort=False,
):
"""Utility function to test filters."""

Expand Down Expand Up @@ -199,6 +230,21 @@ def search(
else:
assert len(results.docs) == expected_count

# check results are in sorted order
if sort:
if isinstance(query, RangeQuery):
assert [int(doc.age) for doc in results.docs] == [12, 14, 18, 100]
else:
assert [int(doc.age) for doc in results.docs] == [
12,
14,
15,
18,
35,
94,
100,
]


@pytest.fixture(
params=["vector_query", "filter_query", "range_query"],
Expand Down Expand Up @@ -339,3 +385,18 @@ def test_paginate_range_query(index, range_query):
assert len(all_results) == expected_count
assert i == expected_iterations
assert all(float(item["vector_distance"]) <= 0.2 for item in all_results)


def test_sort_filter_query(index, sorted_filter_query):
t = Text("job") % ""
search(sorted_filter_query, index, t, 7, sort=True)


def test_sort_vector_query(index, sorted_vector_query):
t = Text("job") % ""
search(sorted_vector_query, index, t, 7, sort=True)


def test_sort_range_query(index, sorted_range_query):
t = Text("job") % ""
search(sorted_range_query, index, t, 7, sort=True)
31 changes: 31 additions & 0 deletions tests/unit/test_query_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def test_filter_query():
assert isinstance(filter_query.params, dict)
assert filter_query.params == {}
assert filter_query._dialect == 2
assert filter_query._sort_by == None

# Test set_filter functionality
new_filter_expression = Tag("category") == "Sportswear"
Expand All @@ -57,6 +58,12 @@ def test_filter_query():
assert filter_query._limit == 7
assert filter_query._num_results == 10

# Test sort_by functionality
filter_query = FilterQuery(
filter_expression, return_fields, num_results=10, sort_by="price"
)
assert filter_query._sort_by == "price"


def test_vector_query():
# Create a vector query
Expand All @@ -73,6 +80,7 @@ def test_vector_query():
assert isinstance(vector_query.params, dict)
assert vector_query.params != {}
assert vector_query._dialect == 3
assert vector_query._sort_by == None

# Test set_filter functionality
new_filter_expression = Tag("category") == "Sportswear"
Expand All @@ -85,6 +93,17 @@ def test_vector_query():
assert vector_query._limit == 7
assert vector_query._num_results == 10

# Test sort_by functionality
vector_query = VectorQuery(
sample_vector,
"vector_field",
["field1", "field2"],
dialect=3,
num_results=10,
sort_by="field2",
)
assert vector_query._sort_by == "field2"


def test_range_query():
# Create a filter expression
Expand All @@ -104,6 +123,7 @@ def test_range_query():
assert isinstance(range_query.query, Query)
assert isinstance(range_query.params, dict)
assert range_query.params != {}
assert range_query._sort_by == None

# Test set_filter functionality
new_filter_expression = Tag("category") == "Outdoor"
Expand All @@ -115,3 +135,14 @@ def test_range_query():
assert range_query._first == 5
assert range_query._limit == 7
assert range_query._num_results == 10

# Test sort_by functionality
range_query = RangeQuery(
sample_vector,
"vector_field",
["field1"],
filter_expression,
num_results=10,
sort_by="field1",
)
assert range_query._sort_by == "field1"
Loading