diff --git a/redisvl/index/index.py b/redisvl/index/index.py index d3919f8d..4bf928ed 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -1064,6 +1064,7 @@ def _hybrid_search(self, query: HybridQuery, **kwargs) -> List[Dict[str, Any]]: if query.postprocessing_config.build_args() else None ), + params_substitution=query.params, # type: ignore[arg-type] **kwargs, ) # type: ignore return [convert_bytes(r) for r in results.results] # type: ignore[union-attr] @@ -1938,6 +1939,7 @@ async def _hybrid_search( if query.postprocessing_config.build_args() else None ), + params_substitution=query.params, # type: ignore[arg-type] **kwargs, ) # type: ignore return [convert_bytes(r) for r in results.results] # type: ignore[union-attr] diff --git a/redisvl/query/hybrid.py b/redisvl/query/hybrid.py index cfcb0c93..edc19e24 100644 --- a/redisvl/query/hybrid.py +++ b/redisvl/query/hybrid.py @@ -49,6 +49,7 @@ def __init__( text_field_name: str, vector: Union[bytes, List[float]], vector_field_name: str, + vector_param_name: str = "vector", text_scorer: str = "BM25STD", yield_text_score_as: Optional[str] = None, vector_search_method: Optional[Literal["KNN", "RANGE"]] = None, @@ -76,6 +77,7 @@ def __init__( text_field_name: The text field name to search in. vector: The vector to perform vector similarity search. vector_field_name: The vector field name to search in. + vector_param_name: The name of the parameter substitution containing the vector blob. text_scorer: The text scorer to use. Options are {TFIDF, TFIDF.DOCNORM, BM25STD, BM25STD.NORM, BM25STD.TANH, DISMAX, DOCSCORE, HAMMING}. Defaults to "BM25STD". For more information about supported scoring algorithms, @@ -146,9 +148,18 @@ def __init__( text, text_field_name, filter_expression ) + if isinstance(vector, bytes): + vector_data = vector + else: + vector_data = array_to_buffer(vector, dtype) + + self.params = { + vector_param_name: vector_data, + } + self.query = build_base_query( text_query=query_string, - vector=vector, + vector_param_name=vector_param_name, vector_field_name=vector_field_name, text_scorer=text_scorer, yield_text_score_as=yield_text_score_as, @@ -159,7 +170,6 @@ def __init__( range_epsilon=range_epsilon, yield_vsim_score_as=yield_vsim_score_as, filter_expression=filter_expression, - dtype=dtype, ) if combination_method: @@ -178,7 +188,7 @@ def __init__( def build_base_query( text_query: str, - vector: Union[bytes, List[float]], + vector_param_name: str, vector_field_name: str, text_scorer: str = "BM25STD", yield_text_score_as: Optional[str] = None, @@ -189,13 +199,12 @@ def build_base_query( range_epsilon: Optional[float] = None, yield_vsim_score_as: Optional[str] = None, filter_expression: Optional[Union[str, FilterExpression]] = None, - dtype: str = "float32", ): """Build a Redis HybridQuery for performing hybrid search. Args: text_query: The query for the text search. - vector: The vector to perform vector similarity search. + vector_param_name: The name of the parameter substitution containing the vector blob. vector_field_name: The vector field name to search in. text_scorer: The text scorer to use. Options are {TFIDF, TFIDF.DOCNORM, BM25STD, BM25STD.NORM, BM25STD.TANH, DISMAX, DOCSCORE, HAMMING}. Defaults to "BM25STD". For more @@ -210,7 +219,6 @@ def build_base_query( accuracy of the search. yield_vsim_score_as: The name of the field to yield the vector similarity score as. filter_expression: The filter expression to use for the vector similarity search. Defaults to None. - dtype: The data type of the vector. Defaults to "float32". Notes: If RRF combination method is used, then at least one of `rrf_window` or `rrf_constant` must be provided. @@ -242,11 +250,6 @@ def build_base_query( yield_score_as=yield_text_score_as, ) - if isinstance(vector, bytes): - vector_data = vector - else: - vector_data = array_to_buffer(vector, dtype) - # Serialize vector similarity search method and params, if specified vsim_search_method: Optional[VectorSearchMethods] = None vsim_search_method_params: Dict[str, Any] = {} @@ -284,7 +287,7 @@ def build_base_query( # Serialize the vector similarity query vsim_query = HybridVsimQuery( vector_field_name="@" + vector_field_name, - vector_data=vector_data, + vector_data="$" + vector_param_name, vsim_search_method=vsim_search_method, vsim_search_method_params=vsim_search_method_params, filter=vsim_filter, diff --git a/tests/unit/test_hybrid_types.py b/tests/unit/test_hybrid_types.py index b78e8ab5..843966ed 100644 --- a/tests/unit/test_hybrid_types.py +++ b/tests/unit/test_hybrid_types.py @@ -49,6 +49,12 @@ def get_query_pieces(query: HybridQuery) -> List[str]: pieces.extend(query.combination_method.get_args()) if query.postprocessing_config.build_args(): pieces.extend(query.postprocessing_config.build_args()) + if query.params: + params = [ + "PARAMS", + len(query.params) * 2, + ] + [item for pair in query.params.items() for item in pair] + pieces.extend(params) return pieces @@ -76,10 +82,14 @@ def test_hybrid_query_basic_initialization(): "BM25STD", "VSIM", "@embedding", - bytes_vector, + "$vector", "LIMIT", "0", "10", + "PARAMS", + 2, + "vector", + bytes_vector, ] # Verify that no combination method is set @@ -126,7 +136,7 @@ def test_hybrid_query_with_all_parameters(): "text_score", "VSIM", "@embedding", - bytes_vector, + "$vector", "KNN", 4, "K", @@ -149,11 +159,15 @@ def test_hybrid_query_with_all_parameters(): "LIMIT", "0", "10", + "PARAMS", + 2, + "vector", + bytes_vector, ] # Add post-processing and verify that it is reflected in the query hybrid_query.postprocessing_config.limit(offset=10, num=20) - assert get_query_pieces(hybrid_query)[-3:] == ["LIMIT", "10", "20"] + assert hybrid_query.postprocessing_config.build_args() == ["LIMIT", "10", "20"] # Stopwords tests @@ -376,12 +390,16 @@ def test_hybrid_query_with_string_filter(): "BM25STD", "VSIM", "@embedding", - bytes_vector, + "$vector", "FILTER", "@category:{tech|science|engineering}", "LIMIT", "0", "10", + "PARAMS", + 2, + "vector", + bytes_vector, ] @@ -405,12 +423,16 @@ def test_hybrid_query_with_tag_filter(): "BM25STD", "VSIM", "@embedding", - bytes_vector, + "$vector", "FILTER", "@genre:{comedy}", "LIMIT", "0", "10", + "PARAMS", + 2, + "vector", + bytes_vector, ] @@ -645,10 +667,14 @@ def test_hybrid_query_special_characters_in_text(): "BM25STD", "VSIM", "@embedding", - bytes_vector, + "$vector", "LIMIT", "0", "10", + "PARAMS", + 2, + "vector", + bytes_vector, ] @@ -672,10 +698,14 @@ def test_hybrid_query_unicode_text(): "BM25STD", "VSIM", "@embedding", - bytes_vector, + "$vector", "LIMIT", "0", "10", + "PARAMS", + 2, + "vector", + bytes_vector, ]