Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions redis/commands/search/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .commands import (
AGGREGATE_CMD,
CONFIG_CMD,
HYBRID_CMD,
INFO_CMD,
PROFILE_CMD,
SEARCH_CMD,
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(self, client, index_name="idx"):
self._RESP2_MODULE_CALLBACKS = {
INFO_CMD: self._parse_info,
SEARCH_CMD: self._parse_search,
HYBRID_CMD: self._parse_hybrid_search,
AGGREGATE_CMD: self._parse_aggregate,
PROFILE_CMD: self._parse_profile,
SPELLCHECK_CMD: self._parse_spellcheck,
Expand Down
147 changes: 144 additions & 3 deletions redis/commands/search/commands.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
import itertools
import time
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

from redis._parsers.helpers import pairs_to_dict
from redis.client import NEVER_DECODE, Pipeline
from redis.commands.search.hybrid_query import (
CombineResultsMethod,
HybridCursorQuery,
HybridPostProcessingConfig,
HybridQuery,
)
from redis.commands.search.hybrid_result import HybridCursorResult, HybridResult
from redis.utils import deprecated_function

from ..helpers import get_protocol_version
from ._util import to_string
from .aggregation import AggregateRequest, AggregateResult, Cursor
from .aggregation import (
AggregateRequest,
AggregateResult,
Cursor,
)
from .document import Document
from .field import Field
from .index_definition import IndexDefinition
Expand Down Expand Up @@ -47,6 +59,7 @@
SUGGET_COMMAND = "FT.SUGGET"
SYNUPDATE_CMD = "FT.SYNUPDATE"
SYNDUMP_CMD = "FT.SYNDUMP"
HYBRID_CMD = "FT.HYBRID"

NOOFFSETS = "NOOFFSETS"
NOFIELDS = "NOFIELDS"
Expand Down Expand Up @@ -84,6 +97,28 @@ def _parse_search(self, res, **kwargs):
field_encodings=kwargs["query"]._return_fields_decode_as,
)

def _parse_hybrid_search(self, res, **kwargs):
res_dict = pairs_to_dict(res, decode_keys=True)
if "cursor" in kwargs:
return HybridCursorResult(
search_cursor_id=int(res_dict["SEARCH"]),
vsim_cursor_id=int(res_dict["VSIM"]),
)

results: List[Dict[str, Any]] = []
# the original results are a list of lists
# we convert them to a list of dicts
for res_item in res_dict["results"]:
item_dict = pairs_to_dict(res_item, decode_keys=True)
results.append(item_dict)

return HybridResult(
total_results=int(res_dict["total_results"]),
results=results,
warnings=res_dict["warnings"],
execution_time=float(res_dict["execution_time"]),
)

def _parse_aggregate(self, res, **kwargs):
return self._get_aggregate_result(res, kwargs["query"], kwargs["has_cursor"])

Expand Down Expand Up @@ -470,7 +505,7 @@ def get_params_args(
return []
args = []
if len(query_params) > 0:
args.append("params")
args.append("PARAMS")
args.append(len(query_params) * 2)
for key, value in query_params.items():
args.append(key)
Expand Down Expand Up @@ -525,6 +560,59 @@ def search(
SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0
)

def hybrid_search(
self,
query: HybridQuery,
combine_method: Optional[CombineResultsMethod] = None,
post_processing: Optional[HybridPostProcessingConfig] = None,
params_substitution: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
timeout: Optional[int] = None,
cursor: Optional[HybridCursorQuery] = None,
) -> Union[HybridResult, HybridCursorResult, Pipeline]:
"""
Execute a hybrid search using both text and vector queries

Args:
- **query**: HybridQuery object
Contains the text and vector queries
- **combine_method**: CombineResultsMethod object
Contains the combine method and parameters
- **post_processing**: HybridPostProcessingConfig object
Contains the post processing configuration
- **params_substitution**: Dict[str, Union[str, int, float, bytes]]
Contains the parameters substitution
- **timeout**: int - contains the timeout in milliseconds
- **cursor**: HybridCursorQuery object - contains the cursor configuration


For more information see `FT.SEARCH <https://redis.io/commands/ft.hybrid>`.
"""
index = self.index_name
options = {}
pieces = [HYBRID_CMD, index]
pieces.extend(query.get_args())
if combine_method:
pieces.extend(combine_method.get_args())
if post_processing:
pieces.extend(post_processing.build_args())
if params_substitution:
pieces.extend(self.get_params_args(params_substitution))
if timeout:
pieces.extend(("TIMEOUT", timeout))
if cursor:
options["cursor"] = True
pieces.extend(cursor.build_args())

if get_protocol_version(self.client) not in ["3", 3]:
options[NEVER_DECODE] = True

res = self.execute_command(*pieces, **options)

if isinstance(res, Pipeline):
return res

return self._parse_results(HYBRID_CMD, res, **options)

def explain(
self,
query: Union[str, Query],
Expand Down Expand Up @@ -965,6 +1053,59 @@ async def search(
SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0
)

async def hybrid_search(
self,
query: HybridQuery,
combine_method: Optional[CombineResultsMethod] = None,
post_processing: Optional[HybridPostProcessingConfig] = None,
params_substitution: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
timeout: Optional[int] = None,
cursor: Optional[HybridCursorQuery] = None,
) -> Union[HybridResult, HybridCursorResult, Pipeline]:
"""
Execute a hybrid search using both text and vector queries

Args:
- **query**: HybridQuery object
Contains the text and vector queries
- **combine_method**: CombineResultsMethod object
Contains the combine method and parameters
- **post_processing**: HybridPostProcessingConfig object
Contains the post processing configuration
- **params_substitution**: Dict[str, Union[str, int, float, bytes]]
Contains the parameters substitution
- **timeout**: int - contains the timeout in milliseconds
- **cursor**: HybridCursorQuery object - contains the cursor configuration


For more information see `FT.SEARCH <https://redis.io/commands/ft.hybrid>`.
"""
index = self.index_name
options = {}
pieces = [HYBRID_CMD, index]
pieces.extend(query.get_args())
if combine_method:
pieces.extend(combine_method.get_args())
if post_processing:
pieces.extend(post_processing.build_args())
if params_substitution:
pieces.extend(self.get_params_args(params_substitution))
if timeout:
pieces.extend(("TIMEOUT", timeout))
if cursor:
options["cursor"] = True
pieces.extend(cursor.build_args())

if get_protocol_version(self.client) not in ["3", 3]:
options[NEVER_DECODE] = True

res = await self.execute_command(*pieces, **options)

if isinstance(res, Pipeline):
return res

return self._parse_results(HYBRID_CMD, res, **options)

async def aggregate(
self,
query: Union[AggregateResult, Cursor],
Expand Down
Loading