Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions redisvl/index/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,7 @@ def _hybrid_search(self, query: HybridQuery, **kwargs) -> List[Dict[str, Any]]:
if query.postprocessing_config.build_args()
else None
),
params_substitution=query.params, # type: ignore[arg-type]
**kwargs,
) # type: ignore
return [convert_bytes(r) for r in results.results] # type: ignore[union-attr]
Expand Down Expand Up @@ -1938,6 +1939,7 @@ async def _hybrid_search(
if query.postprocessing_config.build_args()
else None
),
params_substitution=query.params, # type: ignore[arg-type]
**kwargs,
) # type: ignore
return [convert_bytes(r) for r in results.results] # type: ignore[union-attr]
Expand Down
27 changes: 15 additions & 12 deletions redisvl/query/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
text_field_name: str,
vector: Union[bytes, List[float]],
vector_field_name: str,
vector_param_name: str = "vector",
text_scorer: str = "BM25STD",
yield_text_score_as: Optional[str] = None,
vector_search_method: Optional[Literal["KNN", "RANGE"]] = None,
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(
text_field_name: The text field name to search in.
vector: The vector to perform vector similarity search.
vector_field_name: The vector field name to search in.
vector_param_name: The name of the parameter substitution containing the vector blob.
text_scorer: The text scorer to use. Options are {TFIDF, TFIDF.DOCNORM,
BM25STD, BM25STD.NORM, BM25STD.TANH, DISMAX, DOCSCORE, HAMMING}. Defaults to "BM25STD". For more
information about supported scoring algorithms,
Expand Down Expand Up @@ -146,9 +148,18 @@ def __init__(
text, text_field_name, filter_expression
)

if isinstance(vector, bytes):
vector_data = vector
else:
vector_data = array_to_buffer(vector, dtype)

self.params = {
vector_param_name: vector_data,
}

self.query = build_base_query(
text_query=query_string,
vector=vector,
vector_param_name=vector_param_name,
vector_field_name=vector_field_name,
text_scorer=text_scorer,
yield_text_score_as=yield_text_score_as,
Expand All @@ -159,7 +170,6 @@ def __init__(
range_epsilon=range_epsilon,
yield_vsim_score_as=yield_vsim_score_as,
filter_expression=filter_expression,
dtype=dtype,
)

if combination_method:
Expand All @@ -178,7 +188,7 @@ def __init__(

def build_base_query(
text_query: str,
vector: Union[bytes, List[float]],
vector_param_name: str,
vector_field_name: str,
text_scorer: str = "BM25STD",
yield_text_score_as: Optional[str] = None,
Expand All @@ -189,13 +199,12 @@ def build_base_query(
range_epsilon: Optional[float] = None,
yield_vsim_score_as: Optional[str] = None,
filter_expression: Optional[Union[str, FilterExpression]] = None,
dtype: str = "float32",
):
"""Build a Redis HybridQuery for performing hybrid search.

Args:
text_query: The query for the text search.
vector: The vector to perform vector similarity search.
vector_param_name: The name of the parameter substitution containing the vector blob.
vector_field_name: The vector field name to search in.
text_scorer: The text scorer to use. Options are {TFIDF, TFIDF.DOCNORM,
BM25STD, BM25STD.NORM, BM25STD.TANH, DISMAX, DOCSCORE, HAMMING}. Defaults to "BM25STD". For more
Expand All @@ -210,7 +219,6 @@ def build_base_query(
accuracy of the search.
yield_vsim_score_as: The name of the field to yield the vector similarity score as.
filter_expression: The filter expression to use for the vector similarity search. Defaults to None.
dtype: The data type of the vector. Defaults to "float32".

Notes:
If RRF combination method is used, then at least one of `rrf_window` or `rrf_constant` must be provided.
Expand Down Expand Up @@ -242,11 +250,6 @@ def build_base_query(
yield_score_as=yield_text_score_as,
)

if isinstance(vector, bytes):
vector_data = vector
else:
vector_data = array_to_buffer(vector, dtype)

# Serialize vector similarity search method and params, if specified
vsim_search_method: Optional[VectorSearchMethods] = None
vsim_search_method_params: Dict[str, Any] = {}
Expand Down Expand Up @@ -284,7 +287,7 @@ def build_base_query(
# Serialize the vector similarity query
vsim_query = HybridVsimQuery(
vector_field_name="@" + vector_field_name,
vector_data=vector_data,
vector_data="$" + vector_param_name,
vsim_search_method=vsim_search_method,
vsim_search_method_params=vsim_search_method_params,
filter=vsim_filter,
Expand Down
44 changes: 37 additions & 7 deletions tests/unit/test_hybrid_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def get_query_pieces(query: HybridQuery) -> List[str]:
pieces.extend(query.combination_method.get_args())
if query.postprocessing_config.build_args():
pieces.extend(query.postprocessing_config.build_args())
if query.params:
params = [
"PARAMS",
len(query.params) * 2,
] + [item for pair in query.params.items() for item in pair]
pieces.extend(params)
return pieces


Expand Down Expand Up @@ -76,10 +82,14 @@ def test_hybrid_query_basic_initialization():
"BM25STD",
"VSIM",
"@embedding",
bytes_vector,
"$vector",
"LIMIT",
"0",
"10",
"PARAMS",
2,
"vector",
bytes_vector,
]

# Verify that no combination method is set
Expand Down Expand Up @@ -126,7 +136,7 @@ def test_hybrid_query_with_all_parameters():
"text_score",
"VSIM",
"@embedding",
bytes_vector,
"$vector",
"KNN",
4,
"K",
Expand All @@ -149,11 +159,15 @@ def test_hybrid_query_with_all_parameters():
"LIMIT",
"0",
"10",
"PARAMS",
2,
"vector",
bytes_vector,
]

# Add post-processing and verify that it is reflected in the query
hybrid_query.postprocessing_config.limit(offset=10, num=20)
assert get_query_pieces(hybrid_query)[-3:] == ["LIMIT", "10", "20"]
assert hybrid_query.postprocessing_config.build_args() == ["LIMIT", "10", "20"]


# Stopwords tests
Expand Down Expand Up @@ -376,12 +390,16 @@ def test_hybrid_query_with_string_filter():
"BM25STD",
"VSIM",
"@embedding",
bytes_vector,
"$vector",
"FILTER",
"@category:{tech|science|engineering}",
"LIMIT",
"0",
"10",
"PARAMS",
2,
"vector",
bytes_vector,
]


Expand All @@ -405,12 +423,16 @@ def test_hybrid_query_with_tag_filter():
"BM25STD",
"VSIM",
"@embedding",
bytes_vector,
"$vector",
"FILTER",
"@genre:{comedy}",
"LIMIT",
"0",
"10",
"PARAMS",
2,
"vector",
bytes_vector,
]


Expand Down Expand Up @@ -645,10 +667,14 @@ def test_hybrid_query_special_characters_in_text():
"BM25STD",
"VSIM",
"@embedding",
bytes_vector,
"$vector",
"LIMIT",
"0",
"10",
"PARAMS",
2,
"vector",
bytes_vector,
]


Expand All @@ -672,10 +698,14 @@ def test_hybrid_query_unicode_text():
"BM25STD",
"VSIM",
"@embedding",
bytes_vector,
"$vector",
"LIMIT",
"0",
"10",
"PARAMS",
2,
"vector",
bytes_vector,
]


Expand Down