From 6f674fcf47b87ebd01c3fc199bc28e95771f4d3b Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Wed, 8 Oct 2025 13:03:47 +0300 Subject: [PATCH 1/7] Adding support for hybrid search. --- redis/commands/search/__init__.py | 2 + redis/commands/search/commands.py | 136 +- redis/commands/search/hybrid_query.py | 331 + redis/commands/search/hybrid_result.py | 27 + tests/test_asyncio/test_search.py | 4871 +++++++++------ tests/test_search.py | 7953 ++++++++++++++---------- 6 files changed, 8095 insertions(+), 5225 deletions(-) create mode 100644 redis/commands/search/hybrid_query.py create mode 100644 redis/commands/search/hybrid_result.py diff --git a/redis/commands/search/__init__.py b/redis/commands/search/__init__.py index 9fda9ee4d6..ff14077021 100644 --- a/redis/commands/search/__init__.py +++ b/redis/commands/search/__init__.py @@ -4,6 +4,7 @@ from .commands import ( AGGREGATE_CMD, CONFIG_CMD, + HYBRID_CMD, INFO_CMD, PROFILE_CMD, SEARCH_CMD, @@ -102,6 +103,7 @@ def __init__(self, client, index_name="idx"): self._RESP2_MODULE_CALLBACKS = { INFO_CMD: self._parse_info, SEARCH_CMD: self._parse_search, + HYBRID_CMD: self._parse_hybrid_search, AGGREGATE_CMD: self._parse_aggregate, PROFILE_CMD: self._parse_profile, SPELLCHECK_CMD: self._parse_spellcheck, diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index afe5d1c684..ae50a303ff 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -1,13 +1,24 @@ import itertools import time -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union +from redis._parsers.helpers import pairs_to_dict from redis.client import NEVER_DECODE, Pipeline +from redis.commands.search.hybrid_query import ( + HybridCursorQuery, + HybridPostProcessingConfig, + HybridQuery, +) +from redis.commands.search.hybrid_result import HybridCursorResult, HybridResult from redis.utils import deprecated_function from ..helpers import get_protocol_version from ._util import to_string -from .aggregation import AggregateRequest, AggregateResult, Cursor +from .aggregation import ( + AggregateRequest, + AggregateResult, + Cursor, +) from .document import Document from .field import Field from .index_definition import IndexDefinition @@ -47,6 +58,7 @@ SUGGET_COMMAND = "FT.SUGGET" SYNUPDATE_CMD = "FT.SYNUPDATE" SYNDUMP_CMD = "FT.SYNDUMP" +HYBRID_CMD = "FT.HYBRID" NOOFFSETS = "NOOFFSETS" NOFIELDS = "NOFIELDS" @@ -84,6 +96,28 @@ def _parse_search(self, res, **kwargs): field_encodings=kwargs["query"]._return_fields_decode_as, ) + def _parse_hybrid_search(self, res, **kwargs): + res_dict = pairs_to_dict(res, decode_keys=True) + if "cursor" in kwargs: + return HybridCursorResult( + search_cursor_id=int(res_dict["SEARCH"]), + vsim_cursor_id=int(res_dict["VSIM"]), + ) + + results: List[Dict[str, Any]] = [] + # the original results are a list of lists + # we convert them to a list of dicts + for res_item in res_dict["results"]: + item_dict = pairs_to_dict(res_item, decode_keys=True) + results.append(item_dict) + + return HybridResult( + total_results=int(res_dict["total_results"]), + results=results, + warnings=res_dict["warnings"], + execution_time=float(res_dict["execution_time"]), + ) + def _parse_aggregate(self, res, **kwargs): return self._get_aggregate_result(res, kwargs["query"], kwargs["has_cursor"]) @@ -470,7 +504,7 @@ def get_params_args( return [] args = [] if len(query_params) > 0: - args.append("params") + args.append("PARAMS") args.append(len(query_params) * 2) for key, value in query_params.items(): args.append(key) @@ -525,6 +559,54 @@ def search( SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0 ) + def hybrid_search( + self, + query: HybridQuery, + post_processing: Optional[HybridPostProcessingConfig] = None, + params_substitution: Optional[Dict[str, Union[str, int, float, bytes]]] = None, + timeout: Optional[int] = None, + cursor: Optional[HybridCursorQuery] = None, + ) -> Union[HybridResult, HybridCursorResult, Pipeline]: + """ + Execute a hybrid search using both text and vector queries + + Args: + - **query**: HybridQuery object + Contains the text and vector queries + - **post_processing**: HybridPostProcessingConfig object + Contains the post processing configuration + - **params_substitution**: Dict[str, Union[str, int, float, bytes]] + Contains the parameters substitution + - **timeout**: int - contains the timeout in milliseconds + - **cursor**: HybridCursorQuery object - contains the cursor configuration + + + For more information see `FT.SEARCH `. + """ + index = self.index_name + options = {} + pieces = [HYBRID_CMD, index] + pieces.extend(query.get_args()) + if post_processing: + pieces.extend(post_processing.build_args()) + if params_substitution: + pieces.extend(self.get_params_args(params_substitution)) + if timeout: + pieces.extend(("TIMEOUT", timeout)) + if cursor: + options["cursor"] = True + pieces.extend(cursor.build_args()) + + if get_protocol_version(self.client) not in ["3", 3]: + options[NEVER_DECODE] = True + + res = self.execute_command(*pieces, **options) + + if isinstance(res, Pipeline): + return res + + return self._parse_results(HYBRID_CMD, res, **options) + def explain( self, query: Union[str, Query], @@ -965,6 +1047,54 @@ async def search( SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0 ) + async def hybrid_search( + self, + query: HybridQuery, + post_processing: Optional[HybridPostProcessingConfig] = None, + params_substitution: Optional[Dict[str, Union[str, int, float, bytes]]] = None, + timeout: Optional[int] = None, + cursor: Optional[HybridCursorQuery] = None, + ) -> Union[HybridResult, HybridCursorResult, Pipeline]: + """ + Execute a hybrid search using both text and vector queries + + Args: + - **query**: HybridQuery object + Contains the text and vector queries + - **post_processing**: HybridPostProcessingConfig object + Contains the post processing configuration + - **params_substitution**: Dict[str, Union[str, int, float, bytes]] + Contains the parameters substitution + - **timeout**: int - contains the timeout in milliseconds + - **cursor**: HybridCursorQuery object - contains the cursor configuration + + + For more information see `FT.SEARCH `. + """ + index = self.index_name + options = {} + pieces = [HYBRID_CMD, index] + pieces.extend(query.get_args()) + if post_processing: + pieces.extend(post_processing.build_args()) + if params_substitution: + pieces.extend(self.get_params_args(params_substitution)) + if timeout: + pieces.extend(("TIMEOUT", timeout)) + if cursor: + options["cursor"] = True + pieces.extend(cursor.build_args()) + + if get_protocol_version(self.client) not in ["3", 3]: + options[NEVER_DECODE] = True + + res = await self.execute_command(*pieces, **options) + + if isinstance(res, Pipeline): + return res + + return self._parse_results(HYBRID_CMD, res, **options) + async def aggregate( self, query: Union[AggregateResult, Cursor], diff --git a/redis/commands/search/hybrid_query.py b/redis/commands/search/hybrid_query.py new file mode 100644 index 0000000000..ddcb2593e9 --- /dev/null +++ b/redis/commands/search/hybrid_query.py @@ -0,0 +1,331 @@ +from typing import Any, Dict, List, Literal, Optional, Union + +try: + from typing import Self # Py 3.11+ +except ImportError: + from typing_extensions import Self + +from redis.commands.search.aggregation import Limit, Reducer +from redis.commands.search.query import Filter, SortbyField + + +class HybridSearchQuery: + def __init__( + self, + query_string: str, + scorer: Optional[str] = None, + yield_score_as: Optional[ + str + ] = None, ## TODO check if this will be supported or it should be removed! + ) -> None: + """ + Create a new hybrid search query object. + + Args: + query_string: The query string. + scorer: The scorer to use. Allowed values are "TFIDF" or "BM25". + yield_score_as: The name of the field to yield the score as. + """ + self._query_string = query_string + self._scorer = scorer + self._yield_score_as = yield_score_as + + def query_string(self) -> str: + """Return the query string of this query object.""" + return self._query_string + + def scorer(self, scorer: str) -> "HybridSearchQuery": + """ + Scoring algorithm for text search query. + Allowed values are "TFIDF" or "BM25" + """ + self._scorer = scorer + return self + + def get_args(self) -> List[str]: + args = ["SEARCH", self._query_string] + if self._scorer: + args.extend(("SCORER", self._scorer)) + if ( + self._yield_score_as + ): # TODO check if this will be supported or it should be removed! + args.extend(("YIELD_SCORE_AS", self._yield_score_as)) + return args + + +class HybridVsimQuery: + def __init__( + self, + vector_field_name: str, + vector_data: Union[bytes, str], + vsim_search_method: Optional[str] = None, + vsim_search_method_params: Optional[Dict[str, Any]] = None, + filter: Optional["Filter"] = None, + ) -> None: + """ + Create a new hybrid vsim query object. + + Args: + vector_field_name: Vector field name. + vector_data: Vector data for the search. + vsim_search_method: Search method that will be used for the vsim search. + Allowed values are "KNN" or "RANGE". + vsim_search_method_params: Search method parameters. Use the param names + for keys and the values for the values. Example: {"K": 10, "EF_RUNTIME": 100}. + filter: If defined, a filter will be applied on the vsim query results. + """ + self._vector_field = vector_field_name + self._vector_data = vector_data + if vsim_search_method and vsim_search_method_params: + self.vsim_method_params(vsim_search_method, **vsim_search_method_params) + else: + self._vsim_method_params = None + self._filter = filter + + def vector_field(self) -> str: + """Return the vector field name of this query object.""" + return self._vector_field + + def vector_data(self) -> Union[bytes, str]: + """Return the vector data of this query object.""" + return self._vector_data + + def vsim_method_params( + self, + method: str, + **kwargs, + ) -> "HybridVsimQuery": + """ + Add search method parameters to the query. + + Args: + method: Vector search method name. Supported values are "KNN" or "RANGE". + kwargs: Search method parameters. Use the param names for keys and the + values for the values. Example: {"K": 10, "EF_RUNTIME": 100}. + """ + vsim_method_params: List[Union[str, int]] = [method] + if kwargs: + vsim_method_params.append(len(kwargs.items()) * 2) + for key, value in kwargs.items(): + vsim_method_params.extend((key, value)) + self._vsim_method_params = vsim_method_params + print(self._vsim_method_params) + return self + + def filter(self, flt: "HybridFilter") -> "HybridVsimQuery": + """ + Add a filter to the query. + + Args: + flt: A HybridFilter object, used on a corresponding field. + """ + self._filter = flt + return self + + def get_args(self) -> List[str]: + args = ["VSIM", self._vector_field, self._vector_data] + if self._vsim_method_params: + args.extend(self._vsim_method_params) + if self._filter: + args.extend(self._filter.args) + + return args + + +class HybridQuery: + def __init__( + self, + search_query: HybridSearchQuery, + vector_similarity_query: HybridVsimQuery, + ) -> None: + """ + Create a new hybrid query object. + + Args: + search_query: HybridSearchQuery object containing the text query. + vector_similarity_query: HybridVsimQuery object containing the vector similarity query. + """ + self._search_query = search_query + self._vector_similarity_query = vector_similarity_query + + def get_args(self) -> List[str]: + args = [] + args.extend(self._search_query.get_args()) + args.extend(self._vector_similarity_query.get_args()) + return args + + +class HybridPostProcessingConfig: + def __init__(self) -> None: + """ + Create a new hybrid post processing configuration object. + """ + self._combine = [] + self._load_fields = [] + self._groupby = [] + self._apply = [] + self._sortby_fields = [] + self._filter = None + self._limit = None + + def combine( + self, + method: Literal["RRF", "LINEAR"], + yield_score_as: Optional[ + str + ] = None, # TODO check if this will be supported or it should be removed! + **kwargs, + ) -> Self: + """ + Add combine parameters to the query. + + Args: + method: The combine method to use - RRF or LINEAR. + yield_score_as: Optional field name to yield the score as. + kwargs: Additional combine parameters. + """ + self._combine: List[Union[str, int]] = [method] + + self._combine.append(len(kwargs) * 2) + + for key, value in kwargs.items(): + self._combine.extend([key, value]) + + if ( + yield_score_as + ): # TODO check if this will be supported or it should be removed! + self._combine.extend(["YIELD_SCORE_AS", yield_score_as]) + return self + + def load(self, *fields: str) -> Self: + """ + Add load parameters to the query. + """ + self._load_fields = fields + return self + + def group_by(self, fields: List[str], *reducers: Reducer) -> Self: + """ + Specify by which fields to group the aggregation. + + Args: + fields: Fields to group by. This can either be a single string or a list + of strings. In both cases, the field should be specified as `@field`. + reducers: One or more reducers. Reducers may be found in the + `aggregation` module. + """ + + fields = [fields] if isinstance(fields, str) else fields + + ret = ["GROUPBY", str(len(fields)), *fields] + for reducer in reducers: + ret += ["REDUCE", reducer.NAME, str(len(reducer.args))] + ret.extend(reducer.args) + if reducer._alias is not None: + ret += ["AS", reducer._alias] + + self._groupby.extend(ret) + return self + + def apply(self, **kwexpr) -> Self: + """ + Specify one or more projection expressions to add to each result. + + Args: + kwexpr: One or more key-value pairs for a projection. The key is + the alias for the projection, and the value is the projection + expression itself, for example `apply(square_root="sqrt(@foo)")`. + """ + for alias, expr in kwexpr.items(): + ret = ["APPLY", expr] + if alias is not None: + ret += ["AS", alias] + self._apply.extend(ret) + + return self + + def sort_by(self, *sortby: "SortbyField") -> Self: + """ + Add sortby parameters to the query. + """ + self._sortby_fields = [*sortby] + return self + + def filter(self, filter: "HybridFilter") -> Self: + """ + Add a numeric or string filter to the query. + + Currently, only one of each filter is supported by the engine. + + Args: + filter: A NumericFilter or GeoFilter object, used on a corresponding field. + """ + self._filter = filter + return self + + def limit(self, offset: int, num: int) -> Self: + """ + Add limit parameters to the query. + """ + self._limit = Limit(offset, num) + return self + + def build_args(self) -> List[str]: + args = [] + if self._combine: + args.extend(("COMBINE", *self._combine)) + if self._load_fields: + fields_str = " ".join(self._load_fields) + fields = fields_str.split(" ") + args.extend(("LOAD", len(fields), *fields)) + if self._groupby: + args.extend(self._groupby) + if self._apply: + args.extend(self._apply) + if self._sortby_fields: + sortby_args = [] + for f in self._sortby_fields: + sortby_args.extend(f.args) + args.extend(("SORTBY", len(sortby_args), *sortby_args)) + if self._filter: + args.extend(self._filter.args) + if self._limit: + args.extend(self._limit.build_args()) + + return args + + +class HybridFilter(Filter): + def __init__( + self, + conditions: str, + ) -> None: + """ + Create a new hybrid filter object. + + Args: + conditions: Filter conditions. + """ + args = [conditions] + Filter.__init__(self, "FILTER", *args) + + +class HybridCursorQuery: + def __init__(self, count: int = 0, max_idle: int = 0) -> None: + """ + Create a new hybrid cursor query object. + + Args: + count: Number of results to return per cursor iteration. + max_idle: Maximum idle time for the cursor. + """ + self.count = count + self.max_idle = max_idle + + def build_args(self): + args = ["WITHCURSOR"] + if self.count: + args += ["COUNT", str(self.count)] + if self.max_idle: + args += ["MAXIDLE", str(self.max_idle)] + return args diff --git a/redis/commands/search/hybrid_result.py b/redis/commands/search/hybrid_result.py new file mode 100644 index 0000000000..1154fc040d --- /dev/null +++ b/redis/commands/search/hybrid_result.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Union + + +@dataclass +class HybridResult: + """ + Represents the result of a hybrid search query execution + Returned by the `hybrid_search` command, when using RESP version 2. + """ + + total_results: int + results: List[Dict[str, Any]] + warnings: List[Union[str, bytes]] + execution_time: float + + +class HybridCursorResult: + def __init__(self, search_cursor_id: int, vsim_cursor_id: int) -> None: + """ + Represents the result of a hybrid search query execution with cursor + + search_cursor_id: int - cursor id for the search query + vsim_cursor_id: int - cursor id for the vector similarity query + """ + self.search_cursor_id = search_cursor_id + self.vsim_cursor_id = vsim_cursor_id diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 7f11710cc4..45c9d1207b 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -3,6 +3,7 @@ import os import asyncio from io import TextIOWrapper +import random import numpy as np import pytest @@ -10,6 +11,15 @@ from redis import ResponseError import redis.asyncio as redis import redis.commands.search.aggregation as aggregations +from redis.commands.search.hybrid_query import ( + HybridCursorQuery, + HybridFilter, + HybridPostProcessingConfig, + HybridQuery, + HybridSearchQuery, + HybridVsimQuery, +) +from redis.commands.search.hybrid_result import HybridCursorResult import redis.commands.search.reducers as reducers from redis.commands.search import AsyncSearch from redis.commands.search.field import ( @@ -20,9 +30,10 @@ VectorField, ) from redis.commands.search.index_definition import IndexDefinition, IndexType -from redis.commands.search.query import GeoFilter, NumericFilter, Query +from redis.commands.search.query import GeoFilter, NumericFilter, Query, SortbyField from redis.commands.search.result import Result from redis.commands.search.suggestion import Suggestion +from redis.utils import safe_str from tests.conftest import ( is_resp2_connection, skip_if_redis_enterprise, @@ -41,1972 +52,3166 @@ ) -@pytest_asyncio.fixture() -async def decoded_r(create_redis, stack_url): - return await create_redis(decode_responses=True, url=stack_url) - +class AsyncSearchTestsBase: + @pytest_asyncio.fixture() + async def decoded_r(self, create_redis, stack_url): + return await create_redis(decode_responses=True, url=stack_url) -async def waitForIndex(env, idx, timeout=None): - delay = 0.1 - while True: - try: - res = await env.execute_command("FT.INFO", idx) - if int(res[res.index("indexing") + 1]) == 0: - break - except ValueError: - break - except AttributeError: + @staticmethod + async def waitForIndex(env, idx, timeout=None): + delay = 0.1 + while True: try: - if int(res["indexing"]) == 0: + res = await env.execute_command("FT.INFO", idx) + if int(res[res.index("indexing") + 1]) == 0: break except ValueError: break - except ResponseError: - # index doesn't exist yet - # continue to sleep and try again - pass - - await asyncio.sleep(delay) - if timeout is not None: - timeout -= delay - if timeout <= 0: - break - - -def getClient(decoded_r: redis.Redis): - """ - Gets a client client attached to an index name which is ready to be - created - """ - return decoded_r - - -async def createIndex(decoded_r, num_docs=100, definition=None): - try: - await decoded_r.create_index( - (TextField("play", weight=5.0), TextField("txt"), NumericField("chapter")), - definition=definition, - ) - except redis.ResponseError: - await decoded_r.dropindex(delete_documents=True) - return createIndex(decoded_r, num_docs=num_docs, definition=definition) - - chapters = {} - bzfp = TextIOWrapper(bz2.BZ2File(WILL_PLAY_TEXT), encoding="utf8") - - r = csv.reader(bzfp, delimiter=";") - for n, line in enumerate(r): - play, chapter, _, text = line[1], line[2], line[4], line[5] - - key = f"{play}:{chapter}".lower() - d = chapters.setdefault(key, {}) - d["play"] = play - d["txt"] = d.get("txt", "") + " " + text - d["chapter"] = int(chapter or 0) - if len(chapters) == num_docs: - break - - indexer = decoded_r.batch_indexer(chunk_size=50) - assert isinstance(indexer, AsyncSearch.BatchIndexer) - assert 50 == indexer.chunk_size - - for key, doc in chapters.items(): - await indexer.client.client.hset(key, mapping=doc) - await indexer.commit() - - -@pytest.mark.redismod -async def test_client(decoded_r: redis.Redis): - num_docs = 500 - await createIndex(decoded_r.ft(), num_docs=num_docs) - await waitForIndex(decoded_r, "idx") - # verify info - info = await decoded_r.ft().info() - for k in [ - "index_name", - "index_options", - "attributes", - "num_docs", - "max_doc_id", - "num_terms", - "num_records", - "inverted_sz_mb", - "offset_vectors_sz_mb", - "doc_table_size_mb", - "key_table_size_mb", - "records_per_doc_avg", - "bytes_per_record_avg", - "offsets_per_term_avg", - "offset_bits_per_record_avg", - ]: - assert k in info - - assert decoded_r.ft().index_name == info["index_name"] - assert num_docs == int(info["num_docs"]) - - res = await decoded_r.ft().search("henry iv") - if is_resp2_connection(decoded_r): - assert isinstance(res, Result) - assert 225 == res.total - assert 10 == len(res.docs) - assert res.duration > 0 - - for doc in res.docs: - assert doc.id - assert doc.play == "Henry IV" - assert len(doc.txt) > 0 - - # test no content - res = await decoded_r.ft().search(Query("king").no_content()) - assert 194 == res.total - assert 10 == len(res.docs) - for doc in res.docs: - assert "txt" not in doc.__dict__ - assert "play" not in doc.__dict__ - - # test verbatim vs no verbatim - total = (await decoded_r.ft().search(Query("kings").no_content())).total - vtotal = ( - await decoded_r.ft().search(Query("kings").no_content().verbatim()) - ).total - assert total > vtotal - - # test in fields - txt_total = ( - await decoded_r.ft().search(Query("henry").no_content().limit_fields("txt")) - ).total - play_total = ( - await decoded_r.ft().search( - Query("henry").no_content().limit_fields("play") - ) - ).total - both_total = ( - await decoded_r.ft().search( - Query("henry").no_content().limit_fields("play", "txt") - ) - ).total - assert 129 == txt_total - assert 494 == play_total - assert 494 == both_total - - # test load_document - doc = await decoded_r.ft().load_document("henry vi part 3:62") - assert doc is not None - assert "henry vi part 3:62" == doc.id - assert doc.play == "Henry VI Part 3" - assert len(doc.txt) > 0 - - # test in-keys - ids = [x.id for x in (await decoded_r.ft().search(Query("henry"))).docs] - assert 10 == len(ids) - subset = ids[:5] - docs = await decoded_r.ft().search(Query("henry").limit_ids(*subset)) - assert len(subset) == docs.total - ids = [x.id for x in docs.docs] - assert set(ids) == set(subset) - - # test slop and in order - assert 193 == (await decoded_r.ft().search(Query("henry king"))).total - assert ( - 3 - == ( - await decoded_r.ft().search(Query("henry king").slop(0).in_order()) - ).total - ) - assert ( - 52 - == ( - await decoded_r.ft().search(Query("king henry").slop(0).in_order()) - ).total - ) - assert 53 == (await decoded_r.ft().search(Query("henry king").slop(0))).total - assert 167 == (await decoded_r.ft().search(Query("henry king").slop(100))).total - - # test delete document - await decoded_r.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) - res = await decoded_r.ft().search(Query("death of a salesman")) - assert 1 == res.total - - assert 1 == await decoded_r.ft().delete_document("doc-5ghs2") - res = await decoded_r.ft().search(Query("death of a salesman")) - assert 0 == res.total - assert 0 == await decoded_r.ft().delete_document("doc-5ghs2") - - await decoded_r.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) - res = await decoded_r.ft().search(Query("death of a salesman")) - assert 1 == res.total - await decoded_r.ft().delete_document("doc-5ghs2") - else: - assert isinstance(res, dict) - assert 225 == res["total_results"] - assert 10 == len(res["results"]) - - for doc in res["results"]: - assert doc["id"] - assert doc["extra_attributes"]["play"] == "Henry IV" - assert len(doc["extra_attributes"]["txt"]) > 0 - - # test no content - res = await decoded_r.ft().search(Query("king").no_content()) - assert 194 == res["total_results"] - assert 10 == len(res["results"]) - for doc in res["results"]: - assert "extra_attributes" not in doc.keys() - - # test verbatim vs no verbatim - total = (await decoded_r.ft().search(Query("kings").no_content()))[ - "total_results" - ] - vtotal = (await decoded_r.ft().search(Query("kings").no_content().verbatim()))[ - "total_results" - ] - assert total > vtotal - - # test in fields - txt_total = ( - await decoded_r.ft().search(Query("henry").no_content().limit_fields("txt")) - )["total_results"] - play_total = ( - await decoded_r.ft().search( - Query("henry").no_content().limit_fields("play") - ) - )["total_results"] - both_total = ( - await decoded_r.ft().search( - Query("henry").no_content().limit_fields("play", "txt") - ) - )["total_results"] - assert 129 == txt_total - assert 494 == play_total - assert 494 == both_total - - # test load_document - doc = await decoded_r.ft().load_document("henry vi part 3:62") - assert doc is not None - assert "henry vi part 3:62" == doc.id - assert doc.play == "Henry VI Part 3" - assert len(doc.txt) > 0 - - # test in-keys - ids = [ - x["id"] for x in (await decoded_r.ft().search(Query("henry")))["results"] - ] - assert 10 == len(ids) - subset = ids[:5] - docs = await decoded_r.ft().search(Query("henry").limit_ids(*subset)) - assert len(subset) == docs["total_results"] - ids = [x["id"] for x in docs["results"]] - assert set(ids) == set(subset) - - # test slop and in order - assert ( - 193 == (await decoded_r.ft().search(Query("henry king")))["total_results"] - ) - assert ( - 3 - == (await decoded_r.ft().search(Query("henry king").slop(0).in_order()))[ - "total_results" - ] - ) - assert ( - 52 - == (await decoded_r.ft().search(Query("king henry").slop(0).in_order()))[ - "total_results" - ] - ) - assert ( - 53 - == (await decoded_r.ft().search(Query("henry king").slop(0)))[ - "total_results" - ] - ) - assert ( - 167 - == (await decoded_r.ft().search(Query("henry king").slop(100)))[ - "total_results" - ] - ) - - # test delete document - await decoded_r.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) - res = await decoded_r.ft().search(Query("death of a salesman")) - assert 1 == res["total_results"] - - assert 1 == await decoded_r.ft().delete_document("doc-5ghs2") - res = await decoded_r.ft().search(Query("death of a salesman")) - assert 0 == res["total_results"] - assert 0 == await decoded_r.ft().delete_document("doc-5ghs2") - - await decoded_r.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) - res = await decoded_r.ft().search(Query("death of a salesman")) - assert 1 == res["total_results"] - await decoded_r.ft().delete_document("doc-5ghs2") - - -@pytest.mark.redismod -@pytest.mark.onlynoncluster -@skip_if_server_version_gte("7.9.0") -async def test_scores(decoded_r: redis.Redis): - await decoded_r.ft().create_index((TextField("txt"),)) - - await decoded_r.hset("doc1", mapping={"txt": "foo baz"}) - await decoded_r.hset("doc2", mapping={"txt": "foo bar"}) - - q = Query("foo ~bar").with_scores() - res = await decoded_r.ft().search(q) - if is_resp2_connection(decoded_r): - assert 2 == res.total - assert "doc2" == res.docs[0].id - assert 3.0 == res.docs[0].score - assert "doc1" == res.docs[1].id - else: - assert 2 == res["total_results"] - assert "doc2" == res["results"][0]["id"] - assert 3.0 == res["results"][0]["score"] - assert "doc1" == res["results"][1]["id"] - - -@pytest.mark.redismod -@pytest.mark.onlynoncluster -@skip_if_server_version_lt("7.9.0") -async def test_scores_with_new_default_scorer(decoded_r: redis.Redis): - await decoded_r.ft().create_index((TextField("txt"),)) - - await decoded_r.hset("doc1", mapping={"txt": "foo baz"}) - await decoded_r.hset("doc2", mapping={"txt": "foo bar"}) - - q = Query("foo ~bar").with_scores() - res = await decoded_r.ft().search(q) - if is_resp2_connection(decoded_r): - assert 2 == res.total - assert "doc2" == res.docs[0].id - assert 0.87 == pytest.approx(res.docs[0].score, 0.01) - assert "doc1" == res.docs[1].id - else: - assert 2 == res["total_results"] - assert "doc2" == res["results"][0]["id"] - assert 0.87 == pytest.approx(res["results"][0]["score"], 0.01) - assert "doc1" == res["results"][1]["id"] - - -@pytest.mark.redismod -async def test_stopwords(decoded_r: redis.Redis): - stopwords = ["foo", "bar", "baz"] - await decoded_r.ft().create_index((TextField("txt"),), stopwords=stopwords) - await decoded_r.hset("doc1", mapping={"txt": "foo bar"}) - await decoded_r.hset("doc2", mapping={"txt": "hello world"}) - await waitForIndex(decoded_r, "idx") - - q1 = Query("foo bar").no_content() - q2 = Query("foo bar hello world").no_content() - res1, res2 = await decoded_r.ft().search(q1), await decoded_r.ft().search(q2) - if is_resp2_connection(decoded_r): - assert 0 == res1.total - assert 1 == res2.total - else: - assert 0 == res1["total_results"] - assert 1 == res2["total_results"] - - -@pytest.mark.redismod -async def test_filters(decoded_r: redis.Redis): - await decoded_r.ft().create_index( - (TextField("txt"), NumericField("num"), GeoField("loc")) - ) - await decoded_r.hset( - "doc1", mapping={"txt": "foo bar", "num": 3.141, "loc": "-0.441,51.458"} - ) - await decoded_r.hset( - "doc2", mapping={"txt": "foo baz", "num": 2, "loc": "-0.1,51.2"} - ) - - await waitForIndex(decoded_r, "idx") - # Test numerical filter - q1 = Query("foo").add_filter(NumericFilter("num", 0, 2)).no_content() - q2 = ( - Query("foo") - .add_filter(NumericFilter("num", 2, NumericFilter.INF, minExclusive=True)) - .no_content() - ) - res1, res2 = await decoded_r.ft().search(q1), await decoded_r.ft().search(q2) - - if is_resp2_connection(decoded_r): - assert 1 == res1.total - assert 1 == res2.total - assert "doc2" == res1.docs[0].id - assert "doc1" == res2.docs[0].id - else: - assert 1 == res1["total_results"] - assert 1 == res2["total_results"] - assert "doc2" == res1["results"][0]["id"] - assert "doc1" == res2["results"][0]["id"] - - # Test geo filter - q1 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 10)).no_content() - q2 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 100)).no_content() - res1, res2 = await decoded_r.ft().search(q1), await decoded_r.ft().search(q2) - - if is_resp2_connection(decoded_r): - assert 1 == res1.total - assert 2 == res2.total - assert "doc1" == res1.docs[0].id - - # Sort results, after RDB reload order may change - res = [res2.docs[0].id, res2.docs[1].id] - res.sort() - assert ["doc1", "doc2"] == res - else: - assert 1 == res1["total_results"] - assert 2 == res2["total_results"] - assert "doc1" == res1["results"][0]["id"] - - # Sort results, after RDB reload order may change - res = [res2["results"][0]["id"], res2["results"][1]["id"]] - res.sort() - assert ["doc1", "doc2"] == res - - -@pytest.mark.redismod -async def test_sort_by(decoded_r: redis.Redis): - await decoded_r.ft().create_index( - (TextField("txt"), NumericField("num", sortable=True)) - ) - await decoded_r.hset("doc1", mapping={"txt": "foo bar", "num": 1}) - await decoded_r.hset("doc2", mapping={"txt": "foo baz", "num": 2}) - await decoded_r.hset("doc3", mapping={"txt": "foo qux", "num": 3}) - - # Test sort - q1 = Query("foo").sort_by("num", asc=True).no_content() - q2 = Query("foo").sort_by("num", asc=False).no_content() - res1, res2 = await decoded_r.ft().search(q1), await decoded_r.ft().search(q2) - - if is_resp2_connection(decoded_r): - assert 3 == res1.total - assert "doc1" == res1.docs[0].id - assert "doc2" == res1.docs[1].id - assert "doc3" == res1.docs[2].id - assert 3 == res2.total - assert "doc1" == res2.docs[2].id - assert "doc2" == res2.docs[1].id - assert "doc3" == res2.docs[0].id - else: - assert 3 == res1["total_results"] - assert "doc1" == res1["results"][0]["id"] - assert "doc2" == res1["results"][1]["id"] - assert "doc3" == res1["results"][2]["id"] - assert 3 == res2["total_results"] - assert "doc1" == res2["results"][2]["id"] - assert "doc2" == res2["results"][1]["id"] - assert "doc3" == res2["results"][0]["id"] - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.0.0", "search") -async def test_drop_index(decoded_r: redis.Redis): - """ - Ensure the index gets dropped by data remains by default - """ - for x in range(20): - for keep_docs in [[True, {}], [False, {"name": "haveit"}]]: - idx = "HaveIt" - index = getClient(decoded_r) - await index.hset("index:haveit", mapping={"name": "haveit"}) - idef = IndexDefinition(prefix=["index:"]) - await index.ft(idx).create_index((TextField("name"),), definition=idef) - await waitForIndex(index, idx) - await index.ft(idx).dropindex(delete_documents=keep_docs[0]) - i = await index.hgetall("index:haveit") - assert i == keep_docs[1] - - -@pytest.mark.redismod -async def test_example(decoded_r: redis.Redis): - # Creating the index definition and schema - await decoded_r.ft().create_index( - (TextField("title", weight=5.0), TextField("body")) - ) - - # Indexing a document - await decoded_r.hset( - "doc1", - mapping={ - "title": "RediSearch", - "body": "Redisearch impements a search engine on top of redis", - }, - ) - - # Searching with complex parameters: - q = Query("search engine").verbatim().no_content().paging(0, 5) - - res = await decoded_r.ft().search(q) - assert res is not None - - -@pytest.mark.redismod -async def test_auto_complete(decoded_r: redis.Redis): - n = 0 - with open(TITLES_CSV) as f: - cr = csv.reader(f) - - for row in cr: - n += 1 - term, score = row[0], float(row[1]) - assert n == await decoded_r.ft().sugadd("ac", Suggestion(term, score=score)) - - assert n == await decoded_r.ft().suglen("ac") - ret = await decoded_r.ft().sugget("ac", "bad", with_scores=True) - assert 2 == len(ret) - assert "badger" == ret[0].string - assert isinstance(ret[0].score, float) - assert 1.0 != ret[0].score - assert "badalte rishtey" == ret[1].string - assert isinstance(ret[1].score, float) - assert 1.0 != ret[1].score - - ret = await decoded_r.ft().sugget("ac", "bad", fuzzy=True, num=10) - assert 10 == len(ret) - assert 1.0 == ret[0].score - strs = {x.string for x in ret} - - for sug in strs: - assert 1 == await decoded_r.ft().sugdel("ac", sug) - # make sure a second delete returns 0 - for sug in strs: - assert 0 == await decoded_r.ft().sugdel("ac", sug) - - # make sure they were actually deleted - ret2 = await decoded_r.ft().sugget("ac", "bad", fuzzy=True, num=10) - for sug in ret2: - assert sug.string not in strs - - # Test with payload - await decoded_r.ft().sugadd("ac", Suggestion("pay1", payload="pl1")) - await decoded_r.ft().sugadd("ac", Suggestion("pay2", payload="pl2")) - await decoded_r.ft().sugadd("ac", Suggestion("pay3", payload="pl3")) - - sugs = await decoded_r.ft().sugget( - "ac", "pay", with_payloads=True, with_scores=True - ) - assert 3 == len(sugs) - for sug in sugs: - assert sug.payload - assert sug.payload.startswith("pl") - - -@pytest.mark.redismod -async def test_no_index(decoded_r: redis.Redis): - await decoded_r.ft().create_index( - ( - TextField("field"), - TextField("text", no_index=True, sortable=True), - NumericField("numeric", no_index=True, sortable=True), - GeoField("geo", no_index=True, sortable=True), - TagField("tag", no_index=True, sortable=True), - ) - ) - - await decoded_r.hset( - "doc1", - mapping={"field": "aaa", "text": "1", "numeric": "1", "geo": "1,1", "tag": "1"}, - ) - await decoded_r.hset( - "doc2", - mapping={"field": "aab", "text": "2", "numeric": "2", "geo": "2,2", "tag": "2"}, - ) - await waitForIndex(decoded_r, "idx") - - if is_resp2_connection(decoded_r): - res = await decoded_r.ft().search(Query("@text:aa*")) - assert 0 == res.total - - res = await decoded_r.ft().search(Query("@field:aa*")) - assert 2 == res.total - - res = await decoded_r.ft().search(Query("*").sort_by("text", asc=False)) - assert 2 == res.total - assert "doc2" == res.docs[0].id - - res = await decoded_r.ft().search(Query("*").sort_by("text", asc=True)) - assert "doc1" == res.docs[0].id - - res = await decoded_r.ft().search(Query("*").sort_by("numeric", asc=True)) - assert "doc1" == res.docs[0].id - - res = await decoded_r.ft().search(Query("*").sort_by("geo", asc=True)) - assert "doc1" == res.docs[0].id - - res = await decoded_r.ft().search(Query("*").sort_by("tag", asc=True)) - assert "doc1" == res.docs[0].id - else: - res = await decoded_r.ft().search(Query("@text:aa*")) - assert 0 == res["total_results"] - - res = await decoded_r.ft().search(Query("@field:aa*")) - assert 2 == res["total_results"] - - res = await decoded_r.ft().search(Query("*").sort_by("text", asc=False)) - assert 2 == res["total_results"] - assert "doc2" == res["results"][0]["id"] - - res = await decoded_r.ft().search(Query("*").sort_by("text", asc=True)) - assert "doc1" == res["results"][0]["id"] - - res = await decoded_r.ft().search(Query("*").sort_by("numeric", asc=True)) - assert "doc1" == res["results"][0]["id"] - - res = await decoded_r.ft().search(Query("*").sort_by("geo", asc=True)) - assert "doc1" == res["results"][0]["id"] - - res = await decoded_r.ft().search(Query("*").sort_by("tag", asc=True)) - assert "doc1" == res["results"][0]["id"] - - # Ensure exception is raised for non-indexable, non-sortable fields - with pytest.raises(Exception): - TextField("name", no_index=True, sortable=False) - with pytest.raises(Exception): - NumericField("name", no_index=True, sortable=False) - with pytest.raises(Exception): - GeoField("name", no_index=True, sortable=False) - with pytest.raises(Exception): - TagField("name", no_index=True, sortable=False) - - -@pytest.mark.redismod -@skip_if_server_version_lt("7.4.0") -@skip_ifmodversion_lt("2.10.0", "search") -async def test_create_index_empty_or_missing_fields_with_sortable( - decoded_r: redis.Redis, -): - definition = IndexDefinition(prefix=["property:"], index_type=IndexType.HASH) - - fields = [ - TextField("title", sortable=True, index_empty=True), - TagField("features", index_missing=True, sortable=True), - TextField("description", no_index=True, sortable=True), - ] - - await decoded_r.ft().create_index(fields, definition=definition) - - -@pytest.mark.redismod -async def test_explain(decoded_r: redis.Redis): - await decoded_r.ft().create_index( - (TextField("f1"), TextField("f2"), TextField("f3")) - ) - res = await decoded_r.ft().explain("@f3:f3_val @f2:f2_val @f1:f1_val") - assert res - - -@pytest.mark.redismod -async def test_explaincli(decoded_r: redis.Redis): - with pytest.raises(NotImplementedError): - await decoded_r.ft().explain_cli("foo") - - -@pytest.mark.redismod -async def test_summarize(decoded_r: redis.Redis): - await createIndex(decoded_r.ft()) - await waitForIndex(decoded_r, "idx") - - q = Query('"king henry"').paging(0, 1) - q.highlight(fields=("play", "txt"), tags=("", "")) - q.summarize("txt") - - if is_resp2_connection(decoded_r): - doc = sorted((await decoded_r.ft().search(q)).docs)[0] - assert "Henry IV" == doc.play - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt - ) - - q = Query('"king henry"').paging(0, 1).summarize().highlight() + except AttributeError: + try: + if int(res["indexing"]) == 0: + break + except ValueError: + break + except ResponseError: + # index doesn't exist yet + # continue to sleep and try again + pass + + await asyncio.sleep(delay) + if timeout is not None: + timeout -= delay + if timeout <= 0: + break - doc = sorted((await decoded_r.ft().search(q)).docs)[0] - assert "Henry ... " == doc.play - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt - ) - else: - doc = sorted((await decoded_r.ft().search(q))["results"])[0] - assert "Henry IV" == doc["extra_attributes"]["play"] - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc["extra_attributes"]["txt"] - ) + @staticmethod + def getClient(decoded_r: redis.Redis): + """ + Gets a client client attached to an index name which is ready to be + created + """ + return decoded_r - q = Query('"king henry"').paging(0, 1).summarize().highlight() + @staticmethod + async def createIndex(decoded_r, num_docs=100, definition=None): + try: + await decoded_r.create_index( + ( + TextField("play", weight=5.0), + TextField("txt"), + NumericField("chapter"), + ), + definition=definition, + ) + except redis.ResponseError: + await decoded_r.dropindex(delete_documents=True) + return await AsyncSearchTestsBase.createIndex( + decoded_r, num_docs=num_docs, definition=definition + ) - doc = sorted((await decoded_r.ft().search(q))["results"])[0] - assert "Henry ... " == doc["extra_attributes"]["play"] - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc["extra_attributes"]["txt"] - ) + chapters = {} + bzfp = TextIOWrapper(bz2.BZ2File(WILL_PLAY_TEXT), encoding="utf8") + r = csv.reader(bzfp, delimiter=";") + for n, line in enumerate(r): + play, chapter, _, text = line[1], line[2], line[4], line[5] -@pytest.mark.redismod -@skip_ifmodversion_lt("2.0.0", "search") -async def test_alias(decoded_r: redis.Redis): - index1 = getClient(decoded_r) - index2 = getClient(decoded_r) + key = f"{play}:{chapter}".lower() + d = chapters.setdefault(key, {}) + d["play"] = play + d["txt"] = d.get("txt", "") + " " + text + d["chapter"] = int(chapter or 0) + if len(chapters) == num_docs: + break - def1 = IndexDefinition(prefix=["index1:"]) - def2 = IndexDefinition(prefix=["index2:"]) + indexer = decoded_r.batch_indexer(chunk_size=50) + assert isinstance(indexer, AsyncSearch.BatchIndexer) + assert 50 == indexer.chunk_size - ftindex1 = index1.ft("testAlias") - ftindex2 = index2.ft("testAlias2") - await ftindex1.create_index((TextField("name"),), definition=def1) - await ftindex2.create_index((TextField("name"),), definition=def2) + for key, doc in chapters.items(): + await indexer.client.client.hset(key, mapping=doc) + await indexer.commit() - await index1.hset("index1:lonestar", mapping={"name": "lonestar"}) - await index2.hset("index2:yogurt", mapping={"name": "yogurt"}) - if is_resp2_connection(decoded_r): - res = (await ftindex1.search("*")).docs[0] - assert "index1:lonestar" == res.id +class TestBaseSearchFunctionality(AsyncSearchTestsBase): + @pytest.mark.redismod + async def test_client(self, decoded_r: redis.Redis): + num_docs = 500 + await self.createIndex(decoded_r.ft(), num_docs=num_docs) + await self.waitForIndex(decoded_r, "idx") + # verify info + info = await decoded_r.ft().info() + for k in [ + "index_name", + "index_options", + "attributes", + "num_docs", + "max_doc_id", + "num_terms", + "num_records", + "inverted_sz_mb", + "offset_vectors_sz_mb", + "doc_table_size_mb", + "key_table_size_mb", + "records_per_doc_avg", + "bytes_per_record_avg", + "offsets_per_term_avg", + "offset_bits_per_record_avg", + ]: + assert k in info + + assert decoded_r.ft().index_name == info["index_name"] + assert num_docs == int(info["num_docs"]) + + res = await decoded_r.ft().search("henry iv") + if is_resp2_connection(decoded_r): + assert isinstance(res, Result) + assert 225 == res.total + assert 10 == len(res.docs) + assert res.duration > 0 + + for doc in res.docs: + assert doc.id + assert doc.play == "Henry IV" + assert len(doc.txt) > 0 + + # test no content + res = await decoded_r.ft().search(Query("king").no_content()) + assert 194 == res.total + assert 10 == len(res.docs) + for doc in res.docs: + assert "txt" not in doc.__dict__ + assert "play" not in doc.__dict__ + + # test verbatim vs no verbatim + total = (await decoded_r.ft().search(Query("kings").no_content())).total + vtotal = ( + await decoded_r.ft().search(Query("kings").no_content().verbatim()) + ).total + assert total > vtotal - # create alias and check for results - await ftindex1.aliasadd("spaceballs") - alias_client = getClient(decoded_r).ft("spaceballs") - res = (await alias_client.search("*")).docs[0] - assert "index1:lonestar" == res.id + # test in fields + txt_total = ( + await decoded_r.ft().search( + Query("henry").no_content().limit_fields("txt") + ) + ).total + play_total = ( + await decoded_r.ft().search( + Query("henry").no_content().limit_fields("play") + ) + ).total + both_total = ( + await decoded_r.ft().search( + Query("henry").no_content().limit_fields("play", "txt") + ) + ).total + assert 129 == txt_total + assert 494 == play_total + assert 494 == both_total + + # test load_document + doc = await decoded_r.ft().load_document("henry vi part 3:62") + assert doc is not None + assert "henry vi part 3:62" == doc.id + assert doc.play == "Henry VI Part 3" + assert len(doc.txt) > 0 - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - await ftindex2.aliasadd("spaceballs") + # test in-keys + ids = [x.id for x in (await decoded_r.ft().search(Query("henry"))).docs] + assert 10 == len(ids) + subset = ids[:5] + docs = await decoded_r.ft().search(Query("henry").limit_ids(*subset)) + assert len(subset) == docs.total + ids = [x.id for x in docs.docs] + assert set(ids) == set(subset) + + # test slop and in order + assert 193 == (await decoded_r.ft().search(Query("henry king"))).total + assert ( + 3 + == ( + await decoded_r.ft().search(Query("henry king").slop(0).in_order()) + ).total + ) + assert ( + 52 + == ( + await decoded_r.ft().search(Query("king henry").slop(0).in_order()) + ).total + ) + assert ( + 53 == (await decoded_r.ft().search(Query("henry king").slop(0))).total + ) + assert ( + 167 + == (await decoded_r.ft().search(Query("henry king").slop(100))).total + ) - # update alias and ensure new results - await ftindex2.aliasupdate("spaceballs") - alias_client2 = getClient(decoded_r).ft("spaceballs") + # test delete document + await decoded_r.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = await decoded_r.ft().search(Query("death of a salesman")) + assert 1 == res.total - res = (await alias_client2.search("*")).docs[0] - assert "index2:yogurt" == res.id - else: - res = (await ftindex1.search("*"))["results"][0] - assert "index1:lonestar" == res["id"] + assert 1 == await decoded_r.ft().delete_document("doc-5ghs2") + res = await decoded_r.ft().search(Query("death of a salesman")) + assert 0 == res.total + assert 0 == await decoded_r.ft().delete_document("doc-5ghs2") - # create alias and check for results - await ftindex1.aliasadd("spaceballs") - alias_client = getClient(await decoded_r).ft("spaceballs") - res = (await alias_client.search("*"))["results"][0] - assert "index1:lonestar" == res["id"] + await decoded_r.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = await decoded_r.ft().search(Query("death of a salesman")) + assert 1 == res.total + await decoded_r.ft().delete_document("doc-5ghs2") + else: + assert isinstance(res, dict) + assert 225 == res["total_results"] + assert 10 == len(res["results"]) + + for doc in res["results"]: + assert doc["id"] + assert doc["extra_attributes"]["play"] == "Henry IV" + assert len(doc["extra_attributes"]["txt"]) > 0 + + # test no content + res = await decoded_r.ft().search(Query("king").no_content()) + assert 194 == res["total_results"] + assert 10 == len(res["results"]) + for doc in res["results"]: + assert "extra_attributes" not in doc.keys() + + # test verbatim vs no verbatim + total = (await decoded_r.ft().search(Query("kings").no_content()))[ + "total_results" + ] + vtotal = ( + await decoded_r.ft().search(Query("kings").no_content().verbatim()) + )["total_results"] + assert total > vtotal + + # test in fields + txt_total = ( + await decoded_r.ft().search( + Query("henry").no_content().limit_fields("txt") + ) + )["total_results"] + play_total = ( + await decoded_r.ft().search( + Query("henry").no_content().limit_fields("play") + ) + )["total_results"] + both_total = ( + await decoded_r.ft().search( + Query("henry").no_content().limit_fields("play", "txt") + ) + )["total_results"] + assert 129 == txt_total + assert 494 == play_total + assert 494 == both_total + + # test load_document + doc = await decoded_r.ft().load_document("henry vi part 3:62") + assert doc is not None + assert "henry vi part 3:62" == doc.id + assert doc.play == "Henry VI Part 3" + assert len(doc.txt) > 0 - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - await ftindex2.aliasadd("spaceballs") + # test in-keys + ids = [ + x["id"] + for x in (await decoded_r.ft().search(Query("henry")))["results"] + ] + assert 10 == len(ids) + subset = ids[:5] + docs = await decoded_r.ft().search(Query("henry").limit_ids(*subset)) + assert len(subset) == docs["total_results"] + ids = [x["id"] for x in docs["results"]] + assert set(ids) == set(subset) + + # test slop and in order + assert ( + 193 + == (await decoded_r.ft().search(Query("henry king")))["total_results"] + ) + assert ( + 3 + == ( + await decoded_r.ft().search(Query("henry king").slop(0).in_order()) + )["total_results"] + ) + assert ( + 52 + == ( + await decoded_r.ft().search(Query("king henry").slop(0).in_order()) + )["total_results"] + ) + assert ( + 53 + == (await decoded_r.ft().search(Query("henry king").slop(0)))[ + "total_results" + ] + ) + assert ( + 167 + == (await decoded_r.ft().search(Query("henry king").slop(100)))[ + "total_results" + ] + ) - # update alias and ensure new results - await ftindex2.aliasupdate("spaceballs") - alias_client2 = getClient(await decoded_r).ft("spaceballs") + # test delete document + await decoded_r.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = await decoded_r.ft().search(Query("death of a salesman")) + assert 1 == res["total_results"] - res = (await alias_client2.search("*"))["results"][0] - assert "index2:yogurt" == res["id"] + assert 1 == await decoded_r.ft().delete_document("doc-5ghs2") + res = await decoded_r.ft().search(Query("death of a salesman")) + assert 0 == res["total_results"] + assert 0 == await decoded_r.ft().delete_document("doc-5ghs2") - await ftindex2.aliasdel("spaceballs") - with pytest.raises(Exception): - (await alias_client2.search("*")).docs[0] + await decoded_r.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = await decoded_r.ft().search(Query("death of a salesman")) + assert 1 == res["total_results"] + await decoded_r.ft().delete_document("doc-5ghs2") + @pytest.mark.redismod + @pytest.mark.onlynoncluster + @skip_if_server_version_gte("7.9.0") + async def test_scores(self, decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("txt"),)) -@pytest.mark.redismod -@pytest.mark.xfail(strict=False) -async def test_alias_basic(decoded_r: redis.Redis): - # Creating a client with one index - client = getClient(decoded_r) - await client.flushdb() - index1 = getClient(decoded_r).ft("testAlias") + await decoded_r.hset("doc1", mapping={"txt": "foo baz"}) + await decoded_r.hset("doc2", mapping={"txt": "foo bar"}) - await index1.create_index((TextField("txt"),)) - await index1.client.hset("doc1", mapping={"txt": "text goes here"}) + q = Query("foo ~bar").with_scores() + res = await decoded_r.ft().search(q) + if is_resp2_connection(decoded_r): + assert 2 == res.total + assert "doc2" == res.docs[0].id + assert 3.0 == res.docs[0].score + assert "doc1" == res.docs[1].id + else: + assert 2 == res["total_results"] + assert "doc2" == res["results"][0]["id"] + assert 3.0 == res["results"][0]["score"] + assert "doc1" == res["results"][1]["id"] - index2 = getClient(decoded_r).ft("testAlias2") - await index2.create_index((TextField("txt"),)) - await index2.client.hset("doc2", mapping={"txt": "text goes here"}) + @pytest.mark.redismod + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.9.0") + async def test_scores_with_new_default_scorer(self, decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("txt"),)) - # add the actual alias and check - await index1.aliasadd("myalias") - alias_client = getClient(decoded_r).ft("myalias") - if is_resp2_connection(decoded_r): - res = sorted((await alias_client.search("*")).docs, key=lambda x: x.id) - assert "doc1" == res[0].id + await decoded_r.hset("doc1", mapping={"txt": "foo baz"}) + await decoded_r.hset("doc2", mapping={"txt": "foo bar"}) - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - await index2.aliasadd("myalias") - - # update the alias and ensure we get doc2 - await index2.aliasupdate("myalias") - alias_client2 = getClient(decoded_r).ft("myalias") - res = sorted((await alias_client2.search("*")).docs, key=lambda x: x.id) - assert "doc1" == res[0].id - else: - res = sorted((await alias_client.search("*"))["results"], key=lambda x: x["id"]) - assert "doc1" == res[0]["id"] - - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - await index2.aliasadd("myalias") + q = Query("foo ~bar").with_scores() + res = await decoded_r.ft().search(q) + if is_resp2_connection(decoded_r): + assert 2 == res.total + assert "doc2" == res.docs[0].id + assert 0.87 == pytest.approx(res.docs[0].score, 0.01) + assert "doc1" == res.docs[1].id + else: + assert 2 == res["total_results"] + assert "doc2" == res["results"][0]["id"] + assert 0.87 == pytest.approx(res["results"][0]["score"], 0.01) + assert "doc1" == res["results"][1]["id"] + + @pytest.mark.redismod + async def test_stopwords(self, decoded_r: redis.Redis): + stopwords = ["foo", "bar", "baz"] + await decoded_r.ft().create_index((TextField("txt"),), stopwords=stopwords) + await decoded_r.hset("doc1", mapping={"txt": "foo bar"}) + await decoded_r.hset("doc2", mapping={"txt": "hello world"}) + await self.waitForIndex(decoded_r, "idx") + + q1 = Query("foo bar").no_content() + q2 = Query("foo bar hello world").no_content() + res1, res2 = await decoded_r.ft().search(q1), await decoded_r.ft().search(q2) + if is_resp2_connection(decoded_r): + assert 0 == res1.total + assert 1 == res2.total + else: + assert 0 == res1["total_results"] + assert 1 == res2["total_results"] - # update the alias and ensure we get doc2 - await index2.aliasupdate("myalias") - alias_client2 = getClient(client).ft("myalias") - res = sorted( - (await alias_client2.search("*"))["results"], key=lambda x: x["id"] + @pytest.mark.redismod + async def test_filters(self, decoded_r: redis.Redis): + await decoded_r.ft().create_index( + (TextField("txt"), NumericField("num"), GeoField("loc")) + ) + await decoded_r.hset( + "doc1", mapping={"txt": "foo bar", "num": 3.141, "loc": "-0.441,51.458"} + ) + await decoded_r.hset( + "doc2", mapping={"txt": "foo baz", "num": 2, "loc": "-0.1,51.2"} ) - assert "doc1" == res[0]["id"] - - # delete the alias and expect an error if we try to query again - await index2.aliasdel("myalias") - with pytest.raises(Exception): - _ = (await alias_client2.search("*")).docs[0] - -@pytest.mark.redismod -async def test_tags(decoded_r: redis.Redis): - await decoded_r.ft().create_index((TextField("txt"), TagField("tags"))) - tags = "foo,foo bar,hello;world" - tags2 = "soba,ramen" + await self.waitForIndex(decoded_r, "idx") + # Test numerical filter + q1 = Query("foo").add_filter(NumericFilter("num", 0, 2)).no_content() + q2 = ( + Query("foo") + .add_filter(NumericFilter("num", 2, NumericFilter.INF, minExclusive=True)) + .no_content() + ) + res1, res2 = await decoded_r.ft().search(q1), await decoded_r.ft().search(q2) - await decoded_r.hset("doc1", mapping={"txt": "fooz barz", "tags": tags}) - await decoded_r.hset("doc2", mapping={"txt": "noodles", "tags": tags2}) - await waitForIndex(decoded_r, "idx") + if is_resp2_connection(decoded_r): + assert 1 == res1.total + assert 1 == res2.total + assert "doc2" == res1.docs[0].id + assert "doc1" == res2.docs[0].id + else: + assert 1 == res1["total_results"] + assert 1 == res2["total_results"] + assert "doc2" == res1["results"][0]["id"] + assert "doc1" == res2["results"][0]["id"] - q = Query("@tags:{foo}") - if is_resp2_connection(decoded_r): - res = await decoded_r.ft().search(q) - assert 1 == res.total + # Test geo filter + q1 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 10)).no_content() + q2 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 100)).no_content() + res1, res2 = await decoded_r.ft().search(q1), await decoded_r.ft().search(q2) - q = Query("@tags:{foo bar}") - res = await decoded_r.ft().search(q) - assert 1 == res.total + if is_resp2_connection(decoded_r): + assert 1 == res1.total + assert 2 == res2.total + assert "doc1" == res1.docs[0].id + + # Sort results, after RDB reload order may change + res = [res2.docs[0].id, res2.docs[1].id] + res.sort() + assert ["doc1", "doc2"] == res + else: + assert 1 == res1["total_results"] + assert 2 == res2["total_results"] + assert "doc1" == res1["results"][0]["id"] + + # Sort results, after RDB reload order may change + res = [res2["results"][0]["id"], res2["results"][1]["id"]] + res.sort() + assert ["doc1", "doc2"] == res + + @pytest.mark.redismod + async def test_sort_by(self, decoded_r: redis.Redis): + await decoded_r.ft().create_index( + (TextField("txt"), NumericField("num", sortable=True)) + ) + await decoded_r.hset("doc1", mapping={"txt": "foo bar", "num": 1}) + await decoded_r.hset("doc2", mapping={"txt": "foo baz", "num": 2}) + await decoded_r.hset("doc3", mapping={"txt": "foo qux", "num": 3}) - q = Query("@tags:{foo\\ bar}") - res = await decoded_r.ft().search(q) - assert 1 == res.total + # Test sort + q1 = Query("foo").sort_by("num", asc=True).no_content() + q2 = Query("foo").sort_by("num", asc=False).no_content() + res1, res2 = await decoded_r.ft().search(q1), await decoded_r.ft().search(q2) - q = Query("@tags:{hello\\;world}") - res = await decoded_r.ft().search(q) - assert 1 == res.total + if is_resp2_connection(decoded_r): + assert 3 == res1.total + assert "doc1" == res1.docs[0].id + assert "doc2" == res1.docs[1].id + assert "doc3" == res1.docs[2].id + assert 3 == res2.total + assert "doc1" == res2.docs[2].id + assert "doc2" == res2.docs[1].id + assert "doc3" == res2.docs[0].id + else: + assert 3 == res1["total_results"] + assert "doc1" == res1["results"][0]["id"] + assert "doc2" == res1["results"][1]["id"] + assert "doc3" == res1["results"][2]["id"] + assert 3 == res2["total_results"] + assert "doc1" == res2["results"][2]["id"] + assert "doc2" == res2["results"][1]["id"] + assert "doc3" == res2["results"][0]["id"] + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.0.0", "search") + async def test_drop_index(self, decoded_r: redis.Redis): + """ + Ensure the index gets dropped by data remains by default + """ + for x in range(20): + for keep_docs in [[True, {}], [False, {"name": "haveit"}]]: + idx = "HaveIt" + index = self.getClient(decoded_r) + await index.hset("index:haveit", mapping={"name": "haveit"}) + idef = IndexDefinition(prefix=["index:"]) + await index.ft(idx).create_index((TextField("name"),), definition=idef) + await self.waitForIndex(index, idx) + await index.ft(idx).dropindex(delete_documents=keep_docs[0]) + i = await index.hgetall("index:haveit") + assert i == keep_docs[1] + + @pytest.mark.redismod + async def test_example(self, decoded_r: redis.Redis): + # Creating the index definition and schema + await decoded_r.ft().create_index( + (TextField("title", weight=5.0), TextField("body")) + ) - q2 = await decoded_r.ft().tagvals("tags") - assert (tags.split(",") + tags2.split(",")).sort() == q2.sort() - else: - res = await decoded_r.ft().search(q) - assert 1 == res["total_results"] + # Indexing a document + await decoded_r.hset( + "doc1", + mapping={ + "title": "RediSearch", + "body": "Redisearch impements a search engine on top of redis", + }, + ) - q = Query("@tags:{foo bar}") - res = await decoded_r.ft().search(q) - assert 1 == res["total_results"] + # Searching with complex parameters: + q = Query("search engine").verbatim().no_content().paging(0, 5) - q = Query("@tags:{foo\\ bar}") res = await decoded_r.ft().search(q) - assert 1 == res["total_results"] + assert res is not None + + @pytest.mark.redismod + async def test_auto_complete(self, decoded_r: redis.Redis): + n = 0 + with open(TITLES_CSV) as f: + cr = csv.reader(f) + + for row in cr: + n += 1 + term, score = row[0], float(row[1]) + assert n == await decoded_r.ft().sugadd( + "ac", Suggestion(term, score=score) + ) - q = Query("@tags:{hello\\;world}") - res = await decoded_r.ft().search(q) - assert 1 == res["total_results"] - - q2 = await decoded_r.ft().tagvals("tags") - assert set(tags.split(",") + tags2.split(",")) == set(q2) - - -@pytest.mark.redismod -async def test_textfield_sortable_nostem(decoded_r: redis.Redis): - # Creating the index definition with sortable and no_stem - await decoded_r.ft().create_index((TextField("txt", sortable=True, no_stem=True),)) - - # Now get the index info to confirm its contents - response = await decoded_r.ft().info() - if is_resp2_connection(decoded_r): - assert "SORTABLE" in response["attributes"][0] - assert "NOSTEM" in response["attributes"][0] - else: - assert "SORTABLE" in response["attributes"][0]["flags"] - assert "NOSTEM" in response["attributes"][0]["flags"] - - -@pytest.mark.redismod -async def test_alter_schema_add(decoded_r: redis.Redis): - # Creating the index definition and schema - await decoded_r.ft().create_index(TextField("title")) - - # Using alter to add a field - await decoded_r.ft().alter_schema_add(TextField("body")) - - # Indexing a document - await decoded_r.hset( - "doc1", mapping={"title": "MyTitle", "body": "Some content only in the body"} - ) - - # Searching with parameter only in the body (the added field) - q = Query("only in the body") - - # Ensure we find the result searching on the added body field - res = await decoded_r.ft().search(q) - if is_resp2_connection(decoded_r): - assert 1 == res.total - else: - assert 1 == res["total_results"] - - -@pytest.mark.redismod -async def test_spell_check(decoded_r: redis.Redis): - await decoded_r.ft().create_index((TextField("f1"), TextField("f2"))) - - await decoded_r.hset( - "doc1", mapping={"f1": "some valid content", "f2": "this is sample text"} - ) - await decoded_r.hset("doc2", mapping={"f1": "very important", "f2": "lorem ipsum"}) - await waitForIndex(decoded_r, "idx") - - if is_resp2_connection(decoded_r): - # test spellcheck - res = await decoded_r.ft().spellcheck("impornant") - assert "important" == res["impornant"][0]["suggestion"] - - res = await decoded_r.ft().spellcheck("contnt") - assert "content" == res["contnt"][0]["suggestion"] - - # test spellcheck with Levenshtein distance - res = await decoded_r.ft().spellcheck("vlis") - assert res == {} - res = await decoded_r.ft().spellcheck("vlis", distance=2) - assert "valid" == res["vlis"][0]["suggestion"] - - # test spellcheck include - await decoded_r.ft().dict_add("dict", "lore", "lorem", "lorm") - res = await decoded_r.ft().spellcheck("lorm", include="dict") - assert len(res["lorm"]) == 3 - assert ( - res["lorm"][0]["suggestion"], - res["lorm"][1]["suggestion"], - res["lorm"][2]["suggestion"], - ) == ("lorem", "lore", "lorm") - assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ("0.5", "0") - - # test spellcheck exclude - res = await decoded_r.ft().spellcheck("lorm", exclude="dict") - assert res == {} - else: - # test spellcheck - res = await decoded_r.ft().spellcheck("impornant") - assert "important" in res["results"]["impornant"][0].keys() - - res = await decoded_r.ft().spellcheck("contnt") - assert "content" in res["results"]["contnt"][0].keys() - - # test spellcheck with Levenshtein distance - res = await decoded_r.ft().spellcheck("vlis") - assert res == {"results": {"vlis": []}} - res = await decoded_r.ft().spellcheck("vlis", distance=2) - assert "valid" in res["results"]["vlis"][0].keys() - - # test spellcheck include - await decoded_r.ft().dict_add("dict", "lore", "lorem", "lorm") - res = await decoded_r.ft().spellcheck("lorm", include="dict") - assert len(res["results"]["lorm"]) == 3 - assert "lorem" in res["results"]["lorm"][0].keys() - assert "lore" in res["results"]["lorm"][1].keys() - assert "lorm" in res["results"]["lorm"][2].keys() - assert ( - res["results"]["lorm"][0]["lorem"], - res["results"]["lorm"][1]["lore"], - ) == (0.5, 0) - - # test spellcheck exclude - res = await decoded_r.ft().spellcheck("lorm", exclude="dict") - assert res == {"results": {}} - - -@pytest.mark.redismod -async def test_dict_operations(decoded_r: redis.Redis): - await decoded_r.ft().create_index((TextField("f1"), TextField("f2"))) - # Add three items - res = await decoded_r.ft().dict_add("custom_dict", "item1", "item2", "item3") - assert 3 == res - - # Remove one item - res = await decoded_r.ft().dict_del("custom_dict", "item2") - assert 1 == res - - # Dump dict and inspect content - res = await decoded_r.ft().dict_dump("custom_dict") - assert res == ["item1", "item3"] - - # Remove rest of the items before reload - await decoded_r.ft().dict_del("custom_dict", *res) - - -@pytest.mark.redismod -async def test_phonetic_matcher(decoded_r: redis.Redis): - await decoded_r.ft().create_index((TextField("name"),)) - await decoded_r.hset("doc1", mapping={"name": "Jon"}) - await decoded_r.hset("doc2", mapping={"name": "John"}) - - res = await decoded_r.ft().search(Query("Jon")) - if is_resp2_connection(decoded_r): - assert 1 == len(res.docs) - assert "Jon" == res.docs[0].name - else: - assert 1 == res["total_results"] - assert "Jon" == res["results"][0]["extra_attributes"]["name"] - - # Drop and create index with phonetic matcher - await decoded_r.flushdb() - - await decoded_r.ft().create_index((TextField("name", phonetic_matcher="dm:en"),)) - await decoded_r.hset("doc1", mapping={"name": "Jon"}) - await decoded_r.hset("doc2", mapping={"name": "John"}) - - res = await decoded_r.ft().search(Query("Jon")) - if is_resp2_connection(decoded_r): - assert 2 == len(res.docs) - assert ["John", "Jon"] == sorted(d.name for d in res.docs) - else: - assert 2 == res["total_results"] - assert ["John", "Jon"] == sorted( - d["extra_attributes"]["name"] for d in res["results"] - ) - - -@pytest.mark.redismod -@pytest.mark.onlynoncluster -# NOTE(imalinovskyi): This test contains hardcoded scores valid only for RediSearch 2.8+ -@skip_ifmodversion_lt("2.8.0", "search") -@skip_if_server_version_gte("7.9.0") -async def test_scorer(decoded_r: redis.Redis): - await decoded_r.ft().create_index((TextField("description"),)) - - await decoded_r.hset( - "doc1", mapping={"description": "The quick brown fox jumps over the lazy dog"} - ) - await decoded_r.hset( - "doc2", - mapping={ - "description": "Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do." # noqa - }, - ) - - if is_resp2_connection(decoded_r): - # default scorer is TFIDF - res = await decoded_r.ft().search(Query("quick").with_scores()) - assert 1.0 == res.docs[0].score - res = await decoded_r.ft().search(Query("quick").scorer("TFIDF").with_scores()) - assert 1.0 == res.docs[0].score - res = await decoded_r.ft().search( - Query("quick").scorer("TFIDF.DOCNORM").with_scores() - ) - assert 0.14285714285714285 == res.docs[0].score - res = await decoded_r.ft().search(Query("quick").scorer("BM25").with_scores()) - assert 0.22471909420069797 == res.docs[0].score - res = await decoded_r.ft().search(Query("quick").scorer("DISMAX").with_scores()) - assert 2.0 == res.docs[0].score - res = await decoded_r.ft().search( - Query("quick").scorer("DOCSCORE").with_scores() - ) - assert 1.0 == res.docs[0].score - res = await decoded_r.ft().search( - Query("quick").scorer("HAMMING").with_scores() - ) - assert 0.0 == res.docs[0].score - else: - res = await decoded_r.ft().search(Query("quick").with_scores()) - assert 1.0 == res["results"][0]["score"] - res = await decoded_r.ft().search(Query("quick").scorer("TFIDF").with_scores()) - assert 1.0 == res["results"][0]["score"] - res = await decoded_r.ft().search( - Query("quick").scorer("TFIDF.DOCNORM").with_scores() - ) - assert 0.14285714285714285 == res["results"][0]["score"] - res = await decoded_r.ft().search(Query("quick").scorer("BM25").with_scores()) - assert 0.22471909420069797 == res["results"][0]["score"] - res = await decoded_r.ft().search(Query("quick").scorer("DISMAX").with_scores()) - assert 2.0 == res["results"][0]["score"] - res = await decoded_r.ft().search( - Query("quick").scorer("DOCSCORE").with_scores() - ) - assert 1.0 == res["results"][0]["score"] - res = await decoded_r.ft().search( - Query("quick").scorer("HAMMING").with_scores() - ) - assert 0.0 == res["results"][0]["score"] - - -@pytest.mark.redismod -@pytest.mark.onlynoncluster -# NOTE(imalinovskyi): This test contains hardcoded scores valid only for RediSearch 2.8+ -@skip_ifmodversion_lt("2.8.0", "search") -@skip_if_server_version_lt("7.9.0") -async def test_scorer_with_new_default_scorer(decoded_r: redis.Redis): - await decoded_r.ft().create_index((TextField("description"),)) - - await decoded_r.hset( - "doc1", mapping={"description": "The quick brown fox jumps over the lazy dog"} - ) - await decoded_r.hset( - "doc2", - mapping={ - "description": "Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do." # noqa - }, - ) - - if is_resp2_connection(decoded_r): - # default scorer is BM25STD - res = await decoded_r.ft().search(Query("quick").with_scores()) - assert 0.23 == pytest.approx(res.docs[0].score, 0.05) - res = await decoded_r.ft().search(Query("quick").scorer("TFIDF").with_scores()) - assert 1.0 == res.docs[0].score - res = await decoded_r.ft().search( - Query("quick").scorer("TFIDF.DOCNORM").with_scores() + assert n == await decoded_r.ft().suglen("ac") + ret = await decoded_r.ft().sugget("ac", "bad", with_scores=True) + assert 2 == len(ret) + assert "badger" == ret[0].string + assert isinstance(ret[0].score, float) + assert 1.0 != ret[0].score + assert "badalte rishtey" == ret[1].string + assert isinstance(ret[1].score, float) + assert 1.0 != ret[1].score + + ret = await decoded_r.ft().sugget("ac", "bad", fuzzy=True, num=10) + assert 10 == len(ret) + assert 1.0 == ret[0].score + strs = {x.string for x in ret} + + for sug in strs: + assert 1 == await decoded_r.ft().sugdel("ac", sug) + # make sure a second delete returns 0 + for sug in strs: + assert 0 == await decoded_r.ft().sugdel("ac", sug) + + # make sure they were actually deleted + ret2 = await decoded_r.ft().sugget("ac", "bad", fuzzy=True, num=10) + for sug in ret2: + assert sug.string not in strs + + # Test with payload + await decoded_r.ft().sugadd("ac", Suggestion("pay1", payload="pl1")) + await decoded_r.ft().sugadd("ac", Suggestion("pay2", payload="pl2")) + await decoded_r.ft().sugadd("ac", Suggestion("pay3", payload="pl3")) + + sugs = await decoded_r.ft().sugget( + "ac", "pay", with_payloads=True, with_scores=True ) - assert 0.14285714285714285 == res.docs[0].score - res = await decoded_r.ft().search(Query("quick").scorer("BM25").with_scores()) - assert 0.22471909420069797 == res.docs[0].score - res = await decoded_r.ft().search(Query("quick").scorer("DISMAX").with_scores()) - assert 2.0 == res.docs[0].score - res = await decoded_r.ft().search( - Query("quick").scorer("DOCSCORE").with_scores() + assert 3 == len(sugs) + for sug in sugs: + assert sug.payload + assert sug.payload.startswith("pl") + + @pytest.mark.redismod + async def test_no_index(self, decoded_r: redis.Redis): + await decoded_r.ft().create_index( + ( + TextField("field"), + TextField("text", no_index=True, sortable=True), + NumericField("numeric", no_index=True, sortable=True), + GeoField("geo", no_index=True, sortable=True), + TagField("tag", no_index=True, sortable=True), + ) ) - assert 1.0 == res.docs[0].score - res = await decoded_r.ft().search( - Query("quick").scorer("HAMMING").with_scores() - ) - assert 0.0 == res.docs[0].score - else: - res = await decoded_r.ft().search(Query("quick").with_scores()) - assert 0.23 == pytest.approx(res["results"][0]["score"], 0.05) - res = await decoded_r.ft().search(Query("quick").scorer("TFIDF").with_scores()) - assert 1.0 == res["results"][0]["score"] - res = await decoded_r.ft().search( - Query("quick").scorer("TFIDF.DOCNORM").with_scores() + + await decoded_r.hset( + "doc1", + mapping={ + "field": "aaa", + "text": "1", + "numeric": "1", + "geo": "1,1", + "tag": "1", + }, ) - assert 0.14285714285714285 == res["results"][0]["score"] - res = await decoded_r.ft().search(Query("quick").scorer("BM25").with_scores()) - assert 0.22471909420069797 == res["results"][0]["score"] - res = await decoded_r.ft().search(Query("quick").scorer("DISMAX").with_scores()) - assert 2.0 == res["results"][0]["score"] - res = await decoded_r.ft().search( - Query("quick").scorer("DOCSCORE").with_scores() + await decoded_r.hset( + "doc2", + mapping={ + "field": "aab", + "text": "2", + "numeric": "2", + "geo": "2,2", + "tag": "2", + }, ) - assert 1.0 == res["results"][0]["score"] - res = await decoded_r.ft().search( - Query("quick").scorer("HAMMING").with_scores() - ) - assert 0.0 == res["results"][0]["score"] - - -@pytest.mark.redismod -async def test_get(decoded_r: redis.Redis): - await decoded_r.ft().create_index((TextField("f1"), TextField("f2"))) - - assert [None] == await decoded_r.ft().get("doc1") - assert [None, None] == await decoded_r.ft().get("doc2", "doc1") - - await decoded_r.hset( - "doc1", mapping={"f1": "some valid content dd1", "f2": "this is sample text f1"} - ) - await decoded_r.hset( - "doc2", mapping={"f1": "some valid content dd2", "f2": "this is sample text f2"} - ) - - assert [ - ["f1", "some valid content dd2", "f2", "this is sample text f2"] - ] == await decoded_r.ft().get("doc2") - assert [ - ["f1", "some valid content dd1", "f2", "this is sample text f1"], - ["f1", "some valid content dd2", "f2", "this is sample text f2"], - ] == await decoded_r.ft().get("doc1", "doc2") - - -@pytest.mark.redismod -@pytest.mark.onlynoncluster -@skip_ifmodversion_lt("2.2.0", "search") -@skip_if_server_version_gte("7.9.0") -async def test_config(decoded_r: redis.Redis): - assert await decoded_r.ft().config_set("TIMEOUT", "100") - with pytest.raises(redis.ResponseError): - await decoded_r.ft().config_set("TIMEOUT", "null") - res = await decoded_r.ft().config_get("*") - assert "100" == res["TIMEOUT"] - res = await decoded_r.ft().config_get("TIMEOUT") - assert "100" == res["TIMEOUT"] - - -@pytest.mark.redismod -@pytest.mark.onlynoncluster -@skip_if_server_version_lt("7.9.0") -async def test_config_with_removed_ftconfig(decoded_r: redis.Redis): - assert await decoded_r.config_set("timeout", "100") - with pytest.raises(redis.ResponseError): - await decoded_r.config_set("timeout", "null") - res = await decoded_r.config_get("*") - assert "100" == res["timeout"] - res = await decoded_r.config_get("timeout") - assert "100" == res["timeout"] - - -@pytest.mark.redismod -@pytest.mark.onlynoncluster -async def test_aggregations_groupby(decoded_r: redis.Redis): - # Creating the index definition and schema - await decoded_r.ft().create_index( - ( - NumericField("random_num"), - TextField("title"), - TextField("body"), - TextField("parent"), - ) - ) - - # Indexing a document - await decoded_r.hset( - "search", - mapping={ - "title": "RediSearch", - "body": "Redisearch impements a search engine on top of redis", - "parent": "redis", - "random_num": 10, - }, - ) - await decoded_r.hset( - "ai", - mapping={ - "title": "RedisAI", - "body": "RedisAI executes Deep Learning/Machine Learning models and managing their data.", # noqa - "parent": "redis", - "random_num": 3, - }, - ) - await decoded_r.hset( - "json", - mapping={ - "title": "RedisJson", - "body": "RedisJSON implements ECMA-404 The JSON Data Interchange Standard as a native data type.", # noqa - "parent": "redis", - "random_num": 8, - }, - ) - - for dialect in [1, 2]: - if is_resp2_connection(decoded_r): - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.count()) - .dialect(dialect) - ) + await self.waitForIndex(decoded_r, "idx") - res = (await decoded_r.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + if is_resp2_connection(decoded_r): + res = await decoded_r.ft().search(Query("@text:aa*")) + assert 0 == res.total - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.count_distinct("@title")) - .dialect(dialect) - ) + res = await decoded_r.ft().search(Query("@field:aa*")) + assert 2 == res.total - res = (await decoded_r.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + res = await decoded_r.ft().search(Query("*").sort_by("text", asc=False)) + assert 2 == res.total + assert "doc2" == res.docs[0].id - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.count_distinctish("@title")) - .dialect(dialect) - ) + res = await decoded_r.ft().search(Query("*").sort_by("text", asc=True)) + assert "doc1" == res.docs[0].id - res = (await decoded_r.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + res = await decoded_r.ft().search(Query("*").sort_by("numeric", asc=True)) + assert "doc1" == res.docs[0].id - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.sum("@random_num")) - .dialect(dialect) - ) + res = await decoded_r.ft().search(Query("*").sort_by("geo", asc=True)) + assert "doc1" == res.docs[0].id - res = (await decoded_r.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "21" # 10+8+3 + res = await decoded_r.ft().search(Query("*").sort_by("tag", asc=True)) + assert "doc1" == res.docs[0].id + else: + res = await decoded_r.ft().search(Query("@text:aa*")) + assert 0 == res["total_results"] - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.min("@random_num")) - .dialect(dialect) - ) + res = await decoded_r.ft().search(Query("@field:aa*")) + assert 2 == res["total_results"] - res = (await decoded_r.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3" # min(10,8,3) + res = await decoded_r.ft().search(Query("*").sort_by("text", asc=False)) + assert 2 == res["total_results"] + assert "doc2" == res["results"][0]["id"] - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.max("@random_num")) - .dialect(dialect) - ) + res = await decoded_r.ft().search(Query("*").sort_by("text", asc=True)) + assert "doc1" == res["results"][0]["id"] - res = (await decoded_r.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "10" # max(10,8,3) + res = await decoded_r.ft().search(Query("*").sort_by("numeric", asc=True)) + assert "doc1" == res["results"][0]["id"] - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.avg("@random_num")) - .dialect(dialect) - ) + res = await decoded_r.ft().search(Query("*").sort_by("geo", asc=True)) + assert "doc1" == res["results"][0]["id"] - res = (await decoded_r.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "7" # (10+3+8)/3 + res = await decoded_r.ft().search(Query("*").sort_by("tag", asc=True)) + assert "doc1" == res["results"][0]["id"] - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.stddev("random_num")) - .dialect(dialect) - ) + # Ensure exception is raised for non-indexable, non-sortable fields + with pytest.raises(Exception): + TextField("name", no_index=True, sortable=False) + with pytest.raises(Exception): + NumericField("name", no_index=True, sortable=False) + with pytest.raises(Exception): + GeoField("name", no_index=True, sortable=False) + with pytest.raises(Exception): + TagField("name", no_index=True, sortable=False) + + @pytest.mark.redismod + @skip_if_server_version_lt("7.4.0") + @skip_ifmodversion_lt("2.10.0", "search") + async def test_create_index_empty_or_missing_fields_with_sortable( + self, + decoded_r: redis.Redis, + ): + definition = IndexDefinition(prefix=["property:"], index_type=IndexType.HASH) + + fields = [ + TextField("title", sortable=True, index_empty=True), + TagField("features", index_missing=True, sortable=True), + TextField("description", no_index=True, sortable=True), + ] - res = (await decoded_r.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3.60555127546" + await decoded_r.ft().create_index(fields, definition=definition) - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.quantile("@random_num", 0.5)) - .dialect(dialect) - ) + @pytest.mark.redismod + async def test_explain(self, decoded_r: redis.Redis): + await decoded_r.ft().create_index( + (TextField("f1"), TextField("f2"), TextField("f3")) + ) + res = await decoded_r.ft().explain("@f3:f3_val @f2:f2_val @f1:f1_val") + assert res - res = (await decoded_r.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "8" # median of 3,8,10 + @pytest.mark.redismod + async def test_explaincli(self, decoded_r: redis.Redis): + with pytest.raises(NotImplementedError): + await decoded_r.ft().explain_cli("foo") - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.tolist("@title")) - .dialect(dialect) - ) + @pytest.mark.redismod + async def test_summarize(self, decoded_r: redis.Redis): + await self.createIndex(decoded_r.ft()) + await self.waitForIndex(decoded_r, "idx") - res = (await decoded_r.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"} + q = Query('"king henry"').paging(0, 1) + q.highlight(fields=("play", "txt"), tags=("", "")) + q.summarize("txt") - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.first_value("@title").alias("first")) - .dialect(dialect) + if is_resp2_connection(decoded_r): + doc = sorted((await decoded_r.ft().search(q)).docs)[0] + assert "Henry IV" == doc.play + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc.txt ) - res = (await decoded_r.ft().aggregate(req)).rows[0] - assert res == ["parent", "redis", "first", "RediSearch"] + q = Query('"king henry"').paging(0, 1).summarize().highlight() - req = ( - aggregations.AggregateRequest("redis") - .group_by( - "@parent", reducers.random_sample("@title", 2).alias("random") - ) - .dialect(dialect) + doc = sorted((await decoded_r.ft().search(q)).docs)[0] + assert "Henry ... " == doc.play + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc.txt ) - - res = (await decoded_r.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[2] == "random" - assert len(res[3]) == 2 - assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"] else: - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.count()) - .dialect(dialect) - ) - - res = (await decoded_r.ft().aggregate(req))["results"][0] - assert res["extra_attributes"]["parent"] == "redis" - assert res["extra_attributes"]["__generated_aliascount"] == "3" - - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.count_distinct("@title")) - .dialect(dialect) - ) - - res = (await decoded_r.ft().aggregate(req))["results"][0] - assert res["extra_attributes"]["parent"] == "redis" + doc = sorted((await decoded_r.ft().search(q))["results"])[0] + assert "Henry IV" == doc["extra_attributes"]["play"] assert ( - res["extra_attributes"]["__generated_aliascount_distincttitle"] == "3" + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc["extra_attributes"]["txt"] ) - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.count_distinctish("@title")) - .dialect(dialect) - ) + q = Query('"king henry"').paging(0, 1).summarize().highlight() - res = (await decoded_r.ft().aggregate(req))["results"][0] - assert res["extra_attributes"]["parent"] == "redis" + doc = sorted((await decoded_r.ft().search(q))["results"])[0] + assert "Henry ... " == doc["extra_attributes"]["play"] assert ( - res["extra_attributes"]["__generated_aliascount_distinctishtitle"] - == "3" - ) - - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.sum("@random_num")) - .dialect(dialect) + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc["extra_attributes"]["txt"] ) - res = (await decoded_r.ft().aggregate(req))["results"][0] - assert res["extra_attributes"]["parent"] == "redis" - assert res["extra_attributes"]["__generated_aliassumrandom_num"] == "21" + @pytest.mark.redismod + @skip_ifmodversion_lt("2.0.0", "search") + async def test_alias(self, decoded_r: redis.Redis): + index1 = self.getClient(decoded_r) + index2 = self.getClient(decoded_r) - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.min("@random_num")) - .dialect(dialect) - ) + def1 = IndexDefinition(prefix=["index1:"]) + def2 = IndexDefinition(prefix=["index2:"]) - res = (await decoded_r.ft().aggregate(req))["results"][0] - assert res["extra_attributes"]["parent"] == "redis" - assert res["extra_attributes"]["__generated_aliasminrandom_num"] == "3" + ftindex1 = index1.ft("testAlias") + ftindex2 = index2.ft("testAlias2") + await ftindex1.create_index((TextField("name"),), definition=def1) + await ftindex2.create_index((TextField("name"),), definition=def2) - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.max("@random_num")) - .dialect(dialect) - ) + await index1.hset("index1:lonestar", mapping={"name": "lonestar"}) + await index2.hset("index2:yogurt", mapping={"name": "yogurt"}) - res = (await decoded_r.ft().aggregate(req))["results"][0] - assert res["extra_attributes"]["parent"] == "redis" - assert res["extra_attributes"]["__generated_aliasmaxrandom_num"] == "10" + if is_resp2_connection(decoded_r): + res = (await ftindex1.search("*")).docs[0] + assert "index1:lonestar" == res.id - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.avg("@random_num")) - .dialect(dialect) - ) + # create alias and check for results + await ftindex1.aliasadd("spaceballs") + alias_client = self.getClient(decoded_r).ft("spaceballs") + res = (await alias_client.search("*")).docs[0] + assert "index1:lonestar" == res.id - res = (await decoded_r.ft().aggregate(req))["results"][0] - assert res["extra_attributes"]["parent"] == "redis" - assert res["extra_attributes"]["__generated_aliasavgrandom_num"] == "7" + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + await ftindex2.aliasadd("spaceballs") - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.stddev("random_num")) - .dialect(dialect) - ) + # update alias and ensure new results + await ftindex2.aliasupdate("spaceballs") + alias_client2 = self.getClient(decoded_r).ft("spaceballs") - res = (await decoded_r.ft().aggregate(req))["results"][0] - assert res["extra_attributes"]["parent"] == "redis" - assert ( - res["extra_attributes"]["__generated_aliasstddevrandom_num"] - == "3.60555127546" - ) + res = (await alias_client2.search("*")).docs[0] + assert "index2:yogurt" == res.id + else: + res = (await ftindex1.search("*"))["results"][0] + assert "index1:lonestar" == res["id"] - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.quantile("@random_num", 0.5)) - .dialect(dialect) - ) + # create alias and check for results + await ftindex1.aliasadd("spaceballs") + alias_client = self.getClient(await decoded_r).ft("spaceballs") + res = (await alias_client.search("*"))["results"][0] + assert "index1:lonestar" == res["id"] - res = (await decoded_r.ft().aggregate(req))["results"][0] - assert res["extra_attributes"]["parent"] == "redis" - assert ( - res["extra_attributes"]["__generated_aliasquantilerandom_num,0.5"] - == "8" - ) + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + await ftindex2.aliasadd("spaceballs") - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.tolist("@title")) - .dialect(dialect) - ) + # update alias and ensure new results + await ftindex2.aliasupdate("spaceballs") + alias_client2 = self.getClient(await decoded_r).ft("spaceballs") - res = (await decoded_r.ft().aggregate(req))["results"][0] - assert res["extra_attributes"]["parent"] == "redis" - assert set(res["extra_attributes"]["__generated_aliastolisttitle"]) == { - "RediSearch", - "RedisAI", - "RedisJson", - } + res = (await alias_client2.search("*"))["results"][0] + assert "index2:yogurt" == res["id"] - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.first_value("@title").alias("first")) - .dialect(dialect) + await ftindex2.aliasdel("spaceballs") + with pytest.raises(Exception): + (await alias_client2.search("*")).docs[0] + + @pytest.mark.redismod + @pytest.mark.xfail(strict=False) + async def test_alias_basic(self, decoded_r: redis.Redis): + # Creating a client with one index + client = self.getClient(decoded_r) + await client.flushdb() + index1 = self.getClient(decoded_r).ft("testAlias") + + await index1.create_index((TextField("txt"),)) + await index1.client.hset("doc1", mapping={"txt": "text goes here"}) + + index2 = self.getClient(decoded_r).ft("testAlias2") + await index2.create_index((TextField("txt"),)) + await index2.client.hset("doc2", mapping={"txt": "text goes here"}) + + # add the actual alias and check + await index1.aliasadd("myalias") + alias_client = self.getClient(decoded_r).ft("myalias") + if is_resp2_connection(decoded_r): + res = sorted((await alias_client.search("*")).docs, key=lambda x: x.id) + assert "doc1" == res[0].id + + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + await index2.aliasadd("myalias") + + # update the alias and ensure we get doc2 + await index2.aliasupdate("myalias") + alias_client2 = self.getClient(decoded_r).ft("myalias") + res = sorted((await alias_client2.search("*")).docs, key=lambda x: x.id) + assert "doc1" == res[0].id + else: + res = sorted( + (await alias_client.search("*"))["results"], key=lambda x: x["id"] ) + assert "doc1" == res[0]["id"] - res = (await decoded_r.ft().aggregate(req))["results"][0] - assert res["extra_attributes"] == {"parent": "redis", "first": "RediSearch"} + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + await index2.aliasadd("myalias") - req = ( - aggregations.AggregateRequest("redis") - .group_by( - "@parent", reducers.random_sample("@title", 2).alias("random") - ) - .dialect(dialect) + # update the alias and ensure we get doc2 + await index2.aliasupdate("myalias") + alias_client2 = self.getClient(client).ft("myalias") + res = sorted( + (await alias_client2.search("*"))["results"], key=lambda x: x["id"] ) + assert "doc1" == res[0]["id"] - res = (await decoded_r.ft().aggregate(req))["results"][0] - assert res["extra_attributes"]["parent"] == "redis" - assert "random" in res["extra_attributes"].keys() - assert len(res["extra_attributes"]["random"]) == 2 - assert res["extra_attributes"]["random"][0] in [ - "RediSearch", - "RedisAI", - "RedisJson", - ] - - -@pytest.mark.redismod -async def test_aggregations_sort_by_and_limit(decoded_r: redis.Redis): - await decoded_r.ft().create_index((TextField("t1"), TextField("t2"))) + # delete the alias and expect an error if we try to query again + await index2.aliasdel("myalias") + with pytest.raises(Exception): + _ = (await alias_client2.search("*")).docs[0] - await decoded_r.ft().client.hset("doc1", mapping={"t1": "a", "t2": "b"}) - await decoded_r.ft().client.hset("doc2", mapping={"t1": "b", "t2": "a"}) + @pytest.mark.redismod + async def test_tags(self, decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("txt"), TagField("tags"))) + tags = "foo,foo bar,hello;world" + tags2 = "soba,ramen" - if is_resp2_connection(decoded_r): - # test sort_by using SortDirection - req = aggregations.AggregateRequest("*").sort_by( - aggregations.Asc("@t2"), aggregations.Desc("@t1") - ) - res = await decoded_r.ft().aggregate(req) - assert res.rows[0] == ["t2", "a", "t1", "b"] - assert res.rows[1] == ["t2", "b", "t1", "a"] + await decoded_r.hset("doc1", mapping={"txt": "fooz barz", "tags": tags}) + await decoded_r.hset("doc2", mapping={"txt": "noodles", "tags": tags2}) + await self.waitForIndex(decoded_r, "idx") - # test sort_by without SortDirection - req = aggregations.AggregateRequest("*").sort_by("@t1") - res = await decoded_r.ft().aggregate(req) - assert res.rows[0] == ["t1", "a"] - assert res.rows[1] == ["t1", "b"] + q = Query("@tags:{foo}") + if is_resp2_connection(decoded_r): + res = await decoded_r.ft().search(q) + assert 1 == res.total - # test sort_by with max - req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) - res = await decoded_r.ft().aggregate(req) - assert len(res.rows) == 1 + q = Query("@tags:{foo bar}") + res = await decoded_r.ft().search(q) + assert 1 == res.total - # test limit - req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) - res = await decoded_r.ft().aggregate(req) - assert len(res.rows) == 1 - assert res.rows[0] == ["t1", "b"] - else: - # test sort_by using SortDirection - req = aggregations.AggregateRequest("*").sort_by( - aggregations.Asc("@t2"), aggregations.Desc("@t1") - ) - res = (await decoded_r.ft().aggregate(req))["results"] - assert res[0]["extra_attributes"] == {"t2": "a", "t1": "b"} - assert res[1]["extra_attributes"] == {"t2": "b", "t1": "a"} - - # test sort_by without SortDirection - req = aggregations.AggregateRequest("*").sort_by("@t1") - res = (await decoded_r.ft().aggregate(req))["results"] - assert res[0]["extra_attributes"] == {"t1": "a"} - assert res[1]["extra_attributes"] == {"t1": "b"} - - # test sort_by with max - req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) - res = await decoded_r.ft().aggregate(req) - assert len(res["results"]) == 1 + q = Query("@tags:{foo\\ bar}") + res = await decoded_r.ft().search(q) + assert 1 == res.total - # test limit - req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) - res = await decoded_r.ft().aggregate(req) - assert len(res["results"]) == 1 - assert res["results"][0]["extra_attributes"] == {"t1": "b"} + q = Query("@tags:{hello\\;world}") + res = await decoded_r.ft().search(q) + assert 1 == res.total + q2 = await decoded_r.ft().tagvals("tags") + assert (tags.split(",") + tags2.split(",")).sort() == q2.sort() + else: + res = await decoded_r.ft().search(q) + assert 1 == res["total_results"] -@pytest.mark.redismod -@pytest.mark.experimental -async def test_withsuffixtrie(decoded_r: redis.Redis): - # create index - assert await decoded_r.ft().create_index((TextField("txt"),)) - await waitForIndex(decoded_r, getattr(decoded_r.ft(), "index_name", "idx")) - if is_resp2_connection(decoded_r): - info = await decoded_r.ft().info() - assert "WITHSUFFIXTRIE" not in info["attributes"][0] - assert await decoded_r.ft().dropindex() + q = Query("@tags:{foo bar}") + res = await decoded_r.ft().search(q) + assert 1 == res["total_results"] - # create withsuffixtrie index (text field) - assert await decoded_r.ft().create_index(TextField("t", withsuffixtrie=True)) - await waitForIndex(decoded_r, getattr(decoded_r.ft(), "index_name", "idx")) - info = await decoded_r.ft().info() - assert "WITHSUFFIXTRIE" in info["attributes"][0] - assert await decoded_r.ft().dropindex() + q = Query("@tags:{foo\\ bar}") + res = await decoded_r.ft().search(q) + assert 1 == res["total_results"] - # create withsuffixtrie index (tag field) - assert await decoded_r.ft().create_index(TagField("t", withsuffixtrie=True)) - await waitForIndex(decoded_r, getattr(decoded_r.ft(), "index_name", "idx")) - info = await decoded_r.ft().info() - assert "WITHSUFFIXTRIE" in info["attributes"][0] - else: - info = await decoded_r.ft().info() - assert "WITHSUFFIXTRIE" not in info["attributes"][0]["flags"] - assert await decoded_r.ft().dropindex() + q = Query("@tags:{hello\\;world}") + res = await decoded_r.ft().search(q) + assert 1 == res["total_results"] - # create withsuffixtrie index (text fields) - assert await decoded_r.ft().create_index(TextField("t", withsuffixtrie=True)) - await waitForIndex(decoded_r, getattr(decoded_r.ft(), "index_name", "idx")) - info = await decoded_r.ft().info() - assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] - assert await decoded_r.ft().dropindex() + q2 = await decoded_r.ft().tagvals("tags") + assert set(tags.split(",") + tags2.split(",")) == set(q2) - # create withsuffixtrie index (tag field) - assert await decoded_r.ft().create_index(TagField("t", withsuffixtrie=True)) - await waitForIndex(decoded_r, getattr(decoded_r.ft(), "index_name", "idx")) - info = await decoded_r.ft().info() - assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.10.05", "search") -async def test_aggregations_add_scores(decoded_r: redis.Redis): - assert await decoded_r.ft().create_index( - ( - TextField("name", sortable=True, weight=5.0), - NumericField("age", sortable=True), - ) - ) - - assert await decoded_r.hset("doc1", mapping={"name": "bar", "age": "25"}) - assert await decoded_r.hset("doc2", mapping={"name": "foo", "age": "19"}) - - req = aggregations.AggregateRequest("*").add_scores() - res = await decoded_r.ft().aggregate(req) - - if isinstance(res, dict): - assert len(res["results"]) == 2 - assert res["results"][0]["extra_attributes"] == {"__score": "0.2"} - assert res["results"][1]["extra_attributes"] == {"__score": "0.2"} - else: - assert len(res.rows) == 2 - assert res.rows[0] == ["__score", "0.2"] - assert res.rows[1] == ["__score", "0.2"] - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.10.05", "search") -async def test_aggregations_hybrid_scoring(decoded_r: redis.Redis): - assert await decoded_r.ft().create_index( - ( - TextField("name", sortable=True, weight=5.0), - TextField("description", sortable=True, weight=5.0), - VectorField( - "vector", - "HNSW", - {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "COSINE"}, - ), - ) - ) - - assert await decoded_r.hset( - "doc1", - mapping={ - "name": "cat book", - "description": "an animal book about cats", - "vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(), - }, - ) - assert await decoded_r.hset( - "doc2", - mapping={ - "name": "dog book", - "description": "an animal book about dogs", - "vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(), - }, - ) - - query_string = "(@description:animal)=>[KNN 3 @vector $vec_param AS dist]" - req = ( - aggregations.AggregateRequest(query_string) - .scorer("BM25") - .add_scores() - .apply(hybrid_score="@__score + @dist") - .load("*") - .dialect(4) - ) - - res = await decoded_r.ft().aggregate( - req, - query_params={"vec_param": np.array([0.11, 0.22]).astype(np.float32).tobytes()}, - ) - - if isinstance(res, dict): - assert len(res["results"]) == 2 - else: - assert len(res.rows) == 2 - for row in res.rows: - len(row) == 6 - - -@pytest.mark.redismod -@skip_if_redis_enterprise() -async def test_search_commands_in_pipeline(decoded_r: redis.Redis): - p = await decoded_r.ft().pipeline() - p.create_index((TextField("txt"),)) - p.hset("doc1", mapping={"txt": "foo bar"}) - p.hset("doc2", mapping={"txt": "foo bar"}) - q = Query("foo bar").with_payloads() - await p.search(q) - res = await p.execute() - if is_resp2_connection(decoded_r): - assert res[:3] == ["OK", True, True] - assert 2 == res[3][0] - assert "doc1" == res[3][1] - assert "doc2" == res[3][4] - assert res[3][5] is None - assert res[3][3] == res[3][6] == ["txt", "foo bar"] - else: - assert res[:3] == ["OK", True, True] - assert 2 == res[3]["total_results"] - assert "doc1" == res[3]["results"][0]["id"] - assert "doc2" == res[3]["results"][1]["id"] - assert res[3]["results"][0]["payload"] is None - assert ( - res[3]["results"][0]["extra_attributes"] - == res[3]["results"][1]["extra_attributes"] - == {"txt": "foo bar"} - ) - - -@pytest.mark.redismod -async def test_query_timeout(decoded_r: redis.Redis): - q1 = Query("foo").timeout(5000) - assert q1.get_args() == ["foo", "TIMEOUT", 5000, "DIALECT", 2, "LIMIT", 0, 10] - q2 = Query("foo").timeout("not_a_number") - with pytest.raises(redis.ResponseError): - await decoded_r.ft().search(q2) - - -@pytest.mark.redismod -@skip_if_resp_version(3) -async def test_binary_and_text_fields(decoded_r: redis.Redis): - fake_vec = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32) - - index_name = "mixed_index" - mixed_data = {"first_name": "🐍python", "vector_emb": fake_vec.tobytes()} - await decoded_r.hset(f"{index_name}:1", mapping=mixed_data) - - schema = [ - TagField("first_name"), - VectorField( - "embeddings_bio", - algorithm="HNSW", - attributes={ - "TYPE": "FLOAT32", - "DIM": 4, - "DISTANCE_METRIC": "COSINE", - }, - ), - ] - - await decoded_r.ft(index_name).create_index( - fields=schema, - definition=IndexDefinition( - prefix=[f"{index_name}:"], index_type=IndexType.HASH - ), - ) - await waitForIndex(decoded_r, index_name) - - query = ( - Query("*") - .return_field("vector_emb", decode_field=False) - .return_field("first_name") - ) - result = await decoded_r.ft(index_name).search(query=query, query_params={}) - docs = result.docs - - if len(docs) == 0: - hash_content = await decoded_r.hget(f"{index_name}:1", "first_name") - assert len(docs) > 0, ( - f"Returned search results are empty. Result: {result}; Hash: {hash_content}" - ) - - decoded_vec_from_search_results = np.frombuffer( - docs[0]["vector_emb"], dtype=np.float32 - ) - - assert np.array_equal(decoded_vec_from_search_results, fake_vec), ( - "The vectors are not equal" - ) - - assert docs[0]["first_name"] == mixed_data["first_name"], ( - "The text field is not decoded correctly" - ) - - -# SVS-VAMANA Async Tests -@pytest.mark.redismod -@skip_if_server_version_lt("8.1.224") -async def test_async_svs_vamana_basic_functionality(decoded_r: redis.Redis): - await decoded_r.ft().create_index( - ( - VectorField( - "v", - "SVS-VAMANA", - {"TYPE": "FLOAT32", "DIM": 4, "DISTANCE_METRIC": "L2"}, - ), - ) - ) - - vectors = [ - [1.0, 2.0, 3.0, 4.0], - [2.0, 3.0, 4.0, 5.0], - [3.0, 4.0, 5.0, 6.0], - [10.0, 11.0, 12.0, 13.0], - ] - - for i, vec in enumerate(vectors): - await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) - - query = "*=>[KNN 3 @v $vec]" - q = Query(query).return_field("__v_score").sort_by("__v_score", True) - res = await decoded_r.ft().search( - q, query_params={"vec": np.array(vectors[0], dtype=np.float32).tobytes()} - ) - - if is_resp2_connection(decoded_r): - assert res.total == 3 - assert "doc0" == res.docs[0].id - else: - assert res["total_results"] == 3 - assert "doc0" == res["results"][0]["id"] - - -@pytest.mark.redismod -@skip_if_server_version_lt("8.1.224") -async def test_async_svs_vamana_distance_metrics(decoded_r: redis.Redis): - # Test COSINE distance - await decoded_r.ft().create_index( - ( - VectorField( - "v", - "SVS-VAMANA", - {"TYPE": "FLOAT32", "DIM": 3, "DISTANCE_METRIC": "COSINE"}, - ), + @pytest.mark.redismod + async def test_textfield_sortable_nostem(self, decoded_r: redis.Redis): + # Creating the index definition with sortable and no_stem + await decoded_r.ft().create_index( + (TextField("txt", sortable=True, no_stem=True),) ) - ) - vectors = [[1.0, 0.0, 0.0], [0.707, 0.707, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]] + # Now get the index info to confirm its contents + response = await decoded_r.ft().info() + if is_resp2_connection(decoded_r): + assert "SORTABLE" in response["attributes"][0] + assert "NOSTEM" in response["attributes"][0] + else: + assert "SORTABLE" in response["attributes"][0]["flags"] + assert "NOSTEM" in response["attributes"][0]["flags"] - for i, vec in enumerate(vectors): - await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + @pytest.mark.redismod + async def test_alter_schema_add(self, decoded_r: redis.Redis): + # Creating the index definition and schema + await decoded_r.ft().create_index(TextField("title")) - query = Query("*=>[KNN 2 @v $vec as score]").sort_by("score").no_content() - query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + # Using alter to add a field + await decoded_r.ft().alter_schema_add(TextField("body")) - res = await decoded_r.ft().search(query, query_params=query_params) - if is_resp2_connection(decoded_r): - assert res.total == 2 - assert "doc0" == res.docs[0].id - else: - assert res["total_results"] == 2 - assert "doc0" == res["results"][0]["id"] + # Indexing a document + await decoded_r.hset( + "doc1", + mapping={"title": "MyTitle", "body": "Some content only in the body"}, + ) + # Searching with parameter only in the body (the added field) + q = Query("only in the body") -@pytest.mark.redismod -@skip_if_server_version_lt("8.1.224") -async def test_async_svs_vamana_vector_types(decoded_r: redis.Redis): - # Test FLOAT16 - await decoded_r.ft("idx16").create_index( - ( - VectorField( - "v16", - "SVS-VAMANA", - {"TYPE": "FLOAT16", "DIM": 4, "DISTANCE_METRIC": "L2"}, - ), - ) - ) + # Ensure we find the result searching on the added body field + res = await decoded_r.ft().search(q) + if is_resp2_connection(decoded_r): + assert 1 == res.total + else: + assert 1 == res["total_results"] - vectors = [[1.5, 2.5, 3.5, 4.5], [2.5, 3.5, 4.5, 5.5], [3.5, 4.5, 5.5, 6.5]] + @pytest.mark.redismod + async def test_spell_check(self, decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("f1"), TextField("f2"))) - for i, vec in enumerate(vectors): await decoded_r.hset( - f"doc16_{i}", "v16", np.array(vec, dtype=np.float16).tobytes() + "doc1", mapping={"f1": "some valid content", "f2": "this is sample text"} ) + await decoded_r.hset( + "doc2", mapping={"f1": "very important", "f2": "lorem ipsum"} + ) + await self.waitForIndex(decoded_r, "idx") - query = Query("*=>[KNN 2 @v16 $vec as score]").no_content() - query_params = {"vec": np.array(vectors[0], dtype=np.float16).tobytes()} - - res = await decoded_r.ft("idx16").search(query, query_params=query_params) - if is_resp2_connection(decoded_r): - assert res.total == 2 - assert "doc16_0" == res.docs[0].id - else: - assert res["total_results"] == 2 - assert "doc16_0" == res["results"][0]["id"] + if is_resp2_connection(decoded_r): + # test spellcheck + res = await decoded_r.ft().spellcheck("impornant") + assert "important" == res["impornant"][0]["suggestion"] + + res = await decoded_r.ft().spellcheck("contnt") + assert "content" == res["contnt"][0]["suggestion"] + + # test spellcheck with Levenshtein distance + res = await decoded_r.ft().spellcheck("vlis") + assert res == {} + res = await decoded_r.ft().spellcheck("vlis", distance=2) + assert "valid" == res["vlis"][0]["suggestion"] + + # test spellcheck include + await decoded_r.ft().dict_add("dict", "lore", "lorem", "lorm") + res = await decoded_r.ft().spellcheck("lorm", include="dict") + assert len(res["lorm"]) == 3 + assert ( + res["lorm"][0]["suggestion"], + res["lorm"][1]["suggestion"], + res["lorm"][2]["suggestion"], + ) == ("lorem", "lore", "lorm") + assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ("0.5", "0") + + # test spellcheck exclude + res = await decoded_r.ft().spellcheck("lorm", exclude="dict") + assert res == {} + else: + # test spellcheck + res = await decoded_r.ft().spellcheck("impornant") + assert "important" in res["results"]["impornant"][0].keys() + + res = await decoded_r.ft().spellcheck("contnt") + assert "content" in res["results"]["contnt"][0].keys() + + # test spellcheck with Levenshtein distance + res = await decoded_r.ft().spellcheck("vlis") + assert res == {"results": {"vlis": []}} + res = await decoded_r.ft().spellcheck("vlis", distance=2) + assert "valid" in res["results"]["vlis"][0].keys() + + # test spellcheck include + await decoded_r.ft().dict_add("dict", "lore", "lorem", "lorm") + res = await decoded_r.ft().spellcheck("lorm", include="dict") + assert len(res["results"]["lorm"]) == 3 + assert "lorem" in res["results"]["lorm"][0].keys() + assert "lore" in res["results"]["lorm"][1].keys() + assert "lorm" in res["results"]["lorm"][2].keys() + assert ( + res["results"]["lorm"][0]["lorem"], + res["results"]["lorm"][1]["lore"], + ) == (0.5, 0) + + # test spellcheck exclude + res = await decoded_r.ft().spellcheck("lorm", exclude="dict") + assert res == {"results": {}} + + @pytest.mark.redismod + async def test_dict_operations(self, decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("f1"), TextField("f2"))) + # Add three items + res = await decoded_r.ft().dict_add("custom_dict", "item1", "item2", "item3") + assert 3 == res + + # Remove one item + res = await decoded_r.ft().dict_del("custom_dict", "item2") + assert 1 == res + + # Dump dict and inspect content + res = await decoded_r.ft().dict_dump("custom_dict") + assert res == ["item1", "item3"] + + # Remove rest of the items before reload + await decoded_r.ft().dict_del("custom_dict", *res) + + @pytest.mark.redismod + async def test_phonetic_matcher(self, decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("name"),)) + await decoded_r.hset("doc1", mapping={"name": "Jon"}) + await decoded_r.hset("doc2", mapping={"name": "John"}) + + res = await decoded_r.ft().search(Query("Jon")) + if is_resp2_connection(decoded_r): + assert 1 == len(res.docs) + assert "Jon" == res.docs[0].name + else: + assert 1 == res["total_results"] + assert "Jon" == res["results"][0]["extra_attributes"]["name"] + # Drop and create index with phonetic matcher + await decoded_r.flushdb() -@pytest.mark.redismod -@skip_if_server_version_lt("8.1.224") -async def test_async_svs_vamana_compression(decoded_r: redis.Redis): - await decoded_r.ft().create_index( - ( - VectorField( - "v", - "SVS-VAMANA", - { - "TYPE": "FLOAT32", - "DIM": 8, - "DISTANCE_METRIC": "L2", - "COMPRESSION": "LVQ8", - "TRAINING_THRESHOLD": 1024, - }, - ), + await decoded_r.ft().create_index( + (TextField("name", phonetic_matcher="dm:en"),) ) - ) + await decoded_r.hset("doc1", mapping={"name": "Jon"}) + await decoded_r.hset("doc2", mapping={"name": "John"}) - vectors = [] - for i in range(20): - vec = [float(i + j) for j in range(8)] - vectors.append(vec) - await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + res = await decoded_r.ft().search(Query("Jon")) + if is_resp2_connection(decoded_r): + assert 2 == len(res.docs) + assert ["John", "Jon"] == sorted(d.name for d in res.docs) + else: + assert 2 == res["total_results"] + assert ["John", "Jon"] == sorted( + d["extra_attributes"]["name"] for d in res["results"] + ) - query = Query("*=>[KNN 5 @v $vec as score]").no_content() - query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + @pytest.mark.redismod + async def test_get(self, decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("f1"), TextField("f2"))) - res = await decoded_r.ft().search(query, query_params=query_params) - if is_resp2_connection(decoded_r): - assert res.total == 5 - assert "doc0" == res.docs[0].id - else: - assert res["total_results"] == 5 - assert "doc0" == res["results"][0]["id"] + assert [None] == await decoded_r.ft().get("doc1") + assert [None, None] == await decoded_r.ft().get("doc2", "doc1") + await decoded_r.hset( + "doc1", + mapping={"f1": "some valid content dd1", "f2": "this is sample text f1"}, + ) + await decoded_r.hset( + "doc2", + mapping={"f1": "some valid content dd2", "f2": "this is sample text f2"}, + ) -@pytest.mark.redismod -@skip_if_server_version_lt("8.1.224") -async def test_async_svs_vamana_build_parameters(decoded_r: redis.Redis): - await decoded_r.ft().create_index( - ( + assert [ + ["f1", "some valid content dd2", "f2", "this is sample text f2"] + ] == await decoded_r.ft().get("doc2") + assert [ + ["f1", "some valid content dd1", "f2", "this is sample text f1"], + ["f1", "some valid content dd2", "f2", "this is sample text f2"], + ] == await decoded_r.ft().get("doc1", "doc2") + + @pytest.mark.redismod + async def test_query_timeout(self, decoded_r: redis.Redis): + q1 = Query("foo").timeout(5000) + assert q1.get_args() == ["foo", "TIMEOUT", 5000, "DIALECT", 2, "LIMIT", 0, 10] + q2 = Query("foo").timeout("not_a_number") + with pytest.raises(redis.ResponseError): + await decoded_r.ft().search(q2) + + @pytest.mark.redismod + @skip_if_resp_version(3) + async def test_binary_and_text_fields(self, decoded_r: redis.Redis): + fake_vec = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32) + + index_name = "mixed_index" + mixed_data = {"first_name": "🐍python", "vector_emb": fake_vec.tobytes()} + await decoded_r.hset(f"{index_name}:1", mapping=mixed_data) + + schema = [ + TagField("first_name"), VectorField( - "v", - "SVS-VAMANA", - { + "embeddings_bio", + algorithm="HNSW", + attributes={ "TYPE": "FLOAT32", - "DIM": 6, + "DIM": 4, "DISTANCE_METRIC": "COSINE", - "CONSTRUCTION_WINDOW_SIZE": 300, - "GRAPH_MAX_DEGREE": 64, - "SEARCH_WINDOW_SIZE": 20, - "EPSILON": 0.05, }, ), + ] + + await decoded_r.ft(index_name).create_index( + fields=schema, + definition=IndexDefinition( + prefix=[f"{index_name}:"], index_type=IndexType.HASH + ), ) - ) + await self.waitForIndex(decoded_r, index_name) - vectors = [] - for i in range(15): - vec = [float(i + j) for j in range(6)] - vectors.append(vec) - await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + query = ( + Query("*") + .return_field("vector_emb", decode_field=False) + .return_field("first_name") + ) + result = await decoded_r.ft(index_name).search(query=query, query_params={}) + docs = result.docs - query = Query("*=>[KNN 3 @v $vec as score]").no_content() - query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + if len(docs) == 0: + hash_content = await decoded_r.hget(f"{index_name}:1", "first_name") + assert len(docs) > 0, ( + f"Returned search results are empty. Result: {result}; Hash: {hash_content}" + ) - res = await decoded_r.ft().search(query, query_params=query_params) - if is_resp2_connection(decoded_r): - assert res.total == 3 - assert "doc0" == res.docs[0].id - else: - assert res["total_results"] == 3 - assert "doc0" == res["results"][0]["id"] + decoded_vec_from_search_results = np.frombuffer( + docs[0]["vector_emb"], dtype=np.float32 + ) + + assert np.array_equal(decoded_vec_from_search_results, fake_vec), ( + "The vectors are not equal" + ) + + assert docs[0]["first_name"] == mixed_data["first_name"], ( + "The text field is not decoded correctly" + ) + + +class TestScorers(AsyncSearchTestsBase): + @pytest.mark.redismod + @pytest.mark.onlynoncluster + # NOTE(imalinovskyi): This test contains hardcoded scores valid only for RediSearch 2.8+ + @skip_ifmodversion_lt("2.8.0", "search") + @skip_if_server_version_gte("7.9.0") + async def test_scorer(self, decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("description"),)) + + await decoded_r.hset( + "doc1", + mapping={"description": "The quick brown fox jumps over the lazy dog"}, + ) + await decoded_r.hset( + "doc2", + mapping={ + "description": "Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do." # noqa + }, + ) + + if is_resp2_connection(decoded_r): + # default scorer is TFIDF + res = await decoded_r.ft().search(Query("quick").with_scores()) + assert 1.0 == res.docs[0].score + res = await decoded_r.ft().search( + Query("quick").scorer("TFIDF").with_scores() + ) + assert 1.0 == res.docs[0].score + res = await decoded_r.ft().search( + Query("quick").scorer("TFIDF.DOCNORM").with_scores() + ) + assert 0.14285714285714285 == res.docs[0].score + res = await decoded_r.ft().search( + Query("quick").scorer("BM25").with_scores() + ) + assert 0.22471909420069797 == res.docs[0].score + res = await decoded_r.ft().search( + Query("quick").scorer("DISMAX").with_scores() + ) + assert 2.0 == res.docs[0].score + res = await decoded_r.ft().search( + Query("quick").scorer("DOCSCORE").with_scores() + ) + assert 1.0 == res.docs[0].score + res = await decoded_r.ft().search( + Query("quick").scorer("HAMMING").with_scores() + ) + assert 0.0 == res.docs[0].score + else: + res = await decoded_r.ft().search(Query("quick").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = await decoded_r.ft().search( + Query("quick").scorer("TFIDF").with_scores() + ) + assert 1.0 == res["results"][0]["score"] + res = await decoded_r.ft().search( + Query("quick").scorer("TFIDF.DOCNORM").with_scores() + ) + assert 0.14285714285714285 == res["results"][0]["score"] + res = await decoded_r.ft().search( + Query("quick").scorer("BM25").with_scores() + ) + assert 0.22471909420069797 == res["results"][0]["score"] + res = await decoded_r.ft().search( + Query("quick").scorer("DISMAX").with_scores() + ) + assert 2.0 == res["results"][0]["score"] + res = await decoded_r.ft().search( + Query("quick").scorer("DOCSCORE").with_scores() + ) + assert 1.0 == res["results"][0]["score"] + res = await decoded_r.ft().search( + Query("quick").scorer("HAMMING").with_scores() + ) + assert 0.0 == res["results"][0]["score"] + + @pytest.mark.redismod + @pytest.mark.onlynoncluster + # NOTE(imalinovskyi): This test contains hardcoded scores valid only for RediSearch 2.8+ + @skip_ifmodversion_lt("2.8.0", "search") + @skip_if_server_version_lt("7.9.0") + async def test_scorer_with_new_default_scorer(self, decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("description"),)) + + await decoded_r.hset( + "doc1", + mapping={"description": "The quick brown fox jumps over the lazy dog"}, + ) + await decoded_r.hset( + "doc2", + mapping={ + "description": "Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do." # noqa + }, + ) + + if is_resp2_connection(decoded_r): + # default scorer is BM25STD + res = await decoded_r.ft().search(Query("quick").with_scores()) + assert 0.23 == pytest.approx(res.docs[0].score, 0.05) + res = await decoded_r.ft().search( + Query("quick").scorer("TFIDF").with_scores() + ) + assert 1.0 == res.docs[0].score + res = await decoded_r.ft().search( + Query("quick").scorer("TFIDF.DOCNORM").with_scores() + ) + assert 0.14285714285714285 == res.docs[0].score + res = await decoded_r.ft().search( + Query("quick").scorer("BM25").with_scores() + ) + assert 0.22471909420069797 == res.docs[0].score + res = await decoded_r.ft().search( + Query("quick").scorer("DISMAX").with_scores() + ) + assert 2.0 == res.docs[0].score + res = await decoded_r.ft().search( + Query("quick").scorer("DOCSCORE").with_scores() + ) + assert 1.0 == res.docs[0].score + res = await decoded_r.ft().search( + Query("quick").scorer("HAMMING").with_scores() + ) + assert 0.0 == res.docs[0].score + else: + res = await decoded_r.ft().search(Query("quick").with_scores()) + assert 0.23 == pytest.approx(res["results"][0]["score"], 0.05) + res = await decoded_r.ft().search( + Query("quick").scorer("TFIDF").with_scores() + ) + assert 1.0 == res["results"][0]["score"] + res = await decoded_r.ft().search( + Query("quick").scorer("TFIDF.DOCNORM").with_scores() + ) + assert 0.14285714285714285 == res["results"][0]["score"] + res = await decoded_r.ft().search( + Query("quick").scorer("BM25").with_scores() + ) + assert 0.22471909420069797 == res["results"][0]["score"] + res = await decoded_r.ft().search( + Query("quick").scorer("DISMAX").with_scores() + ) + assert 2.0 == res["results"][0]["score"] + res = await decoded_r.ft().search( + Query("quick").scorer("DOCSCORE").with_scores() + ) + assert 1.0 == res["results"][0]["score"] + res = await decoded_r.ft().search( + Query("quick").scorer("HAMMING").with_scores() + ) + assert 0.0 == res["results"][0]["score"] + + +class TestConfig(AsyncSearchTestsBase): + @pytest.mark.redismod + @pytest.mark.onlynoncluster + @skip_ifmodversion_lt("2.2.0", "search") + @skip_if_server_version_gte("7.9.0") + async def test_config(self, decoded_r: redis.Redis): + assert await decoded_r.ft().config_set("TIMEOUT", "100") + with pytest.raises(redis.ResponseError): + await decoded_r.ft().config_set("TIMEOUT", "null") + res = await decoded_r.ft().config_get("*") + assert "100" == res["TIMEOUT"] + res = await decoded_r.ft().config_get("TIMEOUT") + assert "100" == res["TIMEOUT"] + + @pytest.mark.redismod + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.9.0") + async def test_config_with_removed_ftconfig(self, decoded_r: redis.Redis): + assert await decoded_r.config_set("timeout", "100") + with pytest.raises(redis.ResponseError): + await decoded_r.config_set("timeout", "null") + res = await decoded_r.config_get("*") + assert "100" == res["timeout"] + res = await decoded_r.config_get("timeout") + assert "100" == res["timeout"] + + +class TestAggregations(AsyncSearchTestsBase): + @pytest.mark.redismod + @pytest.mark.onlynoncluster + async def test_aggregations_groupby(self, decoded_r: redis.Redis): + # Creating the index definition and schema + await decoded_r.ft().create_index( + ( + NumericField("random_num"), + TextField("title"), + TextField("body"), + TextField("parent"), + ) + ) + + # Indexing a document + await decoded_r.hset( + "search", + mapping={ + "title": "RediSearch", + "body": "Redisearch impements a search engine on top of redis", + "parent": "redis", + "random_num": 10, + }, + ) + await decoded_r.hset( + "ai", + mapping={ + "title": "RedisAI", + "body": "RedisAI executes Deep Learning/Machine Learning models and managing their data.", # noqa + "parent": "redis", + "random_num": 3, + }, + ) + await decoded_r.hset( + "json", + mapping={ + "title": "RedisJson", + "body": "RedisJSON implements ECMA-404 The JSON Data Interchange Standard as a native data type.", # noqa + "parent": "redis", + "random_num": 8, + }, + ) + + for dialect in [1, 2]: + if is_resp2_connection(decoded_r): + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count()) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3" + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count_distinct("@title")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3" + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count_distinctish("@title")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3" + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.sum("@random_num")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "21" # 10+8+3 + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.min("@random_num")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3" # min(10,8,3) + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.max("@random_num")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "10" # max(10,8,3) + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.avg("@random_num")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "7" # (10+3+8)/3 + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.stddev("random_num")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3.60555127546" + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.quantile("@random_num", 0.5)) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "8" # median of 3,8,10 + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.tolist("@title")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"} + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.first_value("@title").alias("first")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res == ["parent", "redis", "first", "RediSearch"] + + req = ( + aggregations.AggregateRequest("redis") + .group_by( + "@parent", reducers.random_sample("@title", 2).alias("random") + ) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[2] == "random" + assert len(res[3]) == 2 + assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"] + else: + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count()) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliascount"] == "3" + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count_distinct("@title")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert ( + res["extra_attributes"]["__generated_aliascount_distincttitle"] + == "3" + ) + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count_distinctish("@title")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert ( + res["extra_attributes"]["__generated_aliascount_distinctishtitle"] + == "3" + ) + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.sum("@random_num")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliassumrandom_num"] == "21" + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.min("@random_num")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliasminrandom_num"] == "3" + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.max("@random_num")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliasmaxrandom_num"] == "10" + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.avg("@random_num")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliasavgrandom_num"] == "7" + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.stddev("random_num")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert ( + res["extra_attributes"]["__generated_aliasstddevrandom_num"] + == "3.60555127546" + ) + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.quantile("@random_num", 0.5)) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert ( + res["extra_attributes"]["__generated_aliasquantilerandom_num,0.5"] + == "8" + ) + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.tolist("@title")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert set(res["extra_attributes"]["__generated_aliastolisttitle"]) == { + "RediSearch", + "RedisAI", + "RedisJson", + } + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.first_value("@title").alias("first")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"] == { + "parent": "redis", + "first": "RediSearch", + } + + req = ( + aggregations.AggregateRequest("redis") + .group_by( + "@parent", reducers.random_sample("@title", 2).alias("random") + ) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert "random" in res["extra_attributes"].keys() + assert len(res["extra_attributes"]["random"]) == 2 + assert res["extra_attributes"]["random"][0] in [ + "RediSearch", + "RedisAI", + "RedisJson", + ] + + @pytest.mark.redismod + async def test_aggregations_sort_by_and_limit(self, decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("t1"), TextField("t2"))) + + await decoded_r.ft().client.hset("doc1", mapping={"t1": "a", "t2": "b"}) + await decoded_r.ft().client.hset("doc2", mapping={"t1": "b", "t2": "a"}) + + if is_resp2_connection(decoded_r): + # test sort_by using SortDirection + req = aggregations.AggregateRequest("*").sort_by( + aggregations.Asc("@t2"), aggregations.Desc("@t1") + ) + res = await decoded_r.ft().aggregate(req) + assert res.rows[0] == ["t2", "a", "t1", "b"] + assert res.rows[1] == ["t2", "b", "t1", "a"] + + # test sort_by without SortDirection + req = aggregations.AggregateRequest("*").sort_by("@t1") + res = await decoded_r.ft().aggregate(req) + assert res.rows[0] == ["t1", "a"] + assert res.rows[1] == ["t1", "b"] + + # test sort_by with max + req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) + res = await decoded_r.ft().aggregate(req) + assert len(res.rows) == 1 + + # test limit + req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) + res = await decoded_r.ft().aggregate(req) + assert len(res.rows) == 1 + assert res.rows[0] == ["t1", "b"] + else: + # test sort_by using SortDirection + req = aggregations.AggregateRequest("*").sort_by( + aggregations.Asc("@t2"), aggregations.Desc("@t1") + ) + res = (await decoded_r.ft().aggregate(req))["results"] + assert res[0]["extra_attributes"] == {"t2": "a", "t1": "b"} + assert res[1]["extra_attributes"] == {"t2": "b", "t1": "a"} + + # test sort_by without SortDirection + req = aggregations.AggregateRequest("*").sort_by("@t1") + res = (await decoded_r.ft().aggregate(req))["results"] + assert res[0]["extra_attributes"] == {"t1": "a"} + assert res[1]["extra_attributes"] == {"t1": "b"} + + # test sort_by with max + req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) + res = await decoded_r.ft().aggregate(req) + assert len(res["results"]) == 1 + + # test limit + req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) + res = await decoded_r.ft().aggregate(req) + assert len(res["results"]) == 1 + assert res["results"][0]["extra_attributes"] == {"t1": "b"} + + @pytest.mark.redismod + @pytest.mark.experimental + async def test_withsuffixtrie(self, decoded_r: redis.Redis): + # create index + assert await decoded_r.ft().create_index((TextField("txt"),)) + await self.waitForIndex(decoded_r, getattr(decoded_r.ft(), "index_name", "idx")) + if is_resp2_connection(decoded_r): + info = await decoded_r.ft().info() + assert "WITHSUFFIXTRIE" not in info["attributes"][0] + assert await decoded_r.ft().dropindex() + + # create withsuffixtrie index (text field) + assert await decoded_r.ft().create_index( + TextField("t", withsuffixtrie=True) + ) + await self.waitForIndex( + decoded_r, getattr(decoded_r.ft(), "index_name", "idx") + ) + info = await decoded_r.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0] + assert await decoded_r.ft().dropindex() + + # create withsuffixtrie index (tag field) + assert await decoded_r.ft().create_index(TagField("t", withsuffixtrie=True)) + await self.waitForIndex( + decoded_r, getattr(decoded_r.ft(), "index_name", "idx") + ) + info = await decoded_r.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0] + else: + info = await decoded_r.ft().info() + assert "WITHSUFFIXTRIE" not in info["attributes"][0]["flags"] + assert await decoded_r.ft().dropindex() + + # create withsuffixtrie index (text fields) + assert await decoded_r.ft().create_index( + TextField("t", withsuffixtrie=True) + ) + await self.waitForIndex( + decoded_r, getattr(decoded_r.ft(), "index_name", "idx") + ) + info = await decoded_r.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] + assert await decoded_r.ft().dropindex() + + # create withsuffixtrie index (tag field) + assert await decoded_r.ft().create_index(TagField("t", withsuffixtrie=True)) + await self.waitForIndex( + decoded_r, getattr(decoded_r.ft(), "index_name", "idx") + ) + info = await decoded_r.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.10.05", "search") + async def test_aggregations_add_scores(self, decoded_r: redis.Redis): + assert await decoded_r.ft().create_index( + ( + TextField("name", sortable=True, weight=5.0), + NumericField("age", sortable=True), + ) + ) + + assert await decoded_r.hset("doc1", mapping={"name": "bar", "age": "25"}) + assert await decoded_r.hset("doc2", mapping={"name": "foo", "age": "19"}) + + req = aggregations.AggregateRequest("*").add_scores() + res = await decoded_r.ft().aggregate(req) + + if isinstance(res, dict): + assert len(res["results"]) == 2 + assert res["results"][0]["extra_attributes"] == {"__score": "0.2"} + assert res["results"][1]["extra_attributes"] == {"__score": "0.2"} + else: + assert len(res.rows) == 2 + assert res.rows[0] == ["__score", "0.2"] + assert res.rows[1] == ["__score", "0.2"] + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.10.05", "search") + async def test_aggregations_hybrid_scoring(self, decoded_r: redis.Redis): + assert await decoded_r.ft().create_index( + ( + TextField("name", sortable=True, weight=5.0), + TextField("description", sortable=True, weight=5.0), + VectorField( + "vector", + "HNSW", + {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "COSINE"}, + ), + ) + ) + + assert await decoded_r.hset( + "doc1", + mapping={ + "name": "cat book", + "description": "an animal book about cats", + "vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(), + }, + ) + assert await decoded_r.hset( + "doc2", + mapping={ + "name": "dog book", + "description": "an animal book about dogs", + "vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(), + }, + ) + + query_string = "(@description:animal)=>[KNN 3 @vector $vec_param AS dist]" + req = ( + aggregations.AggregateRequest(query_string) + .scorer("BM25") + .add_scores() + .apply(hybrid_score="@__score + @dist") + .load("*") + .dialect(4) + ) + + res = await decoded_r.ft().aggregate( + req, + query_params={ + "vec_param": np.array([0.11, 0.22]).astype(np.float32).tobytes() + }, + ) + + if isinstance(res, dict): + assert len(res["results"]) == 2 + else: + assert len(res.rows) == 2 + for row in res.rows: + len(row) == 6 + + +class TestPipeline(AsyncSearchTestsBase): + @pytest.mark.redismod + @skip_if_redis_enterprise() + async def test_search_commands_in_pipeline(self, decoded_r: redis.Redis): + p = await decoded_r.ft().pipeline() + p.create_index((TextField("txt"),)) + p.hset("doc1", mapping={"txt": "foo bar"}) + p.hset("doc2", mapping={"txt": "foo bar"}) + q = Query("foo bar").with_payloads() + await p.search(q) + res = await p.execute() + if is_resp2_connection(decoded_r): + assert res[:3] == ["OK", True, True] + assert 2 == res[3][0] + assert "doc1" == res[3][1] + assert "doc2" == res[3][4] + assert res[3][5] is None + assert res[3][3] == res[3][6] == ["txt", "foo bar"] + else: + assert res[:3] == ["OK", True, True] + assert 2 == res[3]["total_results"] + assert "doc1" == res[3]["results"][0]["id"] + assert "doc2" == res[3]["results"][1]["id"] + assert res[3]["results"][0]["payload"] is None + assert ( + res[3]["results"][0]["extra_attributes"] + == res[3]["results"][1]["extra_attributes"] + == {"txt": "foo bar"} + ) + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + async def test_hybrid_search_query_with_pipeline(self, decoded_r: redis.Redis): + p = decoded_r.ft().pipeline() + p.create_index( + ( + TextField("txt"), + VectorField( + "embedding", + "FLAT", + {"TYPE": "FLOAT32", "DIM": 4, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + p.hset( + "doc1", + mapping={ + "txt": "foo bar", + "embedding": np.array([1, 2, 3, 4], dtype=np.float32).tobytes(), + }, + ) + p.hset( + "doc2", + mapping={ + "txt": "foo bar", + "embedding": np.array([1, 2, 2, 3], dtype=np.float32).tobytes(), + }, + ) + + # set search query + search_query = HybridSearchQuery("foo") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([2, 2, 3, 3], dtype=np.float32).tobytes(), + ) + + hybrid_query = HybridQuery(search_query, vsim_query) + + await p.hybrid_search(query=hybrid_query) + res = await p.execute() + + # the default results count limit is 10 + assert res[:3] == ["OK", 2, 2] + hybrid_search_res = res[3] + if is_resp2_connection(decoded_r): + # it doesn't get parsed to object in pipeline + assert hybrid_search_res[0] == "total_results" + assert hybrid_search_res[1] == 2 + assert hybrid_search_res[2] == "results" + assert len(hybrid_search_res[3]) == 2 + assert hybrid_search_res[4] == "warnings" + assert hybrid_search_res[5] == [] + assert hybrid_search_res[6] == "execution_time" + assert float(hybrid_search_res[7]) > 0 + else: + assert hybrid_search_res["total_results"] == 2 + assert len(hybrid_search_res["results"]) == 2 + assert hybrid_search_res["warnings"] == [] + assert hybrid_search_res["execution_time"] > 0 + + +class TestSearchWithVamana(AsyncSearchTestsBase): + # SVS-VAMANA Async Tests + @pytest.mark.redismod + @skip_if_server_version_lt("8.1.224") + async def test_async_svs_vamana_basic_functionality(self, decoded_r: redis.Redis): + await decoded_r.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 4, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + vectors = [ + [1.0, 2.0, 3.0, 4.0], + [2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [10.0, 11.0, 12.0, 13.0], + ] + + for i, vec in enumerate(vectors): + await decoded_r.hset( + f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes() + ) + + query = "*=>[KNN 3 @v $vec]" + q = Query(query).return_field("__v_score").sort_by("__v_score", True) + res = await decoded_r.ft().search( + q, query_params={"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + ) + + if is_resp2_connection(decoded_r): + assert res.total == 3 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"] + + @pytest.mark.redismod + @skip_if_server_version_lt("8.1.224") + async def test_async_svs_vamana_distance_metrics(self, decoded_r: redis.Redis): + # Test COSINE distance + await decoded_r.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 3, "DISTANCE_METRIC": "COSINE"}, + ), + ) + ) + + vectors = [ + [1.0, 0.0, 0.0], + [0.707, 0.707, 0.0], + [0.0, 1.0, 0.0], + [-1.0, 0.0, 0.0], + ] + + for i, vec in enumerate(vectors): + await decoded_r.hset( + f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes() + ) + + query = Query("*=>[KNN 2 @v $vec as score]").sort_by("score").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = await decoded_r.ft().search(query, query_params=query_params) + if is_resp2_connection(decoded_r): + assert res.total == 2 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 2 + assert "doc0" == res["results"][0]["id"] + + @pytest.mark.redismod + @skip_if_server_version_lt("8.1.224") + async def test_async_svs_vamana_vector_types(self, decoded_r: redis.Redis): + # Test FLOAT16 + await decoded_r.ft("idx16").create_index( + ( + VectorField( + "v16", + "SVS-VAMANA", + {"TYPE": "FLOAT16", "DIM": 4, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + vectors = [[1.5, 2.5, 3.5, 4.5], [2.5, 3.5, 4.5, 5.5], [3.5, 4.5, 5.5, 6.5]] + + for i, vec in enumerate(vectors): + await decoded_r.hset( + f"doc16_{i}", "v16", np.array(vec, dtype=np.float16).tobytes() + ) + + query = Query("*=>[KNN 2 @v16 $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float16).tobytes()} + + res = await decoded_r.ft("idx16").search(query, query_params=query_params) + if is_resp2_connection(decoded_r): + assert res.total == 2 + assert "doc16_0" == res.docs[0].id + else: + assert res["total_results"] == 2 + assert "doc16_0" == res["results"][0]["id"] + + @pytest.mark.redismod + @skip_if_server_version_lt("8.1.224") + async def test_async_svs_vamana_compression(self, decoded_r: redis.Redis): + await decoded_r.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 8, + "DISTANCE_METRIC": "L2", + "COMPRESSION": "LVQ8", + "TRAINING_THRESHOLD": 1024, + }, + ), + ) + ) + + vectors = [] + for i in range(20): + vec = [float(i + j) for j in range(8)] + vectors.append(vec) + await decoded_r.hset( + f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes() + ) + + query = Query("*=>[KNN 5 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = await decoded_r.ft().search(query, query_params=query_params) + if is_resp2_connection(decoded_r): + assert res.total == 5 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 5 + assert "doc0" == res["results"][0]["id"] + + @pytest.mark.redismod + @skip_if_server_version_lt("8.1.224") + async def test_async_svs_vamana_build_parameters(self, decoded_r: redis.Redis): + await decoded_r.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "COSINE", + "CONSTRUCTION_WINDOW_SIZE": 300, + "GRAPH_MAX_DEGREE": 64, + "SEARCH_WINDOW_SIZE": 20, + "EPSILON": 0.05, + }, + ), + ) + ) + + vectors = [] + for i in range(15): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + await decoded_r.hset( + f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes() + ) + + query = Query("*=>[KNN 3 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = await decoded_r.ft().search(query, query_params=query_params) + if is_resp2_connection(decoded_r): + assert res.total == 3 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"] + + +class TestHybridSearch(AsyncSearchTestsBase): + async def _create_hybrid_search_index(self, decoded_r: redis.Redis, dim=4): + await decoded_r.ft().create_index( + ( + TextField("description"), + NumericField("price"), + TagField("color"), + TagField("item_type"), + NumericField("size"), + VectorField( + "embedding", + "FLAT", + { + "TYPE": "FLOAT32", + "DIM": dim, + "DISTANCE_METRIC": "L2", + }, + ), + VectorField( + "embedding-hnsw", + "HNSW", + { + "TYPE": "FLOAT32", + "DIM": dim, + "DISTANCE_METRIC": "L2", + }, + ), + ), + definition=IndexDefinition(prefix=["item:"]), + ) + await AsyncSearchTestsBase.waitForIndex(decoded_r, "idx") + + @staticmethod + def _generate_random_vector(dim): + return [random.random() for _ in range(dim)] + + @staticmethod + def _generate_random_str_data(dim): + chars = "abcdefgh12345678" + return "".join(random.choice(chars) for _ in range(dim)) + + @staticmethod + async def _add_data_for_hybrid_search( + client: redis.Redis, + items_sets=1, + randomize_data=False, + dim_for_random_data=4, + use_random_str_data=False, + ): + if randomize_data or use_random_str_data: + generate_data_func = ( + TestHybridSearch._generate_random_str_data + if use_random_str_data + else TestHybridSearch._generate_random_vector + ) + + dim_for_random_data = ( + dim_for_random_data * 4 if use_random_str_data else dim_for_random_data + ) + + items = [ + (generate_data_func(dim_for_random_data), "red shoes"), + (generate_data_func(dim_for_random_data), "green shoes with red laces"), + (generate_data_func(dim_for_random_data), "red dress"), + (generate_data_func(dim_for_random_data), "orange dress"), + (generate_data_func(dim_for_random_data), "black shoes"), + ] + else: + items = [ + ([1.0, 2.0, 7.0, 8.0], "red shoes"), + ([1.0, 4.0, 7.0, 8.0], "green shoes with red laces"), + ([1.0, 2.0, 6.0, 5.0], "red dress"), + ([2.0, 3.0, 6.0, 5.0], "orange dress"), + ([5.0, 6.0, 7.0, 8.0], "black shoes"), + ] + items = items * items_sets + pipeline = client.pipeline() + for i, vec in enumerate(items): + vec, description = vec + mapping = { + "description": description, + "embedding": np.array(vec, dtype=np.float32).tobytes() + if not use_random_str_data + else vec, + "embedding-hnsw": np.array(vec, dtype=np.float32).tobytes() + if not use_random_str_data + else vec, + "price": 15 + i % 4, + "color": description.split(" ")[0], + "item_type": description.split(" ")[1], + "size": 10 + i % 3, + } + pipeline.hset(f"item:{i}", mapping=mapping) + await pipeline.execute() # Execute all at once + + @staticmethod + def _convert_dict_values_to_str(list_of_dicts): + res = [] + for d in list_of_dicts: + res_dict = {} + for k, v in d.items(): + if isinstance(v, list): + res_dict[k] = [safe_str(x) for x in v] + else: + res_dict[k] = safe_str(v) + res.append(res_dict) + return res + + @staticmethod + def compare_list_of_dicts(actual, expected): + assert len(actual) == len(expected), ( + f"List of dicts length mismatch: {len(actual)} != {len(expected)}. " + f"Full dicts: actual:{actual}; expected:{expected}" + ) + for expected_dict_item in expected: + found = False + for actual_dict_item in actual: + if actual_dict_item == expected_dict_item: + found = True + break + if not found: + assert False, ( + f"Dict {expected_dict_item} not found in actual list of dicts: {actual}. " + f"All expected:{expected}" + ) + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + async def test_basic_hybrid_search(self, decoded_r): + # Create index and add data + await self._create_hybrid_search_index(decoded_r) + await self._add_data_for_hybrid_search(decoded_r, items_sets=5) + + # set search query + search_query = HybridSearchQuery("@color:{red} @color:{green}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([-100, -200, -200, -300], dtype=np.float32).tobytes(), + ) + + hybrid_query = HybridQuery(search_query, vsim_query) + + res = await decoded_r.ft().hybrid_search(query=hybrid_query) + + # the default results count limit is 10 + if is_resp2_connection(decoded_r): + assert res.total_results == 10 + assert len(res.results) == 10 + assert res.warnings == [] + assert res.execution_time > 0 + assert all(isinstance(res.results[i]["__score"], bytes) for i in range(10)) + assert all(isinstance(res.results[i]["__key"], bytes) for i in range(10)) + else: + assert res["total_results"] == 10 + assert len(res["results"]) == 10 + assert res["warnings"] == [] + assert res["execution_time"] > 0 + assert all(isinstance(res["results"][i]["__score"], str) for i in range(10)) + assert all(isinstance(res["results"][i]["__key"], str) for i in range(10)) + + @pytest.mark.redismod + # @pytest.mark.timeout(900) + @skip_if_server_version_lt("8.3.224") + async def test_hybrid_search_query_with_scorer(self, decoded_r): + # Create index and add data + await self._create_hybrid_search_index(decoded_r) + await self._add_data_for_hybrid_search(decoded_r, items_sets=10) + + # set search query + search_query = HybridSearchQuery("shoes") + search_query.scorer("TFIDF") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([1, 2, 2, 3], dtype=np.float32).tobytes(), + ) + + hybrid_query = HybridQuery(search_query, vsim_query) + + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.combine("LINEAR", ALPHA=1, BETA=0) + posprocessing_config.load( + "@description", "@color", "@price", "@size", "@__score", "@__item" + ) + posprocessing_config.limit(0, 2) + + res = await decoded_r.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) + + expected_results_tfidf = [ + { + "description": b"red shoes", + "color": b"red", + "price": b"15", + "size": b"10", + "__score": b"2", + }, + { + "description": b"green shoes with red laces", + "color": b"green", + "price": b"16", + "size": b"11", + "__score": b"2", + }, + ] + + if is_resp2_connection(decoded_r): + assert res.total_results >= 2 + assert len(res.results) == 2 + assert res.results == expected_results_tfidf + assert res.warnings == [] + else: + assert res["total_results"] >= 2 + assert len(res["results"]) == 2 + assert res["results"] == self._convert_dict_values_to_str( + expected_results_tfidf + ) + assert res["warnings"] == [] + + search_query.scorer("BM25") + res = await decoded_r.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) + expected_results_bm25 = [ + { + "description": b"red shoes", + "color": b"red", + "price": b"15", + "size": b"10", + "__score": b"0.657894719299", + }, + { + "description": b"green shoes with red laces", + "color": b"green", + "price": b"16", + "size": b"11", + "__score": b"0.657894719299", + }, + ] + if is_resp2_connection(decoded_r): + assert res.total_results >= 2 + assert len(res.results) == 2 + assert res.results == expected_results_bm25 + assert res.warnings == [] + else: + assert res["total_results"] >= 2 + assert len(res["results"]) == 2 + assert res["results"] == self._convert_dict_values_to_str( + expected_results_bm25 + ) + assert res["warnings"] == [] + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + async def test_hybrid_search_query_with_vsim_filter(self, decoded_r): + # Create index and add data + await self._create_hybrid_search_index(decoded_r) + await self._add_data_for_hybrid_search( + decoded_r, items_sets=5, use_random_str_data=True + ) + + search_query = HybridSearchQuery("@color:{missing}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data="abcd1234efgh5678", + ) + vsim_query.filter(HybridFilter("@price:[15 16] @size:[10 11]")) + + hybrid_query = HybridQuery(search_query, vsim_query) + + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.load("@price", "@size") + + res = await decoded_r.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) + if is_resp2_connection(decoded_r): + assert len(res.results) > 0 + assert res.warnings == [] + for item in res.results: + assert item["price"] in [b"15", b"16"] + assert item["size"] in [b"10", b"11"] + else: + assert len(res["results"]) > 0 + assert res["warnings"] == [] + for item in res["results"]: + assert item["price"] in ["15", "16"] + assert item["size"] in ["10", "11"] + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + async def test_hybrid_search_query_with_vsim_knn(self, decoded_r): + # Create index and add data + await self._create_hybrid_search_index(decoded_r) + await self._add_data_for_hybrid_search(decoded_r, items_sets=10) + + # set search query + # this query won't have results, so we will be able to validate vsim results + search_query = HybridSearchQuery("@color:{none}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([1, 2, 2, 3], dtype=np.float32).tobytes(), + ) + + vsim_query.vsim_method_params("KNN", K=3) + + hybrid_query = HybridQuery(search_query, vsim_query) + + posprocessing_config = HybridPostProcessingConfig() + + res = await decoded_r.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) + expected_results = [ + {"__key": b"item:2", "__score": b"0.016393442623"}, + {"__key": b"item:7", "__score": b"0.0161290322581"}, + {"__key": b"item:12", "__score": b"0.015873015873"}, + ] + if is_resp2_connection(decoded_r): + assert res.total_results == 3 # KNN top-k value + assert len(res.results) == 3 + assert res.results == expected_results + assert res.warnings == [] + assert res.execution_time > 0 + else: + assert res["total_results"] == 3 # KNN top-k value + assert len(res["results"]) == 3 + assert res["results"] == self._convert_dict_values_to_str(expected_results) + assert res["warnings"] == [] + assert res["execution_time"] > 0 + + vsim_query_with_hnsw = HybridVsimQuery( + vector_field_name="@embedding-hnsw", + vector_data=np.array([1, 2, 2, 3], dtype=np.float32).tobytes(), + ) + vsim_query_with_hnsw.vsim_method_params("KNN", K=3, EF_RUNTIME=1) + hybrid_query_with_hnsw = HybridQuery(search_query, vsim_query_with_hnsw) + + res2 = await decoded_r.ft().hybrid_search( + query=hybrid_query_with_hnsw, timeout=10 + ) + + expected_results2 = [ + {"__key": b"item:12", "__score": b"0.016393442623"}, + {"__key": b"item:22", "__score": b"0.0161290322581"}, + {"__key": b"item:27", "__score": b"0.015873015873"}, + ] + if is_resp2_connection(decoded_r): + assert res2.total_results == 3 # KNN top-k value + assert len(res2.results) == 3 + assert res2.results == expected_results2 + assert res2.warnings == [] + assert res2.execution_time > 0 + else: + assert res2["total_results"] == 3 # KNN top-k value + assert len(res2["results"]) == 3 + assert res2["results"] == self._convert_dict_values_to_str( + expected_results2 + ) + assert res2["warnings"] == [] + assert res2["execution_time"] > 0 + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + async def test_hybrid_search_query_with_vsim_range(self, decoded_r): + # Create index and add data + await self._create_hybrid_search_index(decoded_r) + await self._add_data_for_hybrid_search(decoded_r, items_sets=10) + + # set search query + # this query won't have results, so we will be able to validate vsim results + search_query = HybridSearchQuery("@color:{none}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), + ) + + vsim_query.vsim_method_params("RANGE", RADIUS=2) + + hybrid_query = HybridQuery(search_query, vsim_query) + + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.limit(0, 3) + + res = await decoded_r.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) + + expected_results = [ + {"__key": b"item:2", "__score": b"0.016393442623"}, + {"__key": b"item:7", "__score": b"0.0161290322581"}, + {"__key": b"item:12", "__score": b"0.015873015873"}, + ] + if is_resp2_connection(decoded_r): + assert res.total_results >= 3 # at least 3 results + assert len(res.results) == 3 + assert res.results == expected_results + assert res.warnings == [] + assert res.execution_time > 0 + else: + assert res["total_results"] >= 3 + assert len(res["results"]) == 3 + assert res["results"] == self._convert_dict_values_to_str(expected_results) + assert res["warnings"] == [] + assert res["execution_time"] > 0 + + vsim_query_with_hnsw = HybridVsimQuery( + vector_field_name="@embedding-hnsw", + vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), + ) + + vsim_query_with_hnsw.vsim_method_params("RANGE", RADIUS=2, EPSILON=0.5) + + hybrid_query_with_hnsw = HybridQuery(search_query, vsim_query_with_hnsw) + + res = await decoded_r.ft().hybrid_search( + query=hybrid_query_with_hnsw, + post_processing=posprocessing_config, + timeout=10, + ) + + expected_results_hnsw = [ + {"__key": b"item:27", "__score": b"0.016393442623"}, + {"__key": b"item:12", "__score": b"0.0161290322581"}, + {"__key": b"item:22", "__score": b"0.015873015873"}, + ] + + if is_resp2_connection(decoded_r): + assert res.total_results >= 3 + assert len(res.results) == 3 + assert res.results == expected_results_hnsw + assert res.warnings == [] + assert res.execution_time > 0 + else: + assert res["total_results"] >= 3 + assert len(res["results"]) == 3 + assert res["results"] == self._convert_dict_values_to_str( + expected_results_hnsw + ) + assert res["warnings"] == [] + assert res["execution_time"] > 0 + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + async def test_hybrid_search_query_with_combine(self, decoded_r): + # Create index and add data + await self._create_hybrid_search_index(decoded_r) + await self._add_data_for_hybrid_search(decoded_r, items_sets=10) + + # set search query + search_query = HybridSearchQuery("@color:{red}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), + ) + + hybrid_query = HybridQuery(search_query, vsim_query) + + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.combine("LINEAR", ALPHA=0.5, BETA=0.5) + posprocessing_config.limit(0, 3) + + res = await decoded_r.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) + + expected_results = [ + {"__key": b"item:2", "__score": b"0.166666666667"}, + {"__key": b"item:7", "__score": b"0.166666666667"}, + {"__key": b"item:12", "__score": b"0.166666666667"}, + ] + if is_resp2_connection(decoded_r): + assert res.total_results >= 3 + assert len(res.results) == 3 + assert res.results == expected_results + assert res.warnings == [] + assert res.execution_time > 0 + else: + assert res["total_results"] >= 3 + assert len(res["results"]) == 3 + assert res["results"] == self._convert_dict_values_to_str(expected_results) + assert res["warnings"] == [] + assert res["execution_time"] > 0 + + # combine with RRF and WINDOW + CONSTANT + posprocessing_config.combine("RRF", WINDOW=3, CONSTANT=0.5) + res = await decoded_r.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) + + expected_results = [ + {"__key": b"item:2", "__score": b"1.06666666667"}, + {"__key": b"item:0", "__score": b"0.666666666667"}, + {"__key": b"item:7", "__score": b"0.4"}, + ] + if is_resp2_connection(decoded_r): + assert res.total_results >= 3 + assert len(res.results) == 3 + assert res.results == expected_results + assert res.warnings == [] + assert res.execution_time > 0 + else: + assert res["total_results"] >= 3 + assert len(res["results"]) == 3 + assert res["results"] == self._convert_dict_values_to_str(expected_results) + assert res["warnings"] == [] + assert res["execution_time"] > 0 + + # # LINEAR combine with no params + # posprocessing_config.combine("LINEAR", ALPHA=0.5) + # res = client.ft().hybrid_search( + # query=hybrid_query, post_processing=posprocessing_config, timeout=10 + # ) + + # expected_results = [ + # {"__key": b"item:2", "__score": b"0.166666666667"}, + # {"__key": b"item:7", "__score": b"0.166666666667"}, + # {"__key": b"item:12", "__score": b"0.166666666667"}, + # ] + # if is_resp2_connection(client): + # assert res.total_results >= 3 + # assert len(res.results) == 3 + # assert res.results == expected_results + # assert res.warnings == [] + # assert res.execution_time > 0 + # else: + # assert res["total_results"] >= 3 + # assert len(res["results"]) == 3 + # assert res["results"] == self._convert_dict_values_to_str(expected_results) + # assert res["warnings"] == [] + # assert res["execution_time"] > 0 + + # combine with RRF, not all possible params provided + posprocessing_config.combine("RRF", WINDOW=3) + res = await decoded_r.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) + + expected_results = [ + {"__key": b"item:2", "__score": b"0.032522474881"}, + {"__key": b"item:0", "__score": b"0.016393442623"}, + {"__key": b"item:7", "__score": b"0.0161290322581"}, + ] + if is_resp2_connection(decoded_r): + assert res.total_results >= 3 + assert len(res.results) == 3 + assert res.results == expected_results + assert res.warnings == [] + assert res.execution_time > 0 + else: + assert res["total_results"] >= 3 + assert len(res["results"]) == 3 + assert res["results"] == self._convert_dict_values_to_str(expected_results) + assert res["warnings"] == [] + assert res["execution_time"] > 0 + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + async def test_hybrid_search_query_with_load(self, decoded_r): + # Create index and add data + await self._create_hybrid_search_index(decoded_r) + await self._add_data_for_hybrid_search(decoded_r, items_sets=10) + + # set search query + search_query = HybridSearchQuery("@color:{red|green|black}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), + ) + + hybrid_query = HybridQuery(search_query, vsim_query) + + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.combine("LINEAR", ALPHA=0.5, BETA=0.5) + posprocessing_config.load( + "@description", "@color", "@price", "@size", "@__key AS item_key" + ) + posprocessing_config.limit(0, 1) + + res = await decoded_r.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) + + expected_results = [ + { + "description": b"red dress", + "color": b"red", + "price": b"17", + "size": b"12", + "item_key": b"item:2", + } + ] + if is_resp2_connection(decoded_r): + assert res.total_results >= 1 + assert len(res.results) == 1 + self.compare_list_of_dicts(res.results, expected_results) + assert res.warnings == [] + assert res.execution_time > 0 + else: + assert res["total_results"] >= 1 + assert len(res["results"]) == 1 + self.compare_list_of_dicts( + res["results"], self._convert_dict_values_to_str(expected_results) + ) + assert res["warnings"] == [] + assert res["execution_time"] > 0 + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + async def test_hybrid_search_query_with_load_and_apply(self, decoded_r): + # Create index and add data + await self._create_hybrid_search_index(decoded_r) + await self._add_data_for_hybrid_search(decoded_r, items_sets=10) + + # set search query + search_query = HybridSearchQuery("@color:{red}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), + ) + + hybrid_query = HybridQuery(search_query, vsim_query) + + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.load("@color", "@price", "@size") + posprocessing_config.apply( + price_discount="@price - (@price * 0.1)", + tax_discount="@price_discount * 0.2", + ) + posprocessing_config.limit(0, 3) + + res = await decoded_r.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) + + expected_results = [ + { + "color": b"red", + "price": b"15", + "size": b"10", + "price_discount": b"13.5", + "tax_discount": b"2.7", + }, + { + "color": b"red", + "price": b"17", + "size": b"12", + "price_discount": b"15.3", + "tax_discount": b"3.06", + }, + { + "color": b"red", + "price": b"18", + "size": b"11", + "price_discount": b"16.2", + "tax_discount": b"3.24", + }, + ] + if is_resp2_connection(decoded_r): + assert len(res.results) == 3 + self.compare_list_of_dicts(res.results, expected_results) + assert res.warnings == [] + assert res.execution_time > 0 + else: + assert len(res["results"]) == 3 + self.compare_list_of_dicts( + res["results"], self._convert_dict_values_to_str(expected_results) + ) + assert res["warnings"] == [] + assert res["execution_time"] > 0 + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + async def test_hybrid_search_query_with_load_and_filter(self, decoded_r): + # Create index and add data + await self._create_hybrid_search_index(decoded_r) + await self._add_data_for_hybrid_search(decoded_r, items_sets=10) + + # set search query + search_query = HybridSearchQuery("@color:{red|green|black}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), + ) + + hybrid_query = HybridQuery(search_query, vsim_query) + + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.load("@description", "@color", "@price", "@size") + # for the postprocessing filter we need to filter on the loaded fields + # expecting all of them to be interpreted as strings - the initial filed types + # are not preserved + posprocessing_config.filter(HybridFilter('@price=="15"')) + posprocessing_config.limit(0, 3) + + res = await decoded_r.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) + + if is_resp2_connection(decoded_r): + assert len(res.results) == 3 + for item in res.results: + assert item["price"] == b"15" + assert res.warnings == [] + assert res.execution_time > 0 + else: + assert len(res["results"]) == 3 + for item in res["results"]: + assert item["price"] == "15" + assert res["warnings"] == [] + assert res["execution_time"] > 0 + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + async def test_hybrid_search_query_with_load_apply_and_params(self, decoded_r): + # Create index and add data + await self._create_hybrid_search_index(decoded_r) + await self._add_data_for_hybrid_search( + decoded_r, items_sets=5, use_random_str_data=True + ) + + # set search query + search_query = HybridSearchQuery("@color:{$color_criteria}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data="$vector", + ) + + hybrid_query = HybridQuery(search_query, vsim_query) + + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.load("@description", "@color", "@price") + posprocessing_config.apply(price_discount="@price - (@price * 0.1)") + posprocessing_config.limit(0, 3) + + params_substitution = { + "vector": "abcd1234abcd5678", + "color_criteria": "red", + } + + res = await decoded_r.ft().hybrid_search( + query=hybrid_query, + post_processing=posprocessing_config, + params_substitution=params_substitution, + timeout=10, + ) + + expected_results = [ + { + "description": b"red shoes", + "color": b"red", + "price": b"15", + "price_discount": b"13.5", + }, + { + "description": b"red dress", + "color": b"red", + "price": b"17", + "price_discount": b"15.3", + }, + { + "description": b"red shoes", + "color": b"red", + "price": b"16", + "price_discount": b"14.4", + }, + ] + if is_resp2_connection(decoded_r): + assert len(res.results) == 3 + assert res.results == expected_results + assert res.warnings == [] + assert res.execution_time > 0 + else: + assert len(res["results"]) == 3 + assert res["results"] == self._convert_dict_values_to_str(expected_results) + assert res["warnings"] == [] + assert res["execution_time"] > 0 + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + async def test_hybrid_search_query_with_limit(self, decoded_r): + # Create index and add data + await self._create_hybrid_search_index(decoded_r) + await self._add_data_for_hybrid_search(decoded_r, items_sets=10) + + # set search query + search_query = HybridSearchQuery("@color:{red}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), + ) + + hybrid_query = HybridQuery(search_query, vsim_query) + + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.limit(0, 3) + + res = await decoded_r.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) + + if is_resp2_connection(decoded_r): + assert len(res.results) == 3 + assert res.warnings == [] + else: + assert len(res["results"]) == 3 + assert res["warnings"] == [] + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + async def test_hybrid_search_query_with_load_apply_and_sortby(self, decoded_r): + # Create index and add data + await self._create_hybrid_search_index(decoded_r) + await self._add_data_for_hybrid_search(decoded_r, items_sets=1) + + # set search query + search_query = HybridSearchQuery("@color:{red|green}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), + ) + + hybrid_query = HybridQuery(search_query, vsim_query) + + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.load("@color", "@price") + posprocessing_config.apply(price_discount="@price - (@price * 0.1)") + posprocessing_config.sort_by( + SortbyField("@price_discount", asc=False), SortbyField("@color", asc=True) + ) + posprocessing_config.limit(0, 5) + + res = await decoded_r.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) + + expected_results = [ + {"color": b"orange", "price": b"18", "price_discount": b"16.2"}, + {"color": b"red", "price": b"17", "price_discount": b"15.3"}, + {"color": b"green", "price": b"16", "price_discount": b"14.4"}, + {"color": b"black", "price": b"15", "price_discount": b"13.5"}, + {"color": b"red", "price": b"15", "price_discount": b"13.5"}, + ] + if is_resp2_connection(decoded_r): + assert res.total_results >= 5 + assert len(res.results) == 5 + # the order here should match because of the sort + assert res.results == expected_results + assert res.warnings == [] + assert res.execution_time > 0 + else: + assert res["total_results"] >= 5 + assert len(res["results"]) == 5 + # the order here should match because of the sort + assert res["results"] == self._convert_dict_values_to_str(expected_results) + assert res["warnings"] == [] + assert res["execution_time"] > 0 + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + async def test_hybrid_search_query_with_timeout(self, decoded_r): + dim = 128 + # Create index and add data + await self._create_hybrid_search_index(decoded_r, dim=dim) + await self._add_data_for_hybrid_search( + decoded_r, + items_sets=5000, + dim_for_random_data=dim, + use_random_str_data=True, + ) + + # set search query + search_query = HybridSearchQuery("*") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding-hnsw", + vector_data="abcd" * dim, + ) + vsim_query.vsim_method_params("KNN", K=1000) + vsim_query.filter( + HybridFilter( + "((@price:[15 16] @size:[10 11]) | (@price:[13 15] @size:[11 12])) @description:(shoes) -@description:(green)" + ) + ) + + hybrid_query = HybridQuery(search_query, vsim_query) + + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.combine("RRF", WINDOW=1000) + + timeout = 5000 # 5 second timeout + res = await decoded_r.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=timeout + ) + + if is_resp2_connection(decoded_r): + assert len(res.results) > 0 + assert res.warnings == [] + assert res.execution_time > 0 and res.execution_time < timeout + else: + assert len(res["results"]) > 0 + assert res["warnings"] == [] + assert res["execution_time"] > 0 and res["execution_time"] < timeout + + res = await decoded_r.ft().hybrid_search( + query=hybrid_query, timeout=1 + ) # 1 ms timeout + if is_resp2_connection(decoded_r): + assert ( + b"Timeout limit was reached (VSIM)" in res.warnings + or b"Timeout limit was reached (SEARCH)" in res.warnings + ) + else: + assert ( + "Timeout limit was reached (VSIM)" in res["warnings"] + or "Timeout limit was reached (SEARCH)" in res["warnings"] + ) + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + async def test_hybrid_search_query_with_load_and_groupby(self, decoded_r): + # Create index and add data + await self._create_hybrid_search_index(decoded_r) + await self._add_data_for_hybrid_search(decoded_r, items_sets=10) + + # set search query + search_query = HybridSearchQuery("@color:{red|green}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), + ) + + hybrid_query = HybridQuery(search_query, vsim_query) + + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.load("@color", "@price", "@size", "@item_type") + posprocessing_config.limit(0, 4) + + posprocessing_config.group_by( + ["@price"], + reducers.count_distinct("@color").alias("colors_count"), + ) + + posprocessing_config.sort_by(SortbyField("@price", asc=True)) + + res = await decoded_r.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) + + expected_results = [ + {"price": b"15", "colors_count": b"2"}, + {"price": b"16", "colors_count": b"2"}, + {"price": b"17", "colors_count": b"2"}, + {"price": b"18", "colors_count": b"2"}, + ] + + if is_resp2_connection(decoded_r): + assert len(res.results) == 4 + assert res.results == expected_results + assert res.warnings == [] + else: + assert len(res["results"]) == 4 + assert res["results"] == self._convert_dict_values_to_str(expected_results) + assert res["warnings"] == [] + + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.load("@color", "@price", "@size", "@item_type") + posprocessing_config.limit(0, 6) + posprocessing_config.sort_by( + SortbyField("@price", asc=True), + SortbyField("@item_type", asc=True), + ) + + posprocessing_config.group_by( + ["@price", "@item_type"], + reducers.count_distinct("@color").alias("unique_colors_count"), + ) + + res = await decoded_r.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=1000 + ) + + expected_results = [ + {"price": b"15", "item_type": b"dress", "unique_colors_count": b"1"}, + {"price": b"15", "item_type": b"shoes", "unique_colors_count": b"2"}, + {"price": b"16", "item_type": b"dress", "unique_colors_count": b"1"}, + {"price": b"16", "item_type": b"shoes", "unique_colors_count": b"2"}, + {"price": b"17", "item_type": b"dress", "unique_colors_count": b"1"}, + {"price": b"17", "item_type": b"shoes", "unique_colors_count": b"2"}, + ] + if is_resp2_connection(decoded_r): + assert len(res.results) == 6 + assert res.results == expected_results + assert res.warnings == [] + else: + assert len(res["results"]) == 6 + assert res["results"] == self._convert_dict_values_to_str(expected_results) + assert res["warnings"] == [] + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + async def test_hybrid_search_query_with_cursor(self, decoded_r): + # Create index and add data + await self._create_hybrid_search_index(decoded_r) + await self._add_data_for_hybrid_search(decoded_r, items_sets=10) + + # set search query + search_query = HybridSearchQuery("@color:{red|green}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), + ) + + hybrid_query = HybridQuery(search_query, vsim_query) + + res = await decoded_r.ft().hybrid_search( + query=hybrid_query, + cursor=HybridCursorQuery(count=5, max_idle=100), + timeout=10, + ) + if is_resp2_connection(decoded_r): + assert isinstance(res, HybridCursorResult) + assert res.search_cursor_id > 0 + assert res.vsim_cursor_id > 0 + search_cursor = aggregations.Cursor(res.search_cursor_id) + vsim_cursor = aggregations.Cursor(res.vsim_cursor_id) + else: + assert res["SEARCH"] > 0 + assert res["VSIM"] > 0 + search_cursor = aggregations.Cursor(res["SEARCH"]) + vsim_cursor = aggregations.Cursor(res["VSIM"]) + + search_res_from_cursor = await decoded_r.ft().aggregate(query=search_cursor) + if is_resp2_connection(decoded_r): + assert len(search_res_from_cursor.rows) == 5 + else: + assert len(search_res_from_cursor[0]["results"]) == 5 + + vsim_res_from_cursor = await decoded_r.ft().aggregate(query=vsim_cursor) + if is_resp2_connection(decoded_r): + assert len(vsim_res_from_cursor.rows) == 5 + else: + assert len(vsim_res_from_cursor[0]["results"]) == 5 diff --git a/tests/test_search.py b/tests/test_search.py index b76baf10aa..25d86d0f36 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1,6 +1,7 @@ import bz2 import csv import os +import random import time from io import TextIOWrapper @@ -8,8 +9,17 @@ import pytest from redis import ResponseError import redis -import redis.commands.search + import redis.commands.search.aggregation as aggregations +from redis.commands.search.hybrid_query import ( + HybridCursorQuery, + HybridFilter, + HybridPostProcessingConfig, + HybridQuery, + HybridSearchQuery, + HybridVsimQuery, +) +from redis.commands.search.hybrid_result import HybridCursorResult import redis.commands.search.reducers as reducers from redis.commands.json.path import Path from redis.commands.search import Search @@ -22,9 +32,15 @@ VectorField, ) from redis.commands.search.index_definition import IndexDefinition, IndexType -from redis.commands.search.query import GeoFilter, NumericFilter, Query +from redis.commands.search.query import ( + GeoFilter, + NumericFilter, + Query, + SortbyField, +) from redis.commands.search.result import Result from redis.commands.search.suggestion import Suggestion +from redis.utils import safe_str from .conftest import ( _get_client, @@ -45,3824 +61,4983 @@ ) -def waitForIndex(env, idx, timeout=None): - delay = 0.1 - while True: - try: - res = env.execute_command("FT.INFO", idx) - if int(res[res.index("indexing") + 1]) == 0: - break - except ValueError: - break - except AttributeError: +def _assert_search_result(client, result, expected_doc_ids): + """ + Make sure the result of a geo search is as expected, taking into account the RESP + version being used. + """ + if is_resp2_connection(client): + assert set([doc.id for doc in result.docs]) == set(expected_doc_ids) + else: + assert set([doc["id"] for doc in result["results"]]) == set(expected_doc_ids) + + +class SearchTestsBase: + @staticmethod + def waitForIndex(env, idx, timeout=None): + delay = 0.1 + while True: try: - if int(res["indexing"]) == 0: + res = env.execute_command("FT.INFO", idx) + if int(res[res.index("indexing") + 1]) == 0: break except ValueError: break - except ResponseError: - # index doesn't exist yet - # continue to sleep and try again - pass - - time.sleep(delay) - if timeout is not None: - timeout -= delay - if timeout <= 0: + except AttributeError: + try: + if int(res["indexing"]) == 0: + break + except ValueError: + break + except ResponseError: + # index doesn't exist yet + # continue to sleep and try again + pass + + time.sleep(delay) + if timeout is not None: + timeout -= delay + if timeout <= 0: + break + + @staticmethod + def getClient(client): + """ + Gets a client client attached to an index name which is ready to be + created + """ + return client + + @staticmethod + def createIndex(client, num_docs=100, definition=None): + try: + client.create_index( + ( + TextField("play", weight=5.0), + TextField("txt"), + NumericField("chapter"), + ), + definition=definition, + ) + except redis.ResponseError: + client.dropindex(delete_documents=True) + return SearchTestsBase.createIndex( + client, num_docs=num_docs, definition=definition + ) + + chapters = {} + bzfp = TextIOWrapper(bz2.BZ2File(WILL_PLAY_TEXT), encoding="utf8") + + r = csv.reader(bzfp, delimiter=";") + for n, line in enumerate(r): + play, chapter, _, text = line[1], line[2], line[4], line[5] + + key = f"{play}:{chapter}".lower() + d = chapters.setdefault(key, {}) + d["play"] = play + d["txt"] = d.get("txt", "") + " " + text + d["chapter"] = int(chapter or 0) + if len(chapters) == num_docs: break + indexer = client.batch_indexer(chunk_size=50) + assert isinstance(indexer, Search.BatchIndexer) + assert 50 == indexer.chunk_size -def getClient(client): - """ - Gets a client client attached to an index name which is ready to be - created - """ - return client + for key, doc in chapters.items(): + indexer.client.client.hset(key, mapping=doc) + indexer.commit() + @pytest.fixture + def client(self, request, stack_url): + r = _get_client(redis.Redis, request, decode_responses=True, from_url=stack_url) + r.flushdb() + return r -def createIndex(client, num_docs=100, definition=None): - try: - client.create_index( - (TextField("play", weight=5.0), TextField("txt"), NumericField("chapter")), - definition=definition, - ) - except redis.ResponseError: - client.dropindex(delete_documents=True) - return createIndex(client, num_docs=num_docs, definition=definition) - - chapters = {} - bzfp = TextIOWrapper(bz2.BZ2File(WILL_PLAY_TEXT), encoding="utf8") - - r = csv.reader(bzfp, delimiter=";") - for n, line in enumerate(r): - play, chapter, _, text = line[1], line[2], line[4], line[5] - - key = f"{play}:{chapter}".lower() - d = chapters.setdefault(key, {}) - d["play"] = play - d["txt"] = d.get("txt", "") + " " + text - d["chapter"] = int(chapter or 0) - if len(chapters) == num_docs: - break - - indexer = client.batch_indexer(chunk_size=50) - assert isinstance(indexer, Search.BatchIndexer) - assert 50 == indexer.chunk_size - - for key, doc in chapters.items(): - indexer.client.client.hset(key, mapping=doc) - indexer.commit() - - -@pytest.fixture -def client(request, stack_url): - r = _get_client(redis.Redis, request, decode_responses=True, from_url=stack_url) - r.flushdb() - return r - - -@pytest.mark.redismod -def test_client(client): - num_docs = 500 - createIndex(client.ft(), num_docs=num_docs) - waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - # verify info - info = client.ft().info() - for k in [ - "index_name", - "index_options", - "attributes", - "num_docs", - "max_doc_id", - "num_terms", - "num_records", - "inverted_sz_mb", - "offset_vectors_sz_mb", - "doc_table_size_mb", - "key_table_size_mb", - "records_per_doc_avg", - "bytes_per_record_avg", - "offsets_per_term_avg", - "offset_bits_per_record_avg", - ]: - assert k in info - - assert client.ft().index_name == info["index_name"] - assert num_docs == int(info["num_docs"]) - - res = client.ft().search("henry iv") - if is_resp2_connection(client): - assert isinstance(res, Result) - assert 225 == res.total - assert 10 == len(res.docs) - assert res.duration > 0 - - for doc in res.docs: - assert doc.id - assert doc["id"] - assert doc.play == "Henry IV" - assert doc["play"] == "Henry IV" - assert len(doc.txt) > 0 - # test no content - res = client.ft().search(Query("king").no_content()) - assert 194 == res.total - assert 10 == len(res.docs) - for doc in res.docs: - assert "txt" not in doc.__dict__ - assert "play" not in doc.__dict__ - - # test verbatim vs no verbatim - total = client.ft().search(Query("kings").no_content()).total - vtotal = client.ft().search(Query("kings").no_content().verbatim()).total - assert total > vtotal - - # test in fields - txt_total = ( - client.ft().search(Query("henry").no_content().limit_fields("txt")).total - ) - play_total = ( - client.ft().search(Query("henry").no_content().limit_fields("play")).total - ) - both_total = ( - client.ft() - .search(Query("henry").no_content().limit_fields("play", "txt")) - .total - ) - assert 129 == txt_total - assert 494 == play_total - assert 494 == both_total - - # test load_document - doc = client.ft().load_document("henry vi part 3:62") - assert doc is not None - assert "henry vi part 3:62" == doc.id - assert doc.play == "Henry VI Part 3" - assert len(doc.txt) > 0 - - # test in-keys - ids = [x.id for x in client.ft().search(Query("henry")).docs] - assert 10 == len(ids) - subset = ids[:5] - docs = client.ft().search(Query("henry").limit_ids(*subset)) - assert len(subset) == docs.total - ids = [x.id for x in docs.docs] - assert set(ids) == set(subset) - - # test slop and in order - assert 193 == client.ft().search(Query("henry king")).total - assert 3 == client.ft().search(Query("henry king").slop(0).in_order()).total - assert 52 == client.ft().search(Query("king henry").slop(0).in_order()).total - assert 53 == client.ft().search(Query("henry king").slop(0)).total - assert 167 == client.ft().search(Query("henry king").slop(100)).total - - # test delete document - client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) - res = client.ft().search(Query("death of a salesman")) - assert 1 == res.total - - assert 1 == client.ft().delete_document("doc-5ghs2") - res = client.ft().search(Query("death of a salesman")) - assert 0 == res.total - assert 0 == client.ft().delete_document("doc-5ghs2") - - client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) - res = client.ft().search(Query("death of a salesman")) - assert 1 == res.total - client.ft().delete_document("doc-5ghs2") - else: - assert isinstance(res, dict) - assert 225 == res["total_results"] - assert 10 == len(res["results"]) - - for doc in res["results"]: - assert doc["id"] - assert doc["extra_attributes"]["play"] == "Henry IV" - assert len(doc["extra_attributes"]["txt"]) > 0 - - # test no content - res = client.ft().search(Query("king").no_content()) - assert 194 == res["total_results"] - assert 10 == len(res["results"]) - for doc in res["results"]: - assert "extra_attributes" not in doc.keys() - - # test verbatim vs no verbatim - total = client.ft().search(Query("kings").no_content())["total_results"] - vtotal = client.ft().search(Query("kings").no_content().verbatim())[ - "total_results" - ] - assert total > vtotal +class TestBaseSearchFunctionality(SearchTestsBase): + @pytest.mark.redismod + def test_client(self, client): + num_docs = 500 + self.createIndex(client.ft(), num_docs=num_docs) + self.waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + # verify info + info = client.ft().info() + for k in [ + "index_name", + "index_options", + "attributes", + "num_docs", + "max_doc_id", + "num_terms", + "num_records", + "inverted_sz_mb", + "offset_vectors_sz_mb", + "doc_table_size_mb", + "key_table_size_mb", + "records_per_doc_avg", + "bytes_per_record_avg", + "offsets_per_term_avg", + "offset_bits_per_record_avg", + ]: + assert k in info + + assert client.ft().index_name == info["index_name"] + assert num_docs == int(info["num_docs"]) + + res = client.ft().search("henry iv") + if is_resp2_connection(client): + assert isinstance(res, Result) + assert 225 == res.total + assert 10 == len(res.docs) + assert res.duration > 0 + + for doc in res.docs: + assert doc.id + assert doc["id"] + assert doc.play == "Henry IV" + assert doc["play"] == "Henry IV" + assert len(doc.txt) > 0 + + # test no content + res = client.ft().search(Query("king").no_content()) + assert 194 == res.total + assert 10 == len(res.docs) + for doc in res.docs: + assert "txt" not in doc.__dict__ + assert "play" not in doc.__dict__ + + # test verbatim vs no verbatim + total = client.ft().search(Query("kings").no_content()).total + vtotal = client.ft().search(Query("kings").no_content().verbatim()).total + assert total > vtotal + + # test in fields + txt_total = ( + client.ft() + .search(Query("henry").no_content().limit_fields("txt")) + .total + ) + play_total = ( + client.ft() + .search(Query("henry").no_content().limit_fields("play")) + .total + ) + both_total = ( + client.ft() + .search(Query("henry").no_content().limit_fields("play", "txt")) + .total + ) + assert 129 == txt_total + assert 494 == play_total + assert 494 == both_total + + # test load_document + doc = client.ft().load_document("henry vi part 3:62") + assert doc is not None + assert "henry vi part 3:62" == doc.id + assert doc.play == "Henry VI Part 3" + assert len(doc.txt) > 0 - # test in fields - txt_total = client.ft().search(Query("henry").no_content().limit_fields("txt"))[ - "total_results" - ] - play_total = client.ft().search( - Query("henry").no_content().limit_fields("play") - )["total_results"] - both_total = client.ft().search( - Query("henry").no_content().limit_fields("play", "txt") - )["total_results"] - assert 129 == txt_total - assert 494 == play_total - assert 494 == both_total - - # test load_document - doc = client.ft().load_document("henry vi part 3:62") - assert doc is not None - assert "henry vi part 3:62" == doc.id - assert doc.play == "Henry VI Part 3" - assert len(doc.txt) > 0 - - # test in-keys - ids = [x["id"] for x in client.ft().search(Query("henry"))["results"]] - assert 10 == len(ids) - subset = ids[:5] - docs = client.ft().search(Query("henry").limit_ids(*subset)) - assert len(subset) == docs["total_results"] - ids = [x["id"] for x in docs["results"]] - assert set(ids) == set(subset) - - # test slop and in order - assert 193 == client.ft().search(Query("henry king"))["total_results"] - assert ( - 3 - == client.ft().search(Query("henry king").slop(0).in_order())[ - "total_results" - ] - ) - assert ( - 52 - == client.ft().search(Query("king henry").slop(0).in_order())[ + # test in-keys + ids = [x.id for x in client.ft().search(Query("henry")).docs] + assert 10 == len(ids) + subset = ids[:5] + docs = client.ft().search(Query("henry").limit_ids(*subset)) + assert len(subset) == docs.total + ids = [x.id for x in docs.docs] + assert set(ids) == set(subset) + + # test slop and in order + assert 193 == client.ft().search(Query("henry king")).total + assert 3 == client.ft().search(Query("henry king").slop(0).in_order()).total + assert ( + 52 == client.ft().search(Query("king henry").slop(0).in_order()).total + ) + assert 53 == client.ft().search(Query("henry king").slop(0)).total + assert 167 == client.ft().search(Query("henry king").slop(100)).total + + # test delete document + client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = client.ft().search(Query("death of a salesman")) + assert 1 == res.total + + assert 1 == client.ft().delete_document("doc-5ghs2") + res = client.ft().search(Query("death of a salesman")) + assert 0 == res.total + assert 0 == client.ft().delete_document("doc-5ghs2") + + client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = client.ft().search(Query("death of a salesman")) + assert 1 == res.total + client.ft().delete_document("doc-5ghs2") + else: + assert isinstance(res, dict) + assert 225 == res["total_results"] + assert 10 == len(res["results"]) + + for doc in res["results"]: + assert doc["id"] + assert doc["extra_attributes"]["play"] == "Henry IV" + assert len(doc["extra_attributes"]["txt"]) > 0 + + # test no content + res = client.ft().search(Query("king").no_content()) + assert 194 == res["total_results"] + assert 10 == len(res["results"]) + for doc in res["results"]: + assert "extra_attributes" not in doc.keys() + + # test verbatim vs no verbatim + total = client.ft().search(Query("kings").no_content())["total_results"] + vtotal = client.ft().search(Query("kings").no_content().verbatim())[ "total_results" ] - ) - assert 53 == client.ft().search(Query("henry king").slop(0))["total_results"] - assert 167 == client.ft().search(Query("henry king").slop(100))["total_results"] + assert total > vtotal + + # test in fields + txt_total = client.ft().search( + Query("henry").no_content().limit_fields("txt") + )["total_results"] + play_total = client.ft().search( + Query("henry").no_content().limit_fields("play") + )["total_results"] + both_total = client.ft().search( + Query("henry").no_content().limit_fields("play", "txt") + )["total_results"] + assert 129 == txt_total + assert 494 == play_total + assert 494 == both_total + + # test load_document + doc = client.ft().load_document("henry vi part 3:62") + assert doc is not None + assert "henry vi part 3:62" == doc.id + assert doc.play == "Henry VI Part 3" + assert len(doc.txt) > 0 - # test delete document - client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) - res = client.ft().search(Query("death of a salesman")) - assert 1 == res["total_results"] + # test in-keys + ids = [x["id"] for x in client.ft().search(Query("henry"))["results"]] + assert 10 == len(ids) + subset = ids[:5] + docs = client.ft().search(Query("henry").limit_ids(*subset)) + assert len(subset) == docs["total_results"] + ids = [x["id"] for x in docs["results"]] + assert set(ids) == set(subset) + + # test slop and in order + assert 193 == client.ft().search(Query("henry king"))["total_results"] + assert ( + 3 + == client.ft().search(Query("henry king").slop(0).in_order())[ + "total_results" + ] + ) + assert ( + 52 + == client.ft().search(Query("king henry").slop(0).in_order())[ + "total_results" + ] + ) + assert ( + 53 == client.ft().search(Query("henry king").slop(0))["total_results"] + ) + assert ( + 167 + == client.ft().search(Query("henry king").slop(100))["total_results"] + ) - assert 1 == client.ft().delete_document("doc-5ghs2") - res = client.ft().search(Query("death of a salesman")) - assert 0 == res["total_results"] - assert 0 == client.ft().delete_document("doc-5ghs2") + # test delete document + client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = client.ft().search(Query("death of a salesman")) + assert 1 == res["total_results"] - client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) - res = client.ft().search(Query("death of a salesman")) - assert 1 == res["total_results"] - client.ft().delete_document("doc-5ghs2") + assert 1 == client.ft().delete_document("doc-5ghs2") + res = client.ft().search(Query("death of a salesman")) + assert 0 == res["total_results"] + assert 0 == client.ft().delete_document("doc-5ghs2") + client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = client.ft().search(Query("death of a salesman")) + assert 1 == res["total_results"] + client.ft().delete_document("doc-5ghs2") -@pytest.mark.redismod -@pytest.mark.onlynoncluster -@skip_if_server_version_gte("7.9.0") -def test_scores(client): - client.ft().create_index((TextField("txt"),)) + @pytest.mark.redismod + @pytest.mark.onlynoncluster + @skip_if_server_version_gte("7.9.0") + def test_scores(self, client): + client.ft().create_index((TextField("txt"),)) - client.hset("doc1", mapping={"txt": "foo baz"}) - client.hset("doc2", mapping={"txt": "foo bar"}) + client.hset("doc1", mapping={"txt": "foo baz"}) + client.hset("doc2", mapping={"txt": "foo bar"}) - q = Query("foo ~bar").with_scores() - res = client.ft().search(q) - if is_resp2_connection(client): - assert 2 == res.total - assert "doc2" == res.docs[0].id - assert 3.0 == res.docs[0].score - assert "doc1" == res.docs[1].id - else: - assert 2 == res["total_results"] - assert "doc2" == res["results"][0]["id"] - assert 3.0 == res["results"][0]["score"] - assert "doc1" == res["results"][1]["id"] + q = Query("foo ~bar").with_scores() + res = client.ft().search(q) + if is_resp2_connection(client): + assert 2 == res.total + assert "doc2" == res.docs[0].id + assert 3.0 == res.docs[0].score + assert "doc1" == res.docs[1].id + else: + assert 2 == res["total_results"] + assert "doc2" == res["results"][0]["id"] + assert 3.0 == res["results"][0]["score"] + assert "doc1" == res["results"][1]["id"] + + @pytest.mark.redismod + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.9.0") + def test_scores_with_new_default_scorer(self, client): + client.ft().create_index((TextField("txt"),)) + + client.hset("doc1", mapping={"txt": "foo baz"}) + client.hset("doc2", mapping={"txt": "foo bar"}) + + q = Query("foo ~bar").with_scores() + res = client.ft().search(q) + if is_resp2_connection(client): + assert 2 == res.total + assert "doc2" == res.docs[0].id + assert 0.87 == pytest.approx(res.docs[0].score, 0.01) + assert "doc1" == res.docs[1].id + else: + assert 2 == res["total_results"] + assert "doc2" == res["results"][0]["id"] + assert 0.87 == pytest.approx(res["results"][0]["score"], 0.01) + assert "doc1" == res["results"][1]["id"] + + @pytest.mark.redismod + def test_stopwords(self, client): + client.ft().create_index((TextField("txt"),), stopwords=["foo", "bar", "baz"]) + client.hset("doc1", mapping={"txt": "foo bar"}) + client.hset("doc2", mapping={"txt": "hello world"}) + self.waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + + q1 = Query("foo bar").no_content() + q2 = Query("foo bar hello world").no_content() + res1, res2 = client.ft().search(q1), client.ft().search(q2) + if is_resp2_connection(client): + assert 0 == res1.total + assert 1 == res2.total + else: + assert 0 == res1["total_results"] + assert 1 == res2["total_results"] + + @pytest.mark.redismod + def test_filters(self, client): + client.ft().create_index( + (TextField("txt"), NumericField("num"), GeoField("loc")) + ) + client.hset( + "doc1", mapping={"txt": "foo bar", "num": 3.141, "loc": "-0.441,51.458"} + ) + client.hset("doc2", mapping={"txt": "foo baz", "num": 2, "loc": "-0.1,51.2"}) + + self.waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + # Test numerical filter + q1 = Query("foo").add_filter(NumericFilter("num", 0, 2)).no_content() + q2 = ( + Query("foo") + .add_filter(NumericFilter("num", 2, NumericFilter.INF, minExclusive=True)) + .no_content() + ) + res1, res2 = client.ft().search(q1), client.ft().search(q2) + if is_resp2_connection(client): + assert 1 == res1.total + assert 1 == res2.total + assert "doc2" == res1.docs[0].id + assert "doc1" == res2.docs[0].id + else: + assert 1 == res1["total_results"] + assert 1 == res2["total_results"] + assert "doc2" == res1["results"][0]["id"] + assert "doc1" == res2["results"][0]["id"] + # Test geo filter + q1 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 10)).no_content() + q2 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 100)).no_content() + res1, res2 = client.ft().search(q1), client.ft().search(q2) -@pytest.mark.redismod -@pytest.mark.onlynoncluster -@skip_if_server_version_lt("7.9.0") -def test_scores_with_new_default_scorer(client): - client.ft().create_index((TextField("txt"),)) + if is_resp2_connection(client): + assert 1 == res1.total + assert 2 == res2.total + assert "doc1" == res1.docs[0].id + + # Sort results, after RDB reload order may change + res = [res2.docs[0].id, res2.docs[1].id] + res.sort() + assert ["doc1", "doc2"] == res + else: + assert 1 == res1["total_results"] + assert 2 == res2["total_results"] + assert "doc1" == res1["results"][0]["id"] + + # Sort results, after RDB reload order may change + res = [res2["results"][0]["id"], res2["results"][1]["id"]] + res.sort() + assert ["doc1", "doc2"] == res + + @pytest.mark.redismod + def test_sort_by(self, client): + client.ft().create_index((TextField("txt"), NumericField("num", sortable=True))) + client.hset("doc1", mapping={"txt": "foo bar", "num": 1}) + client.hset("doc2", mapping={"txt": "foo baz", "num": 2}) + client.hset("doc3", mapping={"txt": "foo qux", "num": 3}) + + # Test sort + q1 = Query("foo").sort_by("num", asc=True).no_content() + q2 = Query("foo").sort_by("num", asc=False).no_content() + res1, res2 = client.ft().search(q1), client.ft().search(q2) - client.hset("doc1", mapping={"txt": "foo baz"}) - client.hset("doc2", mapping={"txt": "foo bar"}) + if is_resp2_connection(client): + assert 3 == res1.total + assert "doc1" == res1.docs[0].id + assert "doc2" == res1.docs[1].id + assert "doc3" == res1.docs[2].id + assert 3 == res2.total + assert "doc1" == res2.docs[2].id + assert "doc2" == res2.docs[1].id + assert "doc3" == res2.docs[0].id + else: + assert 3 == res1["total_results"] + assert "doc1" == res1["results"][0]["id"] + assert "doc2" == res1["results"][1]["id"] + assert "doc3" == res1["results"][2]["id"] + assert 3 == res2["total_results"] + assert "doc1" == res2["results"][2]["id"] + assert "doc2" == res2["results"][1]["id"] + assert "doc3" == res2["results"][0]["id"] + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.0.0", "search") + def test_drop_index(self, client): + """ + Ensure the index gets dropped by data remains by default + """ + for x in range(20): + for keep_docs in [[True, {}], [False, {"name": "haveit"}]]: + idx = "HaveIt" + index = self.getClient(client) + index.hset("index:haveit", mapping={"name": "haveit"}) + idef = IndexDefinition(prefix=["index:"]) + index.ft(idx).create_index((TextField("name"),), definition=idef) + self.waitForIndex(index, idx) + index.ft(idx).dropindex(delete_documents=keep_docs[0]) + i = index.hgetall("index:haveit") + assert i == keep_docs[1] + + @pytest.mark.redismod + def test_example(self, client): + # Creating the index definition and schema + client.ft().create_index((TextField("title", weight=5.0), TextField("body"))) + + # Indexing a document + client.hset( + "doc1", + mapping={ + "title": "RediSearch", + "body": "Redisearch impements a search engine on top of redis", + }, + ) - q = Query("foo ~bar").with_scores() - res = client.ft().search(q) - if is_resp2_connection(client): - assert 2 == res.total - assert "doc2" == res.docs[0].id - assert 0.87 == pytest.approx(res.docs[0].score, 0.01) - assert "doc1" == res.docs[1].id - else: - assert 2 == res["total_results"] - assert "doc2" == res["results"][0]["id"] - assert 0.87 == pytest.approx(res["results"][0]["score"], 0.01) - assert "doc1" == res["results"][1]["id"] - - -@pytest.mark.redismod -def test_stopwords(client): - client.ft().create_index((TextField("txt"),), stopwords=["foo", "bar", "baz"]) - client.hset("doc1", mapping={"txt": "foo bar"}) - client.hset("doc2", mapping={"txt": "hello world"}) - waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - - q1 = Query("foo bar").no_content() - q2 = Query("foo bar hello world").no_content() - res1, res2 = client.ft().search(q1), client.ft().search(q2) - if is_resp2_connection(client): - assert 0 == res1.total - assert 1 == res2.total - else: - assert 0 == res1["total_results"] - assert 1 == res2["total_results"] - - -@pytest.mark.redismod -def test_filters(client): - client.ft().create_index((TextField("txt"), NumericField("num"), GeoField("loc"))) - client.hset( - "doc1", mapping={"txt": "foo bar", "num": 3.141, "loc": "-0.441,51.458"} - ) - client.hset("doc2", mapping={"txt": "foo baz", "num": 2, "loc": "-0.1,51.2"}) - - waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - # Test numerical filter - q1 = Query("foo").add_filter(NumericFilter("num", 0, 2)).no_content() - q2 = ( - Query("foo") - .add_filter(NumericFilter("num", 2, NumericFilter.INF, minExclusive=True)) - .no_content() - ) - res1, res2 = client.ft().search(q1), client.ft().search(q2) - if is_resp2_connection(client): - assert 1 == res1.total - assert 1 == res2.total - assert "doc2" == res1.docs[0].id - assert "doc1" == res2.docs[0].id - else: - assert 1 == res1["total_results"] - assert 1 == res2["total_results"] - assert "doc2" == res1["results"][0]["id"] - assert "doc1" == res2["results"][0]["id"] + # Searching with complex parameters: + q = Query("search engine").verbatim().no_content().paging(0, 5) + + res = client.ft().search(q) + assert res is not None + + @pytest.mark.redismod + @skip_if_redis_enterprise() + def test_auto_complete(self, client): + n = 0 + with open(TITLES_CSV) as f: + cr = csv.reader(f) + + for row in cr: + n += 1 + term, score = row[0], float(row[1]) + assert n == client.ft().sugadd("ac", Suggestion(term, score=score)) + + assert n == client.ft().suglen("ac") + ret = client.ft().sugget("ac", "bad", with_scores=True) + assert 2 == len(ret) + assert "badger" == ret[0].string + assert isinstance(ret[0].score, float) + assert 1.0 != ret[0].score + assert "badalte rishtey" == ret[1].string + assert isinstance(ret[1].score, float) + assert 1.0 != ret[1].score + + ret = client.ft().sugget("ac", "bad", fuzzy=True, num=10) + assert 10 == len(ret) + assert 1.0 == ret[0].score + strs = {x.string for x in ret} + + for sug in strs: + assert 1 == client.ft().sugdel("ac", sug) + # make sure a second delete returns 0 + for sug in strs: + assert 0 == client.ft().sugdel("ac", sug) + + # make sure they were actually deleted + ret2 = client.ft().sugget("ac", "bad", fuzzy=True, num=10) + for sug in ret2: + assert sug.string not in strs + + # Test with payload + client.ft().sugadd("ac", Suggestion("pay1", payload="pl1")) + client.ft().sugadd("ac", Suggestion("pay2", payload="pl2")) + client.ft().sugadd("ac", Suggestion("pay3", payload="pl3")) + + sugs = client.ft().sugget("ac", "pay", with_payloads=True, with_scores=True) + assert 3 == len(sugs) + for sug in sugs: + assert sug.payload + assert sug.payload.startswith("pl") + + @pytest.mark.redismod + def test_no_index(self, client): + client.ft().create_index( + ( + TextField("field"), + TextField("text", no_index=True, sortable=True), + NumericField("numeric", no_index=True, sortable=True), + GeoField("geo", no_index=True, sortable=True), + TagField("tag", no_index=True, sortable=True), + ) + ) - # Test geo filter - q1 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 10)).no_content() - q2 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 100)).no_content() - res1, res2 = client.ft().search(q1), client.ft().search(q2) + client.hset( + "doc1", + mapping={ + "field": "aaa", + "text": "1", + "numeric": "1", + "geo": "1,1", + "tag": "1", + }, + ) + client.hset( + "doc2", + mapping={ + "field": "aab", + "text": "2", + "numeric": "2", + "geo": "2,2", + "tag": "2", + }, + ) + self.waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - if is_resp2_connection(client): - assert 1 == res1.total - assert 2 == res2.total - assert "doc1" == res1.docs[0].id - - # Sort results, after RDB reload order may change - res = [res2.docs[0].id, res2.docs[1].id] - res.sort() - assert ["doc1", "doc2"] == res - else: - assert 1 == res1["total_results"] - assert 2 == res2["total_results"] - assert "doc1" == res1["results"][0]["id"] + if is_resp2_connection(client): + res = client.ft().search(Query("@text:aa*")) + assert 0 == res.total - # Sort results, after RDB reload order may change - res = [res2["results"][0]["id"], res2["results"][1]["id"]] - res.sort() - assert ["doc1", "doc2"] == res + res = client.ft().search(Query("@field:aa*")) + assert 2 == res.total + res = client.ft().search(Query("*").sort_by("text", asc=False)) + assert 2 == res.total + assert "doc2" == res.docs[0].id -@pytest.mark.redismod -def test_sort_by(client): - client.ft().create_index((TextField("txt"), NumericField("num", sortable=True))) - client.hset("doc1", mapping={"txt": "foo bar", "num": 1}) - client.hset("doc2", mapping={"txt": "foo baz", "num": 2}) - client.hset("doc3", mapping={"txt": "foo qux", "num": 3}) + res = client.ft().search(Query("*").sort_by("text", asc=True)) + assert "doc1" == res.docs[0].id - # Test sort - q1 = Query("foo").sort_by("num", asc=True).no_content() - q2 = Query("foo").sort_by("num", asc=False).no_content() - res1, res2 = client.ft().search(q1), client.ft().search(q2) + res = client.ft().search(Query("*").sort_by("numeric", asc=True)) + assert "doc1" == res.docs[0].id - if is_resp2_connection(client): - assert 3 == res1.total - assert "doc1" == res1.docs[0].id - assert "doc2" == res1.docs[1].id - assert "doc3" == res1.docs[2].id - assert 3 == res2.total - assert "doc1" == res2.docs[2].id - assert "doc2" == res2.docs[1].id - assert "doc3" == res2.docs[0].id - else: - assert 3 == res1["total_results"] - assert "doc1" == res1["results"][0]["id"] - assert "doc2" == res1["results"][1]["id"] - assert "doc3" == res1["results"][2]["id"] - assert 3 == res2["total_results"] - assert "doc1" == res2["results"][2]["id"] - assert "doc2" == res2["results"][1]["id"] - assert "doc3" == res2["results"][0]["id"] - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.0.0", "search") -def test_drop_index(client): - """ - Ensure the index gets dropped by data remains by default - """ - for x in range(20): - for keep_docs in [[True, {}], [False, {"name": "haveit"}]]: - idx = "HaveIt" - index = getClient(client) - index.hset("index:haveit", mapping={"name": "haveit"}) - idef = IndexDefinition(prefix=["index:"]) - index.ft(idx).create_index((TextField("name"),), definition=idef) - waitForIndex(index, idx) - index.ft(idx).dropindex(delete_documents=keep_docs[0]) - i = index.hgetall("index:haveit") - assert i == keep_docs[1] - - -@pytest.mark.redismod -def test_example(client): - # Creating the index definition and schema - client.ft().create_index((TextField("title", weight=5.0), TextField("body"))) - - # Indexing a document - client.hset( - "doc1", - mapping={ - "title": "RediSearch", - "body": "Redisearch impements a search engine on top of redis", - }, - ) - - # Searching with complex parameters: - q = Query("search engine").verbatim().no_content().paging(0, 5) - - res = client.ft().search(q) - assert res is not None - - -@pytest.mark.redismod -@skip_if_redis_enterprise() -def test_auto_complete(client): - n = 0 - with open(TITLES_CSV) as f: - cr = csv.reader(f) - - for row in cr: - n += 1 - term, score = row[0], float(row[1]) - assert n == client.ft().sugadd("ac", Suggestion(term, score=score)) - - assert n == client.ft().suglen("ac") - ret = client.ft().sugget("ac", "bad", with_scores=True) - assert 2 == len(ret) - assert "badger" == ret[0].string - assert isinstance(ret[0].score, float) - assert 1.0 != ret[0].score - assert "badalte rishtey" == ret[1].string - assert isinstance(ret[1].score, float) - assert 1.0 != ret[1].score - - ret = client.ft().sugget("ac", "bad", fuzzy=True, num=10) - assert 10 == len(ret) - assert 1.0 == ret[0].score - strs = {x.string for x in ret} - - for sug in strs: - assert 1 == client.ft().sugdel("ac", sug) - # make sure a second delete returns 0 - for sug in strs: - assert 0 == client.ft().sugdel("ac", sug) - - # make sure they were actually deleted - ret2 = client.ft().sugget("ac", "bad", fuzzy=True, num=10) - for sug in ret2: - assert sug.string not in strs - - # Test with payload - client.ft().sugadd("ac", Suggestion("pay1", payload="pl1")) - client.ft().sugadd("ac", Suggestion("pay2", payload="pl2")) - client.ft().sugadd("ac", Suggestion("pay3", payload="pl3")) - - sugs = client.ft().sugget("ac", "pay", with_payloads=True, with_scores=True) - assert 3 == len(sugs) - for sug in sugs: - assert sug.payload - assert sug.payload.startswith("pl") - - -@pytest.mark.redismod -def test_no_index(client): - client.ft().create_index( - ( - TextField("field"), - TextField("text", no_index=True, sortable=True), - NumericField("numeric", no_index=True, sortable=True), - GeoField("geo", no_index=True, sortable=True), - TagField("tag", no_index=True, sortable=True), - ) - ) - - client.hset( - "doc1", - mapping={"field": "aaa", "text": "1", "numeric": "1", "geo": "1,1", "tag": "1"}, - ) - client.hset( - "doc2", - mapping={"field": "aab", "text": "2", "numeric": "2", "geo": "2,2", "tag": "2"}, - ) - waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + res = client.ft().search(Query("*").sort_by("geo", asc=True)) + assert "doc1" == res.docs[0].id - if is_resp2_connection(client): - res = client.ft().search(Query("@text:aa*")) - assert 0 == res.total + res = client.ft().search(Query("*").sort_by("tag", asc=True)) + assert "doc1" == res.docs[0].id + else: + res = client.ft().search(Query("@text:aa*")) + assert 0 == res["total_results"] - res = client.ft().search(Query("@field:aa*")) - assert 2 == res.total + res = client.ft().search(Query("@field:aa*")) + assert 2 == res["total_results"] - res = client.ft().search(Query("*").sort_by("text", asc=False)) - assert 2 == res.total - assert "doc2" == res.docs[0].id + res = client.ft().search(Query("*").sort_by("text", asc=False)) + assert 2 == res["total_results"] + assert "doc2" == res["results"][0]["id"] - res = client.ft().search(Query("*").sort_by("text", asc=True)) - assert "doc1" == res.docs[0].id + res = client.ft().search(Query("*").sort_by("text", asc=True)) + assert "doc1" == res["results"][0]["id"] - res = client.ft().search(Query("*").sort_by("numeric", asc=True)) - assert "doc1" == res.docs[0].id + res = client.ft().search(Query("*").sort_by("numeric", asc=True)) + assert "doc1" == res["results"][0]["id"] - res = client.ft().search(Query("*").sort_by("geo", asc=True)) - assert "doc1" == res.docs[0].id + res = client.ft().search(Query("*").sort_by("geo", asc=True)) + assert "doc1" == res["results"][0]["id"] - res = client.ft().search(Query("*").sort_by("tag", asc=True)) - assert "doc1" == res.docs[0].id - else: - res = client.ft().search(Query("@text:aa*")) - assert 0 == res["total_results"] + res = client.ft().search(Query("*").sort_by("tag", asc=True)) + assert "doc1" == res["results"][0]["id"] - res = client.ft().search(Query("@field:aa*")) - assert 2 == res["total_results"] + # Ensure exception is raised for non-indexable, non-sortable fields + with pytest.raises(Exception): + TextField("name", no_index=True, sortable=False) + with pytest.raises(Exception): + NumericField("name", no_index=True, sortable=False) + with pytest.raises(Exception): + GeoField("name", no_index=True, sortable=False) + with pytest.raises(Exception): + TagField("name", no_index=True, sortable=False) - res = client.ft().search(Query("*").sort_by("text", asc=False)) - assert 2 == res["total_results"] - assert "doc2" == res["results"][0]["id"] + @pytest.mark.redismod + def test_explain(self, client): + client.ft().create_index((TextField("f1"), TextField("f2"), TextField("f3"))) + res = client.ft().explain("@f3:f3_val @f2:f2_val @f1:f1_val") + assert res - res = client.ft().search(Query("*").sort_by("text", asc=True)) - assert "doc1" == res["results"][0]["id"] + @pytest.mark.redismod + def test_explaincli(self, client): + with pytest.raises(NotImplementedError): + client.ft().explain_cli("foo") - res = client.ft().search(Query("*").sort_by("numeric", asc=True)) - assert "doc1" == res["results"][0]["id"] + @pytest.mark.redismod + def test_summarize(self, client): + self.createIndex(client.ft()) + self.waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - res = client.ft().search(Query("*").sort_by("geo", asc=True)) - assert "doc1" == res["results"][0]["id"] + q = Query('"king henry"').paging(0, 1) + q.highlight(fields=("play", "txt"), tags=("", "")) + q.summarize("txt") - res = client.ft().search(Query("*").sort_by("tag", asc=True)) - assert "doc1" == res["results"][0]["id"] + if is_resp2_connection(client): + doc = sorted(client.ft().search(q).docs)[0] + assert "Henry IV" == doc.play + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc.txt + ) - # Ensure exception is raised for non-indexable, non-sortable fields - with pytest.raises(Exception): - TextField("name", no_index=True, sortable=False) - with pytest.raises(Exception): - NumericField("name", no_index=True, sortable=False) - with pytest.raises(Exception): - GeoField("name", no_index=True, sortable=False) - with pytest.raises(Exception): - TagField("name", no_index=True, sortable=False) + q = Query('"king henry"').paging(0, 1).summarize().highlight() + doc = sorted(client.ft().search(q).docs)[0] + assert "Henry ... " == doc.play + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc.txt + ) + else: + doc = sorted(client.ft().search(q)["results"])[0] + assert "Henry IV" == doc["extra_attributes"]["play"] + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc["extra_attributes"]["txt"] + ) -@pytest.mark.redismod -def test_explain(client): - client.ft().create_index((TextField("f1"), TextField("f2"), TextField("f3"))) - res = client.ft().explain("@f3:f3_val @f2:f2_val @f1:f1_val") - assert res + q = Query('"king henry"').paging(0, 1).summarize().highlight() + doc = sorted(client.ft().search(q)["results"])[0] + assert "Henry ... " == doc["extra_attributes"]["play"] + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc["extra_attributes"]["txt"] + ) -@pytest.mark.redismod -def test_explaincli(client): - with pytest.raises(NotImplementedError): - client.ft().explain_cli("foo") + @pytest.mark.redismod + @skip_ifmodversion_lt("2.0.0", "search") + def test_alias(self, client): + index1 = self.getClient(client) + index2 = self.getClient(client) + def1 = IndexDefinition(prefix=["index1:"]) + def2 = IndexDefinition(prefix=["index2:"]) -@pytest.mark.redismod -def test_summarize(client): - createIndex(client.ft()) - waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + ftindex1 = index1.ft("testAlias") + ftindex2 = index2.ft("testAlias2") + ftindex1.create_index((TextField("name"),), definition=def1) + ftindex2.create_index((TextField("name"),), definition=def2) - q = Query('"king henry"').paging(0, 1) - q.highlight(fields=("play", "txt"), tags=("", "")) - q.summarize("txt") + index1.hset("index1:lonestar", mapping={"name": "lonestar"}) + index2.hset("index2:yogurt", mapping={"name": "yogurt"}) - if is_resp2_connection(client): - doc = sorted(client.ft().search(q).docs)[0] - assert "Henry IV" == doc.play - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt - ) + if is_resp2_connection(client): + res = ftindex1.search("*").docs[0] + assert "index1:lonestar" == res.id - q = Query('"king henry"').paging(0, 1).summarize().highlight() + # create alias and check for results + ftindex1.aliasadd("spaceballs") + alias_client = self.getClient(client).ft("spaceballs") + res = alias_client.search("*").docs[0] + assert "index1:lonestar" == res.id - doc = sorted(client.ft().search(q).docs)[0] - assert "Henry ... " == doc.play - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt - ) - else: - doc = sorted(client.ft().search(q)["results"])[0] - assert "Henry IV" == doc["extra_attributes"]["play"] - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc["extra_attributes"]["txt"] - ) + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + ftindex2.aliasadd("spaceballs") - q = Query('"king henry"').paging(0, 1).summarize().highlight() + # update alias and ensure new results + ftindex2.aliasupdate("spaceballs") + alias_client2 = self.getClient(client).ft("spaceballs") - doc = sorted(client.ft().search(q)["results"])[0] - assert "Henry ... " == doc["extra_attributes"]["play"] - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc["extra_attributes"]["txt"] - ) + res = alias_client2.search("*").docs[0] + assert "index2:yogurt" == res.id + else: + res = ftindex1.search("*")["results"][0] + assert "index1:lonestar" == res["id"] + # create alias and check for results + ftindex1.aliasadd("spaceballs") + alias_client = self.getClient(client).ft("spaceballs") + res = alias_client.search("*")["results"][0] + assert "index1:lonestar" == res["id"] -@pytest.mark.redismod -@skip_ifmodversion_lt("2.0.0", "search") -def test_alias(client): - index1 = getClient(client) - index2 = getClient(client) + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + ftindex2.aliasadd("spaceballs") - def1 = IndexDefinition(prefix=["index1:"]) - def2 = IndexDefinition(prefix=["index2:"]) + # update alias and ensure new results + ftindex2.aliasupdate("spaceballs") + alias_client2 = self.getClient(client).ft("spaceballs") - ftindex1 = index1.ft("testAlias") - ftindex2 = index2.ft("testAlias2") - ftindex1.create_index((TextField("name"),), definition=def1) - ftindex2.create_index((TextField("name"),), definition=def2) + res = alias_client2.search("*")["results"][0] + assert "index2:yogurt" == res["id"] - index1.hset("index1:lonestar", mapping={"name": "lonestar"}) - index2.hset("index2:yogurt", mapping={"name": "yogurt"}) + ftindex2.aliasdel("spaceballs") + with pytest.raises(Exception): + alias_client2.search("*").docs[0] - if is_resp2_connection(client): - res = ftindex1.search("*").docs[0] - assert "index1:lonestar" == res.id + @pytest.mark.redismod + @pytest.mark.xfail(strict=False) + def test_alias_basic(self, client): + # Creating a client with one index + index1 = self.getClient(client).ft("testAlias") - # create alias and check for results - ftindex1.aliasadd("spaceballs") - alias_client = getClient(client).ft("spaceballs") - res = alias_client.search("*").docs[0] - assert "index1:lonestar" == res.id + index1.create_index((TextField("txt"),)) + index1.client.hset("doc1", mapping={"txt": "text goes here"}) - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - ftindex2.aliasadd("spaceballs") + index2 = self.getClient(client).ft("testAlias2") + index2.create_index((TextField("txt"),)) + index2.client.hset("doc2", mapping={"txt": "text goes here"}) - # update alias and ensure new results - ftindex2.aliasupdate("spaceballs") - alias_client2 = getClient(client).ft("spaceballs") + # add the actual alias and check + index1.aliasadd("myalias") + alias_client = self.getClient(client).ft("myalias") + if is_resp2_connection(client): + res = sorted(alias_client.search("*").docs, key=lambda x: x.id) + assert "doc1" == res[0].id + + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + index2.aliasadd("myalias") + + # update the alias and ensure we get doc2 + index2.aliasupdate("myalias") + alias_client2 = self.getClient(client).ft("myalias") + res = sorted(alias_client2.search("*").docs, key=lambda x: x.id) + assert "doc1" == res[0].id + else: + res = sorted(alias_client.search("*")["results"], key=lambda x: x["id"]) + assert "doc1" == res[0]["id"] - res = alias_client2.search("*").docs[0] - assert "index2:yogurt" == res.id - else: - res = ftindex1.search("*")["results"][0] - assert "index1:lonestar" == res["id"] + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + index2.aliasadd("myalias") - # create alias and check for results - ftindex1.aliasadd("spaceballs") - alias_client = getClient(client).ft("spaceballs") - res = alias_client.search("*")["results"][0] - assert "index1:lonestar" == res["id"] + # update the alias and ensure we get doc2 + index2.aliasupdate("myalias") + alias_client2 = self.getClient(client).ft("myalias") + res = sorted(alias_client2.search("*")["results"], key=lambda x: x["id"]) + assert "doc1" == res[0]["id"] - # Throw an exception when trying to add an alias that already exists + # delete the alias and expect an error if we try to query again + index2.aliasdel("myalias") with pytest.raises(Exception): - ftindex2.aliasadd("spaceballs") + _ = alias_client2.search("*").docs[0] + + @pytest.mark.redismod + def test_textfield_sortable_nostem(self, client): + # Creating the index definition with sortable and no_stem + client.ft().create_index((TextField("txt", sortable=True, no_stem=True),)) + + # Now get the index info to confirm its contents + response = client.ft().info() + if is_resp2_connection(client): + assert "SORTABLE" in response["attributes"][0] + assert "NOSTEM" in response["attributes"][0] + else: + assert "SORTABLE" in response["attributes"][0]["flags"] + assert "NOSTEM" in response["attributes"][0]["flags"] - # update alias and ensure new results - ftindex2.aliasupdate("spaceballs") - alias_client2 = getClient(client).ft("spaceballs") + @pytest.mark.redismod + def test_alter_schema_add(self, client): + # Creating the index definition and schema + client.ft().create_index(TextField("title")) - res = alias_client2.search("*")["results"][0] - assert "index2:yogurt" == res["id"] + # Using alter to add a field + client.ft().alter_schema_add(TextField("body")) - ftindex2.aliasdel("spaceballs") - with pytest.raises(Exception): - alias_client2.search("*").docs[0] + # Indexing a document + client.hset( + "doc1", + mapping={"title": "MyTitle", "body": "Some content only in the body"}, + ) + # Searching with parameter only in the body (the added field) + q = Query("only in the body") -@pytest.mark.redismod -@pytest.mark.xfail(strict=False) -def test_alias_basic(client): - # Creating a client with one index - index1 = getClient(client).ft("testAlias") + # Ensure we find the result searching on the added body field + res = client.ft().search(q) + if is_resp2_connection(client): + assert 1 == res.total + else: + assert 1 == res["total_results"] - index1.create_index((TextField("txt"),)) - index1.client.hset("doc1", mapping={"txt": "text goes here"}) + @pytest.mark.redismod + def test_spell_check(self, client): + client.ft().create_index((TextField("f1"), TextField("f2"))) - index2 = getClient(client).ft("testAlias2") - index2.create_index((TextField("txt"),)) - index2.client.hset("doc2", mapping={"txt": "text goes here"}) + client.hset( + "doc1", mapping={"f1": "some valid content", "f2": "this is sample text"} + ) + client.hset("doc2", mapping={"f1": "very important", "f2": "lorem ipsum"}) + self.waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - # add the actual alias and check - index1.aliasadd("myalias") - alias_client = getClient(client).ft("myalias") - if is_resp2_connection(client): - res = sorted(alias_client.search("*").docs, key=lambda x: x.id) - assert "doc1" == res[0].id + if is_resp2_connection(client): + # test spellcheck + res = client.ft().spellcheck("impornant") + assert "important" == res["impornant"][0]["suggestion"] + + res = client.ft().spellcheck("contnt") + assert "content" == res["contnt"][0]["suggestion"] + + # test spellcheck with Levenshtein distance + res = client.ft().spellcheck("vlis") + assert res == {} + res = client.ft().spellcheck("vlis", distance=2) + assert "valid" == res["vlis"][0]["suggestion"] + + # test spellcheck include + client.ft().dict_add("dict", "lore", "lorem", "lorm") + res = client.ft().spellcheck("lorm", include="dict") + assert len(res["lorm"]) == 3 + assert ( + res["lorm"][0]["suggestion"], + res["lorm"][1]["suggestion"], + res["lorm"][2]["suggestion"], + ) == ("lorem", "lore", "lorm") + assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ("0.5", "0") + + # test spellcheck exclude + res = client.ft().spellcheck("lorm", exclude="dict") + assert res == {} + else: + # test spellcheck + res = client.ft().spellcheck("impornant") + assert "important" in res["results"]["impornant"][0].keys() + + res = client.ft().spellcheck("contnt") + assert "content" in res["results"]["contnt"][0].keys() + + # test spellcheck with Levenshtein distance + res = client.ft().spellcheck("vlis") + assert res == {"results": {"vlis": []}} + res = client.ft().spellcheck("vlis", distance=2) + assert "valid" in res["results"]["vlis"][0].keys() + + # test spellcheck include + client.ft().dict_add("dict", "lore", "lorem", "lorm") + res = client.ft().spellcheck("lorm", include="dict") + assert len(res["results"]["lorm"]) == 3 + assert "lorem" in res["results"]["lorm"][0].keys() + assert "lore" in res["results"]["lorm"][1].keys() + assert "lorm" in res["results"]["lorm"][2].keys() + assert ( + res["results"]["lorm"][0]["lorem"], + res["results"]["lorm"][1]["lore"], + ) == (0.5, 0) + + # test spellcheck exclude + res = client.ft().spellcheck("lorm", exclude="dict") + assert res == {"results": {}} + + @pytest.mark.redismod + def test_dict_operations(self, client): + client.ft().create_index((TextField("f1"), TextField("f2"))) + # Add three items + res = client.ft().dict_add("custom_dict", "item1", "item2", "item3") + assert 3 == res + + # Remove one item + res = client.ft().dict_del("custom_dict", "item2") + assert 1 == res + + # Dump dict and inspect content + res = client.ft().dict_dump("custom_dict") + assert res == ["item1", "item3"] + + # Remove rest of the items before reload + client.ft().dict_del("custom_dict", *res) + + @pytest.mark.redismod + def test_phonetic_matcher(self, client): + client.ft().create_index((TextField("name"),)) + client.hset("doc1", mapping={"name": "Jon"}) + client.hset("doc2", mapping={"name": "John"}) + + res = client.ft().search(Query("Jon")) + if is_resp2_connection(client): + assert 1 == len(res.docs) + assert "Jon" == res.docs[0].name + else: + assert 1 == res["total_results"] + assert "Jon" == res["results"][0]["extra_attributes"]["name"] - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - index2.aliasadd("myalias") + # Drop and create index with phonetic matcher + client.flushdb() - # update the alias and ensure we get doc2 - index2.aliasupdate("myalias") - alias_client2 = getClient(client).ft("myalias") - res = sorted(alias_client2.search("*").docs, key=lambda x: x.id) - assert "doc1" == res[0].id - else: - res = sorted(alias_client.search("*")["results"], key=lambda x: x["id"]) - assert "doc1" == res[0]["id"] + client.ft().create_index((TextField("name", phonetic_matcher="dm:en"),)) + client.hset("doc1", mapping={"name": "Jon"}) + client.hset("doc2", mapping={"name": "John"}) - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - index2.aliasadd("myalias") + res = client.ft().search(Query("Jon")) + if is_resp2_connection(client): + assert 2 == len(res.docs) + assert ["John", "Jon"] == sorted(d.name for d in res.docs) + else: + assert 2 == res["total_results"] + assert ["John", "Jon"] == sorted( + d["extra_attributes"]["name"] for d in res["results"] + ) - # update the alias and ensure we get doc2 - index2.aliasupdate("myalias") - alias_client2 = getClient(client).ft("myalias") - res = sorted(alias_client2.search("*")["results"], key=lambda x: x["id"]) - assert "doc1" == res[0]["id"] + @pytest.mark.redismod + def test_get(self, client): + client.ft().create_index((TextField("f1"), TextField("f2"))) - # delete the alias and expect an error if we try to query again - index2.aliasdel("myalias") - with pytest.raises(Exception): - _ = alias_client2.search("*").docs[0] + assert [None] == client.ft().get("doc1") + assert [None, None] == client.ft().get("doc2", "doc1") + client.hset( + "doc1", + mapping={"f1": "some valid content dd1", "f2": "this is sample text f1"}, + ) + client.hset( + "doc2", + mapping={"f1": "some valid content dd2", "f2": "this is sample text f2"}, + ) -@pytest.mark.redismod -def test_textfield_sortable_nostem(client): - # Creating the index definition with sortable and no_stem - client.ft().create_index((TextField("txt", sortable=True, no_stem=True),)) + assert [ + ["f1", "some valid content dd2", "f2", "this is sample text f2"] + ] == client.ft().get("doc2") + assert [ + ["f1", "some valid content dd1", "f2", "this is sample text f1"], + ["f1", "some valid content dd2", "f2", "this is sample text f2"], + ] == client.ft().get("doc1", "doc2") + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.0.0", "search") + def test_index_definition(self, client): + """ + Create definition and test its args + """ + with pytest.raises(RuntimeError): + IndexDefinition(prefix=["hset:", "henry"], index_type="json") + + definition = IndexDefinition( + prefix=["hset:", "henry"], + filter="@f1==32", + language="English", + language_field="play", + score_field="chapter", + score=0.5, + payload_field="txt", + index_type=IndexType.JSON, + ) - # Now get the index info to confirm its contents - response = client.ft().info() - if is_resp2_connection(client): - assert "SORTABLE" in response["attributes"][0] - assert "NOSTEM" in response["attributes"][0] - else: - assert "SORTABLE" in response["attributes"][0]["flags"] - assert "NOSTEM" in response["attributes"][0]["flags"] + assert [ + "ON", + "JSON", + "PREFIX", + 2, + "hset:", + "henry", + "FILTER", + "@f1==32", + "LANGUAGE_FIELD", + "play", + "LANGUAGE", + "English", + "SCORE_FIELD", + "chapter", + "SCORE", + 0.5, + "PAYLOAD_FIELD", + "txt", + ] == definition.args + + self.createIndex(client.ft(), num_docs=500, definition=definition) + + @pytest.mark.redismod + @pytest.mark.onlynoncluster + @skip_if_redis_enterprise() + @skip_if_server_version_gte("7.9.0") + def test_expire(self, client): + client.ft().create_index((TextField("txt", sortable=True),), temporary=4) + ttl = client.execute_command("ft.debug", "TTL", "idx") + assert ttl > 2 + while ttl > 2: + ttl = client.execute_command("ft.debug", "TTL", "idx") + time.sleep(0.01) -@pytest.mark.redismod -def test_alter_schema_add(client): - # Creating the index definition and schema - client.ft().create_index(TextField("title")) + @pytest.mark.redismod + def test_skip_initial_scan(self, client): + client.hset("doc1", "foo", "bar") + q = Query("@foo:bar") - # Using alter to add a field - client.ft().alter_schema_add(TextField("body")) + client.ft().create_index((TextField("foo"),), skip_initial_scan=True) + res = client.ft().search(q) + if is_resp2_connection(client): + assert res.total == 0 + else: + assert res["total_results"] == 0 - # Indexing a document - client.hset( - "doc1", mapping={"title": "MyTitle", "body": "Some content only in the body"} - ) + @pytest.mark.redismod + def test_summarize_disabled_nooffset(self, client): + client.ft().create_index((TextField("txt"),), no_term_offsets=True) + client.hset("doc1", mapping={"txt": "foo bar"}) + with pytest.raises(Exception): + client.ft().search(Query("foo").summarize(fields=["txt"])) - # Searching with parameter only in the body (the added field) - q = Query("only in the body") + @pytest.mark.redismod + def test_summarize_disabled_nohl(self, client): + client.ft().create_index((TextField("txt"),), no_highlight=True) + client.hset("doc1", mapping={"txt": "foo bar"}) + with pytest.raises(Exception): + client.ft().search(Query("foo").summarize(fields=["txt"])) + + @pytest.mark.redismod + def test_max_text_fields(self, client): + # Creating the index definition + client.ft().create_index((TextField("f0"),)) + for x in range(1, 32): + client.ft().alter_schema_add((TextField(f"f{x}"),)) + + # Should be too many indexes + with pytest.raises(redis.ResponseError): + client.ft().alter_schema_add((TextField(f"f{x}"),)) + + client.ft().dropindex() + # Creating the index definition + client.ft().create_index((TextField("f0"),), max_text_fields=True) + # Fill the index with fields + for x in range(1, 50): + client.ft().alter_schema_add((TextField(f"f{x}"),)) + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.0.0", "search") + def test_create_client_definition(self, client): + """ + Create definition with no index type provided, + and use hset to test the client definition (the default is HASH). + """ + definition = IndexDefinition(prefix=["hset:", "henry"]) + self.createIndex(client.ft(), num_docs=500, definition=definition) - # Ensure we find the result searching on the added body field - res = client.ft().search(q) - if is_resp2_connection(client): - assert 1 == res.total - else: - assert 1 == res["total_results"] + info = client.ft().info() + assert 494 == int(info["num_docs"]) + client.ft().client.hset("hset:1", "f1", "v1") + info = client.ft().info() + assert 495 == int(info["num_docs"]) + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.0.0", "search") + def test_create_client_definition_hash(self, client): + """ + Create definition with IndexType.HASH as index type (ON HASH), + and use hset to test the client definition. + """ + definition = IndexDefinition( + prefix=["hset:", "henry"], index_type=IndexType.HASH + ) + self.createIndex(client.ft(), num_docs=500, definition=definition) -@pytest.mark.redismod -def test_spell_check(client): - client.ft().create_index((TextField("f1"), TextField("f2"))) + info = client.ft().info() + assert 494 == int(info["num_docs"]) - client.hset( - "doc1", mapping={"f1": "some valid content", "f2": "this is sample text"} - ) - client.hset("doc2", mapping={"f1": "very important", "f2": "lorem ipsum"}) - waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + client.ft().client.hset("hset:1", "f1", "v1") + info = client.ft().info() + assert 495 == int(info["num_docs"]) + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.2.0", "search") + def test_create_client_definition_json(self, client): + """ + Create definition with IndexType.JSON as index type (ON JSON), + and use json client to test it. + """ + definition = IndexDefinition(prefix=["king:"], index_type=IndexType.JSON) + client.ft().create_index((TextField("$.name"),), definition=definition) + + client.json().set("king:1", Path.root_path(), {"name": "henry"}) + client.json().set("king:2", Path.root_path(), {"name": "james"}) + + res = client.ft().search("henry") + if is_resp2_connection(client): + assert res.docs[0].id == "king:1" + assert res.docs[0].payload is None + assert res.docs[0].json == '{"name":"henry"}' + assert res.total == 1 + else: + assert res["results"][0]["id"] == "king:1" + assert res["results"][0]["extra_attributes"]["$"] == '{"name":"henry"}' + assert res["total_results"] == 1 + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.2.0", "search") + def test_fields_as_name(self, client): + # create index + SCHEMA = ( + TextField("$.name", sortable=True, as_name="name"), + NumericField("$.age", as_name="just_a_number"), + ) + definition = IndexDefinition(index_type=IndexType.JSON) + client.ft().create_index(SCHEMA, definition=definition) - if is_resp2_connection(client): - # test spellcheck - res = client.ft().spellcheck("impornant") - assert "important" == res["impornant"][0]["suggestion"] - - res = client.ft().spellcheck("contnt") - assert "content" == res["contnt"][0]["suggestion"] - - # test spellcheck with Levenshtein distance - res = client.ft().spellcheck("vlis") - assert res == {} - res = client.ft().spellcheck("vlis", distance=2) - assert "valid" == res["vlis"][0]["suggestion"] - - # test spellcheck include - client.ft().dict_add("dict", "lore", "lorem", "lorm") - res = client.ft().spellcheck("lorm", include="dict") - assert len(res["lorm"]) == 3 - assert ( - res["lorm"][0]["suggestion"], - res["lorm"][1]["suggestion"], - res["lorm"][2]["suggestion"], - ) == ("lorem", "lore", "lorm") - assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ("0.5", "0") - - # test spellcheck exclude - res = client.ft().spellcheck("lorm", exclude="dict") - assert res == {} - else: - # test spellcheck - res = client.ft().spellcheck("impornant") - assert "important" in res["results"]["impornant"][0].keys() - - res = client.ft().spellcheck("contnt") - assert "content" in res["results"]["contnt"][0].keys() - - # test spellcheck with Levenshtein distance - res = client.ft().spellcheck("vlis") - assert res == {"results": {"vlis": []}} - res = client.ft().spellcheck("vlis", distance=2) - assert "valid" in res["results"]["vlis"][0].keys() - - # test spellcheck include - client.ft().dict_add("dict", "lore", "lorem", "lorm") - res = client.ft().spellcheck("lorm", include="dict") - assert len(res["results"]["lorm"]) == 3 - assert "lorem" in res["results"]["lorm"][0].keys() - assert "lore" in res["results"]["lorm"][1].keys() - assert "lorm" in res["results"]["lorm"][2].keys() - assert ( - res["results"]["lorm"][0]["lorem"], - res["results"]["lorm"][1]["lore"], - ) == (0.5, 0) - - # test spellcheck exclude - res = client.ft().spellcheck("lorm", exclude="dict") - assert res == {"results": {}} - - -@pytest.mark.redismod -def test_dict_operations(client): - client.ft().create_index((TextField("f1"), TextField("f2"))) - # Add three items - res = client.ft().dict_add("custom_dict", "item1", "item2", "item3") - assert 3 == res - - # Remove one item - res = client.ft().dict_del("custom_dict", "item2") - assert 1 == res - - # Dump dict and inspect content - res = client.ft().dict_dump("custom_dict") - assert res == ["item1", "item3"] - - # Remove rest of the items before reload - client.ft().dict_del("custom_dict", *res) - - -@pytest.mark.redismod -def test_phonetic_matcher(client): - client.ft().create_index((TextField("name"),)) - client.hset("doc1", mapping={"name": "Jon"}) - client.hset("doc2", mapping={"name": "John"}) - - res = client.ft().search(Query("Jon")) - if is_resp2_connection(client): - assert 1 == len(res.docs) - assert "Jon" == res.docs[0].name - else: - assert 1 == res["total_results"] - assert "Jon" == res["results"][0]["extra_attributes"]["name"] + # insert json data + res = client.json().set("doc:1", Path.root_path(), {"name": "Jon", "age": 25}) + assert res - # Drop and create index with phonetic matcher - client.flushdb() + res = client.ft().search(Query("Jon").return_fields("name", "just_a_number")) + if is_resp2_connection(client): + assert 1 == len(res.docs) + assert "doc:1" == res.docs[0].id + assert "Jon" == res.docs[0].name + assert "25" == res.docs[0].just_a_number + else: + assert 1 == len(res["results"]) + assert "doc:1" == res["results"][0]["id"] + assert "Jon" == res["results"][0]["extra_attributes"]["name"] + assert "25" == res["results"][0]["extra_attributes"]["just_a_number"] - client.ft().create_index((TextField("name", phonetic_matcher="dm:en"),)) - client.hset("doc1", mapping={"name": "Jon"}) - client.hset("doc2", mapping={"name": "John"}) + @pytest.mark.redismod + def test_casesensitive(self, client): + # create index + SCHEMA = (TagField("t", case_sensitive=False),) + client.ft().create_index(SCHEMA) + client.ft().client.hset("1", "t", "HELLO") + client.ft().client.hset("2", "t", "hello") - res = client.ft().search(Query("Jon")) - if is_resp2_connection(client): - assert 2 == len(res.docs) - assert ["John", "Jon"] == sorted(d.name for d in res.docs) - else: - assert 2 == res["total_results"] - assert ["John", "Jon"] == sorted( - d["extra_attributes"]["name"] for d in res["results"] - ) - - -@pytest.mark.redismod -@pytest.mark.onlynoncluster -# NOTE(imalinovskyi): This test contains hardcoded scores valid only for RediSearch 2.8+ -@skip_ifmodversion_lt("2.8.0", "search") -@skip_if_server_version_gte("7.9.0") -def test_scorer(client): - client.ft().create_index((TextField("description"),)) - - client.hset( - "doc1", mapping={"description": "The quick brown fox jumps over the lazy dog"} - ) - client.hset( - "doc2", - mapping={ - "description": "Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do." # noqa - }, - ) - - # default scorer is TFIDF - if is_resp2_connection(client): - res = client.ft().search(Query("quick").with_scores()) - assert 1.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) - assert 1.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) - assert 0.14285714285714285 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("BM25").with_scores()) - assert 0.22471909420069797 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) - assert 2.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) - assert 1.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("HAMMING").with_scores()) - assert 0.0 == res.docs[0].score - else: - res = client.ft().search(Query("quick").with_scores()) - assert 1.0 == res["results"][0]["score"] - res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) - assert 1.0 == res["results"][0]["score"] - res = client.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) - assert 0.14285714285714285 == res["results"][0]["score"] - res = client.ft().search(Query("quick").scorer("BM25").with_scores()) - assert 0.22471909420069797 == res["results"][0]["score"] - res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) - assert 2.0 == res["results"][0]["score"] - res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) - assert 1.0 == res["results"][0]["score"] - res = client.ft().search(Query("quick").scorer("HAMMING").with_scores()) - assert 0.0 == res["results"][0]["score"] - - -@pytest.mark.redismod -@pytest.mark.onlynoncluster -@skip_if_server_version_lt("7.9.0") -def test_scorer_with_new_default_scorer(client): - client.ft().create_index((TextField("description"),)) - - client.hset( - "doc1", mapping={"description": "The quick brown fox jumps over the lazy dog"} - ) - client.hset( - "doc2", - mapping={ - "description": "Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do." # noqa - }, - ) - - # default scorer is BM25STD - if is_resp2_connection(client): - res = client.ft().search(Query("quick").with_scores()) - assert 0.23 == pytest.approx(res.docs[0].score, 0.05) - res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) - assert 1.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) - assert 0.14285714285714285 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("BM25").with_scores()) - assert 0.22471909420069797 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) - assert 2.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) - assert 1.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("HAMMING").with_scores()) - assert 0.0 == res.docs[0].score - else: - res = client.ft().search(Query("quick").with_scores()) - assert 0.23 == pytest.approx(res["results"][0]["score"], 0.05) - res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) - assert 1.0 == res["results"][0]["score"] - res = client.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) - assert 0.14285714285714285 == res["results"][0]["score"] - res = client.ft().search(Query("quick").scorer("BM25").with_scores()) - assert 0.22471909420069797 == res["results"][0]["score"] - res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) - assert 2.0 == res["results"][0]["score"] - res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) - assert 1.0 == res["results"][0]["score"] - res = client.ft().search(Query("quick").scorer("HAMMING").with_scores()) - assert 0.0 == res["results"][0]["score"] - - -@pytest.mark.redismod -def test_get(client): - client.ft().create_index((TextField("f1"), TextField("f2"))) - - assert [None] == client.ft().get("doc1") - assert [None, None] == client.ft().get("doc2", "doc1") - - client.hset( - "doc1", mapping={"f1": "some valid content dd1", "f2": "this is sample text f1"} - ) - client.hset( - "doc2", mapping={"f1": "some valid content dd2", "f2": "this is sample text f2"} - ) - - assert [ - ["f1", "some valid content dd2", "f2", "this is sample text f2"] - ] == client.ft().get("doc2") - assert [ - ["f1", "some valid content dd1", "f2", "this is sample text f1"], - ["f1", "some valid content dd2", "f2", "this is sample text f2"], - ] == client.ft().get("doc1", "doc2") - - -@pytest.mark.redismod -@pytest.mark.onlynoncluster -@skip_ifmodversion_lt("2.2.0", "search") -@skip_if_server_version_gte("7.9.0") -def test_config(client): - assert client.ft().config_set("TIMEOUT", "100") - with pytest.raises(redis.ResponseError): - client.ft().config_set("TIMEOUT", "null") - res = client.ft().config_get("*") - assert "100" == res["TIMEOUT"] - res = client.ft().config_get("TIMEOUT") - assert "100" == res["TIMEOUT"] - - -@pytest.mark.redismod -@pytest.mark.onlynoncluster -@skip_if_server_version_lt("7.9.0") -def test_config_with_removed_ftconfig(client): - assert client.config_set("timeout", "100") - with pytest.raises(redis.ResponseError): - client.config_set("timeout", "null") - res = client.config_get("*") - assert "100" == res["timeout"] - res = client.config_get("timeout") - assert "100" == res["timeout"] - - -@pytest.mark.redismod -@pytest.mark.onlynoncluster -def test_aggregations_groupby(client): - # Creating the index definition and schema - client.ft().create_index( - ( - NumericField("random_num"), - TextField("title"), - TextField("body"), - TextField("parent"), - ) - ) - - # Indexing a document - client.hset( - "search", - mapping={ - "title": "RediSearch", - "body": "Redisearch impements a search engine on top of redis", - "parent": "redis", - "random_num": 10, - }, - ) - client.hset( - "ai", - mapping={ - "title": "RedisAI", - "body": "RedisAI executes Deep Learning/Machine Learning models and managing their data.", # noqa - "parent": "redis", - "random_num": 3, - }, - ) - client.hset( - "json", - mapping={ - "title": "RedisJson", - "body": "RedisJSON implements ECMA-404 The JSON Data Interchange Standard as a native data type.", # noqa - "parent": "redis", - "random_num": 8, - }, - ) + res = client.ft().search("@t:{HELLO}") - if is_resp2_connection(client): - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.count() - ) + if is_resp2_connection(client): + assert 2 == len(res.docs) + assert "1" == res.docs[0].id + assert "2" == res.docs[1].id + else: + assert 2 == len(res["results"]) + assert "1" == res["results"][0]["id"] + assert "2" == res["results"][1]["id"] - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + # create casesensitive index + client.ft().dropindex() + SCHEMA = (TagField("t", case_sensitive=True),) + client.ft().create_index(SCHEMA) + self.waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.count_distinct("@title") + res = client.ft().search("@t:{HELLO}") + if is_resp2_connection(client): + assert 1 == len(res.docs) + assert "1" == res.docs[0].id + else: + assert 1 == len(res["results"]) + assert "1" == res["results"][0]["id"] + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.2.0", "search") + def test_search_return_fields(self, client): + res = client.json().set( + "doc:1", + Path.root_path(), + {"t": "riceratops", "t2": "telmatosaurus", "n": 9072, "flt": 97.2}, ) + assert res - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + # create index on + definition = IndexDefinition(index_type=IndexType.JSON) + SCHEMA = (TextField("$.t"), NumericField("$.flt")) + client.ft().create_index(SCHEMA, definition=definition) + self.waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.count_distinctish("@title") - ) + if is_resp2_connection(client): + total = ( + client.ft().search(Query("*").return_field("$.t", as_field="txt")).docs + ) + assert 1 == len(total) + assert "doc:1" == total[0].id + assert "riceratops" == total[0].txt - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + total = ( + client.ft().search(Query("*").return_field("$.t2", as_field="txt")).docs + ) + assert 1 == len(total) + assert "doc:1" == total[0].id + assert "telmatosaurus" == total[0].txt + else: + total = client.ft().search(Query("*").return_field("$.t", as_field="txt")) + assert 1 == len(total["results"]) + assert "doc:1" == total["results"][0]["id"] + assert "riceratops" == total["results"][0]["extra_attributes"]["txt"] + + total = client.ft().search(Query("*").return_field("$.t2", as_field="txt")) + assert 1 == len(total["results"]) + assert "doc:1" == total["results"][0]["id"] + assert "telmatosaurus" == total["results"][0]["extra_attributes"]["txt"] + + @pytest.mark.redismod + @skip_if_resp_version(3) + def test_binary_and_text_fields(self, client): + fake_vec = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32) + + index_name = "mixed_index" + mixed_data = {"first_name": "🐍python", "vector_emb": fake_vec.tobytes()} + client.hset(f"{index_name}:1", mapping=mixed_data) + + schema = ( + TagField("first_name"), + VectorField( + "embeddings_bio", + algorithm="HNSW", + attributes={ + "TYPE": "FLOAT32", + "DIM": 4, + "DISTANCE_METRIC": "COSINE", + }, + ), + ) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.sum("@random_num") + client.ft(index_name).create_index( + fields=schema, + definition=IndexDefinition( + prefix=[f"{index_name}:"], index_type=IndexType.HASH + ), ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "21" # 10+8+3 + self.waitForIndex(client, index_name) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.min("@random_num") + query = ( + Query("*") + .return_field("vector_emb", decode_field=False) + .return_field("first_name") + ) + docs = client.ft(index_name).search(query=query, query_params={}).docs + decoded_vec_from_search_results = np.frombuffer( + docs[0]["vector_emb"], dtype=np.float32 ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3" # min(10,8,3) - - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.max("@random_num") + assert np.array_equal(decoded_vec_from_search_results, fake_vec), ( + "The vectors are not equal" ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "10" # max(10,8,3) + assert docs[0]["first_name"] == mixed_data["first_name"], ( + "The text field is not decoded correctly" + ) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.avg("@random_num") + @pytest.mark.redismod + def test_synupdate(self, client): + definition = IndexDefinition(index_type=IndexType.HASH) + client.ft().create_index( + (TextField("title"), TextField("body")), definition=definition ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - index = res.index("__generated_aliasavgrandom_num") - assert res[index + 1] == "7" # (10+3+8)/3 + client.ft().synupdate("id1", True, "boy", "child", "offspring") + client.hset("doc1", mapping={"title": "he is a baby", "body": "this is a test"}) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.stddev("random_num") + client.ft().synupdate("id1", True, "baby") + client.hset( + "doc2", mapping={"title": "he is another baby", "body": "another test"} ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3.60555127546" + res = client.ft().search(Query("child").expander("SYNONYM")) + if is_resp2_connection(client): + assert res.docs[0].id == "doc2" + assert res.docs[0].title == "he is another baby" + assert res.docs[0].body == "another test" + else: + assert res["results"][0]["id"] == "doc2" + assert ( + res["results"][0]["extra_attributes"]["title"] == "he is another baby" + ) + assert res["results"][0]["extra_attributes"]["body"] == "another test" - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.quantile("@random_num", 0.5) + @pytest.mark.redismod + def test_syndump(self, client): + definition = IndexDefinition(index_type=IndexType.HASH) + client.ft().create_index( + (TextField("title"), TextField("body")), definition=definition ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "8" # median of 3,8,10 + client.ft().synupdate("id1", False, "boy", "child", "offspring") + client.ft().synupdate("id2", False, "baby", "child") + client.ft().synupdate("id3", False, "tree", "wood") + res = client.ft().syndump() + assert res == { + "boy": ["id1"], + "tree": ["id3"], + "wood": ["id3"], + "child": ["id1", "id2"], + "baby": ["id2"], + "offspring": ["id1"], + } - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.tolist("@title") - ) + @pytest.mark.redismod + def test_expire_while_search(self, client: redis.Redis): + client.ft().create_index((TextField("txt"),)) + client.hset("hset:1", "txt", "a") + client.hset("hset:2", "txt", "b") + client.hset("hset:3", "txt", "c") + if is_resp2_connection(client): + assert 3 == client.ft().search(Query("*")).total + client.pexpire("hset:2", 300) + for _ in range(500): + client.ft().search(Query("*")).docs[1] + time.sleep(1) + assert 2 == client.ft().search(Query("*")).total + else: + assert 3 == client.ft().search(Query("*"))["total_results"] + client.pexpire("hset:2", 300) + for _ in range(500): + client.ft().search(Query("*"))["results"][1] + time.sleep(1) + assert 2 == client.ft().search(Query("*"))["total_results"] + + @pytest.mark.redismod + @pytest.mark.experimental + def test_withsuffixtrie(self, client: redis.Redis): + # create index + assert client.ft().create_index((TextField("txt"),)) + self.waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + if is_resp2_connection(client): + info = client.ft().info() + assert "WITHSUFFIXTRIE" not in info["attributes"][0] + assert client.ft().dropindex() + + # create withsuffixtrie index (text fields) + assert client.ft().create_index(TextField("t", withsuffixtrie=True)) + self.waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + info = client.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0] + assert client.ft().dropindex() + + # create withsuffixtrie index (tag field) + assert client.ft().create_index(TagField("t", withsuffixtrie=True)) + self.waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + info = client.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0] + else: + info = client.ft().info() + assert "WITHSUFFIXTRIE" not in info["attributes"][0]["flags"] + assert client.ft().dropindex() + + # create withsuffixtrie index (text fields) + assert client.ft().create_index(TextField("t", withsuffixtrie=True)) + self.waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + info = client.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] + assert client.ft().dropindex() + + # create withsuffixtrie index (tag field) + assert client.ft().create_index(TagField("t", withsuffixtrie=True)) + self.waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + info = client.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] + + @pytest.mark.redismod + def test_query_timeout(self, r: redis.Redis): + q1 = Query("foo").timeout(5000) + assert q1.get_args() == ["foo", "TIMEOUT", 5000, "DIALECT", 2, "LIMIT", 0, 10] + q1 = Query("foo").timeout(0) + assert q1.get_args() == ["foo", "TIMEOUT", 0, "DIALECT", 2, "LIMIT", 0, 10] + q2 = Query("foo").timeout("not_a_number") + with pytest.raises(redis.ResponseError): + r.ft().search(q2) + + @pytest.mark.redismod + @skip_if_server_version_lt("7.2.0") + @skip_ifmodversion_lt("2.8.4", "search") + def test_geoshape(self, client: redis.Redis): + client.ft().create_index(GeoShapeField("geom", GeoShapeField.FLAT)) + self.waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + client.hset("small", "geom", "POLYGON((1 1, 1 100, 100 100, 100 1, 1 1))") + client.hset("large", "geom", "POLYGON((1 1, 1 200, 200 200, 200 1, 1 1))") + q1 = Query("@geom:[WITHIN $poly]").dialect(3) + qp1 = {"poly": "POLYGON((0 0, 0 150, 150 150, 150 0, 0 0))"} + q2 = Query("@geom:[CONTAINS $poly]").dialect(3) + qp2 = {"poly": "POLYGON((2 2, 2 50, 50 50, 50 2, 2 2))"} + result = client.ft().search(q1, query_params=qp1) + _assert_search_result(client, result, ["small"]) + result = client.ft().search(q2, query_params=qp2) + _assert_search_result(client, result, ["small", "large"]) + + @pytest.mark.redismod + @skip_if_server_version_lt("7.4.0") + @skip_ifmodversion_lt("2.10.0", "search") + def test_search_missing_fields(self, client): + definition = IndexDefinition(prefix=["property:"], index_type=IndexType.HASH) + + fields = [ + TextField("title", sortable=True), + TagField("features", index_missing=True), + TextField("description", index_missing=True), + ] - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"} + client.ft().create_index(fields, definition=definition) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.first_value("@title").alias("first") + # All fields present + client.hset( + "property:1", + mapping={ + "title": "Luxury Villa in Malibu", + "features": "pool,sea view,modern", + "description": "A stunning modern villa overlooking the Pacific Ocean.", + }, ) - res = client.ft().aggregate(req).rows[0] - assert res == ["parent", "redis", "first", "RediSearch"] - - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.random_sample("@title", 2).alias("random") + # Missing features + client.hset( + "property:2", + mapping={ + "title": "Downtown Flat", + "description": "Modern flat in central Paris with easy access to metro.", + }, ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[2] == "random" - assert len(res[3]) == 2 - assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"] - else: - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.count() + # Missing description + client.hset( + "property:3", + mapping={ + "title": "Beachfront Bungalow", + "features": "beachfront,sun deck", + }, ) - res = client.ft().aggregate(req)["results"][0] - assert res["extra_attributes"]["parent"] == "redis" - assert res["extra_attributes"]["__generated_aliascount"] == "3" + with pytest.raises(redis.exceptions.ResponseError): + client.ft().search( + Query("ismissing(@title)").return_field("id").no_content() + ) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.count_distinct("@title") + res = client.ft().search( + Query("ismissing(@features)").return_field("id").no_content() ) + _assert_search_result(client, res, ["property:2"]) - res = client.ft().aggregate(req)["results"][0] - assert res["extra_attributes"]["parent"] == "redis" - assert res["extra_attributes"]["__generated_aliascount_distincttitle"] == "3" - - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.count_distinctish("@title") + res = client.ft().search( + Query("-ismissing(@features)").return_field("id").no_content() ) + _assert_search_result(client, res, ["property:1", "property:3"]) - res = client.ft().aggregate(req)["results"][0] - assert res["extra_attributes"]["parent"] == "redis" - assert res["extra_attributes"]["__generated_aliascount_distinctishtitle"] == "3" + res = client.ft().search( + Query("ismissing(@description)").return_field("id").no_content() + ) + _assert_search_result(client, res, ["property:3"]) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.sum("@random_num") + res = client.ft().search( + Query("-ismissing(@description)").return_field("id").no_content() ) + _assert_search_result(client, res, ["property:1", "property:2"]) + + @pytest.mark.redismod + @skip_if_server_version_lt("7.4.0") + @skip_ifmodversion_lt("2.10.0", "search") + def test_create_index_empty_or_missing_fields_with_sortable(self, client): + definition = IndexDefinition(prefix=["property:"], index_type=IndexType.HASH) + + fields = [ + TextField("title", sortable=True, index_empty=True), + TagField("features", index_missing=True, sortable=True), + TextField("description", no_index=True, sortable=True), + ] - res = client.ft().aggregate(req)["results"][0] - assert res["extra_attributes"]["parent"] == "redis" - assert res["extra_attributes"]["__generated_aliassumrandom_num"] == "21" + client.ft().create_index(fields, definition=definition) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.min("@random_num") - ) + @pytest.mark.redismod + @skip_if_server_version_lt("7.4.0") + @skip_ifmodversion_lt("2.10.0", "search") + def test_search_empty_fields(self, client): + definition = IndexDefinition(prefix=["property:"], index_type=IndexType.HASH) + + fields = [ + TextField("title", sortable=True), + TagField("features", index_empty=True), + TextField("description", index_empty=True), + ] - res = client.ft().aggregate(req)["results"][0] - assert res["extra_attributes"]["parent"] == "redis" - assert res["extra_attributes"]["__generated_aliasminrandom_num"] == "3" + client.ft().create_index(fields, definition=definition) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.max("@random_num") + # All fields present + client.hset( + "property:1", + mapping={ + "title": "Luxury Villa in Malibu", + "features": "pool,sea view,modern", + "description": "A stunning modern villa overlooking the Pacific Ocean.", + }, ) - res = client.ft().aggregate(req)["results"][0] - assert res["extra_attributes"]["parent"] == "redis" - assert res["extra_attributes"]["__generated_aliasmaxrandom_num"] == "10" + # Empty features + client.hset( + "property:2", + mapping={ + "title": "Downtown Flat", + "features": "", + "description": "Modern flat in central Paris with easy access to metro.", + }, + ) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.avg("@random_num") + # Empty description + client.hset( + "property:3", + mapping={ + "title": "Beachfront Bungalow", + "features": "beachfront,sun deck", + "description": "", + }, ) - res = client.ft().aggregate(req)["results"][0] - assert res["extra_attributes"]["parent"] == "redis" - assert res["extra_attributes"]["__generated_aliasavgrandom_num"] == "7" + with pytest.raises(redis.exceptions.ResponseError) as e: + client.ft().search(Query("@title:''").return_field("id").no_content()) + assert "Use `INDEXEMPTY` in field creation" in e.value.args[0] - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.stddev("random_num") + res = client.ft().search( + Query("@features:{$empty}").return_field("id").no_content(), + query_params={"empty": ""}, ) + _assert_search_result(client, res, ["property:2"]) - res = client.ft().aggregate(req)["results"][0] - assert res["extra_attributes"]["parent"] == "redis" - assert ( - res["extra_attributes"]["__generated_aliasstddevrandom_num"] - == "3.60555127546" + res = client.ft().search( + Query("-@features:{$empty}").return_field("id").no_content(), + query_params={"empty": ""}, ) + _assert_search_result(client, res, ["property:1", "property:3"]) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.quantile("@random_num", 0.5) + res = client.ft().search( + Query("@description:''").return_field("id").no_content() ) + _assert_search_result(client, res, ["property:3"]) - res = client.ft().aggregate(req)["results"][0] - assert res["extra_attributes"]["parent"] == "redis" - assert res["extra_attributes"]["__generated_aliasquantilerandom_num,0.5"] == "8" - - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.tolist("@title") + res = client.ft().search( + Query("-@description:''").return_field("id").no_content() ) + _assert_search_result(client, res, ["property:1", "property:2"]) + + @pytest.mark.redismod + @skip_if_server_version_lt("7.4.0") + @skip_ifmodversion_lt("2.10.0", "search") + def test_special_characters_in_fields(self, client): + definition = IndexDefinition(prefix=["resource:"], index_type=IndexType.HASH) + + fields = [ + TagField("uuid"), + TagField("tags", separator="|"), + TextField("description"), + NumericField("rating"), + ] - res = client.ft().aggregate(req)["results"][0] - assert res["extra_attributes"]["parent"] == "redis" - assert set(res["extra_attributes"]["__generated_aliastolisttitle"]) == { - "RediSearch", - "RedisAI", - "RedisJson", - } + client.ft().create_index(fields, definition=definition) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.first_value("@title").alias("first") + client.hset( + "resource:1", + mapping={ + "uuid": "123e4567-e89b-12d3-a456-426614174000", + "tags": "finance|crypto|$btc|blockchain", + "description": "Analysis of blockchain technologies & Bitcoin's potential.", + "rating": 5, + }, ) - res = client.ft().aggregate(req)["results"][0] - assert res["extra_attributes"] == {"parent": "redis", "first": "RediSearch"} + client.hset( + "resource:2", + mapping={ + "uuid": "987e6543-e21c-12d3-a456-426614174999", + "tags": "health|well-being|fitness|new-year's-resolutions", + "description": "Health trends for the new year, including fitness regimes.", + "rating": 4, + }, + ) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.random_sample("@title", 2).alias("random") + # no need to escape - when using params + res = client.ft().search( + Query("@uuid:{$uuid}"), + query_params={"uuid": "123e4567-e89b-12d3-a456-426614174000"}, ) + _assert_search_result(client, res, ["resource:1"]) - res = client.ft().aggregate(req)["results"][0] - assert res["extra_attributes"]["parent"] == "redis" - assert "random" in res["extra_attributes"].keys() - assert len(res["extra_attributes"]["random"]) == 2 - assert res["extra_attributes"]["random"][0] in [ - "RediSearch", - "RedisAI", - "RedisJson", - ] + # with double quotes exact match no need to escape the - even without params + res = client.ft().search( + Query('@uuid:{"123e4567-e89b-12d3-a456-426614174000"}') + ) + _assert_search_result(client, res, ["resource:1"]) + + res = client.ft().search(Query('@tags:{"new-year\'s-resolutions"}')) + _assert_search_result(client, res, ["resource:2"]) + + # possible to search numeric fields by single value + res = client.ft().search(Query("@rating:[4]")) + _assert_search_result(client, res, ["resource:2"]) + + # some chars still need escaping + res = client.ft().search(Query(r"@tags:{\$btc}")) + _assert_search_result(client, res, ["resource:1"]) + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + def test_vector_search_with_default_dialect(self, client): + client.ft().create_index( + ( + VectorField( + "v", "HNSW", {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"} + ), + ) + ) + client.hset("a", "v", "aaaaaaaa") + client.hset("b", "v", "aaaabaaa") + client.hset("c", "v", "aaaaabaa") -@pytest.mark.redismod -def test_aggregations_sort_by_and_limit(client): - client.ft().create_index((TextField("t1"), TextField("t2"))) + query = "*=>[KNN 2 @v $vec]" + q = Query(query) - client.ft().client.hset("doc1", mapping={"t1": "a", "t2": "b"}) - client.ft().client.hset("doc2", mapping={"t1": "b", "t2": "a"}) + assert "DIALECT" in q.get_args() + assert 2 in q.get_args() - if is_resp2_connection(client): - # test sort_by using SortDirection - req = aggregations.AggregateRequest("*").sort_by( - aggregations.Asc("@t2"), aggregations.Desc("@t1") + res = client.ft().search(q, query_params={"vec": "aaaaaaaa"}) + if is_resp2_connection(client): + assert res.total == 2 + else: + assert res["total_results"] == 2 + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + def test_search_query_with_different_dialects(self, client): + client.ft().create_index( + (TextField("name"), TextField("lastname")), + definition=IndexDefinition(prefix=["test:"]), ) - res = client.ft().aggregate(req) - assert res.rows[0] == ["t2", "a", "t1", "b"] - assert res.rows[1] == ["t2", "b", "t1", "a"] - # test sort_by without SortDirection - req = aggregations.AggregateRequest("*").sort_by("@t1") - res = client.ft().aggregate(req) - assert res.rows[0] == ["t1", "a"] - assert res.rows[1] == ["t1", "b"] + client.hset("test:1", "name", "James") + client.hset("test:1", "lastname", "Brown") - # test sort_by with max - req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) - res = client.ft().aggregate(req) - assert len(res.rows) == 1 + # Query with default DIALECT 2 + query = "@name: James Brown" + q = Query(query) + res = client.ft().search(q) + if is_resp2_connection(client): + assert res.total == 1 + else: + assert res["total_results"] == 1 - # test limit - req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) - res = client.ft().aggregate(req) - assert len(res.rows) == 1 - assert res.rows[0] == ["t1", "b"] - else: - # test sort_by using SortDirection - req = aggregations.AggregateRequest("*").sort_by( - aggregations.Asc("@t2"), aggregations.Desc("@t1") - ) - res = client.ft().aggregate(req)["results"] - assert res[0]["extra_attributes"] == {"t2": "a", "t1": "b"} - assert res[1]["extra_attributes"] == {"t2": "b", "t1": "a"} - - # test sort_by without SortDirection - req = aggregations.AggregateRequest("*").sort_by("@t1") - res = client.ft().aggregate(req)["results"] - assert res[0]["extra_attributes"] == {"t1": "a"} - assert res[1]["extra_attributes"] == {"t1": "b"} - - # test sort_by with max - req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) - res = client.ft().aggregate(req) - assert len(res["results"]) == 1 + # Query with explicit DIALECT 1 + query = "@name: James Brown" + q = Query(query).dialect(1) + res = client.ft().search(q) + if is_resp2_connection(client): + assert res.total == 0 + else: + assert res["total_results"] == 0 - # test limit - req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) - res = client.ft().aggregate(req) - assert len(res["results"]) == 1 - assert res["results"][0]["extra_attributes"] == {"t1": "b"} + @pytest.mark.redismod + @skip_if_server_version_lt("7.9.0") + def test_info_exposes_search_info(self, client): + assert len(client.info("search")) > 0 -@pytest.mark.redismod -def test_aggregations_load(client): - client.ft().create_index((TextField("t1"), TextField("t2"))) +class TestScorers(SearchTestsBase): + @pytest.mark.redismod + @pytest.mark.onlynoncluster + # NOTE(imalinovskyi): This test contains hardcoded scores valid only for RediSearch 2.8+ + @skip_ifmodversion_lt("2.8.0", "search") + @skip_if_server_version_gte("7.9.0") + def test_scorer(self, client): + client.ft().create_index((TextField("description"),)) - client.ft().client.hset("doc1", mapping={"t1": "hello", "t2": "world"}) + client.hset( + "doc1", + mapping={"description": "The quick brown fox jumps over the lazy dog"}, + ) + client.hset( + "doc2", + mapping={ + "description": "Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do." # noqa + }, + ) - if is_resp2_connection(client): - # load t1 - req = aggregations.AggregateRequest("*").load("t1") - res = client.ft().aggregate(req) - assert res.rows[0] == ["t1", "hello"] + # default scorer is TFIDF + if is_resp2_connection(client): + res = client.ft().search(Query("quick").with_scores()) + assert 1.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) + assert 1.0 == res.docs[0].score + res = client.ft().search( + Query("quick").scorer("TFIDF.DOCNORM").with_scores() + ) + assert 0.14285714285714285 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("BM25").with_scores()) + assert 0.22471909420069797 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) + assert 2.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) + assert 1.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("HAMMING").with_scores()) + assert 0.0 == res.docs[0].score + else: + res = client.ft().search(Query("quick").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = client.ft().search( + Query("quick").scorer("TFIDF.DOCNORM").with_scores() + ) + assert 0.14285714285714285 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("BM25").with_scores()) + assert 0.22471909420069797 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) + assert 2.0 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("HAMMING").with_scores()) + assert 0.0 == res["results"][0]["score"] + + @pytest.mark.redismod + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.9.0") + def test_scorer_with_new_default_scorer(self, client): + client.ft().create_index((TextField("description"),)) - # load t2 - req = aggregations.AggregateRequest("*").load("t2") - res = client.ft().aggregate(req) - assert res.rows[0] == ["t2", "world"] + client.hset( + "doc1", + mapping={"description": "The quick brown fox jumps over the lazy dog"}, + ) + client.hset( + "doc2", + mapping={ + "description": "Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do." # noqa + }, + ) - # load all - req = aggregations.AggregateRequest("*").load() - res = client.ft().aggregate(req) - assert res.rows[0] == ["t1", "hello", "t2", "world"] - else: - # load t1 - req = aggregations.AggregateRequest("*").load("t1") - res = client.ft().aggregate(req) - assert res["results"][0]["extra_attributes"] == {"t1": "hello"} + # default scorer is BM25STD + if is_resp2_connection(client): + res = client.ft().search(Query("quick").with_scores()) + assert 0.23 == pytest.approx(res.docs[0].score, 0.05) + res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) + assert 1.0 == res.docs[0].score + res = client.ft().search( + Query("quick").scorer("TFIDF.DOCNORM").with_scores() + ) + assert 0.14285714285714285 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("BM25").with_scores()) + assert 0.22471909420069797 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) + assert 2.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) + assert 1.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("HAMMING").with_scores()) + assert 0.0 == res.docs[0].score + else: + res = client.ft().search(Query("quick").with_scores()) + assert 0.23 == pytest.approx(res["results"][0]["score"], 0.05) + res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = client.ft().search( + Query("quick").scorer("TFIDF.DOCNORM").with_scores() + ) + assert 0.14285714285714285 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("BM25").with_scores()) + assert 0.22471909420069797 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) + assert 2.0 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("HAMMING").with_scores()) + assert 0.0 == res["results"][0]["score"] + + +class TestConfig(SearchTestsBase): + @pytest.mark.redismod + @pytest.mark.onlynoncluster + @skip_ifmodversion_lt("2.2.0", "search") + @skip_if_server_version_gte("7.9.0") + def test_config(self, client): + assert client.ft().config_set("TIMEOUT", "100") + with pytest.raises(redis.ResponseError): + client.ft().config_set("TIMEOUT", "null") + res = client.ft().config_get("*") + assert "100" == res["TIMEOUT"] + res = client.ft().config_get("TIMEOUT") + assert "100" == res["TIMEOUT"] + + @pytest.mark.redismod + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.9.0") + def test_config_with_removed_ftconfig(self, client): + assert client.config_set("timeout", "100") + with pytest.raises(redis.ResponseError): + client.config_set("timeout", "null") + res = client.config_get("*") + assert "100" == res["timeout"] + res = client.config_get("timeout") + assert "100" == res["timeout"] + + @pytest.mark.redismod + @pytest.mark.onlynoncluster + @skip_ifmodversion_lt("2.4.3", "search") + def test_dialect_config(self, client): + assert client.ft().config_get("DEFAULT_DIALECT") + client.ft().config_set("DEFAULT_DIALECT", 2) + assert client.ft().config_get("DEFAULT_DIALECT") == {"DEFAULT_DIALECT": "2"} + with pytest.raises(redis.ResponseError): + client.ft().config_set("DEFAULT_DIALECT", 0) + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + def test_dialect(self, client): + client.ft().create_index( + ( + TagField("title"), + TextField("t1"), + TextField("t2"), + NumericField("num"), + VectorField( + "v", + "HNSW", + {"TYPE": "FLOAT32", "DIM": 1, "DISTANCE_METRIC": "COSINE"}, + ), + ) + ) + client.hset("h", "t1", "hello") + with pytest.raises(redis.ResponseError) as err: + client.ft().explain(Query("(*)").dialect(1)) + assert "Syntax error" in str(err) + assert "WILDCARD" in client.ft().explain(Query("(*)")) + + with pytest.raises(redis.ResponseError) as err: + client.ft().explain(Query("$hello").dialect(1)) + assert "Syntax error" in str(err) + q = Query("$hello") + expected = "UNION {\n hello\n +hello(expanded)\n}\n" + assert expected in client.ft().explain(q, query_params={"hello": "hello"}) + + expected = "NUMERIC {0.000000 <= @num <= 10.000000}\n" + assert expected in client.ft().explain(Query("@title:(@num:[0 10])").dialect(1)) + with pytest.raises(redis.ResponseError) as err: + client.ft().explain(Query("@title:(@num:[0 10])")) + assert "Syntax error" in str(err) + + +class TestAggregations(SearchTestsBase): + @pytest.mark.redismod + @pytest.mark.onlynoncluster + def test_aggregations_groupby(self, client): + # Creating the index definition and schema + client.ft().create_index( + ( + NumericField("random_num"), + TextField("title"), + TextField("body"), + TextField("parent"), + ) + ) - # load t2 - req = aggregations.AggregateRequest("*").load("t2") - res = client.ft().aggregate(req) - assert res["results"][0]["extra_attributes"] == {"t2": "world"} + # Indexing a document + client.hset( + "search", + mapping={ + "title": "RediSearch", + "body": "Redisearch impements a search engine on top of redis", + "parent": "redis", + "random_num": 10, + }, + ) + client.hset( + "ai", + mapping={ + "title": "RedisAI", + "body": "RedisAI executes Deep Learning/Machine Learning models and managing their data.", # noqa + "parent": "redis", + "random_num": 3, + }, + ) + client.hset( + "json", + mapping={ + "title": "RedisJson", + "body": "RedisJSON implements ECMA-404 The JSON Data Interchange Standard as a native data type.", # noqa + "parent": "redis", + "random_num": 8, + }, + ) - # load all - req = aggregations.AggregateRequest("*").load() - res = client.ft().aggregate(req) - assert res["results"][0]["extra_attributes"] == {"t1": "hello", "t2": "world"} - - -@pytest.mark.redismod -def test_aggregations_apply(client): - client.ft().create_index( - ( - TextField("PrimaryKey", sortable=True), - NumericField("CreatedDateTimeUTC", sortable=True), - ) - ) - - client.ft().client.hset( - "doc1", - mapping={"PrimaryKey": "9::362330", "CreatedDateTimeUTC": "637387878524969984"}, - ) - client.ft().client.hset( - "doc2", - mapping={"PrimaryKey": "9::362329", "CreatedDateTimeUTC": "637387875859270016"}, - ) - - req = aggregations.AggregateRequest("*").apply( - CreatedDateTimeUTC="@CreatedDateTimeUTC * 10" - ) - res = client.ft().aggregate(req) - if is_resp2_connection(client): - res_set = {res.rows[0][1], res.rows[1][1]} - assert res_set == {"6373878785249699840", "6373878758592700416"} - else: - res_set = { - res["results"][0]["extra_attributes"]["CreatedDateTimeUTC"], - res["results"][1]["extra_attributes"]["CreatedDateTimeUTC"], - } - assert res_set == {"6373878785249699840", "6373878758592700416"} + if is_resp2_connection(client): + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count() + ) + + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3" + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count_distinct("@title") + ) -@pytest.mark.redismod -def test_aggregations_filter(client): - client.ft().create_index( - (TextField("name", sortable=True), NumericField("age", sortable=True)) - ) + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3" - client.ft().client.hset("doc1", mapping={"name": "bar", "age": "25"}) - client.ft().client.hset("doc2", mapping={"name": "foo", "age": "19"}) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count_distinctish("@title") + ) + + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.sum("@random_num") + ) + + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "21" # 10+8+3 + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.min("@random_num") + ) + + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3" # min(10,8,3) + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.max("@random_num") + ) + + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "10" # max(10,8,3) + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.avg("@random_num") + ) + + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + index = res.index("__generated_aliasavgrandom_num") + assert res[index + 1] == "7" # (10+3+8)/3 + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.stddev("random_num") + ) + + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3.60555127546" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.quantile("@random_num", 0.5) + ) + + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "8" # median of 3,8,10 + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.tolist("@title") + ) + + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"} + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.first_value("@title").alias("first") + ) + + res = client.ft().aggregate(req).rows[0] + assert res == ["parent", "redis", "first", "RediSearch"] + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.random_sample("@title", 2).alias("random") + ) + + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[2] == "random" + assert len(res[3]) == 2 + assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"] + else: + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count() + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliascount"] == "3" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count_distinct("@title") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert ( + res["extra_attributes"]["__generated_aliascount_distincttitle"] == "3" + ) + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count_distinctish("@title") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert ( + res["extra_attributes"]["__generated_aliascount_distinctishtitle"] + == "3" + ) + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.sum("@random_num") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliassumrandom_num"] == "21" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.min("@random_num") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliasminrandom_num"] == "3" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.max("@random_num") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliasmaxrandom_num"] == "10" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.avg("@random_num") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliasavgrandom_num"] == "7" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.stddev("random_num") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert ( + res["extra_attributes"]["__generated_aliasstddevrandom_num"] + == "3.60555127546" + ) + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.quantile("@random_num", 0.5) + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert ( + res["extra_attributes"]["__generated_aliasquantilerandom_num,0.5"] + == "8" + ) + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.tolist("@title") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert set(res["extra_attributes"]["__generated_aliastolisttitle"]) == { + "RediSearch", + "RedisAI", + "RedisJson", + } + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.first_value("@title").alias("first") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"] == {"parent": "redis", "first": "RediSearch"} + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.random_sample("@title", 2).alias("random") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert "random" in res["extra_attributes"].keys() + assert len(res["extra_attributes"]["random"]) == 2 + assert res["extra_attributes"]["random"][0] in [ + "RediSearch", + "RedisAI", + "RedisJson", + ] + + @pytest.mark.redismod + def test_aggregations_sort_by_and_limit(self, client): + client.ft().create_index((TextField("t1"), TextField("t2"))) + + client.ft().client.hset("doc1", mapping={"t1": "a", "t2": "b"}) + client.ft().client.hset("doc2", mapping={"t1": "b", "t2": "a"}) - for dialect in [1, 2]: - req = ( - aggregations.AggregateRequest("*") - .filter("@name=='foo' && @age < 20") - .dialect(dialect) - ) - res = client.ft().aggregate(req) if is_resp2_connection(client): + # test sort_by using SortDirection + req = aggregations.AggregateRequest("*").sort_by( + aggregations.Asc("@t2"), aggregations.Desc("@t1") + ) + res = client.ft().aggregate(req) + assert res.rows[0] == ["t2", "a", "t1", "b"] + assert res.rows[1] == ["t2", "b", "t1", "a"] + + # test sort_by without SortDirection + req = aggregations.AggregateRequest("*").sort_by("@t1") + res = client.ft().aggregate(req) + assert res.rows[0] == ["t1", "a"] + assert res.rows[1] == ["t1", "b"] + + # test sort_by with max + req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) + res = client.ft().aggregate(req) assert len(res.rows) == 1 - assert res.rows[0] == ["name", "foo", "age", "19"] - req = ( - aggregations.AggregateRequest("*") - .filter("@age > 15") - .sort_by("@age") - .dialect(dialect) - ) + # test limit + req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) res = client.ft().aggregate(req) - assert len(res.rows) == 2 - assert res.rows[0] == ["age", "19"] - assert res.rows[1] == ["age", "25"] + assert len(res.rows) == 1 + assert res.rows[0] == ["t1", "b"] else: + # test sort_by using SortDirection + req = aggregations.AggregateRequest("*").sort_by( + aggregations.Asc("@t2"), aggregations.Desc("@t1") + ) + res = client.ft().aggregate(req)["results"] + assert res[0]["extra_attributes"] == {"t2": "a", "t1": "b"} + assert res[1]["extra_attributes"] == {"t2": "b", "t1": "a"} + + # test sort_by without SortDirection + req = aggregations.AggregateRequest("*").sort_by("@t1") + res = client.ft().aggregate(req)["results"] + assert res[0]["extra_attributes"] == {"t1": "a"} + assert res[1]["extra_attributes"] == {"t1": "b"} + + # test sort_by with max + req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) + res = client.ft().aggregate(req) assert len(res["results"]) == 1 - assert res["results"][0]["extra_attributes"] == {"name": "foo", "age": "19"} + # test limit + req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) + res = client.ft().aggregate(req) + assert len(res["results"]) == 1 + assert res["results"][0]["extra_attributes"] == {"t1": "b"} + + @pytest.mark.redismod + def test_aggregations_load(self, client): + client.ft().create_index((TextField("t1"), TextField("t2"))) + + client.ft().client.hset("doc1", mapping={"t1": "hello", "t2": "world"}) + + if is_resp2_connection(client): + # load t1 + req = aggregations.AggregateRequest("*").load("t1") + res = client.ft().aggregate(req) + assert res.rows[0] == ["t1", "hello"] + + # load t2 + req = aggregations.AggregateRequest("*").load("t2") + res = client.ft().aggregate(req) + assert res.rows[0] == ["t2", "world"] + + # load all + req = aggregations.AggregateRequest("*").load() + res = client.ft().aggregate(req) + assert res.rows[0] == ["t1", "hello", "t2", "world"] + else: + # load t1 + req = aggregations.AggregateRequest("*").load("t1") + res = client.ft().aggregate(req) + assert res["results"][0]["extra_attributes"] == {"t1": "hello"} + + # load t2 + req = aggregations.AggregateRequest("*").load("t2") + res = client.ft().aggregate(req) + assert res["results"][0]["extra_attributes"] == {"t2": "world"} + + # load all + req = aggregations.AggregateRequest("*").load() + res = client.ft().aggregate(req) + assert res["results"][0]["extra_attributes"] == { + "t1": "hello", + "t2": "world", + } + + @pytest.mark.redismod + def test_aggregations_apply(self, client): + client.ft().create_index( + ( + TextField("PrimaryKey", sortable=True), + NumericField("CreatedDateTimeUTC", sortable=True), + ) + ) + + client.ft().client.hset( + "doc1", + mapping={ + "PrimaryKey": "9::362330", + "CreatedDateTimeUTC": "637387878524969984", + }, + ) + client.ft().client.hset( + "doc2", + mapping={ + "PrimaryKey": "9::362329", + "CreatedDateTimeUTC": "637387875859270016", + }, + ) + + req = aggregations.AggregateRequest("*").apply( + CreatedDateTimeUTC="@CreatedDateTimeUTC * 10" + ) + res = client.ft().aggregate(req) + if is_resp2_connection(client): + res_set = {res.rows[0][1], res.rows[1][1]} + assert res_set == {"6373878785249699840", "6373878758592700416"} + else: + res_set = { + res["results"][0]["extra_attributes"]["CreatedDateTimeUTC"], + res["results"][1]["extra_attributes"]["CreatedDateTimeUTC"], + } + assert res_set == {"6373878785249699840", "6373878758592700416"} + + @pytest.mark.redismod + def test_aggregations_filter(self, client): + client.ft().create_index( + (TextField("name", sortable=True), NumericField("age", sortable=True)) + ) + + client.ft().client.hset("doc1", mapping={"name": "bar", "age": "25"}) + client.ft().client.hset("doc2", mapping={"name": "foo", "age": "19"}) + + for dialect in [1, 2]: req = ( aggregations.AggregateRequest("*") - .filter("@age > 15") - .sort_by("@age") + .filter("@name=='foo' && @age < 20") .dialect(dialect) ) res = client.ft().aggregate(req) + if is_resp2_connection(client): + assert len(res.rows) == 1 + assert res.rows[0] == ["name", "foo", "age", "19"] + + req = ( + aggregations.AggregateRequest("*") + .filter("@age > 15") + .sort_by("@age") + .dialect(dialect) + ) + res = client.ft().aggregate(req) + assert len(res.rows) == 2 + assert res.rows[0] == ["age", "19"] + assert res.rows[1] == ["age", "25"] + else: + assert len(res["results"]) == 1 + assert res["results"][0]["extra_attributes"] == { + "name": "foo", + "age": "19", + } + + req = ( + aggregations.AggregateRequest("*") + .filter("@age > 15") + .sort_by("@age") + .dialect(dialect) + ) + res = client.ft().aggregate(req) + assert len(res["results"]) == 2 + assert res["results"][0]["extra_attributes"] == {"age": "19"} + assert res["results"][1]["extra_attributes"] == {"age": "25"} + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.10.05", "search") + def test_aggregations_add_scores(self, client): + client.ft().create_index( + ( + TextField("name", sortable=True, weight=5.0), + NumericField("age", sortable=True), + ) + ) + + client.hset("doc1", mapping={"name": "bar", "age": "25"}) + client.hset("doc2", mapping={"name": "foo", "age": "19"}) + + req = aggregations.AggregateRequest("*").add_scores() + res = client.ft().aggregate(req) + + if isinstance(res, dict): + assert len(res["results"]) == 2 + assert res["results"][0]["extra_attributes"] == {"__score": "0.2"} + assert res["results"][1]["extra_attributes"] == {"__score": "0.2"} + else: + assert len(res.rows) == 2 + assert res.rows[0] == ["__score", "0.2"] + assert res.rows[1] == ["__score", "0.2"] + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.10.05", "search") + async def test_aggregations_hybrid_scoring(self, client): + client.ft().create_index( + ( + TextField("name", sortable=True, weight=5.0), + TextField("description", sortable=True, weight=5.0), + VectorField( + "vector", + "HNSW", + {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "COSINE"}, + ), + ) + ) + + client.hset( + "doc1", + mapping={ + "name": "cat book", + "description": "an animal book about cats", + "vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(), + }, + ) + client.hset( + "doc2", + mapping={ + "name": "dog book", + "description": "an animal book about dogs", + "vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(), + }, + ) + + query_string = "(@description:animal)=>[KNN 3 @vector $vec_param AS dist]" + req = ( + aggregations.AggregateRequest(query_string) + .scorer("BM25") + .add_scores() + .apply(hybrid_score="@__score + @dist") + .load("*") + .dialect(4) + ) + + res = client.ft().aggregate( + req, + query_params={ + "vec_param": np.array([0.11, 0.21]).astype(np.float32).tobytes() + }, + ) + + if isinstance(res, dict): assert len(res["results"]) == 2 - assert res["results"][0]["extra_attributes"] == {"age": "19"} - assert res["results"][1]["extra_attributes"] == {"age": "25"} + else: + assert len(res.rows) == 2 + for row in res.rows: + len(row) == 6 + + +class TestSearchWithJsonIndex(SearchTestsBase): + @pytest.mark.redismod + @skip_ifmodversion_lt("2.2.0", "search") + def test_create_json_with_alias(self, client): + """ + Create definition with IndexType.JSON as index type (ON JSON) with two + fields with aliases, and use json client to test it. + """ + definition = IndexDefinition(prefix=["king:"], index_type=IndexType.JSON) + client.ft().create_index( + (TextField("$.name", as_name="name"), NumericField("$.num", as_name="num")), + definition=definition, + ) + client.json().set("king:1", Path.root_path(), {"name": "henry", "num": 42}) + client.json().set("king:2", Path.root_path(), {"name": "james", "num": 3.14}) -@pytest.mark.redismod -@skip_ifmodversion_lt("2.10.05", "search") -def test_aggregations_add_scores(client): - client.ft().create_index( - ( - TextField("name", sortable=True, weight=5.0), - NumericField("age", sortable=True), + if is_resp2_connection(client): + res = client.ft().search("@name:henry") + assert res.docs[0].id == "king:1" + assert res.docs[0].json == '{"name":"henry","num":42}' + assert res.total == 1 + + res = client.ft().search("@num:[0 10]") + assert res.docs[0].id == "king:2" + assert res.docs[0].json == '{"name":"james","num":3.14}' + assert res.total == 1 + else: + res = client.ft().search("@name:henry") + assert res["results"][0]["id"] == "king:1" + assert ( + res["results"][0]["extra_attributes"]["$"] + == '{"name":"henry","num":42}' + ) + assert res["total_results"] == 1 + + res = client.ft().search("@num:[0 10]") + assert res["results"][0]["id"] == "king:2" + assert ( + res["results"][0]["extra_attributes"]["$"] + == '{"name":"james","num":3.14}' + ) + assert res["total_results"] == 1 + + # Tests returns an error if path contain special characters (user should + # use an alias) + with pytest.raises(Exception): + client.ft().search("@$.name:henry") + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.2.0", "search") + def test_json_with_multipath(self, client): + """ + Create definition with IndexType.JSON as index type (ON JSON), + and use json client to test it. + """ + definition = IndexDefinition(prefix=["king:"], index_type=IndexType.JSON) + client.ft().create_index( + (TagField("$..name", as_name="name")), definition=definition ) - ) - client.hset("doc1", mapping={"name": "bar", "age": "25"}) - client.hset("doc2", mapping={"name": "foo", "age": "19"}) + client.json().set( + "king:1", + Path.root_path(), + {"name": "henry", "country": {"name": "england"}}, + ) - req = aggregations.AggregateRequest("*").add_scores() - res = client.ft().aggregate(req) + if is_resp2_connection(client): + res = client.ft().search("@name:{henry}") + assert res.docs[0].id == "king:1" + assert res.docs[0].json == '{"name":"henry","country":{"name":"england"}}' + assert res.total == 1 + + res = client.ft().search("@name:{england}") + assert res.docs[0].id == "king:1" + assert res.docs[0].json == '{"name":"henry","country":{"name":"england"}}' + assert res.total == 1 + else: + res = client.ft().search("@name:{henry}") + assert res["results"][0]["id"] == "king:1" + assert ( + res["results"][0]["extra_attributes"]["$"] + == '{"name":"henry","country":{"name":"england"}}' + ) + assert res["total_results"] == 1 - if isinstance(res, dict): - assert len(res["results"]) == 2 - assert res["results"][0]["extra_attributes"] == {"__score": "0.2"} - assert res["results"][1]["extra_attributes"] == {"__score": "0.2"} - else: - assert len(res.rows) == 2 - assert res.rows[0] == ["__score", "0.2"] - assert res.rows[1] == ["__score", "0.2"] - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.10.05", "search") -async def test_aggregations_hybrid_scoring(client): - client.ft().create_index( - ( - TextField("name", sortable=True, weight=5.0), - TextField("description", sortable=True, weight=5.0), - VectorField( - "vector", - "HNSW", - {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "COSINE"}, + res = client.ft().search("@name:{england}") + assert res["results"][0]["id"] == "king:1" + assert ( + res["results"][0]["extra_attributes"]["$"] + == '{"name":"henry","country":{"name":"england"}}' + ) + assert res["total_results"] == 1 + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.2.0", "search") + def test_json_with_jsonpath(self, client): + definition = IndexDefinition(index_type=IndexType.JSON) + client.ft().create_index( + ( + TextField('$["prod:name"]', as_name="name"), + TextField("$.prod:name", as_name="name_unsupported"), ), + definition=definition, ) - ) - - client.hset( - "doc1", - mapping={ - "name": "cat book", - "description": "an animal book about cats", - "vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(), - }, - ) - client.hset( - "doc2", - mapping={ - "name": "dog book", - "description": "an animal book about dogs", - "vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(), - }, - ) - - query_string = "(@description:animal)=>[KNN 3 @vector $vec_param AS dist]" - req = ( - aggregations.AggregateRequest(query_string) - .scorer("BM25") - .add_scores() - .apply(hybrid_score="@__score + @dist") - .load("*") - .dialect(4) - ) - - res = client.ft().aggregate( - req, - query_params={"vec_param": np.array([0.11, 0.21]).astype(np.float32).tobytes()}, - ) - - if isinstance(res, dict): - assert len(res["results"]) == 2 - else: - assert len(res.rows) == 2 - for row in res.rows: - len(row) == 6 + client.json().set("doc:1", Path.root_path(), {"prod:name": "RediSearch"}) -@pytest.mark.redismod -@skip_ifmodversion_lt("2.0.0", "search") -def test_index_definition(client): - """ - Create definition and test its args - """ - with pytest.raises(RuntimeError): - IndexDefinition(prefix=["hset:", "henry"], index_type="json") - - definition = IndexDefinition( - prefix=["hset:", "henry"], - filter="@f1==32", - language="English", - language_field="play", - score_field="chapter", - score=0.5, - payload_field="txt", - index_type=IndexType.JSON, - ) - - assert [ - "ON", - "JSON", - "PREFIX", - 2, - "hset:", - "henry", - "FILTER", - "@f1==32", - "LANGUAGE_FIELD", - "play", - "LANGUAGE", - "English", - "SCORE_FIELD", - "chapter", - "SCORE", - 0.5, - "PAYLOAD_FIELD", - "txt", - ] == definition.args - - createIndex(client.ft(), num_docs=500, definition=definition) - - -@pytest.mark.redismod -@pytest.mark.onlynoncluster -@skip_if_redis_enterprise() -@skip_if_server_version_gte("7.9.0") -def test_expire(client): - client.ft().create_index((TextField("txt", sortable=True),), temporary=4) - ttl = client.execute_command("ft.debug", "TTL", "idx") - assert ttl > 2 - - while ttl > 2: - ttl = client.execute_command("ft.debug", "TTL", "idx") - time.sleep(0.01) + if is_resp2_connection(client): + # query for a supported field succeeds + res = client.ft().search(Query("@name:RediSearch")) + assert res.total == 1 + assert res.docs[0].id == "doc:1" + assert res.docs[0].json == '{"prod:name":"RediSearch"}' + + # query for an unsupported field + res = client.ft().search("@name_unsupported:RediSearch") + assert res.total == 1 + + # return of a supported field succeeds + res = client.ft().search(Query("@name:RediSearch").return_field("name")) + assert res.total == 1 + assert res.docs[0].id == "doc:1" + assert res.docs[0].name == "RediSearch" + else: + # query for a supported field succeeds + res = client.ft().search(Query("@name:RediSearch")) + assert res["total_results"] == 1 + assert res["results"][0]["id"] == "doc:1" + assert ( + res["results"][0]["extra_attributes"]["$"] + == '{"prod:name":"RediSearch"}' + ) + # query for an unsupported field + res = client.ft().search("@name_unsupported:RediSearch") + assert res["total_results"] == 1 + + # return of a supported field succeeds + res = client.ft().search(Query("@name:RediSearch").return_field("name")) + assert res["total_results"] == 1 + assert res["results"][0]["id"] == "doc:1" + assert res["results"][0]["extra_attributes"]["name"] == "RediSearch" + + +class TestProfile(SearchTestsBase): + @pytest.mark.redismod + @pytest.mark.onlynoncluster + @skip_if_redis_enterprise() + @skip_if_server_version_gte("7.9.0") + @skip_if_server_version_lt("6.3.0") + def test_profile(self, client): + client.ft().create_index((TextField("t"),)) + client.ft().client.hset("1", "t", "hello") + client.ft().client.hset("2", "t", "world") + + # check using Query + q = Query("hello|world").no_content() + if is_resp2_connection(client): + res, det = client.ft().profile(q) + det = det.info -@pytest.mark.redismod -def test_skip_initial_scan(client): - client.hset("doc1", "foo", "bar") - q = Query("@foo:bar") + assert isinstance(det, list) + assert len(res.docs) == 2 # check also the search result - client.ft().create_index((TextField("foo"),), skip_initial_scan=True) - res = client.ft().search(q) - if is_resp2_connection(client): - assert res.total == 0 - else: - assert res["total_results"] == 0 + # check using AggregateRequest + req = ( + aggregations.AggregateRequest("*") + .load("t") + .apply(prefix="startswith(@t, 'hel')") + ) + res, det = client.ft().profile(req) + det = det.info + assert isinstance(det, list) + assert len(res.rows) == 2 # check also the search result + else: + res = client.ft().profile(q) + res = res.info + assert isinstance(res, dict) + assert len(res["results"]) == 2 # check also the search result -@pytest.mark.redismod -def test_summarize_disabled_nooffset(client): - client.ft().create_index((TextField("txt"),), no_term_offsets=True) - client.hset("doc1", mapping={"txt": "foo bar"}) - with pytest.raises(Exception): - client.ft().search(Query("foo").summarize(fields=["txt"])) + # check using AggregateRequest + req = ( + aggregations.AggregateRequest("*") + .load("t") + .apply(prefix="startswith(@t, 'hel')") + ) + res = client.ft().profile(req) + res = res.info + + assert isinstance(res, dict) + assert len(res["results"]) == 2 # check also the search result + + @pytest.mark.redismod + @pytest.mark.onlynoncluster + @skip_if_redis_enterprise() + @skip_if_server_version_lt("7.9.0") + def test_profile_with_coordinator(self, client): + client.ft().create_index((TextField("t"),)) + client.ft().client.hset("1", "t", "hello") + client.ft().client.hset("2", "t", "world") + + # check using Query + q = Query("hello|world").no_content() + if is_resp2_connection(client): + res, det = client.ft().profile(q) + det = det.info + assert isinstance(det, list) + assert len(res.docs) == 2 # check also the search result -@pytest.mark.redismod -def test_summarize_disabled_nohl(client): - client.ft().create_index((TextField("txt"),), no_highlight=True) - client.hset("doc1", mapping={"txt": "foo bar"}) - with pytest.raises(Exception): - client.ft().search(Query("foo").summarize(fields=["txt"])) + # check using AggregateRequest + req = ( + aggregations.AggregateRequest("*") + .load("t") + .apply(prefix="startswith(@t, 'hel')") + ) + res, det = client.ft().profile(req) + det = det.info + assert isinstance(det, list) + assert det[0] == "Shards" + assert det[2] == "Coordinator" + assert len(res.rows) == 2 # check also the search result + else: + res = client.ft().profile(q) + res = res.info -@pytest.mark.redismod -def test_max_text_fields(client): - # Creating the index definition - client.ft().create_index((TextField("f0"),)) - for x in range(1, 32): - client.ft().alter_schema_add((TextField(f"f{x}"),)) + assert isinstance(res, dict) + assert len(res["Results"]["results"]) == 2 # check also the search result - # Should be too many indexes - with pytest.raises(redis.ResponseError): - client.ft().alter_schema_add((TextField(f"f{x}"),)) + # check using AggregateRequest + req = ( + aggregations.AggregateRequest("*") + .load("t") + .apply(prefix="startswith(@t, 'hel')") + ) + res = client.ft().profile(req) + res = res.info + + assert isinstance(res, dict) + assert len(res["Results"]["results"]) == 2 # check also the search result + + @pytest.mark.redismod + @pytest.mark.onlynoncluster + @skip_if_server_version_gte("7.9.0") + @skip_if_server_version_lt("6.3.0") + def test_profile_limited(self, client): + client.ft().create_index((TextField("t"),)) + client.ft().client.hset("1", "t", "hello") + client.ft().client.hset("2", "t", "hell") + client.ft().client.hset("3", "t", "help") + client.ft().client.hset("4", "t", "helowa") + + q = Query("%hell% hel*") + if is_resp2_connection(client): + res, det = client.ft().profile(q, limited=True) + det = det.info + assert det[4][1][7][9] == "The number of iterators in the union is 3" + assert det[4][1][8][9] == "The number of iterators in the union is 4" + assert det[4][1][1] == "INTERSECT" + assert len(res.docs) == 3 # check also the search result + else: + res = client.ft().profile(q, limited=True) + res = res.info + iterators_profile = res["profile"]["Iterators profile"] + assert ( + iterators_profile[0]["Child iterators"][0]["Child iterators"] + == "The number of iterators in the union is 3" + ) + assert ( + iterators_profile[0]["Child iterators"][1]["Child iterators"] + == "The number of iterators in the union is 4" + ) + assert iterators_profile[0]["Type"] == "INTERSECT" + assert len(res["results"]) == 3 # check also the search result + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_gte("7.9.0") + @skip_if_server_version_lt("6.3.0") + def test_profile_query_params(self, client): + client.ft().create_index( + ( + VectorField( + "v", "HNSW", {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"} + ), + ) + ) + client.hset("a", "v", "aaaaaaaa") + client.hset("b", "v", "aaaabaaa") + client.hset("c", "v", "aaaaabaa") + query = "*=>[KNN 2 @v $vec]" + q = Query(query).return_field("__v_score").sort_by("__v_score", True) + if is_resp2_connection(client): + res, det = client.ft().profile(q, query_params={"vec": "aaaaaaaa"}) + det = det.info + assert det[4][1][5] == 2.0 + assert det[4][1][1] == "VECTOR" + assert res.total == 2 + assert "a" == res.docs[0].id + assert "0" == res.docs[0].__getattribute__("__v_score") + else: + res = client.ft().profile(q, query_params={"vec": "aaaaaaaa"}) + res = res.info + assert res["profile"]["Iterators profile"][0]["Counter"] == 2 + assert res["profile"]["Iterators profile"][0]["Type"] == "VECTOR" + assert res["total_results"] == 2 + assert "a" == res["results"][0]["id"] + assert "0" == res["results"][0]["extra_attributes"]["__v_score"] + + +class TestDifferentFieldTypesSearch(SearchTestsBase): + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + def test_vector_field(self, client): + client.flushdb() + client.ft().create_index( + ( + VectorField( + "v", "HNSW", {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"} + ), + ) + ) + client.hset("a", "v", "aaaaaaaa") + client.hset("b", "v", "aaaabaaa") + client.hset("c", "v", "aaaaabaa") - client.ft().dropindex() - # Creating the index definition - client.ft().create_index((TextField("f0"),), max_text_fields=True) - # Fill the index with fields - for x in range(1, 50): - client.ft().alter_schema_add((TextField(f"f{x}"),)) + query = "*=>[KNN 2 @v $vec]" + q = Query(query).return_field("__v_score").sort_by("__v_score", True) + res = client.ft().search(q, query_params={"vec": "aaaaaaaa"}) + if is_resp2_connection(client): + assert "a" == res.docs[0].id + assert "0" == res.docs[0].__getattribute__("__v_score") + else: + assert "a" == res["results"][0]["id"] + assert "0" == res["results"][0]["extra_attributes"]["__v_score"] -@pytest.mark.redismod -@skip_ifmodversion_lt("2.0.0", "search") -def test_create_client_definition(client): - """ - Create definition with no index type provided, - and use hset to test the client definition (the default is HASH). - """ - definition = IndexDefinition(prefix=["hset:", "henry"]) - createIndex(client.ft(), num_docs=500, definition=definition) + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + def test_vector_field_error(self, r): + r.flushdb() + + # sortable tag + with pytest.raises(Exception): + r.ft().create_index((VectorField("v", "HNSW", {}, sortable=True),)) - info = client.ft().info() - assert 494 == int(info["num_docs"]) + # not supported algorithm + with pytest.raises(Exception): + r.ft().create_index((VectorField("v", "SORT", {}),)) - client.ft().client.hset("hset:1", "f1", "v1") - info = client.ft().info() - assert 495 == int(info["num_docs"]) + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + def test_text_params(self, client): + client.flushdb() + client.ft().create_index((TextField("name"),)) + client.hset("doc1", mapping={"name": "Alice"}) + client.hset("doc2", mapping={"name": "Bob"}) + client.hset("doc3", mapping={"name": "Carol"}) -@pytest.mark.redismod -@skip_ifmodversion_lt("2.0.0", "search") -def test_create_client_definition_hash(client): - """ - Create definition with IndexType.HASH as index type (ON HASH), - and use hset to test the client definition. - """ - definition = IndexDefinition(prefix=["hset:", "henry"], index_type=IndexType.HASH) - createIndex(client.ft(), num_docs=500, definition=definition) + params_dict = {"name1": "Alice", "name2": "Bob"} + q = Query("@name:($name1 | $name2 )") + res = client.ft().search(q, query_params=params_dict) + if is_resp2_connection(client): + assert 2 == res.total + assert "doc1" == res.docs[0].id + assert "doc2" == res.docs[1].id + else: + assert 2 == res["total_results"] + assert "doc1" == res["results"][0]["id"] + assert "doc2" == res["results"][1]["id"] + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + def test_numeric_params(self, client): + client.flushdb() + client.ft().create_index((NumericField("numval"),)) + + client.hset("doc1", mapping={"numval": 101}) + client.hset("doc2", mapping={"numval": 102}) + client.hset("doc3", mapping={"numval": 103}) + + params_dict = {"min": 101, "max": 102} + q = Query("@numval:[$min $max]") + res = client.ft().search(q, query_params=params_dict) + + if is_resp2_connection(client): + assert 2 == res.total + assert "doc1" == res.docs[0].id + assert "doc2" == res.docs[1].id + else: + assert 2 == res["total_results"] + assert "doc1" == res["results"][0]["id"] + assert "doc2" == res["results"][1]["id"] + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + def test_geo_params(self, client): + client.ft().create_index(GeoField("g")) + client.hset("doc1", mapping={"g": "29.69465, 34.95126"}) + client.hset("doc2", mapping={"g": "29.69350, 34.94737"}) + client.hset("doc3", mapping={"g": "29.68746, 34.94882"}) + + params_dict = { + "lat": "34.95126", + "lon": "29.69465", + "radius": 1000, + "units": "km", + } + q = Query("@g:[$lon $lat $radius $units]") + res = client.ft().search(q, query_params=params_dict) + _assert_search_result(client, res, ["doc1", "doc2", "doc3"]) + + @pytest.mark.redismod + @skip_if_server_version_lt("7.4.0") + @skip_ifmodversion_lt("2.10.0", "search") + def test_geoshapes_query_intersects_and_disjoint(self, client): + client.ft().create_index((GeoShapeField("g", coord_system=GeoShapeField.FLAT))) + client.hset("doc_point1", mapping={"g": "POINT (10 10)"}) + client.hset("doc_point2", mapping={"g": "POINT (50 50)"}) + client.hset( + "doc_polygon1", mapping={"g": "POLYGON ((20 20, 25 35, 35 25, 20 20))"} + ) + client.hset( + "doc_polygon2", + mapping={"g": "POLYGON ((60 60, 65 75, 70 70, 65 55, 60 60))"}, + ) + + intersection = client.ft().search( + Query("@g:[intersects $shape]").dialect(3), + query_params={"shape": "POLYGON((15 15, 75 15, 50 70, 20 40, 15 15))"}, + ) + _assert_search_result(client, intersection, ["doc_point2", "doc_polygon1"]) - info = client.ft().info() - assert 494 == int(info["num_docs"]) + disjunction = client.ft().search( + Query("@g:[disjoint $shape]").dialect(3), + query_params={"shape": "POLYGON((15 15, 75 15, 50 70, 20 40, 15 15))"}, + ) + _assert_search_result(client, disjunction, ["doc_point1", "doc_polygon2"]) + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.10.0", "search") + def test_geoshapes_query_contains_and_within(self, client): + client.ft().create_index((GeoShapeField("g", coord_system=GeoShapeField.FLAT))) + client.hset("doc_point1", mapping={"g": "POINT (10 10)"}) + client.hset("doc_point2", mapping={"g": "POINT (50 50)"}) + client.hset( + "doc_polygon1", mapping={"g": "POLYGON ((20 20, 25 35, 35 25, 20 20))"} + ) + client.hset( + "doc_polygon2", + mapping={"g": "POLYGON ((60 60, 65 75, 70 70, 65 55, 60 60))"}, + ) + + contains_a = client.ft().search( + Query("@g:[contains $shape]").dialect(3), + query_params={"shape": "POINT(25 25)"}, + ) + _assert_search_result(client, contains_a, ["doc_polygon1"]) + + contains_b = client.ft().search( + Query("@g:[contains $shape]").dialect(3), + query_params={"shape": "POLYGON((24 24, 24 26, 25 25, 24 24))"}, + ) + _assert_search_result(client, contains_b, ["doc_polygon1"]) + + within = client.ft().search( + Query("@g:[within $shape]").dialect(3), + query_params={"shape": "POLYGON((15 15, 75 15, 50 70, 20 40, 15 15))"}, + ) + _assert_search_result(client, within, ["doc_point2", "doc_polygon1"]) + + @pytest.mark.redismod + @skip_if_server_version_lt("7.9.0") + def test_vector_search_with_int8_type(self, client): + client.ft().create_index( + ( + VectorField( + "v", "FLAT", {"TYPE": "INT8", "DIM": 2, "DISTANCE_METRIC": "L2"} + ), + ) + ) + + a = [1.5, 10] + b = [123, 100] + c = [1, 1] + + client.hset("a", "v", np.array(a, dtype=np.int8).tobytes()) + client.hset("b", "v", np.array(b, dtype=np.int8).tobytes()) + client.hset("c", "v", np.array(c, dtype=np.int8).tobytes()) + + query = Query("*=>[KNN 2 @v $vec as score]").no_content() + query_params = {"vec": np.array(a, dtype=np.int8).tobytes()} + + assert 2 in query.get_args() + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 2 + else: + assert res["total_results"] == 2 + + @pytest.mark.redismod + @skip_if_server_version_lt("7.9.0") + def test_vector_search_with_uint8_type(self, client): + client.ft().create_index( + ( + VectorField( + "v", "FLAT", {"TYPE": "UINT8", "DIM": 2, "DISTANCE_METRIC": "L2"} + ), + ) + ) + + a = [1.5, 10] + b = [123, 100] + c = [1, 1] + + client.hset("a", "v", np.array(a, dtype=np.uint8).tobytes()) + client.hset("b", "v", np.array(b, dtype=np.uint8).tobytes()) + client.hset("c", "v", np.array(c, dtype=np.uint8).tobytes()) + + query = Query("*=>[KNN 2 @v $vec as score]").no_content() + query_params = {"vec": np.array(a, dtype=np.uint8).tobytes()} + + assert 2 in query.get_args() + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 2 + else: + assert res["total_results"] == 2 + + +class TestPipeline(SearchTestsBase): + @pytest.mark.redismod + @skip_if_redis_enterprise() + def test_search_commands_in_pipeline(self, client): + p = client.ft().pipeline() + p.create_index((TextField("txt"),)) + p.hset("doc1", mapping={"txt": "foo bar"}) + p.hset("doc2", mapping={"txt": "foo bar"}) + q = Query("foo bar").with_payloads() + p.search(q) + res = p.execute() + if is_resp2_connection(client): + assert res[:3] == ["OK", True, True] + assert 2 == res[3][0] + assert "doc1" == res[3][1] + assert "doc2" == res[3][4] + assert res[3][5] is None + assert res[3][3] == res[3][6] == ["txt", "foo bar"] + else: + assert res[:3] == ["OK", True, True] + assert 2 == res[3]["total_results"] + assert "doc1" == res[3]["results"][0]["id"] + assert "doc2" == res[3]["results"][1]["id"] + assert res[3]["results"][0]["payload"] is None + assert ( + res[3]["results"][0]["extra_attributes"] + == res[3]["results"][1]["extra_attributes"] + == {"txt": "foo bar"} + ) + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + def test_hybrid_search_query_with_pipeline(self, client): + p = client.ft().pipeline() + p.create_index( + ( + TextField("txt"), + VectorField( + "embedding", + "FLAT", + {"TYPE": "FLOAT32", "DIM": 4, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + p.hset( + "doc1", + mapping={ + "txt": "foo bar", + "embedding": np.array([1, 2, 3, 4], dtype=np.float32).tobytes(), + }, + ) + p.hset( + "doc2", + mapping={ + "txt": "foo bar", + "embedding": np.array([1, 2, 2, 3], dtype=np.float32).tobytes(), + }, + ) - client.ft().client.hset("hset:1", "f1", "v1") - info = client.ft().info() - assert 495 == int(info["num_docs"]) + # set search query + search_query = HybridSearchQuery("foo") + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([2, 2, 3, 3], dtype=np.float32).tobytes(), + ) -@pytest.mark.redismod -@skip_ifmodversion_lt("2.2.0", "search") -def test_create_client_definition_json(client): - """ - Create definition with IndexType.JSON as index type (ON JSON), - and use json client to test it. - """ - definition = IndexDefinition(prefix=["king:"], index_type=IndexType.JSON) - client.ft().create_index((TextField("$.name"),), definition=definition) + hybrid_query = HybridQuery(search_query, vsim_query) - client.json().set("king:1", Path.root_path(), {"name": "henry"}) - client.json().set("king:2", Path.root_path(), {"name": "james"}) + p.hybrid_search(query=hybrid_query) + res = p.execute() - res = client.ft().search("henry") - if is_resp2_connection(client): - assert res.docs[0].id == "king:1" - assert res.docs[0].payload is None - assert res.docs[0].json == '{"name":"henry"}' - assert res.total == 1 - else: - assert res["results"][0]["id"] == "king:1" - assert res["results"][0]["extra_attributes"]["$"] == '{"name":"henry"}' - assert res["total_results"] == 1 - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.2.0", "search") -def test_fields_as_name(client): - # create index - SCHEMA = ( - TextField("$.name", sortable=True, as_name="name"), - NumericField("$.age", as_name="just_a_number"), - ) - definition = IndexDefinition(index_type=IndexType.JSON) - client.ft().create_index(SCHEMA, definition=definition) - - # insert json data - res = client.json().set("doc:1", Path.root_path(), {"name": "Jon", "age": 25}) - assert res - - res = client.ft().search(Query("Jon").return_fields("name", "just_a_number")) - if is_resp2_connection(client): - assert 1 == len(res.docs) - assert "doc:1" == res.docs[0].id - assert "Jon" == res.docs[0].name - assert "25" == res.docs[0].just_a_number - else: - assert 1 == len(res["results"]) - assert "doc:1" == res["results"][0]["id"] - assert "Jon" == res["results"][0]["extra_attributes"]["name"] - assert "25" == res["results"][0]["extra_attributes"]["just_a_number"] + # the default results count limit is 10 + assert res[:3] == ["OK", 2, 2] + hybrid_search_res = res[3] + if is_resp2_connection(client): + # it doesn't get parsed to object in pipeline + assert hybrid_search_res[0] == "total_results" + assert hybrid_search_res[1] == 2 + assert hybrid_search_res[2] == "results" + assert len(hybrid_search_res[3]) == 2 + assert hybrid_search_res[4] == "warnings" + assert hybrid_search_res[5] == [] + assert hybrid_search_res[6] == "execution_time" + assert float(hybrid_search_res[7]) > 0 + else: + assert hybrid_search_res["total_results"] == 2 + assert len(hybrid_search_res["results"]) == 2 + assert hybrid_search_res["warnings"] == [] + assert hybrid_search_res["execution_time"] > 0 + + +class TestSearchWithVamana(SearchTestsBase): + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_lt("8.1.224") + def test_svs_vamana_l2_distance_metric(self, client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 3, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + # L2 distance test vectors + vectors = [[1.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [5.0, 0.0, 0.0]] -@pytest.mark.redismod -def test_casesensitive(client): - # create index - SCHEMA = (TagField("t", case_sensitive=False),) - client.ft().create_index(SCHEMA) - client.ft().client.hset("1", "t", "HELLO") - client.ft().client.hset("2", "t", "hello") + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) - res = client.ft().search("@t:{HELLO}") + query = Query("*=>[KNN 3 @v $vec as score]").sort_by("score").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} - if is_resp2_connection(client): - assert 2 == len(res.docs) - assert "1" == res.docs[0].id - assert "2" == res.docs[1].id - else: - assert 2 == len(res["results"]) - assert "1" == res["results"][0]["id"] - assert "2" == res["results"][1]["id"] + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 3 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"] - # create casesensitive index - client.ft().dropindex() - SCHEMA = (TagField("t", case_sensitive=True),) - client.ft().create_index(SCHEMA) - waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_lt("8.1.224") + def test_svs_vamana_cosine_distance_metric(self, client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 3, "DISTANCE_METRIC": "COSINE"}, + ), + ) + ) - res = client.ft().search("@t:{HELLO}") - if is_resp2_connection(client): - assert 1 == len(res.docs) - assert "1" == res.docs[0].id - else: - assert 1 == len(res["results"]) - assert "1" == res["results"][0]["id"] - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.2.0", "search") -def test_search_return_fields(client): - res = client.json().set( - "doc:1", - Path.root_path(), - {"t": "riceratops", "t2": "telmatosaurus", "n": 9072, "flt": 97.2}, - ) - assert res - - # create index on - definition = IndexDefinition(index_type=IndexType.JSON) - SCHEMA = (TextField("$.t"), NumericField("$.flt")) - client.ft().create_index(SCHEMA, definition=definition) - waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + vectors = [ + [1.0, 0.0, 0.0], + [0.707, 0.707, 0.0], + [0.0, 1.0, 0.0], + [-1.0, 0.0, 0.0], + ] - if is_resp2_connection(client): - total = client.ft().search(Query("*").return_field("$.t", as_field="txt")).docs - assert 1 == len(total) - assert "doc:1" == total[0].id - assert "riceratops" == total[0].txt - - total = client.ft().search(Query("*").return_field("$.t2", as_field="txt")).docs - assert 1 == len(total) - assert "doc:1" == total[0].id - assert "telmatosaurus" == total[0].txt - else: - total = client.ft().search(Query("*").return_field("$.t", as_field="txt")) - assert 1 == len(total["results"]) - assert "doc:1" == total["results"][0]["id"] - assert "riceratops" == total["results"][0]["extra_attributes"]["txt"] - - total = client.ft().search(Query("*").return_field("$.t2", as_field="txt")) - assert 1 == len(total["results"]) - assert "doc:1" == total["results"][0]["id"] - assert "telmatosaurus" == total["results"][0]["extra_attributes"]["txt"] - - -@pytest.mark.redismod -@skip_if_resp_version(3) -def test_binary_and_text_fields(client): - fake_vec = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32) - - index_name = "mixed_index" - mixed_data = {"first_name": "🐍python", "vector_emb": fake_vec.tobytes()} - client.hset(f"{index_name}:1", mapping=mixed_data) - - schema = ( - TagField("first_name"), - VectorField( - "embeddings_bio", - algorithm="HNSW", - attributes={ - "TYPE": "FLOAT32", - "DIM": 4, - "DISTANCE_METRIC": "COSINE", - }, - ), - ) - - client.ft(index_name).create_index( - fields=schema, - definition=IndexDefinition( - prefix=[f"{index_name}:"], index_type=IndexType.HASH - ), - ) - - waitForIndex(client, index_name) - - query = ( - Query("*") - .return_field("vector_emb", decode_field=False) - .return_field("first_name") - ) - docs = client.ft(index_name).search(query=query, query_params={}).docs - decoded_vec_from_search_results = np.frombuffer( - docs[0]["vector_emb"], dtype=np.float32 - ) - - assert np.array_equal(decoded_vec_from_search_results, fake_vec), ( - "The vectors are not equal" - ) - - assert docs[0]["first_name"] == mixed_data["first_name"], ( - "The text field is not decoded correctly" - ) - - -@pytest.mark.redismod -def test_synupdate(client): - definition = IndexDefinition(index_type=IndexType.HASH) - client.ft().create_index( - (TextField("title"), TextField("body")), definition=definition - ) - - client.ft().synupdate("id1", True, "boy", "child", "offspring") - client.hset("doc1", mapping={"title": "he is a baby", "body": "this is a test"}) - - client.ft().synupdate("id1", True, "baby") - client.hset("doc2", mapping={"title": "he is another baby", "body": "another test"}) - - res = client.ft().search(Query("child").expander("SYNONYM")) - if is_resp2_connection(client): - assert res.docs[0].id == "doc2" - assert res.docs[0].title == "he is another baby" - assert res.docs[0].body == "another test" - else: - assert res["results"][0]["id"] == "doc2" - assert res["results"][0]["extra_attributes"]["title"] == "he is another baby" - assert res["results"][0]["extra_attributes"]["body"] == "another test" - - -@pytest.mark.redismod -def test_syndump(client): - definition = IndexDefinition(index_type=IndexType.HASH) - client.ft().create_index( - (TextField("title"), TextField("body")), definition=definition - ) - - client.ft().synupdate("id1", False, "boy", "child", "offspring") - client.ft().synupdate("id2", False, "baby", "child") - client.ft().synupdate("id3", False, "tree", "wood") - res = client.ft().syndump() - assert res == { - "boy": ["id1"], - "tree": ["id3"], - "wood": ["id3"], - "child": ["id1", "id2"], - "baby": ["id2"], - "offspring": ["id1"], - } - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.2.0", "search") -def test_create_json_with_alias(client): - """ - Create definition with IndexType.JSON as index type (ON JSON) with two - fields with aliases, and use json client to test it. - """ - definition = IndexDefinition(prefix=["king:"], index_type=IndexType.JSON) - client.ft().create_index( - (TextField("$.name", as_name="name"), NumericField("$.num", as_name="num")), - definition=definition, - ) + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) - client.json().set("king:1", Path.root_path(), {"name": "henry", "num": 42}) - client.json().set("king:2", Path.root_path(), {"name": "james", "num": 3.14}) + query = Query("*=>[KNN 3 @v $vec as score]").sort_by("score").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} - if is_resp2_connection(client): - res = client.ft().search("@name:henry") - assert res.docs[0].id == "king:1" - assert res.docs[0].json == '{"name":"henry","num":42}' - assert res.total == 1 - - res = client.ft().search("@num:[0 10]") - assert res.docs[0].id == "king:2" - assert res.docs[0].json == '{"name":"james","num":3.14}' - assert res.total == 1 - else: - res = client.ft().search("@name:henry") - assert res["results"][0]["id"] == "king:1" - assert res["results"][0]["extra_attributes"]["$"] == '{"name":"henry","num":42}' - assert res["total_results"] == 1 + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 3 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"] - res = client.ft().search("@num:[0 10]") - assert res["results"][0]["id"] == "king:2" - assert ( - res["results"][0]["extra_attributes"]["$"] == '{"name":"james","num":3.14}' + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_lt("8.1.224") + def test_svs_vamana_ip_distance_metric(self, client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 3, "DISTANCE_METRIC": "IP"}, + ), + ) ) - assert res["total_results"] == 1 - - # Tests returns an error if path contain special characters (user should - # use an alias) - with pytest.raises(Exception): - client.ft().search("@$.name:henry") + vectors = [[1.0, 2.0, 3.0], [2.0, 1.0, 1.0], [3.0, 3.0, 3.0], [0.1, 0.1, 0.1]] -@pytest.mark.redismod -@skip_ifmodversion_lt("2.2.0", "search") -def test_json_with_multipath(client): - """ - Create definition with IndexType.JSON as index type (ON JSON), - and use json client to test it. - """ - definition = IndexDefinition(prefix=["king:"], index_type=IndexType.JSON) - client.ft().create_index( - (TagField("$..name", as_name="name")), definition=definition - ) + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) - client.json().set( - "king:1", Path.root_path(), {"name": "henry", "country": {"name": "england"}} - ) + query = Query("*=>[KNN 3 @v $vec as score]").sort_by("score").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} - if is_resp2_connection(client): - res = client.ft().search("@name:{henry}") - assert res.docs[0].id == "king:1" - assert res.docs[0].json == '{"name":"henry","country":{"name":"england"}}' - assert res.total == 1 - - res = client.ft().search("@name:{england}") - assert res.docs[0].id == "king:1" - assert res.docs[0].json == '{"name":"henry","country":{"name":"england"}}' - assert res.total == 1 - else: - res = client.ft().search("@name:{henry}") - assert res["results"][0]["id"] == "king:1" - assert ( - res["results"][0]["extra_attributes"]["$"] - == '{"name":"henry","country":{"name":"england"}}' + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 3 + assert "doc2" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc2" == res["results"][0]["id"] + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_lt("8.1.224") + def test_svs_vamana_basic_functionality(self, client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 4, "DISTANCE_METRIC": "L2"}, + ), + ) ) - assert res["total_results"] == 1 - res = client.ft().search("@name:{england}") - assert res["results"][0]["id"] == "king:1" - assert ( - res["results"][0]["extra_attributes"]["$"] - == '{"name":"henry","country":{"name":"england"}}' - ) - assert res["total_results"] == 1 + vectors = [ + [1.0, 2.0, 3.0, 4.0], + [2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [10.0, 11.0, 12.0, 13.0], + ] + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) -@pytest.mark.redismod -@skip_ifmodversion_lt("2.2.0", "search") -def test_json_with_jsonpath(client): - definition = IndexDefinition(index_type=IndexType.JSON) - client.ft().create_index( - ( - TextField('$["prod:name"]', as_name="name"), - TextField("$.prod:name", as_name="name_unsupported"), - ), - definition=definition, - ) + query = "*=>[KNN 3 @v $vec]" + q = Query(query).return_field("__v_score").sort_by("__v_score", True) + res = client.ft().search( + q, query_params={"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + ) - client.json().set("doc:1", Path.root_path(), {"prod:name": "RediSearch"}) + if is_resp2_connection(client): + assert res.total == 3 + assert "doc0" == res.docs[0].id # Should be closest to itself + assert "0" == res.docs[0].__getattribute__("__v_score") + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"] + assert "0" == res["results"][0]["extra_attributes"]["__v_score"] + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_lt("8.1.224") + def test_svs_vamana_float16_type(self, client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT16", "DIM": 4, "DISTANCE_METRIC": "L2"}, + ), + ) + ) - if is_resp2_connection(client): - # query for a supported field succeeds - res = client.ft().search(Query("@name:RediSearch")) - assert res.total == 1 - assert res.docs[0].id == "doc:1" - assert res.docs[0].json == '{"prod:name":"RediSearch"}' - - # query for an unsupported field - res = client.ft().search("@name_unsupported:RediSearch") - assert res.total == 1 - - # return of a supported field succeeds - res = client.ft().search(Query("@name:RediSearch").return_field("name")) - assert res.total == 1 - assert res.docs[0].id == "doc:1" - assert res.docs[0].name == "RediSearch" - else: - # query for a supported field succeeds - res = client.ft().search(Query("@name:RediSearch")) - assert res["total_results"] == 1 - assert res["results"][0]["id"] == "doc:1" - assert ( - res["results"][0]["extra_attributes"]["$"] == '{"prod:name":"RediSearch"}' - ) - - # query for an unsupported field - res = client.ft().search("@name_unsupported:RediSearch") - assert res["total_results"] == 1 - - # return of a supported field succeeds - res = client.ft().search(Query("@name:RediSearch").return_field("name")) - assert res["total_results"] == 1 - assert res["results"][0]["id"] == "doc:1" - assert res["results"][0]["extra_attributes"]["name"] == "RediSearch" - - -@pytest.mark.redismod -@pytest.mark.onlynoncluster -@skip_if_redis_enterprise() -@skip_if_server_version_gte("7.9.0") -@skip_if_server_version_lt("6.3.0") -def test_profile(client): - client.ft().create_index((TextField("t"),)) - client.ft().client.hset("1", "t", "hello") - client.ft().client.hset("2", "t", "world") - - # check using Query - q = Query("hello|world").no_content() - if is_resp2_connection(client): - res, det = client.ft().profile(q) - det = det.info + vectors = [[1.5, 2.5, 3.5, 4.5], [2.5, 3.5, 4.5, 5.5], [3.5, 4.5, 5.5, 6.5]] - assert isinstance(det, list) - assert len(res.docs) == 2 # check also the search result + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float16).tobytes()) - # check using AggregateRequest - req = ( - aggregations.AggregateRequest("*") - .load("t") - .apply(prefix="startswith(@t, 'hel')") - ) - res, det = client.ft().profile(req) - det = det.info - assert isinstance(det, list) - assert len(res.rows) == 2 # check also the search result - else: - res = client.ft().profile(q) - res = res.info + query = Query("*=>[KNN 2 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float16).tobytes()} - assert isinstance(res, dict) - assert len(res["results"]) == 2 # check also the search result + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 2 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 2 + assert "doc0" == res["results"][0]["id"] - # check using AggregateRequest - req = ( - aggregations.AggregateRequest("*") - .load("t") - .apply(prefix="startswith(@t, 'hel')") + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_lt("8.1.224") + def test_svs_vamana_float32_type(self, client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 4, "DISTANCE_METRIC": "L2"}, + ), + ) ) - res = client.ft().profile(req) - res = res.info - - assert isinstance(res, dict) - assert len(res["results"]) == 2 # check also the search result + vectors = [[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0], [3.0, 4.0, 5.0, 6.0]] -@pytest.mark.redismod -@pytest.mark.onlynoncluster -@skip_if_redis_enterprise() -@skip_if_server_version_lt("7.9.0") -def test_profile_with_coordinator(client): - client.ft().create_index((TextField("t"),)) - client.ft().client.hset("1", "t", "hello") - client.ft().client.hset("2", "t", "world") + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) - # check using Query - q = Query("hello|world").no_content() - if is_resp2_connection(client): - res, det = client.ft().profile(q) - det = det.info + query = Query("*=>[KNN 2 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} - assert isinstance(det, list) - assert len(res.docs) == 2 # check also the search result + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 2 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 2 + assert "doc0" == res["results"][0]["id"] - # check using AggregateRequest - req = ( - aggregations.AggregateRequest("*") - .load("t") - .apply(prefix="startswith(@t, 'hel')") + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_lt("8.1.224") + def test_svs_vamana_vector_search_with_default_dialect(self, client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"}, + ), + ) ) - res, det = client.ft().profile(req) - det = det.info - assert isinstance(det, list) - assert det[0] == "Shards" - assert det[2] == "Coordinator" - assert len(res.rows) == 2 # check also the search result - else: - res = client.ft().profile(q) - res = res.info + client.hset("a", "v", "aaaaaaaa") + client.hset("b", "v", "aaaabaaa") + client.hset("c", "v", "aaaaabaa") - assert isinstance(res, dict) - assert len(res["Results"]["results"]) == 2 # check also the search result + query = "*=>[KNN 2 @v $vec]" + q = Query(query).return_field("__v_score").sort_by("__v_score", True) + res = client.ft().search(q, query_params={"vec": "aaaaaaaa"}) - # check using AggregateRequest - req = ( - aggregations.AggregateRequest("*") - .load("t") - .apply(prefix="startswith(@t, 'hel')") - ) - res = client.ft().profile(req) - res = res.info - - assert isinstance(res, dict) - assert len(res["Results"]["results"]) == 2 # check also the search result - - -@pytest.mark.redismod -@pytest.mark.onlynoncluster -@skip_if_redis_enterprise() -@skip_if_server_version_gte("6.3.0") -def test_profile_with_no_warnings(client): - client.ft().create_index((TextField("t"),)) - client.ft().client.hset("1", "t", "hello") - client.ft().client.hset("2", "t", "world") - - # check using Query - q = Query("hello|world").no_content() - res, det = client.ft().profile(q) - det = det.info - - assert isinstance(det, list) - assert len(res.docs) == 2 # check also the search result - - # check using AggregateRequest - req = ( - aggregations.AggregateRequest("*") - .load("t") - .apply(prefix="startswith(@t, 'hel')") - ) - res, det = client.ft().profile(req) - det = det.info - - assert isinstance(det, list) - assert len(res.rows) == 2 # check also the search result - - -@pytest.mark.redismod -@pytest.mark.onlynoncluster -@skip_if_server_version_gte("7.9.0") -@skip_if_server_version_lt("6.3.0") -def test_profile_limited(client): - client.ft().create_index((TextField("t"),)) - client.ft().client.hset("1", "t", "hello") - client.ft().client.hset("2", "t", "hell") - client.ft().client.hset("3", "t", "help") - client.ft().client.hset("4", "t", "helowa") - - q = Query("%hell% hel*") - if is_resp2_connection(client): - res, det = client.ft().profile(q, limited=True) - det = det.info - assert det[4][1][7][9] == "The number of iterators in the union is 3" - assert det[4][1][8][9] == "The number of iterators in the union is 4" - assert det[4][1][1] == "INTERSECT" - assert len(res.docs) == 3 # check also the search result - else: - res = client.ft().profile(q, limited=True) - res = res.info - iterators_profile = res["profile"]["Iterators profile"] - assert ( - iterators_profile[0]["Child iterators"][0]["Child iterators"] - == "The number of iterators in the union is 3" - ) - assert ( - iterators_profile[0]["Child iterators"][1]["Child iterators"] - == "The number of iterators in the union is 4" - ) - assert iterators_profile[0]["Type"] == "INTERSECT" - assert len(res["results"]) == 3 # check also the search result - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_gte("7.9.0") -@skip_if_server_version_lt("6.3.0") -def test_profile_query_params(client): - client.ft().create_index( - ( - VectorField( - "v", "HNSW", {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"} - ), - ) - ) - client.hset("a", "v", "aaaaaaaa") - client.hset("b", "v", "aaaabaaa") - client.hset("c", "v", "aaaaabaa") - query = "*=>[KNN 2 @v $vec]" - q = Query(query).return_field("__v_score").sort_by("__v_score", True) - if is_resp2_connection(client): - res, det = client.ft().profile(q, query_params={"vec": "aaaaaaaa"}) - det = det.info - assert det[4][1][5] == 2.0 - assert det[4][1][1] == "VECTOR" - assert res.total == 2 - assert "a" == res.docs[0].id - assert "0" == res.docs[0].__getattribute__("__v_score") - else: - res = client.ft().profile(q, query_params={"vec": "aaaaaaaa"}) - res = res.info - assert res["profile"]["Iterators profile"][0]["Counter"] == 2 - assert res["profile"]["Iterators profile"][0]["Type"] == "VECTOR" - assert res["total_results"] == 2 - assert "a" == res["results"][0]["id"] - assert "0" == res["results"][0]["extra_attributes"]["__v_score"] - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -def test_vector_field(client): - client.flushdb() - client.ft().create_index( - ( - VectorField( - "v", "HNSW", {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"} - ), + if is_resp2_connection(client): + assert res.total == 2 + else: + assert res["total_results"] == 2 + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_lt("8.1.224") + def test_svs_vamana_vector_field_basic(self): + field = VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 128, "DISTANCE_METRIC": "COSINE"}, ) - ) - client.hset("a", "v", "aaaaaaaa") - client.hset("b", "v", "aaaabaaa") - client.hset("c", "v", "aaaaabaa") - - query = "*=>[KNN 2 @v $vec]" - q = Query(query).return_field("__v_score").sort_by("__v_score", True) - res = client.ft().search(q, query_params={"vec": "aaaaaaaa"}) - - if is_resp2_connection(client): - assert "a" == res.docs[0].id - assert "0" == res.docs[0].__getattribute__("__v_score") - else: - assert "a" == res["results"][0]["id"] - assert "0" == res["results"][0]["extra_attributes"]["__v_score"] + # Check that the field was created successfully + assert field.name == "v" + assert field.args[0] == "VECTOR" + assert field.args[1] == "SVS-VAMANA" + assert field.args[2] == 6 + assert "TYPE" in field.args + assert "FLOAT32" in field.args + assert "DIM" in field.args + assert 128 in field.args + assert "DISTANCE_METRIC" in field.args + assert "COSINE" in field.args + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_lt("8.1.224") + def test_svs_vamana_lvq8_compression(self, client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 8, + "DISTANCE_METRIC": "L2", + "COMPRESSION": "LVQ8", + "TRAINING_THRESHOLD": 1024, + }, + ), + ) + ) -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -def test_vector_field_error(r): - r.flushdb() + vectors = [] + for i in range(20): + vec = [float(i + j) for j in range(8)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) - # sortable tag - with pytest.raises(Exception): - r.ft().create_index((VectorField("v", "HNSW", {}, sortable=True),)) + query = Query("*=>[KNN 5 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} - # not supported algorithm - with pytest.raises(Exception): - r.ft().create_index((VectorField("v", "SORT", {}),)) + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 5 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 5 + assert "doc0" == res["results"][0]["id"] + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_lt("8.1.224") + def test_svs_vamana_compression_with_both_vector_types(self, client): + # Test FLOAT16 with LVQ8 + client.ft("idx16").create_index( + ( + VectorField( + "v16", + "SVS-VAMANA", + { + "TYPE": "FLOAT16", + "DIM": 8, + "DISTANCE_METRIC": "L2", + "COMPRESSION": "LVQ8", + "TRAINING_THRESHOLD": 1024, + }, + ), + ) + ) -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -def test_text_params(client): - client.flushdb() - client.ft().create_index((TextField("name"),)) + # Test FLOAT32 with LVQ8 + client.ft("idx32").create_index( + ( + VectorField( + "v32", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 8, + "DISTANCE_METRIC": "L2", + "COMPRESSION": "LVQ8", + "TRAINING_THRESHOLD": 1024, + }, + ), + ) + ) - client.hset("doc1", mapping={"name": "Alice"}) - client.hset("doc2", mapping={"name": "Bob"}) - client.hset("doc3", mapping={"name": "Carol"}) + # Add data to both indices + for i in range(15): + vec = [float(i + j) for j in range(8)] + client.hset(f"doc16_{i}", "v16", np.array(vec, dtype=np.float16).tobytes()) + client.hset(f"doc32_{i}", "v32", np.array(vec, dtype=np.float32).tobytes()) + + # Test both indices + query = Query("*=>[KNN 3 @v16 $vec as score]").no_content() + res16 = client.ft("idx16").search( + query, + query_params={ + "vec": np.array( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], dtype=np.float16 + ).tobytes() + }, + ) - params_dict = {"name1": "Alice", "name2": "Bob"} - q = Query("@name:($name1 | $name2 )") - res = client.ft().search(q, query_params=params_dict) - if is_resp2_connection(client): - assert 2 == res.total - assert "doc1" == res.docs[0].id - assert "doc2" == res.docs[1].id - else: - assert 2 == res["total_results"] - assert "doc1" == res["results"][0]["id"] - assert "doc2" == res["results"][1]["id"] + query = Query("*=>[KNN 3 @v32 $vec as score]").no_content() + res32 = client.ft("idx32").search( + query, + query_params={ + "vec": np.array( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], dtype=np.float32 + ).tobytes() + }, + ) + if is_resp2_connection(client): + assert res16.total == 3 + assert res32.total == 3 + else: + assert res16["total_results"] == 3 + assert res32["total_results"] == 3 + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_lt("8.1.224") + def test_svs_vamana_construction_window_size(self, client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "L2", + "CONSTRUCTION_WINDOW_SIZE": 300, + }, + ), + ) + ) -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -def test_numeric_params(client): - client.flushdb() - client.ft().create_index((NumericField("numval"),)) + vectors = [] + for i in range(20): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) - client.hset("doc1", mapping={"numval": 101}) - client.hset("doc2", mapping={"numval": 102}) - client.hset("doc3", mapping={"numval": 103}) + query = Query("*=>[KNN 5 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} - params_dict = {"min": 101, "max": 102} - q = Query("@numval:[$min $max]") - res = client.ft().search(q, query_params=params_dict) + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 5 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 5 + assert "doc0" == res["results"][0]["id"] - if is_resp2_connection(client): - assert 2 == res.total - assert "doc1" == res.docs[0].id - assert "doc2" == res.docs[1].id - else: - assert 2 == res["total_results"] - assert "doc1" == res["results"][0]["id"] - assert "doc2" == res["results"][1]["id"] - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -def test_geo_params(client): - client.ft().create_index(GeoField("g")) - client.hset("doc1", mapping={"g": "29.69465, 34.95126"}) - client.hset("doc2", mapping={"g": "29.69350, 34.94737"}) - client.hset("doc3", mapping={"g": "29.68746, 34.94882"}) - - params_dict = {"lat": "34.95126", "lon": "29.69465", "radius": 1000, "units": "km"} - q = Query("@g:[$lon $lat $radius $units]") - res = client.ft().search(q, query_params=params_dict) - _assert_search_result(client, res, ["doc1", "doc2", "doc3"]) - - -@pytest.mark.redismod -@skip_if_server_version_lt("7.4.0") -@skip_ifmodversion_lt("2.10.0", "search") -def test_geoshapes_query_intersects_and_disjoint(client): - client.ft().create_index((GeoShapeField("g", coord_system=GeoShapeField.FLAT))) - client.hset("doc_point1", mapping={"g": "POINT (10 10)"}) - client.hset("doc_point2", mapping={"g": "POINT (50 50)"}) - client.hset("doc_polygon1", mapping={"g": "POLYGON ((20 20, 25 35, 35 25, 20 20))"}) - client.hset( - "doc_polygon2", mapping={"g": "POLYGON ((60 60, 65 75, 70 70, 65 55, 60 60))"} - ) - - intersection = client.ft().search( - Query("@g:[intersects $shape]").dialect(3), - query_params={"shape": "POLYGON((15 15, 75 15, 50 70, 20 40, 15 15))"}, - ) - _assert_search_result(client, intersection, ["doc_point2", "doc_polygon1"]) - - disjunction = client.ft().search( - Query("@g:[disjoint $shape]").dialect(3), - query_params={"shape": "POLYGON((15 15, 75 15, 50 70, 20 40, 15 15))"}, - ) - _assert_search_result(client, disjunction, ["doc_point1", "doc_polygon2"]) - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.10.0", "search") -def test_geoshapes_query_contains_and_within(client): - client.ft().create_index((GeoShapeField("g", coord_system=GeoShapeField.FLAT))) - client.hset("doc_point1", mapping={"g": "POINT (10 10)"}) - client.hset("doc_point2", mapping={"g": "POINT (50 50)"}) - client.hset("doc_polygon1", mapping={"g": "POLYGON ((20 20, 25 35, 35 25, 20 20))"}) - client.hset( - "doc_polygon2", mapping={"g": "POLYGON ((60 60, 65 75, 70 70, 65 55, 60 60))"} - ) - - contains_a = client.ft().search( - Query("@g:[contains $shape]").dialect(3), - query_params={"shape": "POINT(25 25)"}, - ) - _assert_search_result(client, contains_a, ["doc_polygon1"]) - - contains_b = client.ft().search( - Query("@g:[contains $shape]").dialect(3), - query_params={"shape": "POLYGON((24 24, 24 26, 25 25, 24 24))"}, - ) - _assert_search_result(client, contains_b, ["doc_polygon1"]) - - within = client.ft().search( - Query("@g:[within $shape]").dialect(3), - query_params={"shape": "POLYGON((15 15, 75 15, 50 70, 20 40, 15 15))"}, - ) - _assert_search_result(client, within, ["doc_point2", "doc_polygon1"]) - - -@pytest.mark.redismod -@skip_if_redis_enterprise() -def test_search_commands_in_pipeline(client): - p = client.ft().pipeline() - p.create_index((TextField("txt"),)) - p.hset("doc1", mapping={"txt": "foo bar"}) - p.hset("doc2", mapping={"txt": "foo bar"}) - q = Query("foo bar").with_payloads() - p.search(q) - res = p.execute() - if is_resp2_connection(client): - assert res[:3] == ["OK", True, True] - assert 2 == res[3][0] - assert "doc1" == res[3][1] - assert "doc2" == res[3][4] - assert res[3][5] is None - assert res[3][3] == res[3][6] == ["txt", "foo bar"] - else: - assert res[:3] == ["OK", True, True] - assert 2 == res[3]["total_results"] - assert "doc1" == res[3]["results"][0]["id"] - assert "doc2" == res[3]["results"][1]["id"] - assert res[3]["results"][0]["payload"] is None - assert ( - res[3]["results"][0]["extra_attributes"] - == res[3]["results"][1]["extra_attributes"] - == {"txt": "foo bar"} - ) - - -@pytest.mark.redismod -@pytest.mark.onlynoncluster -@skip_ifmodversion_lt("2.4.3", "search") -def test_dialect_config(client): - assert client.ft().config_get("DEFAULT_DIALECT") - client.ft().config_set("DEFAULT_DIALECT", 2) - assert client.ft().config_get("DEFAULT_DIALECT") == {"DEFAULT_DIALECT": "2"} - with pytest.raises(redis.ResponseError): - client.ft().config_set("DEFAULT_DIALECT", 0) - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -def test_dialect(client): - client.ft().create_index( - ( - TagField("title"), - TextField("t1"), - TextField("t2"), - NumericField("num"), - VectorField( - "v", "HNSW", {"TYPE": "FLOAT32", "DIM": 1, "DISTANCE_METRIC": "COSINE"} - ), + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_lt("8.1.224") + def test_svs_vamana_graph_max_degree(self, client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "COSINE", + "GRAPH_MAX_DEGREE": 64, + }, + ), + ) ) - ) - client.hset("h", "t1", "hello") - with pytest.raises(redis.ResponseError) as err: - client.ft().explain(Query("(*)").dialect(1)) - assert "Syntax error" in str(err) - assert "WILDCARD" in client.ft().explain(Query("(*)")) - - with pytest.raises(redis.ResponseError) as err: - client.ft().explain(Query("$hello").dialect(1)) - assert "Syntax error" in str(err) - q = Query("$hello") - expected = "UNION {\n hello\n +hello(expanded)\n}\n" - assert expected in client.ft().explain(q, query_params={"hello": "hello"}) - - expected = "NUMERIC {0.000000 <= @num <= 10.000000}\n" - assert expected in client.ft().explain(Query("@title:(@num:[0 10])").dialect(1)) - with pytest.raises(redis.ResponseError) as err: - client.ft().explain(Query("@title:(@num:[0 10])")) - assert "Syntax error" in str(err) - - -@pytest.mark.redismod -def test_expire_while_search(client: redis.Redis): - client.ft().create_index((TextField("txt"),)) - client.hset("hset:1", "txt", "a") - client.hset("hset:2", "txt", "b") - client.hset("hset:3", "txt", "c") - if is_resp2_connection(client): - assert 3 == client.ft().search(Query("*")).total - client.pexpire("hset:2", 300) - for _ in range(500): - client.ft().search(Query("*")).docs[1] - time.sleep(1) - assert 2 == client.ft().search(Query("*")).total - else: - assert 3 == client.ft().search(Query("*"))["total_results"] - client.pexpire("hset:2", 300) - for _ in range(500): - client.ft().search(Query("*"))["results"][1] - time.sleep(1) - assert 2 == client.ft().search(Query("*"))["total_results"] - - -@pytest.mark.redismod -@pytest.mark.experimental -def test_withsuffixtrie(client: redis.Redis): - # create index - assert client.ft().create_index((TextField("txt"),)) - waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - if is_resp2_connection(client): - info = client.ft().info() - assert "WITHSUFFIXTRIE" not in info["attributes"][0] - assert client.ft().dropindex() - # create withsuffixtrie index (text fields) - assert client.ft().create_index(TextField("t", withsuffixtrie=True)) - waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - info = client.ft().info() - assert "WITHSUFFIXTRIE" in info["attributes"][0] - assert client.ft().dropindex() + vectors = [] + for i in range(25): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) - # create withsuffixtrie index (tag field) - assert client.ft().create_index(TagField("t", withsuffixtrie=True)) - waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - info = client.ft().info() - assert "WITHSUFFIXTRIE" in info["attributes"][0] - else: - info = client.ft().info() - assert "WITHSUFFIXTRIE" not in info["attributes"][0]["flags"] - assert client.ft().dropindex() + query = Query("*=>[KNN 6 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} - # create withsuffixtrie index (text fields) - assert client.ft().create_index(TextField("t", withsuffixtrie=True)) - waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - info = client.ft().info() - assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] - assert client.ft().dropindex() + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 6 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 6 + assert "doc0" == res["results"][0]["id"] - # create withsuffixtrie index (tag field) - assert client.ft().create_index(TagField("t", withsuffixtrie=True)) - waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - info = client.ft().info() - assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] - - -@pytest.mark.redismod -def test_query_timeout(r: redis.Redis): - q1 = Query("foo").timeout(5000) - assert q1.get_args() == ["foo", "TIMEOUT", 5000, "DIALECT", 2, "LIMIT", 0, 10] - q1 = Query("foo").timeout(0) - assert q1.get_args() == ["foo", "TIMEOUT", 0, "DIALECT", 2, "LIMIT", 0, 10] - q2 = Query("foo").timeout("not_a_number") - with pytest.raises(redis.ResponseError): - r.ft().search(q2) - - -@pytest.mark.redismod -@skip_if_server_version_lt("7.2.0") -@skip_ifmodversion_lt("2.8.4", "search") -def test_geoshape(client: redis.Redis): - client.ft().create_index(GeoShapeField("geom", GeoShapeField.FLAT)) - waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - client.hset("small", "geom", "POLYGON((1 1, 1 100, 100 100, 100 1, 1 1))") - client.hset("large", "geom", "POLYGON((1 1, 1 200, 200 200, 200 1, 1 1))") - q1 = Query("@geom:[WITHIN $poly]").dialect(3) - qp1 = {"poly": "POLYGON((0 0, 0 150, 150 150, 150 0, 0 0))"} - q2 = Query("@geom:[CONTAINS $poly]").dialect(3) - qp2 = {"poly": "POLYGON((2 2, 2 50, 50 50, 50 2, 2 2))"} - result = client.ft().search(q1, query_params=qp1) - _assert_search_result(client, result, ["small"]) - result = client.ft().search(q2, query_params=qp2) - _assert_search_result(client, result, ["small", "large"]) - - -@pytest.mark.redismod -@skip_if_server_version_lt("7.4.0") -@skip_ifmodversion_lt("2.10.0", "search") -def test_search_missing_fields(client): - definition = IndexDefinition(prefix=["property:"], index_type=IndexType.HASH) - - fields = [ - TextField("title", sortable=True), - TagField("features", index_missing=True), - TextField("description", index_missing=True), - ] - - client.ft().create_index(fields, definition=definition) - - # All fields present - client.hset( - "property:1", - mapping={ - "title": "Luxury Villa in Malibu", - "features": "pool,sea view,modern", - "description": "A stunning modern villa overlooking the Pacific Ocean.", - }, - ) - - # Missing features - client.hset( - "property:2", - mapping={ - "title": "Downtown Flat", - "description": "Modern flat in central Paris with easy access to metro.", - }, - ) - - # Missing description - client.hset( - "property:3", - mapping={ - "title": "Beachfront Bungalow", - "features": "beachfront,sun deck", - }, - ) - - with pytest.raises(redis.exceptions.ResponseError): - client.ft().search(Query("ismissing(@title)").return_field("id").no_content()) - - res = client.ft().search( - Query("ismissing(@features)").return_field("id").no_content() - ) - _assert_search_result(client, res, ["property:2"]) - - res = client.ft().search( - Query("-ismissing(@features)").return_field("id").no_content() - ) - _assert_search_result(client, res, ["property:1", "property:3"]) - - res = client.ft().search( - Query("ismissing(@description)").return_field("id").no_content() - ) - _assert_search_result(client, res, ["property:3"]) - - res = client.ft().search( - Query("-ismissing(@description)").return_field("id").no_content() - ) - _assert_search_result(client, res, ["property:1", "property:2"]) - - -@pytest.mark.redismod -@skip_if_server_version_lt("7.4.0") -@skip_ifmodversion_lt("2.10.0", "search") -def test_create_index_empty_or_missing_fields_with_sortable(client): - definition = IndexDefinition(prefix=["property:"], index_type=IndexType.HASH) - - fields = [ - TextField("title", sortable=True, index_empty=True), - TagField("features", index_missing=True, sortable=True), - TextField("description", no_index=True, sortable=True), - ] - - client.ft().create_index(fields, definition=definition) - - -@pytest.mark.redismod -@skip_if_server_version_lt("7.4.0") -@skip_ifmodversion_lt("2.10.0", "search") -def test_search_empty_fields(client): - definition = IndexDefinition(prefix=["property:"], index_type=IndexType.HASH) - - fields = [ - TextField("title", sortable=True), - TagField("features", index_empty=True), - TextField("description", index_empty=True), - ] - - client.ft().create_index(fields, definition=definition) - - # All fields present - client.hset( - "property:1", - mapping={ - "title": "Luxury Villa in Malibu", - "features": "pool,sea view,modern", - "description": "A stunning modern villa overlooking the Pacific Ocean.", - }, - ) - - # Empty features - client.hset( - "property:2", - mapping={ - "title": "Downtown Flat", - "features": "", - "description": "Modern flat in central Paris with easy access to metro.", - }, - ) - - # Empty description - client.hset( - "property:3", - mapping={ - "title": "Beachfront Bungalow", - "features": "beachfront,sun deck", - "description": "", - }, - ) - - with pytest.raises(redis.exceptions.ResponseError) as e: - client.ft().search(Query("@title:''").return_field("id").no_content()) - assert "Use `INDEXEMPTY` in field creation" in e.value.args[0] - - res = client.ft().search( - Query("@features:{$empty}").return_field("id").no_content(), - query_params={"empty": ""}, - ) - _assert_search_result(client, res, ["property:2"]) - - res = client.ft().search( - Query("-@features:{$empty}").return_field("id").no_content(), - query_params={"empty": ""}, - ) - _assert_search_result(client, res, ["property:1", "property:3"]) - - res = client.ft().search(Query("@description:''").return_field("id").no_content()) - _assert_search_result(client, res, ["property:3"]) - - res = client.ft().search(Query("-@description:''").return_field("id").no_content()) - _assert_search_result(client, res, ["property:1", "property:2"]) - - -@pytest.mark.redismod -@skip_if_server_version_lt("7.4.0") -@skip_ifmodversion_lt("2.10.0", "search") -def test_special_characters_in_fields(client): - definition = IndexDefinition(prefix=["resource:"], index_type=IndexType.HASH) - - fields = [ - TagField("uuid"), - TagField("tags", separator="|"), - TextField("description"), - NumericField("rating"), - ] - - client.ft().create_index(fields, definition=definition) - - client.hset( - "resource:1", - mapping={ - "uuid": "123e4567-e89b-12d3-a456-426614174000", - "tags": "finance|crypto|$btc|blockchain", - "description": "Analysis of blockchain technologies & Bitcoin's potential.", - "rating": 5, - }, - ) - - client.hset( - "resource:2", - mapping={ - "uuid": "987e6543-e21c-12d3-a456-426614174999", - "tags": "health|well-being|fitness|new-year's-resolutions", - "description": "Health trends for the new year, including fitness regimes.", - "rating": 4, - }, - ) - - # no need to escape - when using params - res = client.ft().search( - Query("@uuid:{$uuid}"), - query_params={"uuid": "123e4567-e89b-12d3-a456-426614174000"}, - ) - _assert_search_result(client, res, ["resource:1"]) - - # with double quotes exact match no need to escape the - even without params - res = client.ft().search(Query('@uuid:{"123e4567-e89b-12d3-a456-426614174000"}')) - _assert_search_result(client, res, ["resource:1"]) - - res = client.ft().search(Query('@tags:{"new-year\'s-resolutions"}')) - _assert_search_result(client, res, ["resource:2"]) - - # possible to search numeric fields by single value - res = client.ft().search(Query("@rating:[4]")) - _assert_search_result(client, res, ["resource:2"]) - - # some chars still need escaping - res = client.ft().search(Query(r"@tags:{\$btc}")) - _assert_search_result(client, res, ["resource:1"]) - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -def test_vector_search_with_default_dialect(client): - client.ft().create_index( - ( - VectorField( - "v", "HNSW", {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"} - ), + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_lt("8.1.224") + def test_svs_vamana_search_window_size(self, client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "L2", + "SEARCH_WINDOW_SIZE": 20, + }, + ), + ) ) - ) - - client.hset("a", "v", "aaaaaaaa") - client.hset("b", "v", "aaaabaaa") - client.hset("c", "v", "aaaaabaa") - - query = "*=>[KNN 2 @v $vec]" - q = Query(query) - assert "DIALECT" in q.get_args() - assert 2 in q.get_args() + vectors = [] + for i in range(30): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) - res = client.ft().search(q, query_params={"vec": "aaaaaaaa"}) - if is_resp2_connection(client): - assert res.total == 2 - else: - assert res["total_results"] == 2 + query = Query("*=>[KNN 8 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 8 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 8 + assert "doc0" == res["results"][0]["id"] -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_lt("8.1.224") -def test_svs_vamana_l2_distance_metric(client): - client.ft().create_index( - ( - VectorField( - "v", - "SVS-VAMANA", - {"TYPE": "FLOAT32", "DIM": 3, "DISTANCE_METRIC": "L2"}, - ), + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_lt("8.1.224") + def test_svs_vamana_epsilon_parameter(self, client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "L2", + "EPSILON": 0.05, + }, + ), + ) ) - ) - - # L2 distance test vectors - vectors = [[1.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [5.0, 0.0, 0.0]] - for i, vec in enumerate(vectors): - client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + vectors = [] + for i in range(20): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) - query = Query("*=>[KNN 3 @v $vec as score]").sort_by("score").no_content() - query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} - - res = client.ft().search(query, query_params=query_params) - if is_resp2_connection(client): - assert res.total == 3 - assert "doc0" == res.docs[0].id - else: - assert res["total_results"] == 3 - assert "doc0" == res["results"][0]["id"] + query = Query("*=>[KNN 5 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 5 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 5 + assert "doc0" == res["results"][0]["id"] -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_lt("8.1.224") -def test_svs_vamana_cosine_distance_metric(client): - client.ft().create_index( - ( - VectorField( - "v", - "SVS-VAMANA", - {"TYPE": "FLOAT32", "DIM": 3, "DISTANCE_METRIC": "COSINE"}, - ), + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_lt("8.1.224") + def test_svs_vamana_all_build_parameters_combined(self, client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 8, + "DISTANCE_METRIC": "IP", + "CONSTRUCTION_WINDOW_SIZE": 250, + "GRAPH_MAX_DEGREE": 48, + "SEARCH_WINDOW_SIZE": 15, + "EPSILON": 0.02, + }, + ), + ) ) - ) - vectors = [[1.0, 0.0, 0.0], [0.707, 0.707, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]] + vectors = [] + for i in range(35): + vec = [float(i + j) for j in range(8)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) - for i, vec in enumerate(vectors): - client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + query = Query("*=>[KNN 7 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} - query = Query("*=>[KNN 3 @v $vec as score]").sort_by("score").no_content() - query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 7 + doc_ids = [doc.id for doc in res.docs] + assert len(doc_ids) == 7 + else: + assert res["total_results"] == 7 + doc_ids = [doc["id"] for doc in res["results"]] + assert len(doc_ids) == 7 + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_lt("8.1.224") + def test_svs_vamana_comprehensive_configuration(self, client): + client.flushdb() + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT16", + "DIM": 32, + "DISTANCE_METRIC": "COSINE", + "COMPRESSION": "LVQ8", + "CONSTRUCTION_WINDOW_SIZE": 400, + "GRAPH_MAX_DEGREE": 96, + "SEARCH_WINDOW_SIZE": 25, + "EPSILON": 0.03, + "TRAINING_THRESHOLD": 2048, + }, + ), + ) + ) - res = client.ft().search(query, query_params=query_params) - if is_resp2_connection(client): - assert res.total == 3 - assert "doc0" == res.docs[0].id - else: - assert res["total_results"] == 3 - assert "doc0" == res["results"][0]["id"] + vectors = [] + for i in range(60): + vec = [float(i + j) for j in range(32)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float16).tobytes()) + query = Query("*=>[KNN 10 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float16).tobytes()} -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_lt("8.1.224") -def test_svs_vamana_ip_distance_metric(client): - client.ft().create_index( - ( - VectorField( - "v", - "SVS-VAMANA", - {"TYPE": "FLOAT32", "DIM": 3, "DISTANCE_METRIC": "IP"}, - ), + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 10 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 10 + assert "doc0" == res["results"][0]["id"] + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_lt("8.1.224") + def test_svs_vamana_hybrid_text_vector_search(self, client): + client.flushdb() + client.ft().create_index( + ( + TextField("title"), + TextField("content"), + VectorField( + "embedding", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "COSINE", + "SEARCH_WINDOW_SIZE": 20, + }, + ), + ) ) - ) - vectors = [[1.0, 2.0, 3.0], [2.0, 1.0, 1.0], [3.0, 3.0, 3.0], [0.1, 0.1, 0.1]] + docs = [ + { + "title": "AI Research", + "content": "machine learning algorithms", + "embedding": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + }, + { + "title": "Data Science", + "content": "statistical analysis methods", + "embedding": [2.0, 3.0, 4.0, 5.0, 6.0, 7.0], + }, + { + "title": "Deep Learning", + "content": "neural network architectures", + "embedding": [3.0, 4.0, 5.0, 6.0, 7.0, 8.0], + }, + { + "title": "Computer Vision", + "content": "image processing techniques", + "embedding": [10.0, 11.0, 12.0, 13.0, 14.0, 15.0], + }, + ] - for i, vec in enumerate(vectors): - client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + for i, doc in enumerate(docs): + client.hset( + f"doc{i}", + mapping={ + "title": doc["title"], + "content": doc["content"], + "embedding": np.array(doc["embedding"], dtype=np.float32).tobytes(), + }, + ) - query = Query("*=>[KNN 3 @v $vec as score]").sort_by("score").no_content() - query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + # Hybrid query - text filter + vector similarity + query = "(@title:AI|@content:machine)=>[KNN 2 @embedding $vec]" + q = ( + Query(query) + .return_field("__embedding_score") + .sort_by("__embedding_score", True) + ) + res = client.ft().search( + q, + query_params={ + "vec": np.array( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=np.float32 + ).tobytes() + }, + ) - res = client.ft().search(query, query_params=query_params) - if is_resp2_connection(client): - assert res.total == 3 - assert "doc2" == res.docs[0].id - else: - assert res["total_results"] == 3 - assert "doc2" == res["results"][0]["id"] + if is_resp2_connection(client): + assert res.total >= 1 + doc_ids = [doc.id for doc in res.docs] + assert "doc0" in doc_ids + else: + assert res["total_results"] >= 1 + doc_ids = [doc["id"] for doc in res["results"]] + assert "doc0" in doc_ids + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_lt("8.1.224") + def test_svs_vamana_large_dimension_vectors(self, client): + client.flushdb() + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 512, + "DISTANCE_METRIC": "L2", + "CONSTRUCTION_WINDOW_SIZE": 300, + "GRAPH_MAX_DEGREE": 64, + }, + ), + ) + ) + vectors = [] + for i in range(10): + vec = [float(i + j) for j in range(512)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) -@pytest.mark.redismod -@skip_if_server_version_lt("7.9.0") -def test_vector_search_with_int8_type(client): - client.ft().create_index( - (VectorField("v", "FLAT", {"TYPE": "INT8", "DIM": 2, "DISTANCE_METRIC": "L2"}),) - ) + query = Query("*=>[KNN 5 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} - a = [1.5, 10] - b = [123, 100] - c = [1, 1] + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 5 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 5 + assert "doc0" == res["results"][0]["id"] - client.hset("a", "v", np.array(a, dtype=np.int8).tobytes()) - client.hset("b", "v", np.array(b, dtype=np.int8).tobytes()) - client.hset("c", "v", np.array(c, dtype=np.int8).tobytes()) + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_lt("8.1.224") + def test_svs_vamana_training_threshold_behavior(self, client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 8, + "DISTANCE_METRIC": "L2", + "COMPRESSION": "LVQ8", + "TRAINING_THRESHOLD": 1024, + }, + ), + ) + ) - query = Query("*=>[KNN 2 @v $vec as score]").no_content() - query_params = {"vec": np.array(a, dtype=np.int8).tobytes()} + vectors = [] + for i in range(20): + vec = [float(i + j) for j in range(8)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + if i >= 5: + query = Query("*=>[KNN 3 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + res = client.ft().search(query, query_params=query_params) + + if is_resp2_connection(client): + assert res.total >= 1 + else: + assert res["total_results"] >= 1 + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_lt("8.1.224") + def test_svs_vamana_different_k_values(self, client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "L2", + "SEARCH_WINDOW_SIZE": 15, + }, + ), + ) + ) - assert 2 in query.get_args() + vectors = [] + for i in range(25): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) - res = client.ft().search(query, query_params=query_params) - if is_resp2_connection(client): - assert res.total == 2 - else: - assert res["total_results"] == 2 + for k in [1, 3, 5, 10, 15]: + query = Query(f"*=>[KNN {k} @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == k + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == k + assert "doc0" == res["results"][0]["id"] + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_lt("8.1.224") + def test_svs_vamana_vector_field_error(self, client): + # sortable tag + with pytest.raises(Exception): + client.ft().create_index( + (VectorField("v", "SVS-VAMANA", {}, sortable=True),) + ) -@pytest.mark.redismod -@skip_if_server_version_lt("7.9.0") -def test_vector_search_with_uint8_type(client): - client.ft().create_index( - ( - VectorField( - "v", "FLAT", {"TYPE": "UINT8", "DIM": 2, "DISTANCE_METRIC": "L2"} - ), + # no_index tag + with pytest.raises(Exception): + client.ft().create_index( + (VectorField("v", "SVS-VAMANA", {}, no_index=True),) + ) + + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_lt("8.1.224") + def test_svs_vamana_vector_search_with_parameters(self, client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 4, + "DISTANCE_METRIC": "L2", + "CONSTRUCTION_WINDOW_SIZE": 200, + "GRAPH_MAX_DEGREE": 64, + "SEARCH_WINDOW_SIZE": 40, + "EPSILON": 0.01, + }, + ), + ) ) - ) - a = [1.5, 10] - b = [123, 100] - c = [1, 1] + # Create test vectors + vectors = [ + [1.0, 2.0, 3.0, 4.0], + [2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [4.0, 5.0, 6.0, 7.0], + [5.0, 6.0, 7.0, 8.0], + ] - client.hset("a", "v", np.array(a, dtype=np.uint8).tobytes()) - client.hset("b", "v", np.array(b, dtype=np.uint8).tobytes()) - client.hset("c", "v", np.array(c, dtype=np.uint8).tobytes()) + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) - query = Query("*=>[KNN 2 @v $vec as score]").no_content() - query_params = {"vec": np.array(a, dtype=np.uint8).tobytes()} + query = Query("*=>[KNN 3 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} - assert 2 in query.get_args() + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 3 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"] - res = client.ft().search(query, query_params=query_params) - if is_resp2_connection(client): - assert res.total == 2 - else: - assert res["total_results"] == 2 + @pytest.mark.redismod + @skip_ifmodversion_lt("2.4.3", "search") + @skip_if_server_version_lt("8.1.224") + def test_svs_vamana_vector_search_with_parameters_leanvec(self, client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 8, + "DISTANCE_METRIC": "L2", + "COMPRESSION": "LeanVec8x8", # LeanVec compression required for REDUCE + "CONSTRUCTION_WINDOW_SIZE": 200, + "GRAPH_MAX_DEGREE": 32, + "SEARCH_WINDOW_SIZE": 15, + "EPSILON": 0.01, + "TRAINING_THRESHOLD": 1024, + "REDUCE": 4, # Half of DIM (8/2 = 4) + }, + ), + ) + ) + # Create test vectors (8-dimensional to match DIM) + vectors = [ + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], + [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], + [3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0], + [5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], + ] -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -def test_search_query_with_different_dialects(client): - client.ft().create_index( - (TextField("name"), TextField("lastname")), - definition=IndexDefinition(prefix=["test:"]), - ) + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) - client.hset("test:1", "name", "James") - client.hset("test:1", "lastname", "Brown") + query = Query("*=>[KNN 3 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} - # Query with default DIALECT 2 - query = "@name: James Brown" - q = Query(query) - res = client.ft().search(q) - if is_resp2_connection(client): - assert res.total == 1 - else: - assert res["total_results"] == 1 + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 3 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"] - # Query with explicit DIALECT 1 - query = "@name: James Brown" - q = Query(query).dialect(1) - res = client.ft().search(q) - if is_resp2_connection(client): - assert res.total == 0 - else: - assert res["total_results"] == 0 +class TestHybridSearch(SearchTestsBase): + def _create_hybrid_search_index(self, client, dim=4): + client.ft().create_index( + ( + TextField("description"), + NumericField("price"), + TagField("color"), + TagField("item_type"), + NumericField("size"), + VectorField( + "embedding", + "FLAT", + { + "TYPE": "FLOAT32", + "DIM": dim, + "DISTANCE_METRIC": "L2", + }, + ), + VectorField( + "embedding-hnsw", + "HNSW", + { + "TYPE": "FLOAT32", + "DIM": dim, + "DISTANCE_METRIC": "L2", + }, + ), + ), + definition=IndexDefinition(prefix=["item:"]), + ) + SearchTestsBase.waitForIndex(client, "idx") + + @staticmethod + def _generate_random_vector(dim): + return [random.random() for _ in range(dim)] + + @staticmethod + def _generate_random_str_data(dim): + chars = "abcdefgh12345678" + return "".join(random.choice(chars) for _ in range(dim)) + + @staticmethod + def _add_data_for_hybrid_search( + client, + items_sets=1, + randomize_data=False, + dim_for_random_data=4, + use_random_str_data=False, + ): + if randomize_data or use_random_str_data: + generate_data_func = ( + TestHybridSearch._generate_random_str_data + if use_random_str_data + else TestHybridSearch._generate_random_vector + ) -@pytest.mark.redismod -@skip_if_server_version_lt("7.9.0") -def test_info_exposes_search_info(client): - assert len(client.info("search")) > 0 + dim_for_random_data = ( + dim_for_random_data * 4 if use_random_str_data else dim_for_random_data + ) + items = [ + (generate_data_func(dim_for_random_data), "red shoes"), + (generate_data_func(dim_for_random_data), "green shoes with red laces"), + (generate_data_func(dim_for_random_data), "red dress"), + (generate_data_func(dim_for_random_data), "orange dress"), + (generate_data_func(dim_for_random_data), "black shoes"), + ] + else: + items = [ + ([1.0, 2.0, 7.0, 8.0], "red shoes"), + ([1.0, 4.0, 7.0, 8.0], "green shoes with red laces"), + ([1.0, 2.0, 6.0, 5.0], "red dress"), + ([2.0, 3.0, 6.0, 5.0], "orange dress"), + ([5.0, 6.0, 7.0, 8.0], "black shoes"), + ] + items = items * items_sets + + pipeline = client.pipeline() + for i, vec in enumerate(items): + vec, description = vec + mapping = { + "description": description, + "embedding": np.array(vec, dtype=np.float32).tobytes() + if not use_random_str_data + else vec, + "embedding-hnsw": np.array(vec, dtype=np.float32).tobytes() + if not use_random_str_data + else vec, + "price": 15 + i % 4, + "color": description.split(" ")[0], + "item_type": description.split(" ")[1], + "size": 10 + i % 3, + } + + pipeline.hset( + f"item:{i}", + mapping=mapping, + ) + pipeline.execute() # Execute all at once + + @staticmethod + def _convert_dict_values_to_str(list_of_dicts): + res = [] + for d in list_of_dicts: + res_dict = {} + for k, v in d.items(): + if isinstance(v, list): + res_dict[k] = [safe_str(x) for x in v] + else: + res_dict[k] = safe_str(v) + res.append(res_dict) + return res + + @staticmethod + def compare_list_of_dicts(actual, expected): + assert len(actual) == len(expected), ( + f"List of dicts length mismatch: {len(actual)} != {len(expected)}. " + f"Full dicts: actual:{actual}; expected:{expected}" + ) + for expected_dict_item in expected: + found = False + for actual_dict_item in actual: + if actual_dict_item == expected_dict_item: + found = True + break + if not found: + assert False, ( + f"Dict {expected_dict_item} not found in actual list of dicts: {actual}. " + f"All expected:{expected}" + ) + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + def test_basic_hybrid_search(self, client): + # Create index and add data + self._create_hybrid_search_index(client) + self._add_data_for_hybrid_search(client, items_sets=5) + + # set search query + search_query = HybridSearchQuery("@color:{red} @color:{green}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([-100, -200, -200, -300], dtype=np.float32).tobytes(), + ) -def _assert_search_result(client, result, expected_doc_ids): - """ - Make sure the result of a geo search is as expected, taking into account the RESP - version being used. - """ - if is_resp2_connection(client): - assert set([doc.id for doc in result.docs]) == set(expected_doc_ids) - else: - assert set([doc["id"] for doc in result["results"]]) == set(expected_doc_ids) + hybrid_query = HybridQuery(search_query, vsim_query) + res = client.ft().hybrid_search(query=hybrid_query) -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_lt("8.1.224") -def test_svs_vamana_basic_functionality(client): - client.ft().create_index( - ( - VectorField( - "v", - "SVS-VAMANA", - {"TYPE": "FLOAT32", "DIM": 4, "DISTANCE_METRIC": "L2"}, - ), + # the default results count limit is 10 + if is_resp2_connection(client): + assert res.total_results == 10 + assert len(res.results) == 10 + assert res.warnings == [] + assert res.execution_time > 0 + assert all(isinstance(res.results[i]["__score"], bytes) for i in range(10)) + assert all(isinstance(res.results[i]["__key"], bytes) for i in range(10)) + else: + assert res["total_results"] == 10 + assert len(res["results"]) == 10 + assert res["warnings"] == [] + assert res["execution_time"] > 0 + assert all(isinstance(res["results"][i]["__score"], str) for i in range(10)) + assert all(isinstance(res["results"][i]["__key"], str) for i in range(10)) + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + def test_hybrid_search_query_with_scorer(self, client): + # Create index and add data + self._create_hybrid_search_index(client) + self._add_data_for_hybrid_search(client, items_sets=10) + + # set search query + search_query = HybridSearchQuery("shoes") + search_query.scorer("TFIDF") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([1, 2, 2, 3], dtype=np.float32).tobytes(), ) - ) - vectors = [ - [1.0, 2.0, 3.0, 4.0], - [2.0, 3.0, 4.0, 5.0], - [3.0, 4.0, 5.0, 6.0], - [10.0, 11.0, 12.0, 13.0], - ] + hybrid_query = HybridQuery(search_query, vsim_query) - for i, vec in enumerate(vectors): - client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.combine("LINEAR", ALPHA=1, BETA=0) + posprocessing_config.load( + "@description", "@color", "@price", "@size", "@__score", "@__item" + ) + posprocessing_config.limit(0, 2) - query = "*=>[KNN 3 @v $vec]" - q = Query(query).return_field("__v_score").sort_by("__v_score", True) - res = client.ft().search( - q, query_params={"vec": np.array(vectors[0], dtype=np.float32).tobytes()} - ) + res = client.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) - if is_resp2_connection(client): - assert res.total == 3 - assert "doc0" == res.docs[0].id # Should be closest to itself - assert "0" == res.docs[0].__getattribute__("__v_score") - else: - assert res["total_results"] == 3 - assert "doc0" == res["results"][0]["id"] - assert "0" == res["results"][0]["extra_attributes"]["__v_score"] + expected_results_tfidf = [ + { + "description": b"red shoes", + "color": b"red", + "price": b"15", + "size": b"10", + "__score": b"2", + }, + { + "description": b"green shoes with red laces", + "color": b"green", + "price": b"16", + "size": b"11", + "__score": b"2", + }, + ] + if is_resp2_connection(client): + assert res.total_results >= 2 + assert len(res.results) == 2 + assert res.results == expected_results_tfidf + assert res.warnings == [] + else: + assert res["total_results"] >= 2 + assert len(res["results"]) == 2 + assert res["results"] == self._convert_dict_values_to_str( + expected_results_tfidf + ) + assert res["warnings"] == [] -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_lt("8.1.224") -def test_svs_vamana_float16_type(client): - client.ft().create_index( - ( - VectorField( - "v", - "SVS-VAMANA", - {"TYPE": "FLOAT16", "DIM": 4, "DISTANCE_METRIC": "L2"}, - ), + search_query.scorer("BM25") + res = client.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) + expected_results_bm25 = [ + { + "description": b"red shoes", + "color": b"red", + "price": b"15", + "size": b"10", + "__score": b"0.657894719299", + }, + { + "description": b"green shoes with red laces", + "color": b"green", + "price": b"16", + "size": b"11", + "__score": b"0.657894719299", + }, + ] + if is_resp2_connection(client): + assert res.total_results >= 2 + assert len(res.results) == 2 + assert res.results == expected_results_bm25 + assert res.warnings == [] + else: + assert res["total_results"] >= 2 + assert len(res["results"]) == 2 + assert res["results"] == self._convert_dict_values_to_str( + expected_results_bm25 + ) + assert res["warnings"] == [] + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + def test_hybrid_search_query_with_vsim_method_defined_query_init(self, client): + # Create index and add data + self._create_hybrid_search_index(client) + self._add_data_for_hybrid_search(client, items_sets=5, use_random_str_data=True) + # set search query + search_query = HybridSearchQuery("shoes") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding-hnsw", + vector_data="abcd1234efgh5678", + vsim_search_method="KNN", + vsim_search_method_params={"K": 3, "EF_RUNTIME": 1}, ) - ) - - vectors = [[1.5, 2.5, 3.5, 4.5], [2.5, 3.5, 4.5, 5.5], [3.5, 4.5, 5.5, 6.5]] - for i, vec in enumerate(vectors): - client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float16).tobytes()) + hybrid_query = HybridQuery(search_query, vsim_query) - query = Query("*=>[KNN 2 @v $vec as score]").no_content() - query_params = {"vec": np.array(vectors[0], dtype=np.float16).tobytes()} + res = client.ft().hybrid_search(query=hybrid_query, timeout=10) + if is_resp2_connection(client): + assert len(res.results) > 0 + assert res.warnings == [] + else: + assert len(res["results"]) > 0 + assert res["warnings"] == [] - res = client.ft().search(query, query_params=query_params) - if is_resp2_connection(client): - assert res.total == 2 - assert "doc0" == res.docs[0].id - else: - assert res["total_results"] == 2 - assert "doc0" == res["results"][0]["id"] + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + def test_hybrid_search_query_with_vsim_filter(self, client): + # Create index and add data + self._create_hybrid_search_index(client) + self._add_data_for_hybrid_search(client, items_sets=5, use_random_str_data=True) + search_query = HybridSearchQuery("@color:{missing}") -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_lt("8.1.224") -def test_svs_vamana_float32_type(client): - client.ft().create_index( - ( - VectorField( - "v", - "SVS-VAMANA", - {"TYPE": "FLOAT32", "DIM": 4, "DISTANCE_METRIC": "L2"}, - ), + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data="abcd1234efgh5678", ) - ) + vsim_query.filter(HybridFilter("@price:[15 16] @size:[10 11]")) - vectors = [[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0], [3.0, 4.0, 5.0, 6.0]] + hybrid_query = HybridQuery(search_query, vsim_query) - for i, vec in enumerate(vectors): - client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.load("@price", "@size") - query = Query("*=>[KNN 2 @v $vec as score]").no_content() - query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + res = client.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) + if is_resp2_connection(client): + assert len(res.results) > 0 + assert res.warnings == [] + for item in res.results: + assert item["price"] in [b"15", b"16"] + assert item["size"] in [b"10", b"11"] + else: + assert len(res["results"]) > 0 + assert res["warnings"] == [] + for item in res["results"]: + assert item["price"] in ["15", "16"] + assert item["size"] in ["10", "11"] + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + def test_hybrid_search_query_with_vsim_knn(self, client): + # Create index and add data + self._create_hybrid_search_index(client) + self._add_data_for_hybrid_search(client, items_sets=10) + + # set search query + # this query won't have results, so we will be able to validate vsim results + search_query = HybridSearchQuery("@color:{none}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([1, 2, 2, 3], dtype=np.float32).tobytes(), + ) - res = client.ft().search(query, query_params=query_params) - if is_resp2_connection(client): - assert res.total == 2 - assert "doc0" == res.docs[0].id - else: - assert res["total_results"] == 2 - assert "doc0" == res["results"][0]["id"] + vsim_query.vsim_method_params("KNN", K=3) + hybrid_query = HybridQuery(search_query, vsim_query) -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_lt("8.1.224") -def test_svs_vamana_vector_search_with_default_dialect(client): - client.ft().create_index( - ( - VectorField( - "v", - "SVS-VAMANA", - {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"}, - ), - ) - ) + posprocessing_config = HybridPostProcessingConfig() - client.hset("a", "v", "aaaaaaaa") - client.hset("b", "v", "aaaabaaa") - client.hset("c", "v", "aaaaabaa") + res = client.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) + expected_results = [ + {"__key": b"item:2", "__score": b"0.016393442623"}, + {"__key": b"item:7", "__score": b"0.0161290322581"}, + {"__key": b"item:12", "__score": b"0.015873015873"}, + ] + if is_resp2_connection(client): + assert res.total_results == 3 # KNN top-k value + assert len(res.results) == 3 + assert res.results == expected_results + assert res.warnings == [] + assert res.execution_time > 0 + else: + assert res["total_results"] == 3 # KNN top-k value + assert len(res["results"]) == 3 + assert res["results"] == self._convert_dict_values_to_str(expected_results) + assert res["warnings"] == [] + assert res["execution_time"] > 0 + + vsim_query_with_hnsw = HybridVsimQuery( + vector_field_name="@embedding-hnsw", + vector_data=np.array([1, 2, 2, 3], dtype=np.float32).tobytes(), + ) + vsim_query_with_hnsw.vsim_method_params("KNN", K=3, EF_RUNTIME=1) + hybrid_query_with_hnsw = HybridQuery(search_query, vsim_query_with_hnsw) - query = "*=>[KNN 2 @v $vec]" - q = Query(query).return_field("__v_score").sort_by("__v_score", True) - res = client.ft().search(q, query_params={"vec": "aaaaaaaa"}) + res2 = client.ft().hybrid_search(query=hybrid_query_with_hnsw, timeout=10) - if is_resp2_connection(client): - assert res.total == 2 - else: - assert res["total_results"] == 2 - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_lt("8.1.224") -def test_svs_vamana_vector_field_basic(): - field = VectorField( - "v", "SVS-VAMANA", {"TYPE": "FLOAT32", "DIM": 128, "DISTANCE_METRIC": "COSINE"} - ) - - # Check that the field was created successfully - assert field.name == "v" - assert field.args[0] == "VECTOR" - assert field.args[1] == "SVS-VAMANA" - assert field.args[2] == 6 - assert "TYPE" in field.args - assert "FLOAT32" in field.args - assert "DIM" in field.args - assert 128 in field.args - assert "DISTANCE_METRIC" in field.args - assert "COSINE" in field.args - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_lt("8.1.224") -def test_svs_vamana_lvq8_compression(client): - client.ft().create_index( - ( - VectorField( - "v", - "SVS-VAMANA", - { - "TYPE": "FLOAT32", - "DIM": 8, - "DISTANCE_METRIC": "L2", - "COMPRESSION": "LVQ8", - "TRAINING_THRESHOLD": 1024, - }, - ), + expected_results2 = [ + {"__key": b"item:12", "__score": b"0.016393442623"}, + {"__key": b"item:22", "__score": b"0.0161290322581"}, + {"__key": b"item:27", "__score": b"0.015873015873"}, + ] + if is_resp2_connection(client): + assert res2.total_results == 3 # KNN top-k value + assert len(res2.results) == 3 + assert res2.results == expected_results2 + assert res2.warnings == [] + assert res2.execution_time > 0 + else: + assert res2["total_results"] == 3 # KNN top-k value + assert len(res2["results"]) == 3 + assert res2["results"] == self._convert_dict_values_to_str( + expected_results2 + ) + assert res2["warnings"] == [] + assert res2["execution_time"] > 0 + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + def test_hybrid_search_query_with_vsim_range(self, client): + # Create index and add data + self._create_hybrid_search_index(client) + self._add_data_for_hybrid_search(client, items_sets=10) + + # set search query + # this query won't have results, so we will be able to validate vsim results + search_query = HybridSearchQuery("@color:{none}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), ) - ) - - vectors = [] - for i in range(20): - vec = [float(i + j) for j in range(8)] - vectors.append(vec) - client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) - query = Query("*=>[KNN 5 @v $vec as score]").no_content() - query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + vsim_query.vsim_method_params("RANGE", RADIUS=2) - res = client.ft().search(query, query_params=query_params) - if is_resp2_connection(client): - assert res.total == 5 - assert "doc0" == res.docs[0].id - else: - assert res["total_results"] == 5 - assert "doc0" == res["results"][0]["id"] + hybrid_query = HybridQuery(search_query, vsim_query) + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.limit(0, 3) -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_lt("8.1.224") -def test_svs_vamana_compression_with_both_vector_types(client): - # Test FLOAT16 with LVQ8 - client.ft("idx16").create_index( - ( - VectorField( - "v16", - "SVS-VAMANA", - { - "TYPE": "FLOAT16", - "DIM": 8, - "DISTANCE_METRIC": "L2", - "COMPRESSION": "LVQ8", - "TRAINING_THRESHOLD": 1024, - }, - ), + res = client.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 ) - ) - # Test FLOAT32 with LVQ8 - client.ft("idx32").create_index( - ( - VectorField( - "v32", - "SVS-VAMANA", - { - "TYPE": "FLOAT32", - "DIM": 8, - "DISTANCE_METRIC": "L2", - "COMPRESSION": "LVQ8", - "TRAINING_THRESHOLD": 1024, - }, - ), + expected_results = [ + {"__key": b"item:2", "__score": b"0.016393442623"}, + {"__key": b"item:7", "__score": b"0.0161290322581"}, + {"__key": b"item:12", "__score": b"0.015873015873"}, + ] + if is_resp2_connection(client): + assert res.total_results >= 3 # at least 3 results + assert len(res.results) == 3 + assert res.results == expected_results + assert res.warnings == [] + assert res.execution_time > 0 + else: + assert res["total_results"] >= 3 + assert len(res["results"]) == 3 + assert res["results"] == self._convert_dict_values_to_str(expected_results) + assert res["warnings"] == [] + assert res["execution_time"] > 0 + + vsim_query_with_hnsw = HybridVsimQuery( + vector_field_name="@embedding-hnsw", + vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), ) - ) - - # Add data to both indices - for i in range(15): - vec = [float(i + j) for j in range(8)] - client.hset(f"doc16_{i}", "v16", np.array(vec, dtype=np.float16).tobytes()) - client.hset(f"doc32_{i}", "v32", np.array(vec, dtype=np.float32).tobytes()) - - # Test both indices - query = Query("*=>[KNN 3 @v16 $vec as score]").no_content() - res16 = client.ft("idx16").search( - query, - query_params={ - "vec": np.array( - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], dtype=np.float16 - ).tobytes() - }, - ) - - query = Query("*=>[KNN 3 @v32 $vec as score]").no_content() - res32 = client.ft("idx32").search( - query, - query_params={ - "vec": np.array( - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], dtype=np.float32 - ).tobytes() - }, - ) - if is_resp2_connection(client): - assert res16.total == 3 - assert res32.total == 3 - else: - assert res16["total_results"] == 3 - assert res32["total_results"] == 3 + vsim_query_with_hnsw.vsim_method_params("RANGE", RADIUS=2, EPSILON=0.5) + hybrid_query_with_hnsw = HybridQuery(search_query, vsim_query_with_hnsw) -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_lt("8.1.224") -def test_svs_vamana_construction_window_size(client): - client.ft().create_index( - ( - VectorField( - "v", - "SVS-VAMANA", - { - "TYPE": "FLOAT32", - "DIM": 6, - "DISTANCE_METRIC": "L2", - "CONSTRUCTION_WINDOW_SIZE": 300, - }, - ), + res = client.ft().hybrid_search( + query=hybrid_query_with_hnsw, + post_processing=posprocessing_config, + timeout=10, ) - ) - vectors = [] - for i in range(20): - vec = [float(i + j) for j in range(6)] - vectors.append(vec) - client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + expected_results_hnsw = [ + {"__key": b"item:27", "__score": b"0.016393442623"}, + {"__key": b"item:12", "__score": b"0.0161290322581"}, + {"__key": b"item:22", "__score": b"0.015873015873"}, + ] - query = Query("*=>[KNN 5 @v $vec as score]").no_content() - query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + if is_resp2_connection(client): + assert res.total_results >= 3 + assert len(res.results) == 3 + assert res.results == expected_results_hnsw + assert res.warnings == [] + assert res.execution_time > 0 + else: + assert res["total_results"] >= 3 + assert len(res["results"]) == 3 + assert res["results"] == self._convert_dict_values_to_str( + expected_results_hnsw + ) + assert res["warnings"] == [] + assert res["execution_time"] > 0 + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + def test_hybrid_search_query_with_combine(self, client): + # Create index and add data + self._create_hybrid_search_index(client) + self._add_data_for_hybrid_search(client, items_sets=10) + + # set search query + search_query = HybridSearchQuery("@color:{red}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), + ) - res = client.ft().search(query, query_params=query_params) - if is_resp2_connection(client): - assert res.total == 5 - assert "doc0" == res.docs[0].id - else: - assert res["total_results"] == 5 - assert "doc0" == res["results"][0]["id"] + hybrid_query = HybridQuery(search_query, vsim_query) + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.combine("LINEAR", ALPHA=0.5, BETA=0.5) + posprocessing_config.limit(0, 3) -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_lt("8.1.224") -def test_svs_vamana_graph_max_degree(client): - client.ft().create_index( - ( - VectorField( - "v", - "SVS-VAMANA", - { - "TYPE": "FLOAT32", - "DIM": 6, - "DISTANCE_METRIC": "COSINE", - "GRAPH_MAX_DEGREE": 64, - }, - ), + res = client.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 ) - ) - vectors = [] - for i in range(25): - vec = [float(i + j) for j in range(6)] - vectors.append(vec) - client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + expected_results = [ + {"__key": b"item:2", "__score": b"0.166666666667"}, + {"__key": b"item:7", "__score": b"0.166666666667"}, + {"__key": b"item:12", "__score": b"0.166666666667"}, + ] + if is_resp2_connection(client): + assert res.total_results >= 3 + assert len(res.results) == 3 + assert res.results == expected_results + assert res.warnings == [] + assert res.execution_time > 0 + else: + assert res["total_results"] >= 3 + assert len(res["results"]) == 3 + assert res["results"] == self._convert_dict_values_to_str(expected_results) + assert res["warnings"] == [] + assert res["execution_time"] > 0 + + # combine with RRF and WINDOW + CONSTANT + posprocessing_config.combine("RRF", WINDOW=3, CONSTANT=0.5) + res = client.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) - query = Query("*=>[KNN 6 @v $vec as score]").no_content() - query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + expected_results = [ + {"__key": b"item:2", "__score": b"1.06666666667"}, + {"__key": b"item:0", "__score": b"0.666666666667"}, + {"__key": b"item:7", "__score": b"0.4"}, + ] + if is_resp2_connection(client): + assert res.total_results >= 3 + assert len(res.results) == 3 + assert res.results == expected_results + assert res.warnings == [] + assert res.execution_time > 0 + else: + assert res["total_results"] >= 3 + assert len(res["results"]) == 3 + assert res["results"] == self._convert_dict_values_to_str(expected_results) + assert res["warnings"] == [] + assert res["execution_time"] > 0 + + # # LINEAR combine with no params + # posprocessing_config.combine("LINEAR", ALPHA=0.5) + # res = client.ft().hybrid_search( + # query=hybrid_query, post_processing=posprocessing_config, timeout=10 + # ) + + # expected_results = [ + # {"__key": b"item:2", "__score": b"0.166666666667"}, + # {"__key": b"item:7", "__score": b"0.166666666667"}, + # {"__key": b"item:12", "__score": b"0.166666666667"}, + # ] + # if is_resp2_connection(client): + # assert res.total_results >= 3 + # assert len(res.results) == 3 + # assert res.results == expected_results + # assert res.warnings == [] + # assert res.execution_time > 0 + # else: + # assert res["total_results"] >= 3 + # assert len(res["results"]) == 3 + # assert res["results"] == self._convert_dict_values_to_str(expected_results) + # assert res["warnings"] == [] + # assert res["execution_time"] > 0 + + # combine with RRF, not all possible params provided + posprocessing_config.combine("RRF", WINDOW=3) + res = client.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) - res = client.ft().search(query, query_params=query_params) - if is_resp2_connection(client): - assert res.total == 6 - assert "doc0" == res.docs[0].id - else: - assert res["total_results"] == 6 - assert "doc0" == res["results"][0]["id"] + expected_results = [ + {"__key": b"item:2", "__score": b"0.032522474881"}, + {"__key": b"item:0", "__score": b"0.016393442623"}, + {"__key": b"item:7", "__score": b"0.0161290322581"}, + ] + if is_resp2_connection(client): + assert res.total_results >= 3 + assert len(res.results) == 3 + assert res.results == expected_results + assert res.warnings == [] + assert res.execution_time > 0 + else: + assert res["total_results"] >= 3 + assert len(res["results"]) == 3 + assert res["results"] == self._convert_dict_values_to_str(expected_results) + assert res["warnings"] == [] + assert res["execution_time"] > 0 + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + def test_hybrid_search_query_with_load(self, client): + # Create index and add data + self._create_hybrid_search_index(client) + self._add_data_for_hybrid_search(client, items_sets=10) + + # set search query + search_query = HybridSearchQuery("@color:{red|green|black}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), + ) + hybrid_query = HybridQuery(search_query, vsim_query) -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_lt("8.1.224") -def test_svs_vamana_search_window_size(client): - client.ft().create_index( - ( - VectorField( - "v", - "SVS-VAMANA", - { - "TYPE": "FLOAT32", - "DIM": 6, - "DISTANCE_METRIC": "L2", - "SEARCH_WINDOW_SIZE": 20, - }, - ), + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.combine("LINEAR", ALPHA=0.5, BETA=0.5) + posprocessing_config.load( + "@description", "@color", "@price", "@size", "@__key AS item_key" ) - ) + posprocessing_config.limit(0, 1) - vectors = [] - for i in range(30): - vec = [float(i + j) for j in range(6)] - vectors.append(vec) - client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) - - query = Query("*=>[KNN 8 @v $vec as score]").no_content() - query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + res = client.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) - res = client.ft().search(query, query_params=query_params) - if is_resp2_connection(client): - assert res.total == 8 - assert "doc0" == res.docs[0].id - else: - assert res["total_results"] == 8 - assert "doc0" == res["results"][0]["id"] + expected_results = [ + { + "description": b"red dress", + "color": b"red", + "price": b"17", + "size": b"12", + "item_key": b"item:2", + } + ] + if is_resp2_connection(client): + assert res.total_results >= 1 + assert len(res.results) == 1 + self.compare_list_of_dicts(res.results, expected_results) + assert res.warnings == [] + assert res.execution_time > 0 + else: + assert res["total_results"] >= 1 + assert len(res["results"]) == 1 + self.compare_list_of_dicts( + res["results"], self._convert_dict_values_to_str(expected_results) + ) + assert res["warnings"] == [] + assert res["execution_time"] > 0 + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + # @pytest.mark.repeat(6) + def test_hybrid_search_query_with_load_and_apply(self, client): + # Create index and add data + self._create_hybrid_search_index(client) + self._add_data_for_hybrid_search(client, items_sets=10) + + # set search query + search_query = HybridSearchQuery("@color:{red}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), + ) + hybrid_query = HybridQuery(search_query, vsim_query) -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_lt("8.1.224") -def test_svs_vamana_epsilon_parameter(client): - client.ft().create_index( - ( - VectorField( - "v", - "SVS-VAMANA", - {"TYPE": "FLOAT32", "DIM": 6, "DISTANCE_METRIC": "L2", "EPSILON": 0.05}, - ), + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.load("@color", "@price", "@size") + posprocessing_config.apply( + price_discount="@price - (@price * 0.1)", + tax_discount="@price_discount * 0.2", ) - ) + posprocessing_config.limit(0, 3) - vectors = [] - for i in range(20): - vec = [float(i + j) for j in range(6)] - vectors.append(vec) - client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + res = client.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) - query = Query("*=>[KNN 5 @v $vec as score]").no_content() - query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + expected_results = [ + { + "color": b"red", + "price": b"15", + "size": b"10", + "price_discount": b"13.5", + "tax_discount": b"2.7", + }, + { + "color": b"red", + "price": b"17", + "size": b"12", + "price_discount": b"15.3", + "tax_discount": b"3.06", + }, + { + "color": b"red", + "price": b"18", + "size": b"11", + "price_discount": b"16.2", + "tax_discount": b"3.24", + }, + ] + if is_resp2_connection(client): + assert len(res.results) == 3 + self.compare_list_of_dicts(res.results, expected_results) + assert res.warnings == [] + assert res.execution_time > 0 + else: + assert len(res["results"]) == 3 + self.compare_list_of_dicts( + res["results"], self._convert_dict_values_to_str(expected_results) + ) + assert res["warnings"] == [] + assert res["execution_time"] > 0 + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + def test_hybrid_search_query_with_load_and_filter(self, client): + # Create index and add data + self._create_hybrid_search_index(client) + self._add_data_for_hybrid_search(client, items_sets=10) + + # set search query + search_query = HybridSearchQuery("@color:{red|green|black}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), + ) - res = client.ft().search(query, query_params=query_params) - if is_resp2_connection(client): - assert res.total == 5 - assert "doc0" == res.docs[0].id - else: - assert res["total_results"] == 5 - assert "doc0" == res["results"][0]["id"] + hybrid_query = HybridQuery(search_query, vsim_query) + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.load("@description", "@color", "@price", "@size") + # for the postprocessing filter we need to filter on the loaded fields + # expecting all of them to be interpreted as strings - the initial filed types + # are not preserved + posprocessing_config.filter(HybridFilter('@price=="15"')) + posprocessing_config.limit(0, 3) -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_lt("8.1.224") -def test_svs_vamana_all_build_parameters_combined(client): - client.ft().create_index( - ( - VectorField( - "v", - "SVS-VAMANA", - { - "TYPE": "FLOAT32", - "DIM": 8, - "DISTANCE_METRIC": "IP", - "CONSTRUCTION_WINDOW_SIZE": 250, - "GRAPH_MAX_DEGREE": 48, - "SEARCH_WINDOW_SIZE": 15, - "EPSILON": 0.02, - }, - ), + res = client.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 ) - ) - - vectors = [] - for i in range(35): - vec = [float(i + j) for j in range(8)] - vectors.append(vec) - client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) - query = Query("*=>[KNN 7 @v $vec as score]").no_content() - query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} - - res = client.ft().search(query, query_params=query_params) - if is_resp2_connection(client): - assert res.total == 7 - doc_ids = [doc.id for doc in res.docs] - assert len(doc_ids) == 7 - else: - assert res["total_results"] == 7 - doc_ids = [doc["id"] for doc in res["results"]] - assert len(doc_ids) == 7 - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_lt("8.1.224") -def test_svs_vamana_comprehensive_configuration(client): - client.flushdb() - client.ft().create_index( - ( - VectorField( - "v", - "SVS-VAMANA", - { - "TYPE": "FLOAT16", - "DIM": 32, - "DISTANCE_METRIC": "COSINE", - "COMPRESSION": "LVQ8", - "CONSTRUCTION_WINDOW_SIZE": 400, - "GRAPH_MAX_DEGREE": 96, - "SEARCH_WINDOW_SIZE": 25, - "EPSILON": 0.03, - "TRAINING_THRESHOLD": 2048, - }, - ), + if is_resp2_connection(client): + assert len(res.results) == 3 + for item in res.results: + assert item["price"] == b"15" + assert res.warnings == [] + assert res.execution_time > 0 + else: + assert len(res["results"]) == 3 + for item in res["results"]: + assert item["price"] == "15" + assert res["warnings"] == [] + assert res["execution_time"] > 0 + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + def test_hybrid_search_query_with_load_apply_and_params(self, client): + # Create index and add data + self._create_hybrid_search_index(client) + self._add_data_for_hybrid_search(client, items_sets=5, use_random_str_data=True) + + # set search query + search_query = HybridSearchQuery("@color:{$color_criteria}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data="$vector", ) - ) - vectors = [] - for i in range(60): - vec = [float(i + j) for j in range(32)] - vectors.append(vec) - client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float16).tobytes()) + hybrid_query = HybridQuery(search_query, vsim_query) - query = Query("*=>[KNN 10 @v $vec as score]").no_content() - query_params = {"vec": np.array(vectors[0], dtype=np.float16).tobytes()} + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.load("@description", "@color", "@price") + posprocessing_config.apply(price_discount="@price - (@price * 0.1)") + posprocessing_config.limit(0, 3) - res = client.ft().search(query, query_params=query_params) - if is_resp2_connection(client): - assert res.total == 10 - assert "doc0" == res.docs[0].id - else: - assert res["total_results"] == 10 - assert "doc0" == res["results"][0]["id"] - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_lt("8.1.224") -def test_svs_vamana_hybrid_text_vector_search(client): - client.flushdb() - client.ft().create_index( - ( - TextField("title"), - TextField("content"), - VectorField( - "embedding", - "SVS-VAMANA", - { - "TYPE": "FLOAT32", - "DIM": 6, - "DISTANCE_METRIC": "COSINE", - "SEARCH_WINDOW_SIZE": 20, - }, - ), + params_substitution = { + "vector": "abcd1234abcd5678", + "color_criteria": "red", + } + + res = client.ft().hybrid_search( + query=hybrid_query, + post_processing=posprocessing_config, + params_substitution=params_substitution, + timeout=10, ) - ) - - docs = [ - { - "title": "AI Research", - "content": "machine learning algorithms", - "embedding": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], - }, - { - "title": "Data Science", - "content": "statistical analysis methods", - "embedding": [2.0, 3.0, 4.0, 5.0, 6.0, 7.0], - }, - { - "title": "Deep Learning", - "content": "neural network architectures", - "embedding": [3.0, 4.0, 5.0, 6.0, 7.0, 8.0], - }, - { - "title": "Computer Vision", - "content": "image processing techniques", - "embedding": [10.0, 11.0, 12.0, 13.0, 14.0, 15.0], - }, - ] - - for i, doc in enumerate(docs): - client.hset( - f"doc{i}", - mapping={ - "title": doc["title"], - "content": doc["content"], - "embedding": np.array(doc["embedding"], dtype=np.float32).tobytes(), + + expected_results = [ + { + "description": b"red shoes", + "color": b"red", + "price": b"15", + "price_discount": b"13.5", + }, + { + "description": b"red dress", + "color": b"red", + "price": b"17", + "price_discount": b"15.3", }, + { + "description": b"red shoes", + "color": b"red", + "price": b"16", + "price_discount": b"14.4", + }, + ] + if is_resp2_connection(client): + assert len(res.results) == 3 + assert res.results == expected_results + assert res.warnings == [] + assert res.execution_time > 0 + else: + assert len(res["results"]) == 3 + assert res["results"] == self._convert_dict_values_to_str(expected_results) + assert res["warnings"] == [] + assert res["execution_time"] > 0 + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + def test_hybrid_search_query_with_limit(self, client): + # Create index and add data + self._create_hybrid_search_index(client) + self._add_data_for_hybrid_search(client, items_sets=10) + + # set search query + search_query = HybridSearchQuery("@color:{red}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), ) - # Hybrid query - text filter + vector similarity - query = "(@title:AI|@content:machine)=>[KNN 2 @embedding $vec]" - q = ( - Query(query) - .return_field("__embedding_score") - .sort_by("__embedding_score", True) - ) - res = client.ft().search( - q, - query_params={ - "vec": np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=np.float32).tobytes() - }, - ) + hybrid_query = HybridQuery(search_query, vsim_query) - if is_resp2_connection(client): - assert res.total >= 1 - doc_ids = [doc.id for doc in res.docs] - assert "doc0" in doc_ids - else: - assert res["total_results"] >= 1 - doc_ids = [doc["id"] for doc in res["results"]] - assert "doc0" in doc_ids - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_lt("8.1.224") -def test_svs_vamana_large_dimension_vectors(client): - client.flushdb() - client.ft().create_index( - ( - VectorField( - "v", - "SVS-VAMANA", - { - "TYPE": "FLOAT32", - "DIM": 512, - "DISTANCE_METRIC": "L2", - "CONSTRUCTION_WINDOW_SIZE": 300, - "GRAPH_MAX_DEGREE": 64, - }, - ), + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.limit(0, 3) + + res = client.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 ) - ) - vectors = [] - for i in range(10): - vec = [float(i + j) for j in range(512)] - vectors.append(vec) - client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + if is_resp2_connection(client): + assert len(res.results) == 3 + assert res.warnings == [] + else: + assert len(res["results"]) == 3 + assert res["warnings"] == [] + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + def test_hybrid_search_query_with_load_apply_and_sortby(self, client): + # Create index and add data + self._create_hybrid_search_index(client) + self._add_data_for_hybrid_search(client, items_sets=1) + + # set search query + search_query = HybridSearchQuery("@color:{red|green}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), + ) - query = Query("*=>[KNN 5 @v $vec as score]").no_content() - query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + hybrid_query = HybridQuery(search_query, vsim_query) - res = client.ft().search(query, query_params=query_params) - if is_resp2_connection(client): - assert res.total == 5 - assert "doc0" == res.docs[0].id - else: - assert res["total_results"] == 5 - assert "doc0" == res["results"][0]["id"] + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.load("@color", "@price") + posprocessing_config.apply(price_discount="@price - (@price * 0.1)") + posprocessing_config.sort_by( + SortbyField("@price_discount", asc=False), SortbyField("@color", asc=True) + ) + posprocessing_config.limit(0, 5) + res = client.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_lt("8.1.224") -def test_svs_vamana_training_threshold_behavior(client): - client.ft().create_index( - ( - VectorField( - "v", - "SVS-VAMANA", - { - "TYPE": "FLOAT32", - "DIM": 8, - "DISTANCE_METRIC": "L2", - "COMPRESSION": "LVQ8", - "TRAINING_THRESHOLD": 1024, - }, - ), + expected_results = [ + {"color": b"orange", "price": b"18", "price_discount": b"16.2"}, + {"color": b"red", "price": b"17", "price_discount": b"15.3"}, + {"color": b"green", "price": b"16", "price_discount": b"14.4"}, + {"color": b"black", "price": b"15", "price_discount": b"13.5"}, + {"color": b"red", "price": b"15", "price_discount": b"13.5"}, + ] + if is_resp2_connection(client): + assert res.total_results >= 5 + assert len(res.results) == 5 + # the order here should match because of the sort + assert res.results == expected_results + assert res.warnings == [] + assert res.execution_time > 0 + else: + assert res["total_results"] >= 5 + assert len(res["results"]) == 5 + # the order here should match because of the sort + assert res["results"] == self._convert_dict_values_to_str(expected_results) + assert res["warnings"] == [] + assert res["execution_time"] > 0 + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + def test_hybrid_search_query_with_timeout(self, client): + dim = 128 + # Create index and add data + self._create_hybrid_search_index(client, dim=dim) + self._add_data_for_hybrid_search( + client, + items_sets=5000, + dim_for_random_data=dim, + use_random_str_data=True, ) - ) - vectors = [] - for i in range(20): - vec = [float(i + j) for j in range(8)] - vectors.append(vec) - client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + # set search query + search_query = HybridSearchQuery("*") - if i >= 5: - query = Query("*=>[KNN 3 @v $vec as score]").no_content() - query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} - res = client.ft().search(query, query_params=query_params) + vsim_query = HybridVsimQuery( + vector_field_name="@embedding-hnsw", + vector_data="abcd" * dim, + ) + vsim_query.vsim_method_params("KNN", K=1000) + vsim_query.filter( + HybridFilter( + "((@price:[15 16] @size:[10 11]) | (@price:[13 15] @size:[11 12])) @description:(shoes) -@description:(green)" + ) + ) - if is_resp2_connection(client): - assert res.total >= 1 - else: - assert res["total_results"] >= 1 + hybrid_query = HybridQuery(search_query, vsim_query) + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.combine("RRF", WINDOW=1000) -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_lt("8.1.224") -def test_svs_vamana_different_k_values(client): - client.ft().create_index( - ( - VectorField( - "v", - "SVS-VAMANA", - { - "TYPE": "FLOAT32", - "DIM": 6, - "DISTANCE_METRIC": "L2", - "SEARCH_WINDOW_SIZE": 15, - }, - ), + timeout = 5000 # 5 second timeout + res = client.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=timeout ) - ) - vectors = [] - for i in range(25): - vec = [float(i + j) for j in range(6)] - vectors.append(vec) - client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) - - for k in [1, 3, 5, 10, 15]: - query = Query(f"*=>[KNN {k} @v $vec as score]").no_content() - query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} - res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert len(res.results) > 0 + assert res.warnings == [] + assert res.execution_time > 0 and res.execution_time < timeout + else: + assert len(res["results"]) > 0 + assert res["warnings"] == [] + assert res["execution_time"] > 0 and res["execution_time"] < timeout + res = client.ft().hybrid_search(query=hybrid_query, timeout=1) # 1 ms timeout if is_resp2_connection(client): - assert res.total == k - assert "doc0" == res.docs[0].id + assert ( + b"Timeout limit was reached (VSIM)" in res.warnings + or b"Timeout limit was reached (SEARCH)" in res.warnings + ) else: - assert res["total_results"] == k - assert "doc0" == res["results"][0]["id"] + assert ( + "Timeout limit was reached (VSIM)" in res["warnings"] + or "Timeout limit was reached (SEARCH)" in res["warnings"] + ) + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + def test_hybrid_search_query_with_load_and_groupby(self, client): + # Create index and add data + self._create_hybrid_search_index(client) + self._add_data_for_hybrid_search(client, items_sets=10) + # set search query + search_query = HybridSearchQuery("@color:{red|green}") -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_lt("8.1.224") -def test_svs_vamana_vector_field_error(client): - # sortable tag - with pytest.raises(Exception): - client.ft().create_index((VectorField("v", "SVS-VAMANA", {}, sortable=True),)) + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), + ) - # no_index tag - with pytest.raises(Exception): - client.ft().create_index((VectorField("v", "SVS-VAMANA", {}, no_index=True),)) + hybrid_query = HybridQuery(search_query, vsim_query) + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.load("@color", "@price", "@size", "@item_type") + posprocessing_config.limit(0, 4) -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_lt("8.1.224") -def test_svs_vamana_vector_search_with_parameters(client): - client.ft().create_index( - ( - VectorField( - "v", - "SVS-VAMANA", - { - "TYPE": "FLOAT32", - "DIM": 4, - "DISTANCE_METRIC": "L2", - "CONSTRUCTION_WINDOW_SIZE": 200, - "GRAPH_MAX_DEGREE": 64, - "SEARCH_WINDOW_SIZE": 40, - "EPSILON": 0.01, - }, - ), + posprocessing_config.group_by( + ["@price"], + reducers.count_distinct("@color").alias("colors_count"), ) - ) - # Create test vectors - vectors = [ - [1.0, 2.0, 3.0, 4.0], - [2.0, 3.0, 4.0, 5.0], - [3.0, 4.0, 5.0, 6.0], - [4.0, 5.0, 6.0, 7.0], - [5.0, 6.0, 7.0, 8.0], - ] + posprocessing_config.sort_by(SortbyField("@price", asc=True)) - for i, vec in enumerate(vectors): - client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + res = client.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) - query = Query("*=>[KNN 3 @v $vec as score]").no_content() - query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + expected_results = [ + {"price": b"15", "colors_count": b"2"}, + {"price": b"16", "colors_count": b"2"}, + {"price": b"17", "colors_count": b"2"}, + {"price": b"18", "colors_count": b"2"}, + ] - res = client.ft().search(query, query_params=query_params) - if is_resp2_connection(client): - assert res.total == 3 - assert "doc0" == res.docs[0].id - else: - assert res["total_results"] == 3 - assert "doc0" == res["results"][0]["id"] + if is_resp2_connection(client): + assert len(res.results) == 4 + assert res.results == expected_results + assert res.warnings == [] + else: + assert len(res["results"]) == 4 + assert res["results"] == self._convert_dict_values_to_str(expected_results) + assert res["warnings"] == [] + + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.load("@color", "@price", "@size", "@item_type") + posprocessing_config.limit(0, 6) + posprocessing_config.sort_by( + SortbyField("@price", asc=True), + SortbyField("@item_type", asc=True), + ) + posprocessing_config.group_by( + ["@price", "@item_type"], + reducers.count_distinct("@color").alias("unique_colors_count"), + ) -@pytest.mark.redismod -@skip_ifmodversion_lt("2.4.3", "search") -@skip_if_server_version_lt("8.1.224") -def test_svs_vamana_vector_search_with_parameters_leanvec(client): - client.ft().create_index( - ( - VectorField( - "v", - "SVS-VAMANA", - { - "TYPE": "FLOAT32", - "DIM": 8, - "DISTANCE_METRIC": "L2", - "COMPRESSION": "LeanVec8x8", # LeanVec compression required for REDUCE - "CONSTRUCTION_WINDOW_SIZE": 200, - "GRAPH_MAX_DEGREE": 32, - "SEARCH_WINDOW_SIZE": 15, - "EPSILON": 0.01, - "TRAINING_THRESHOLD": 1024, - "REDUCE": 4, # Half of DIM (8/2 = 4) - }, - ), + res = client.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=1000 + ) + + expected_results = [ + {"price": b"15", "item_type": b"dress", "unique_colors_count": b"1"}, + {"price": b"15", "item_type": b"shoes", "unique_colors_count": b"2"}, + {"price": b"16", "item_type": b"dress", "unique_colors_count": b"1"}, + {"price": b"16", "item_type": b"shoes", "unique_colors_count": b"2"}, + {"price": b"17", "item_type": b"dress", "unique_colors_count": b"1"}, + {"price": b"17", "item_type": b"shoes", "unique_colors_count": b"2"}, + ] + if is_resp2_connection(client): + assert len(res.results) == 6 + assert res.results == expected_results + assert res.warnings == [] + else: + assert len(res["results"]) == 6 + assert res["results"] == self._convert_dict_values_to_str(expected_results) + assert res["warnings"] == [] + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + def test_hybrid_search_query_with_cursor(self, client): + # Create index and add data + self._create_hybrid_search_index(client) + self._add_data_for_hybrid_search(client, items_sets=10) + + # set search query + search_query = HybridSearchQuery("@color:{red|green}") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), ) - ) - # Create test vectors (8-dimensional to match DIM) - vectors = [ - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], - [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], - [3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], - [4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0], - [5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], - ] + hybrid_query = HybridQuery(search_query, vsim_query) - for i, vec in enumerate(vectors): - client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + res = client.ft().hybrid_search( + query=hybrid_query, + cursor=HybridCursorQuery(count=5, max_idle=100), + timeout=10, + ) + if is_resp2_connection(client): + assert isinstance(res, HybridCursorResult) + assert res.search_cursor_id > 0 + assert res.vsim_cursor_id > 0 + search_cursor = aggregations.Cursor(res.search_cursor_id) + vsim_cursor = aggregations.Cursor(res.vsim_cursor_id) + else: + assert res["SEARCH"] > 0 + assert res["VSIM"] > 0 + search_cursor = aggregations.Cursor(res["SEARCH"]) + vsim_cursor = aggregations.Cursor(res["VSIM"]) - query = Query("*=>[KNN 3 @v $vec as score]").no_content() - query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + search_res_from_cursor = client.ft().aggregate(query=search_cursor) + if is_resp2_connection(client): + assert len(search_res_from_cursor.rows) == 5 + else: + assert len(search_res_from_cursor[0]["results"]) == 5 - res = client.ft().search(query, query_params=query_params) - if is_resp2_connection(client): - assert res.total == 3 - assert "doc0" == res.docs[0].id - else: - assert res["total_results"] == 3 - assert "doc0" == res["results"][0]["id"] + vsim_res_from_cursor = client.ft().aggregate(query=vsim_cursor) + if is_resp2_connection(client): + assert len(vsim_res_from_cursor.rows) == 5 + else: + assert len(vsim_res_from_cursor[0]["results"]) == 5 From b5fb7d85e280da88d929a301d2cb30e872364742 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Tue, 4 Nov 2025 12:54:49 +0200 Subject: [PATCH 2/7] Adding YIELD_SCORE_AS tests and clearing up support for the keyword --- redis/commands/search/hybrid_query.py | 25 ++-- tests/test_asyncio/test_search.py | 64 +++++++++ tests/test_search.py | 187 ++++++++++++++++++++++++++ 3 files changed, 261 insertions(+), 15 deletions(-) diff --git a/redis/commands/search/hybrid_query.py b/redis/commands/search/hybrid_query.py index ddcb2593e9..fc758f92d4 100644 --- a/redis/commands/search/hybrid_query.py +++ b/redis/commands/search/hybrid_query.py @@ -14,9 +14,7 @@ def __init__( self, query_string: str, scorer: Optional[str] = None, - yield_score_as: Optional[ - str - ] = None, ## TODO check if this will be supported or it should be removed! + yield_score_as: Optional[str] = None, ) -> None: """ Create a new hybrid search query object. @@ -42,13 +40,18 @@ def scorer(self, scorer: str) -> "HybridSearchQuery": self._scorer = scorer return self + def yield_score_as(self, alias: str) -> "HybridSearchQuery": + """ + Yield the score as a field. + """ + self._yield_score_as = alias + return self + def get_args(self) -> List[str]: args = ["SEARCH", self._query_string] if self._scorer: args.extend(("SCORER", self._scorer)) - if ( - self._yield_score_as - ): # TODO check if this will be supported or it should be removed! + if self._yield_score_as: args.extend(("YIELD_SCORE_AS", self._yield_score_as)) return args @@ -109,7 +112,7 @@ def vsim_method_params( for key, value in kwargs.items(): vsim_method_params.extend((key, value)) self._vsim_method_params = vsim_method_params - print(self._vsim_method_params) + return self def filter(self, flt: "HybridFilter") -> "HybridVsimQuery": @@ -171,9 +174,6 @@ def __init__(self) -> None: def combine( self, method: Literal["RRF", "LINEAR"], - yield_score_as: Optional[ - str - ] = None, # TODO check if this will be supported or it should be removed! **kwargs, ) -> Self: """ @@ -181,7 +181,6 @@ def combine( Args: method: The combine method to use - RRF or LINEAR. - yield_score_as: Optional field name to yield the score as. kwargs: Additional combine parameters. """ self._combine: List[Union[str, int]] = [method] @@ -191,10 +190,6 @@ def combine( for key, value in kwargs.items(): self._combine.extend([key, value]) - if ( - yield_score_as - ): # TODO check if this will be supported or it should be removed! - self._combine.extend(["YIELD_SCORE_AS", yield_score_as]) return self def load(self, *fields: str) -> Self: diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 45c9d1207b..6bc1c33ffb 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -2605,6 +2605,70 @@ async def test_hybrid_search_query_with_vsim_range(self, decoded_r): assert res["warnings"] == [] assert res["execution_time"] > 0 + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + async def test_hybrid_search_query_with_combine_all_score_aliases(self, decoded_r): + # Create index and add data + await self._create_hybrid_search_index(decoded_r) + await self._add_data_for_hybrid_search( + decoded_r, items_sets=1, use_random_str_data=True + ) + + search_query = HybridSearchQuery("shoes") + search_query.yield_score_as("search_score") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding-hnsw", + vector_data="abcd1234efgh5678", + vsim_search_method="KNN", + vsim_search_method_params={ + "K": 3, + "EF_RUNTIME": 1, + "YIELD_SCORE_AS": "vsim_score", + }, + ) + + hybrid_query = HybridQuery(search_query, vsim_query) + + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.combine( + "LINEAR", ALPHA=0.5, BETA=0.5, YIELD_SCORE_AS="combined_score" + ) + + res = await decoded_r.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) + + if is_resp2_connection(decoded_r): + assert len(res.results) > 0 + assert res.warnings == [] + for item in res.results: + assert item["combined_score"] is not None + assert "__score" not in item + if item["__key"] in [b"item:0", b"item:1", b"item:4"]: + assert item["search_score"] is not None + else: + assert "search_score" not in item + if item["__key"] in [b"item:0", b"item:1", b"item:2"]: + assert item["vsim_score"] is not None + else: + assert "vsim_score" not in item + + else: + assert len(res["results"]) > 0 + assert res["warnings"] == [] + for item in res["results"]: + assert item["combined_score"] is not None + assert "__score" not in item + if item["__key"] in ["item:0", "item:1", "item:4"]: + assert item["search_score"] is not None + else: + assert "search_score" not in item + if item["__key"] in ["item:0", "item:1", "item:2"]: + assert item["vsim_score"] is not None + else: + assert "vsim_score" not in item + @pytest.mark.redismod @skip_if_server_version_lt("8.3.224") async def test_hybrid_search_query_with_combine(self, decoded_r): diff --git a/tests/test_search.py b/tests/test_search.py index 25d86d0f36..b2046d56e8 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -4281,6 +4281,193 @@ def test_hybrid_search_query_with_vsim_filter(self, client): assert item["price"] in ["15", "16"] assert item["size"] in ["10", "11"] + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + def test_hybrid_search_query_with_search_score_aliases(self, client): + # Create index and add data + self._create_hybrid_search_index(client) + self._add_data_for_hybrid_search(client, items_sets=1, use_random_str_data=True) + + search_query = HybridSearchQuery("shoes") + search_query.yield_score_as("search_score") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data="abcd1234efgh5678", + ) + + hybrid_query = HybridQuery(search_query, vsim_query) + + res = client.ft().hybrid_search(query=hybrid_query, timeout=10) + + if is_resp2_connection(client): + assert len(res.results) > 0 + assert res.warnings == [] + for item in res.results: + if item["__key"] in [b"item:0", b"item:1", b"item:4"]: + assert item["search_score"] is not None + assert item["__score"] is not None + else: + assert "search_score" not in item + assert item["__score"] is not None + + else: + assert len(res["results"]) > 0 + assert res["warnings"] == [] + for item in res["results"]: + if item["__key"] in ["item:0", "item:1", "item:4"]: + assert item["search_score"] is not None + assert item["__score"] is not None + else: + assert "search_score" not in item + assert item["__score"] is not None + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + def test_hybrid_search_query_with_vsim_score_aliases(self, client): + # Create index and add data + self._create_hybrid_search_index(client) + self._add_data_for_hybrid_search(client, items_sets=1, use_random_str_data=True) + + search_query = HybridSearchQuery("shoes") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding-hnsw", + vector_data="abcd1234efgh5678", + vsim_search_method="KNN", + vsim_search_method_params={ + "K": 3, + "EF_RUNTIME": 1, + "YIELD_SCORE_AS": "vsim_score", + }, + ) + + hybrid_query = HybridQuery(search_query, vsim_query) + + res = client.ft().hybrid_search(query=hybrid_query, timeout=10) + + if is_resp2_connection(client): + assert len(res.results) > 0 + assert res.warnings == [] + for item in res.results: + if item["__key"] in [b"item:0", b"item:1", b"item:2"]: + assert item["vsim_score"] is not None + assert item["__score"] is not None + else: + assert "vsim_score" not in item + assert item["__score"] is not None + + else: + assert len(res["results"]) > 0 + assert res["warnings"] == [] + for item in res["results"]: + if item["__key"] in ["item:0", "item:1", "item:2"]: + assert item["vsim_score"] is not None + assert item["__score"] is not None + else: + assert "vsim_score" not in item + assert item["__score"] is not None + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + def test_hybrid_search_query_with_combine_score_aliases(self, client): + # Create index and add data + self._create_hybrid_search_index(client) + self._add_data_for_hybrid_search(client, items_sets=1, use_random_str_data=True) + + search_query = HybridSearchQuery("shoes") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding-hnsw", vector_data="abcd1234efgh5678" + ) + + hybrid_query = HybridQuery(search_query, vsim_query) + + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.combine( + "LINEAR", ALPHA=0.5, BETA=0.5, YIELD_SCORE_AS="combined_score" + ) + + res = client.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) + + if is_resp2_connection(client): + assert len(res.results) > 0 + assert res.warnings == [] + for item in res.results: + assert item["combined_score"] is not None + assert "__score" not in item + + else: + assert len(res["results"]) > 0 + assert res["warnings"] == [] + for item in res["results"]: + assert item["combined_score"] is not None + assert "__score" not in item + + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + def test_hybrid_search_query_with_combine_all_score_aliases(self, client): + # Create index and add data + self._create_hybrid_search_index(client) + self._add_data_for_hybrid_search(client, items_sets=1, use_random_str_data=True) + + search_query = HybridSearchQuery("shoes") + search_query.yield_score_as("search_score") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding-hnsw", + vector_data="abcd1234efgh5678", + vsim_search_method="KNN", + vsim_search_method_params={ + "K": 3, + "EF_RUNTIME": 1, + "YIELD_SCORE_AS": "vsim_score", + }, + ) + + hybrid_query = HybridQuery(search_query, vsim_query) + + posprocessing_config = HybridPostProcessingConfig() + posprocessing_config.combine( + "LINEAR", ALPHA=0.5, BETA=0.5, YIELD_SCORE_AS="combined_score" + ) + + res = client.ft().hybrid_search( + query=hybrid_query, post_processing=posprocessing_config, timeout=10 + ) + + if is_resp2_connection(client): + assert len(res.results) > 0 + assert res.warnings == [] + for item in res.results: + assert item["combined_score"] is not None + assert "__score" not in item + if item["__key"] in [b"item:0", b"item:1", b"item:4"]: + assert item["search_score"] is not None + else: + assert "search_score" not in item + if item["__key"] in [b"item:0", b"item:1", b"item:2"]: + assert item["vsim_score"] is not None + else: + assert "vsim_score" not in item + + else: + assert len(res["results"]) > 0 + assert res["warnings"] == [] + for item in res["results"]: + assert item["combined_score"] is not None + assert "__score" not in item + if item["__key"] in ["item:0", "item:1", "item:4"]: + assert item["search_score"] is not None + else: + assert "search_score" not in item + if item["__key"] in ["item:0", "item:1", "item:2"]: + assert item["vsim_score"] is not None + else: + assert "vsim_score" not in item + @pytest.mark.redismod @skip_if_server_version_lt("8.3.224") def test_hybrid_search_query_with_vsim_knn(self, client): From bd777abad1a99deb786a5cc75e567b2ef899baff Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Tue, 4 Nov 2025 14:49:50 +0200 Subject: [PATCH 3/7] Removing commented test code. --- tests/test_asyncio/test_search.py | 24 ------------------------ tests/test_search.py | 24 ------------------------ 2 files changed, 48 deletions(-) diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 6bc1c33ffb..6c16cc722b 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -2736,30 +2736,6 @@ async def test_hybrid_search_query_with_combine(self, decoded_r): assert res["warnings"] == [] assert res["execution_time"] > 0 - # # LINEAR combine with no params - # posprocessing_config.combine("LINEAR", ALPHA=0.5) - # res = client.ft().hybrid_search( - # query=hybrid_query, post_processing=posprocessing_config, timeout=10 - # ) - - # expected_results = [ - # {"__key": b"item:2", "__score": b"0.166666666667"}, - # {"__key": b"item:7", "__score": b"0.166666666667"}, - # {"__key": b"item:12", "__score": b"0.166666666667"}, - # ] - # if is_resp2_connection(client): - # assert res.total_results >= 3 - # assert len(res.results) == 3 - # assert res.results == expected_results - # assert res.warnings == [] - # assert res.execution_time > 0 - # else: - # assert res["total_results"] >= 3 - # assert len(res["results"]) == 3 - # assert res["results"] == self._convert_dict_values_to_str(expected_results) - # assert res["warnings"] == [] - # assert res["execution_time"] > 0 - # combine with RRF, not all possible params provided posprocessing_config.combine("RRF", WINDOW=3) res = await decoded_r.ft().hybrid_search( diff --git a/tests/test_search.py b/tests/test_search.py index b2046d56e8..6a19ec1ae9 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -4688,30 +4688,6 @@ def test_hybrid_search_query_with_combine(self, client): assert res["warnings"] == [] assert res["execution_time"] > 0 - # # LINEAR combine with no params - # posprocessing_config.combine("LINEAR", ALPHA=0.5) - # res = client.ft().hybrid_search( - # query=hybrid_query, post_processing=posprocessing_config, timeout=10 - # ) - - # expected_results = [ - # {"__key": b"item:2", "__score": b"0.166666666667"}, - # {"__key": b"item:7", "__score": b"0.166666666667"}, - # {"__key": b"item:12", "__score": b"0.166666666667"}, - # ] - # if is_resp2_connection(client): - # assert res.total_results >= 3 - # assert len(res.results) == 3 - # assert res.results == expected_results - # assert res.warnings == [] - # assert res.execution_time > 0 - # else: - # assert res["total_results"] >= 3 - # assert len(res["results"]) == 3 - # assert res["results"] == self._convert_dict_values_to_str(expected_results) - # assert res["warnings"] == [] - # assert res["execution_time"] > 0 - # combine with RRF, not all possible params provided posprocessing_config.combine("RRF", WINDOW=3) res = client.ft().hybrid_search( From a0985b4072ed5934f26b484eefa78edd5c270c51 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Tue, 4 Nov 2025 20:03:56 +0200 Subject: [PATCH 4/7] Applying review comments - part 1 --- redis/commands/search/commands.py | 11 +++ redis/commands/search/hybrid_query.py | 79 ++++++++++------ tests/test_asyncio/test_search.py | 86 +++++++++++++----- tests/test_search.py | 126 +++++++++++++++++++------- 4 files changed, 216 insertions(+), 86 deletions(-) diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index ae50a303ff..9064f8b7ee 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -5,6 +5,7 @@ from redis._parsers.helpers import pairs_to_dict from redis.client import NEVER_DECODE, Pipeline from redis.commands.search.hybrid_query import ( + CombineResultsMethod, HybridCursorQuery, HybridPostProcessingConfig, HybridQuery, @@ -562,6 +563,7 @@ def search( def hybrid_search( self, query: HybridQuery, + combine_method: Optional[CombineResultsMethod] = None, post_processing: Optional[HybridPostProcessingConfig] = None, params_substitution: Optional[Dict[str, Union[str, int, float, bytes]]] = None, timeout: Optional[int] = None, @@ -573,6 +575,8 @@ def hybrid_search( Args: - **query**: HybridQuery object Contains the text and vector queries + - **combine_method**: CombineResultsMethod object + Contains the combine method and parameters - **post_processing**: HybridPostProcessingConfig object Contains the post processing configuration - **params_substitution**: Dict[str, Union[str, int, float, bytes]] @@ -587,6 +591,8 @@ def hybrid_search( options = {} pieces = [HYBRID_CMD, index] pieces.extend(query.get_args()) + if combine_method: + pieces.extend(combine_method.get_args()) if post_processing: pieces.extend(post_processing.build_args()) if params_substitution: @@ -1050,6 +1056,7 @@ async def search( async def hybrid_search( self, query: HybridQuery, + combine_method: Optional[CombineResultsMethod] = None, post_processing: Optional[HybridPostProcessingConfig] = None, params_substitution: Optional[Dict[str, Union[str, int, float, bytes]]] = None, timeout: Optional[int] = None, @@ -1061,6 +1068,8 @@ async def hybrid_search( Args: - **query**: HybridQuery object Contains the text and vector queries + - **combine_method**: CombineResultsMethod object + Contains the combine method and parameters - **post_processing**: HybridPostProcessingConfig object Contains the post processing configuration - **params_substitution**: Dict[str, Union[str, int, float, bytes]] @@ -1075,6 +1084,8 @@ async def hybrid_search( options = {} pieces = [HYBRID_CMD, index] pieces.extend(query.get_args()) + if combine_method: + pieces.extend(combine_method.get_args()) if post_processing: pieces.extend(post_processing.build_args()) if params_substitution: diff --git a/redis/commands/search/hybrid_query.py b/redis/commands/search/hybrid_query.py index fc758f92d4..6b172a05c9 100644 --- a/redis/commands/search/hybrid_query.py +++ b/redis/commands/search/hybrid_query.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Any, Dict, List, Literal, Optional, Union try: @@ -35,7 +36,7 @@ def query_string(self) -> str: def scorer(self, scorer: str) -> "HybridSearchQuery": """ Scoring algorithm for text search query. - Allowed values are "TFIDF" or "BM25" + Allowed values are "TFIDF", "DISMAX", "DOCSCORE", "BM25", etc. """ self._scorer = scorer return self @@ -56,12 +57,17 @@ def get_args(self) -> List[str]: return args +class VectorSearchMethods(Enum): + KNN = "KNN" + RANGE = "RANGE" + + class HybridVsimQuery: def __init__( self, vector_field_name: str, vector_data: Union[bytes, str], - vsim_search_method: Optional[str] = None, + vsim_search_method: Optional[VectorSearchMethods] = None, vsim_search_method_params: Optional[Dict[str, Any]] = None, filter: Optional["Filter"] = None, ) -> None: @@ -70,11 +76,22 @@ def __init__( Args: vector_field_name: Vector field name. + vector_data: Vector data for the search. + vsim_search_method: Search method that will be used for the vsim search. - Allowed values are "KNN" or "RANGE". + vsim_search_method_params: Search method parameters. Use the param names - for keys and the values for the values. Example: {"K": 10, "EF_RUNTIME": 100}. + for keys and the values for the values. + Example for KNN: {"K": 10, "EF_RUNTIME": 100} + where K is mandatory and defines the number of results + and EF_RUNTIME is optional and definesthe exploration factor. + Example for RANGE: {"RADIUS": 10, "EPSILON": 0.1} + where RADIUS is mandatory and defines the radius of the search + and EPSILON is optional and defines the accuracy of the search. + For both KNN and RANGE, the following parameter is optional: + YIELD_SCORE_AS: The name of the field to yield the calculated score as. + filter: If defined, a filter will be applied on the vsim query results. """ self._vector_field = vector_field_name @@ -95,7 +112,7 @@ def vector_data(self) -> Union[bytes, str]: def vsim_method_params( self, - method: str, + method: VectorSearchMethods, **kwargs, ) -> "HybridVsimQuery": """ @@ -106,7 +123,7 @@ def vsim_method_params( kwargs: Search method parameters. Use the param names for keys and the values for the values. Example: {"K": 10, "EF_RUNTIME": 100}. """ - vsim_method_params: List[Union[str, int]] = [method] + vsim_method_params: List[Union[str, int]] = [method.value] if kwargs: vsim_method_params.append(len(kwargs.items()) * 2) for key, value in kwargs.items(): @@ -158,12 +175,37 @@ def get_args(self) -> List[str]: return args +class CombinationMethods(Enum): + RRF = "RRF" + LINEAR = "LINEAR" + + +class CombineResultsMethod: + def __init__(self, method: CombinationMethods, **kwargs) -> None: + """ + Create a new combine results method object. + + Args: + method: The combine method to use - RRF or LINEAR. + kwargs: Additional combine parameters. + """ + self._method = method + self._kwargs = kwargs + + def get_args(self) -> List[Union[str, int]]: + args: List[Union[str, int]] = ["COMBINE", self._method.value] + if self._kwargs: + args.append(len(self._kwargs.items()) * 2) + for key, value in self._kwargs.items(): + args.extend((key, value)) + return args + + class HybridPostProcessingConfig: def __init__(self) -> None: """ Create a new hybrid post processing configuration object. """ - self._combine = [] self._load_fields = [] self._groupby = [] self._apply = [] @@ -171,27 +213,6 @@ def __init__(self) -> None: self._filter = None self._limit = None - def combine( - self, - method: Literal["RRF", "LINEAR"], - **kwargs, - ) -> Self: - """ - Add combine parameters to the query. - - Args: - method: The combine method to use - RRF or LINEAR. - kwargs: Additional combine parameters. - """ - self._combine: List[Union[str, int]] = [method] - - self._combine.append(len(kwargs) * 2) - - for key, value in kwargs.items(): - self._combine.extend([key, value]) - - return self - def load(self, *fields: str) -> Self: """ Add load parameters to the query. @@ -267,8 +288,6 @@ def limit(self, offset: int, num: int) -> Self: def build_args(self) -> List[str]: args = [] - if self._combine: - args.extend(("COMBINE", *self._combine)) if self._load_fields: fields_str = " ".join(self._load_fields) fields = fields_str.split(" ") diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 6c16cc722b..6b83fc05fe 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -12,12 +12,15 @@ import redis.asyncio as redis import redis.commands.search.aggregation as aggregations from redis.commands.search.hybrid_query import ( + CombinationMethods, + CombineResultsMethod, HybridCursorQuery, HybridFilter, HybridPostProcessingConfig, HybridQuery, HybridSearchQuery, HybridVsimQuery, + VectorSearchMethods, ) from redis.commands.search.hybrid_result import HybridCursorResult import redis.commands.search.reducers as reducers @@ -2320,7 +2323,6 @@ async def test_basic_hybrid_search(self, decoded_r): assert all(isinstance(res["results"][i]["__key"], str) for i in range(10)) @pytest.mark.redismod - # @pytest.mark.timeout(900) @skip_if_server_version_lt("8.3.224") async def test_hybrid_search_query_with_scorer(self, decoded_r): # Create index and add data @@ -2338,15 +2340,21 @@ async def test_hybrid_search_query_with_scorer(self, decoded_r): hybrid_query = HybridQuery(search_query, vsim_query) + combine_method = CombineResultsMethod( + CombinationMethods.LINEAR, ALPHA=1, BETA=0 + ) + posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.combine("LINEAR", ALPHA=1, BETA=0) posprocessing_config.load( "@description", "@color", "@price", "@size", "@__score", "@__item" ) posprocessing_config.limit(0, 2) res = await decoded_r.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, + combine_method=combine_method, + post_processing=posprocessing_config, + timeout=10, ) expected_results_tfidf = [ @@ -2381,7 +2389,10 @@ async def test_hybrid_search_query_with_scorer(self, decoded_r): search_query.scorer("BM25") res = await decoded_r.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, + combine_method=combine_method, + post_processing=posprocessing_config, + timeout=10, ) expected_results_bm25 = [ { @@ -2466,7 +2477,7 @@ async def test_hybrid_search_query_with_vsim_knn(self, decoded_r): vector_data=np.array([1, 2, 2, 3], dtype=np.float32).tobytes(), ) - vsim_query.vsim_method_params("KNN", K=3) + vsim_query.vsim_method_params(VectorSearchMethods.KNN, K=3) hybrid_query = HybridQuery(search_query, vsim_query) @@ -2497,7 +2508,9 @@ async def test_hybrid_search_query_with_vsim_knn(self, decoded_r): vector_field_name="@embedding-hnsw", vector_data=np.array([1, 2, 2, 3], dtype=np.float32).tobytes(), ) - vsim_query_with_hnsw.vsim_method_params("KNN", K=3, EF_RUNTIME=1) + vsim_query_with_hnsw.vsim_method_params( + VectorSearchMethods.KNN, K=3, EF_RUNTIME=1 + ) hybrid_query_with_hnsw = HybridQuery(search_query, vsim_query_with_hnsw) res2 = await decoded_r.ft().hybrid_search( @@ -2540,7 +2553,7 @@ async def test_hybrid_search_query_with_vsim_range(self, decoded_r): vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), ) - vsim_query.vsim_method_params("RANGE", RADIUS=2) + vsim_query.vsim_method_params(VectorSearchMethods.RANGE, RADIUS=2) hybrid_query = HybridQuery(search_query, vsim_query) @@ -2574,7 +2587,9 @@ async def test_hybrid_search_query_with_vsim_range(self, decoded_r): vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), ) - vsim_query_with_hnsw.vsim_method_params("RANGE", RADIUS=2, EPSILON=0.5) + vsim_query_with_hnsw.vsim_method_params( + VectorSearchMethods.RANGE, RADIUS=2, EPSILON=0.5 + ) hybrid_query_with_hnsw = HybridQuery(search_query, vsim_query_with_hnsw) @@ -2620,7 +2635,7 @@ async def test_hybrid_search_query_with_combine_all_score_aliases(self, decoded_ vsim_query = HybridVsimQuery( vector_field_name="@embedding-hnsw", vector_data="abcd1234efgh5678", - vsim_search_method="KNN", + vsim_search_method=VectorSearchMethods.KNN, vsim_search_method_params={ "K": 3, "EF_RUNTIME": 1, @@ -2630,13 +2645,15 @@ async def test_hybrid_search_query_with_combine_all_score_aliases(self, decoded_ hybrid_query = HybridQuery(search_query, vsim_query) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.combine( - "LINEAR", ALPHA=0.5, BETA=0.5, YIELD_SCORE_AS="combined_score" + combine_method = CombineResultsMethod( + CombinationMethods.LINEAR, + ALPHA=0.5, + BETA=0.5, + YIELD_SCORE_AS="combined_score", ) res = await decoded_r.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, combine_method=combine_method, timeout=10 ) if is_resp2_connection(decoded_r): @@ -2686,12 +2703,18 @@ async def test_hybrid_search_query_with_combine(self, decoded_r): hybrid_query = HybridQuery(search_query, vsim_query) + combine_method_linear = CombineResultsMethod( + CombinationMethods.LINEAR, ALPHA=0.5, BETA=0.5 + ) + posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.combine("LINEAR", ALPHA=0.5, BETA=0.5) posprocessing_config.limit(0, 3) res = await decoded_r.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, + combine_method=combine_method_linear, + post_processing=posprocessing_config, + timeout=10, ) expected_results = [ @@ -2713,9 +2736,14 @@ async def test_hybrid_search_query_with_combine(self, decoded_r): assert res["execution_time"] > 0 # combine with RRF and WINDOW + CONSTANT - posprocessing_config.combine("RRF", WINDOW=3, CONSTANT=0.5) + combine_method_rrf = CombineResultsMethod( + CombinationMethods.RRF, WINDOW=3, CONSTANT=0.5 + ) res = await decoded_r.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, + combine_method=combine_method_rrf, + post_processing=posprocessing_config, + timeout=10, ) expected_results = [ @@ -2737,9 +2765,12 @@ async def test_hybrid_search_query_with_combine(self, decoded_r): assert res["execution_time"] > 0 # combine with RRF, not all possible params provided - posprocessing_config.combine("RRF", WINDOW=3) + combine_method_rrf_2 = CombineResultsMethod(CombinationMethods.RRF, WINDOW=3) res = await decoded_r.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, + combine_method=combine_method_rrf_2, + post_processing=posprocessing_config, + timeout=10, ) expected_results = [ @@ -2777,15 +2808,21 @@ async def test_hybrid_search_query_with_load(self, decoded_r): hybrid_query = HybridQuery(search_query, vsim_query) + combine_method = CombineResultsMethod( + CombinationMethods.LINEAR, ALPHA=0.5, BETA=0.5 + ) + posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.combine("LINEAR", ALPHA=0.5, BETA=0.5) posprocessing_config.load( "@description", "@color", "@price", "@size", "@__key AS item_key" ) posprocessing_config.limit(0, 1) res = await decoded_r.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, + combine_method=combine_method, + post_processing=posprocessing_config, + timeout=10, ) expected_results = [ @@ -3088,7 +3125,7 @@ async def test_hybrid_search_query_with_timeout(self, decoded_r): vector_field_name="@embedding-hnsw", vector_data="abcd" * dim, ) - vsim_query.vsim_method_params("KNN", K=1000) + vsim_query.vsim_method_params(VectorSearchMethods.KNN, K=1000) vsim_query.filter( HybridFilter( "((@price:[15 16] @size:[10 11]) | (@price:[13 15] @size:[11 12])) @description:(shoes) -@description:(green)" @@ -3097,12 +3134,11 @@ async def test_hybrid_search_query_with_timeout(self, decoded_r): hybrid_query = HybridQuery(search_query, vsim_query) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.combine("RRF", WINDOW=1000) + combine_method = CombineResultsMethod(CombinationMethods.RRF, WINDOW=1000) timeout = 5000 # 5 second timeout res = await decoded_r.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=timeout + query=hybrid_query, combine_method=combine_method, timeout=timeout ) if is_resp2_connection(decoded_r): diff --git a/tests/test_search.py b/tests/test_search.py index 6a19ec1ae9..16e7dd8f4b 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -12,12 +12,15 @@ import redis.commands.search.aggregation as aggregations from redis.commands.search.hybrid_query import ( + CombinationMethods, + CombineResultsMethod, HybridCursorQuery, HybridFilter, HybridPostProcessingConfig, HybridQuery, HybridSearchQuery, HybridVsimQuery, + VectorSearchMethods, ) from redis.commands.search.hybrid_result import HybridCursorResult import redis.commands.search.reducers as reducers @@ -4092,6 +4095,32 @@ def compare_list_of_dicts(actual, expected): f"All expected:{expected}" ) + @pytest.mark.redismod + @skip_if_server_version_lt("8.3.224") + def test_review_feedback_hybrid_search(self, client): + # Create index and add data + self._create_hybrid_search_index(client) + self._add_data_for_hybrid_search(client, items_sets=5) + + # set search query + search_query = HybridSearchQuery("@color:{red} @color:{green}") + search_query.scorer("TFIDF") + + vsim_query = HybridVsimQuery( + vector_field_name="@embedding", + vector_data=np.array([-100, -200, -200, -300], dtype=np.float32).tobytes(), + ) + + hybrid_query = HybridQuery(search_query, vsim_query) + + res = client.ft().hybrid_search(query=hybrid_query) + if is_resp2_connection(client): + assert len(res.results) > 0 + assert res.warnings == [] + else: + assert len(res["results"]) > 0 + assert res["warnings"] == [] + @pytest.mark.redismod @skip_if_server_version_lt("8.3.224") def test_basic_hybrid_search(self, client): @@ -4145,15 +4174,21 @@ def test_hybrid_search_query_with_scorer(self, client): hybrid_query = HybridQuery(search_query, vsim_query) + combine_config = CombineResultsMethod( + CombinationMethods.LINEAR, ALPHA=1, BETA=0 + ) + posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.combine("LINEAR", ALPHA=1, BETA=0) posprocessing_config.load( "@description", "@color", "@price", "@size", "@__score", "@__item" ) posprocessing_config.limit(0, 2) res = client.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, + combine_method=combine_config, + post_processing=posprocessing_config, + timeout=10, ) expected_results_tfidf = [ @@ -4188,7 +4223,10 @@ def test_hybrid_search_query_with_scorer(self, client): search_query.scorer("BM25") res = client.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, + combine_method=combine_config, + post_processing=posprocessing_config, + timeout=10, ) expected_results_bm25 = [ { @@ -4231,7 +4269,7 @@ def test_hybrid_search_query_with_vsim_method_defined_query_init(self, client): vsim_query = HybridVsimQuery( vector_field_name="@embedding-hnsw", vector_data="abcd1234efgh5678", - vsim_search_method="KNN", + vsim_search_method=VectorSearchMethods.KNN, vsim_search_method_params={"K": 3, "EF_RUNTIME": 1}, ) @@ -4334,7 +4372,7 @@ def test_hybrid_search_query_with_vsim_score_aliases(self, client): vsim_query = HybridVsimQuery( vector_field_name="@embedding-hnsw", vector_data="abcd1234efgh5678", - vsim_search_method="KNN", + vsim_search_method=VectorSearchMethods.KNN, vsim_search_method_params={ "K": 3, "EF_RUNTIME": 1, @@ -4382,14 +4420,15 @@ def test_hybrid_search_query_with_combine_score_aliases(self, client): ) hybrid_query = HybridQuery(search_query, vsim_query) - - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.combine( - "LINEAR", ALPHA=0.5, BETA=0.5, YIELD_SCORE_AS="combined_score" + combine_method = CombineResultsMethod( + CombinationMethods.LINEAR, + ALPHA=0.5, + BETA=0.5, + YIELD_SCORE_AS="combined_score", ) res = client.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, combine_method=combine_method, timeout=10 ) if is_resp2_connection(client): @@ -4419,7 +4458,7 @@ def test_hybrid_search_query_with_combine_all_score_aliases(self, client): vsim_query = HybridVsimQuery( vector_field_name="@embedding-hnsw", vector_data="abcd1234efgh5678", - vsim_search_method="KNN", + vsim_search_method=VectorSearchMethods.KNN, vsim_search_method_params={ "K": 3, "EF_RUNTIME": 1, @@ -4429,13 +4468,15 @@ def test_hybrid_search_query_with_combine_all_score_aliases(self, client): hybrid_query = HybridQuery(search_query, vsim_query) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.combine( - "LINEAR", ALPHA=0.5, BETA=0.5, YIELD_SCORE_AS="combined_score" + combine_method = CombineResultsMethod( + CombinationMethods.LINEAR, + ALPHA=0.5, + BETA=0.5, + YIELD_SCORE_AS="combined_score", ) res = client.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, combine_method=combine_method, timeout=10 ) if is_resp2_connection(client): @@ -4484,7 +4525,7 @@ def test_hybrid_search_query_with_vsim_knn(self, client): vector_data=np.array([1, 2, 2, 3], dtype=np.float32).tobytes(), ) - vsim_query.vsim_method_params("KNN", K=3) + vsim_query.vsim_method_params(VectorSearchMethods.KNN, K=3) hybrid_query = HybridQuery(search_query, vsim_query) @@ -4515,7 +4556,9 @@ def test_hybrid_search_query_with_vsim_knn(self, client): vector_field_name="@embedding-hnsw", vector_data=np.array([1, 2, 2, 3], dtype=np.float32).tobytes(), ) - vsim_query_with_hnsw.vsim_method_params("KNN", K=3, EF_RUNTIME=1) + vsim_query_with_hnsw.vsim_method_params( + VectorSearchMethods.KNN, K=3, EF_RUNTIME=1 + ) hybrid_query_with_hnsw = HybridQuery(search_query, vsim_query_with_hnsw) res2 = client.ft().hybrid_search(query=hybrid_query_with_hnsw, timeout=10) @@ -4556,7 +4599,7 @@ def test_hybrid_search_query_with_vsim_range(self, client): vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), ) - vsim_query.vsim_method_params("RANGE", RADIUS=2) + vsim_query.vsim_method_params(VectorSearchMethods.RANGE, RADIUS=2) hybrid_query = HybridQuery(search_query, vsim_query) @@ -4590,7 +4633,9 @@ def test_hybrid_search_query_with_vsim_range(self, client): vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(), ) - vsim_query_with_hnsw.vsim_method_params("RANGE", RADIUS=2, EPSILON=0.5) + vsim_query_with_hnsw.vsim_method_params( + VectorSearchMethods.RANGE, RADIUS=2, EPSILON=0.5 + ) hybrid_query_with_hnsw = HybridQuery(search_query, vsim_query_with_hnsw) @@ -4638,12 +4683,18 @@ def test_hybrid_search_query_with_combine(self, client): hybrid_query = HybridQuery(search_query, vsim_query) + combine_method_linear = CombineResultsMethod( + CombinationMethods.LINEAR, ALPHA=0.5, BETA=0.5 + ) + posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.combine("LINEAR", ALPHA=0.5, BETA=0.5) posprocessing_config.limit(0, 3) res = client.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, + combine_method=combine_method_linear, + post_processing=posprocessing_config, + timeout=10, ) expected_results = [ @@ -4665,9 +4716,14 @@ def test_hybrid_search_query_with_combine(self, client): assert res["execution_time"] > 0 # combine with RRF and WINDOW + CONSTANT - posprocessing_config.combine("RRF", WINDOW=3, CONSTANT=0.5) + combine_method_rrf = CombineResultsMethod( + CombinationMethods.RRF, WINDOW=3, CONSTANT=0.5 + ) res = client.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, + combine_method=combine_method_rrf, + post_processing=posprocessing_config, + timeout=10, ) expected_results = [ @@ -4689,9 +4745,12 @@ def test_hybrid_search_query_with_combine(self, client): assert res["execution_time"] > 0 # combine with RRF, not all possible params provided - posprocessing_config.combine("RRF", WINDOW=3) + combine_method_rrf_2 = CombineResultsMethod(CombinationMethods.RRF, WINDOW=3) res = client.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, + combine_method=combine_method_rrf_2, + post_processing=posprocessing_config, + timeout=10, ) expected_results = [ @@ -4729,15 +4788,21 @@ def test_hybrid_search_query_with_load(self, client): hybrid_query = HybridQuery(search_query, vsim_query) + combine_method = CombineResultsMethod( + CombinationMethods.LINEAR, ALPHA=0.5, BETA=0.5 + ) + posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.combine("LINEAR", ALPHA=0.5, BETA=0.5) posprocessing_config.load( "@description", "@color", "@price", "@size", "@__key AS item_key" ) posprocessing_config.limit(0, 1) res = client.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, + combine_method=combine_method, + post_processing=posprocessing_config, + timeout=10, ) expected_results = [ @@ -5039,7 +5104,7 @@ def test_hybrid_search_query_with_timeout(self, client): vector_field_name="@embedding-hnsw", vector_data="abcd" * dim, ) - vsim_query.vsim_method_params("KNN", K=1000) + vsim_query.vsim_method_params(VectorSearchMethods.KNN, K=1000) vsim_query.filter( HybridFilter( "((@price:[15 16] @size:[10 11]) | (@price:[13 15] @size:[11 12])) @description:(shoes) -@description:(green)" @@ -5048,12 +5113,11 @@ def test_hybrid_search_query_with_timeout(self, client): hybrid_query = HybridQuery(search_query, vsim_query) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.combine("RRF", WINDOW=1000) + combine_method = CombineResultsMethod(CombinationMethods.RRF, WINDOW=1000) timeout = 5000 # 5 second timeout res = client.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=timeout + query=hybrid_query, combine_method=combine_method, timeout=timeout ) if is_resp2_connection(client): From 12b153e0799a2bd1d98458d65720ed2e92a6d222 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Tue, 4 Nov 2025 20:07:29 +0200 Subject: [PATCH 5/7] Fixing linters --- redis/commands/search/hybrid_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/commands/search/hybrid_query.py b/redis/commands/search/hybrid_query.py index 6b172a05c9..21405e0e5e 100644 --- a/redis/commands/search/hybrid_query.py +++ b/redis/commands/search/hybrid_query.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Optional, Union try: from typing import Self # Py 3.11+ From e12304efe2b6cc67d88a2e3c50ec2ea78449052d Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Wed, 5 Nov 2025 08:43:08 +0200 Subject: [PATCH 6/7] vset test causes crashes of the test servers in pipeline - changing the problematic tests to use less data(sync and async tests) --- tests/test_asyncio/test_vsets.py | 12 ++++++------ tests/test_vsets.py | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/test_asyncio/test_vsets.py b/tests/test_asyncio/test_vsets.py index 447e474e9c..fa37a517c7 100644 --- a/tests/test_asyncio/test_vsets.py +++ b/tests/test_asyncio/test_vsets.py @@ -451,23 +451,23 @@ async def test_vsim_with_filter(d_client): @skip_if_server_version_lt("7.9.0") async def test_vsim_truth_no_thread_enabled(d_client): - elements_count = 5000 + elements_count = 1000 vector_dim = 50 for i in range(1, elements_count + 1): - float_array = [i for _ in range(vector_dim)] + float_array = [i * vector_dim for _ in range(vector_dim)] await d_client.vset().vadd("myset", float_array, f"elem_{i}") await d_client.vset().vadd("myset", [-22 for _ in range(vector_dim)], "elem_man_2") sim_without_truth = await d_client.vset().vsim( - "myset", input="elem_man_2", with_scores=True + "myset", input="elem_man_2", count=30, with_scores=True ) sim_truth = await d_client.vset().vsim( - "myset", input="elem_man_2", with_scores=True, truth=True + "myset", input="elem_man_2", count=30, with_scores=True, truth=True ) - assert len(sim_without_truth) == 10 - assert len(sim_truth) == 10 + assert len(sim_without_truth) == 30 + assert len(sim_truth) == 30 assert isinstance(sim_without_truth, dict) assert isinstance(sim_truth, dict) diff --git a/tests/test_vsets.py b/tests/test_vsets.py index 6376eeb898..bb1ed269a0 100644 --- a/tests/test_vsets.py +++ b/tests/test_vsets.py @@ -453,23 +453,23 @@ def test_vsim_with_filter(d_client): @skip_if_server_version_lt("7.9.0") def test_vsim_truth_no_thread_enabled(d_client): - elements_count = 5000 + elements_count = 1000 vector_dim = 50 for i in range(1, elements_count + 1): - float_array = [i for _ in range(vector_dim)] + float_array = [i * vector_dim for _ in range(vector_dim)] d_client.vset().vadd("myset", float_array, f"elem_{i}") d_client.vset().vadd("myset", [-22 for _ in range(vector_dim)], "elem_man_2") sim_without_truth = d_client.vset().vsim( - "myset", input="elem_man_2", with_scores=True + "myset", input="elem_man_2", with_scores=True, count=30 ) sim_truth = d_client.vset().vsim( - "myset", input="elem_man_2", with_scores=True, truth=True + "myset", input="elem_man_2", with_scores=True, count=30, truth=True ) - assert len(sim_without_truth) == 10 - assert len(sim_truth) == 10 + assert len(sim_without_truth) == 30 + assert len(sim_truth) == 30 assert isinstance(sim_without_truth, dict) assert isinstance(sim_truth, dict) From d1cea545124ad861f84222b75873c8ae45d20c55 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Wed, 5 Nov 2025 10:35:18 +0200 Subject: [PATCH 7/7] Update list concatenation to use extend. Fix spelling error in tests.Extend a test to use two reducers. --- redis/commands/search/hybrid_query.py | 6 +- tests/test_asyncio/test_search.py | 118 ++++++++++---------- tests/test_search.py | 149 +++++++++++++++----------- 3 files changed, 147 insertions(+), 126 deletions(-) diff --git a/redis/commands/search/hybrid_query.py b/redis/commands/search/hybrid_query.py index 21405e0e5e..30cce8e06f 100644 --- a/redis/commands/search/hybrid_query.py +++ b/redis/commands/search/hybrid_query.py @@ -235,10 +235,10 @@ def group_by(self, fields: List[str], *reducers: Reducer) -> Self: ret = ["GROUPBY", str(len(fields)), *fields] for reducer in reducers: - ret += ["REDUCE", reducer.NAME, str(len(reducer.args))] + ret.extend(("REDUCE", reducer.NAME, str(len(reducer.args)))) ret.extend(reducer.args) if reducer._alias is not None: - ret += ["AS", reducer._alias] + ret.extend(("AS", reducer._alias)) self._groupby.extend(ret) return self @@ -255,7 +255,7 @@ def apply(self, **kwexpr) -> Self: for alias, expr in kwexpr.items(): ret = ["APPLY", expr] if alias is not None: - ret += ["AS", alias] + ret.extend(("AS", alias)) self._apply.extend(ret) return self diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 6b83fc05fe..f4b26ae82c 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -2344,16 +2344,16 @@ async def test_hybrid_search_query_with_scorer(self, decoded_r): CombinationMethods.LINEAR, ALPHA=1, BETA=0 ) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.load( + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.load( "@description", "@color", "@price", "@size", "@__score", "@__item" ) - posprocessing_config.limit(0, 2) + postprocessing_config.limit(0, 2) res = await decoded_r.ft().hybrid_search( query=hybrid_query, combine_method=combine_method, - post_processing=posprocessing_config, + post_processing=postprocessing_config, timeout=10, ) @@ -2391,7 +2391,7 @@ async def test_hybrid_search_query_with_scorer(self, decoded_r): res = await decoded_r.ft().hybrid_search( query=hybrid_query, combine_method=combine_method, - post_processing=posprocessing_config, + post_processing=postprocessing_config, timeout=10, ) expected_results_bm25 = [ @@ -2442,11 +2442,11 @@ async def test_hybrid_search_query_with_vsim_filter(self, decoded_r): hybrid_query = HybridQuery(search_query, vsim_query) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.load("@price", "@size") + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.load("@price", "@size") res = await decoded_r.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, post_processing=postprocessing_config, timeout=10 ) if is_resp2_connection(decoded_r): assert len(res.results) > 0 @@ -2481,10 +2481,10 @@ async def test_hybrid_search_query_with_vsim_knn(self, decoded_r): hybrid_query = HybridQuery(search_query, vsim_query) - posprocessing_config = HybridPostProcessingConfig() + postprocessing_config = HybridPostProcessingConfig() res = await decoded_r.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, post_processing=postprocessing_config, timeout=10 ) expected_results = [ {"__key": b"item:2", "__score": b"0.016393442623"}, @@ -2557,11 +2557,11 @@ async def test_hybrid_search_query_with_vsim_range(self, decoded_r): hybrid_query = HybridQuery(search_query, vsim_query) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.limit(0, 3) + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.limit(0, 3) res = await decoded_r.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, post_processing=postprocessing_config, timeout=10 ) expected_results = [ @@ -2595,7 +2595,7 @@ async def test_hybrid_search_query_with_vsim_range(self, decoded_r): res = await decoded_r.ft().hybrid_search( query=hybrid_query_with_hnsw, - post_processing=posprocessing_config, + post_processing=postprocessing_config, timeout=10, ) @@ -2707,13 +2707,13 @@ async def test_hybrid_search_query_with_combine(self, decoded_r): CombinationMethods.LINEAR, ALPHA=0.5, BETA=0.5 ) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.limit(0, 3) + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.limit(0, 3) res = await decoded_r.ft().hybrid_search( query=hybrid_query, combine_method=combine_method_linear, - post_processing=posprocessing_config, + post_processing=postprocessing_config, timeout=10, ) @@ -2742,7 +2742,7 @@ async def test_hybrid_search_query_with_combine(self, decoded_r): res = await decoded_r.ft().hybrid_search( query=hybrid_query, combine_method=combine_method_rrf, - post_processing=posprocessing_config, + post_processing=postprocessing_config, timeout=10, ) @@ -2769,7 +2769,7 @@ async def test_hybrid_search_query_with_combine(self, decoded_r): res = await decoded_r.ft().hybrid_search( query=hybrid_query, combine_method=combine_method_rrf_2, - post_processing=posprocessing_config, + post_processing=postprocessing_config, timeout=10, ) @@ -2812,16 +2812,16 @@ async def test_hybrid_search_query_with_load(self, decoded_r): CombinationMethods.LINEAR, ALPHA=0.5, BETA=0.5 ) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.load( + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.load( "@description", "@color", "@price", "@size", "@__key AS item_key" ) - posprocessing_config.limit(0, 1) + postprocessing_config.limit(0, 1) res = await decoded_r.ft().hybrid_search( query=hybrid_query, combine_method=combine_method, - post_processing=posprocessing_config, + post_processing=postprocessing_config, timeout=10, ) @@ -2866,16 +2866,16 @@ async def test_hybrid_search_query_with_load_and_apply(self, decoded_r): hybrid_query = HybridQuery(search_query, vsim_query) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.load("@color", "@price", "@size") - posprocessing_config.apply( + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.load("@color", "@price", "@size") + postprocessing_config.apply( price_discount="@price - (@price * 0.1)", tax_discount="@price_discount * 0.2", ) - posprocessing_config.limit(0, 3) + postprocessing_config.limit(0, 3) res = await decoded_r.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, post_processing=postprocessing_config, timeout=10 ) expected_results = [ @@ -2931,16 +2931,16 @@ async def test_hybrid_search_query_with_load_and_filter(self, decoded_r): hybrid_query = HybridQuery(search_query, vsim_query) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.load("@description", "@color", "@price", "@size") + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.load("@description", "@color", "@price", "@size") # for the postprocessing filter we need to filter on the loaded fields # expecting all of them to be interpreted as strings - the initial filed types # are not preserved - posprocessing_config.filter(HybridFilter('@price=="15"')) - posprocessing_config.limit(0, 3) + postprocessing_config.filter(HybridFilter('@price=="15"')) + postprocessing_config.limit(0, 3) res = await decoded_r.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, post_processing=postprocessing_config, timeout=10 ) if is_resp2_connection(decoded_r): @@ -2975,10 +2975,10 @@ async def test_hybrid_search_query_with_load_apply_and_params(self, decoded_r): hybrid_query = HybridQuery(search_query, vsim_query) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.load("@description", "@color", "@price") - posprocessing_config.apply(price_discount="@price - (@price * 0.1)") - posprocessing_config.limit(0, 3) + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.load("@description", "@color", "@price") + postprocessing_config.apply(price_discount="@price - (@price * 0.1)") + postprocessing_config.limit(0, 3) params_substitution = { "vector": "abcd1234abcd5678", @@ -2987,7 +2987,7 @@ async def test_hybrid_search_query_with_load_apply_and_params(self, decoded_r): res = await decoded_r.ft().hybrid_search( query=hybrid_query, - post_processing=posprocessing_config, + post_processing=postprocessing_config, params_substitution=params_substitution, timeout=10, ) @@ -3040,11 +3040,11 @@ async def test_hybrid_search_query_with_limit(self, decoded_r): hybrid_query = HybridQuery(search_query, vsim_query) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.limit(0, 3) + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.limit(0, 3) res = await decoded_r.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, post_processing=postprocessing_config, timeout=10 ) if is_resp2_connection(decoded_r): @@ -3071,16 +3071,16 @@ async def test_hybrid_search_query_with_load_apply_and_sortby(self, decoded_r): hybrid_query = HybridQuery(search_query, vsim_query) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.load("@color", "@price") - posprocessing_config.apply(price_discount="@price - (@price * 0.1)") - posprocessing_config.sort_by( + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.load("@color", "@price") + postprocessing_config.apply(price_discount="@price - (@price * 0.1)") + postprocessing_config.sort_by( SortbyField("@price_discount", asc=False), SortbyField("@color", asc=True) ) - posprocessing_config.limit(0, 5) + postprocessing_config.limit(0, 5) res = await decoded_r.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, post_processing=postprocessing_config, timeout=10 ) expected_results = [ @@ -3181,19 +3181,19 @@ async def test_hybrid_search_query_with_load_and_groupby(self, decoded_r): hybrid_query = HybridQuery(search_query, vsim_query) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.load("@color", "@price", "@size", "@item_type") - posprocessing_config.limit(0, 4) + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.load("@color", "@price", "@size", "@item_type") + postprocessing_config.limit(0, 4) - posprocessing_config.group_by( + postprocessing_config.group_by( ["@price"], reducers.count_distinct("@color").alias("colors_count"), ) - posprocessing_config.sort_by(SortbyField("@price", asc=True)) + postprocessing_config.sort_by(SortbyField("@price", asc=True)) res = await decoded_r.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, post_processing=postprocessing_config, timeout=10 ) expected_results = [ @@ -3212,21 +3212,21 @@ async def test_hybrid_search_query_with_load_and_groupby(self, decoded_r): assert res["results"] == self._convert_dict_values_to_str(expected_results) assert res["warnings"] == [] - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.load("@color", "@price", "@size", "@item_type") - posprocessing_config.limit(0, 6) - posprocessing_config.sort_by( + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.load("@color", "@price", "@size", "@item_type") + postprocessing_config.limit(0, 6) + postprocessing_config.sort_by( SortbyField("@price", asc=True), SortbyField("@item_type", asc=True), ) - posprocessing_config.group_by( + postprocessing_config.group_by( ["@price", "@item_type"], reducers.count_distinct("@color").alias("unique_colors_count"), ) res = await decoded_r.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=1000 + query=hybrid_query, post_processing=postprocessing_config, timeout=1000 ) expected_results = [ diff --git a/tests/test_search.py b/tests/test_search.py index 16e7dd8f4b..c5f5679f90 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -4178,16 +4178,16 @@ def test_hybrid_search_query_with_scorer(self, client): CombinationMethods.LINEAR, ALPHA=1, BETA=0 ) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.load( + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.load( "@description", "@color", "@price", "@size", "@__score", "@__item" ) - posprocessing_config.limit(0, 2) + postprocessing_config.limit(0, 2) res = client.ft().hybrid_search( query=hybrid_query, combine_method=combine_config, - post_processing=posprocessing_config, + post_processing=postprocessing_config, timeout=10, ) @@ -4225,7 +4225,7 @@ def test_hybrid_search_query_with_scorer(self, client): res = client.ft().hybrid_search( query=hybrid_query, combine_method=combine_config, - post_processing=posprocessing_config, + post_processing=postprocessing_config, timeout=10, ) expected_results_bm25 = [ @@ -4300,11 +4300,11 @@ def test_hybrid_search_query_with_vsim_filter(self, client): hybrid_query = HybridQuery(search_query, vsim_query) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.load("@price", "@size") + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.load("@price", "@size") res = client.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, post_processing=postprocessing_config, timeout=10 ) if is_resp2_connection(client): assert len(res.results) > 0 @@ -4529,10 +4529,10 @@ def test_hybrid_search_query_with_vsim_knn(self, client): hybrid_query = HybridQuery(search_query, vsim_query) - posprocessing_config = HybridPostProcessingConfig() + postprocessing_config = HybridPostProcessingConfig() res = client.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, post_processing=postprocessing_config, timeout=10 ) expected_results = [ {"__key": b"item:2", "__score": b"0.016393442623"}, @@ -4603,11 +4603,11 @@ def test_hybrid_search_query_with_vsim_range(self, client): hybrid_query = HybridQuery(search_query, vsim_query) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.limit(0, 3) + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.limit(0, 3) res = client.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, post_processing=postprocessing_config, timeout=10 ) expected_results = [ @@ -4641,7 +4641,7 @@ def test_hybrid_search_query_with_vsim_range(self, client): res = client.ft().hybrid_search( query=hybrid_query_with_hnsw, - post_processing=posprocessing_config, + post_processing=postprocessing_config, timeout=10, ) @@ -4687,13 +4687,13 @@ def test_hybrid_search_query_with_combine(self, client): CombinationMethods.LINEAR, ALPHA=0.5, BETA=0.5 ) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.limit(0, 3) + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.limit(0, 3) res = client.ft().hybrid_search( query=hybrid_query, combine_method=combine_method_linear, - post_processing=posprocessing_config, + post_processing=postprocessing_config, timeout=10, ) @@ -4722,7 +4722,7 @@ def test_hybrid_search_query_with_combine(self, client): res = client.ft().hybrid_search( query=hybrid_query, combine_method=combine_method_rrf, - post_processing=posprocessing_config, + post_processing=postprocessing_config, timeout=10, ) @@ -4749,7 +4749,7 @@ def test_hybrid_search_query_with_combine(self, client): res = client.ft().hybrid_search( query=hybrid_query, combine_method=combine_method_rrf_2, - post_processing=posprocessing_config, + post_processing=postprocessing_config, timeout=10, ) @@ -4792,16 +4792,16 @@ def test_hybrid_search_query_with_load(self, client): CombinationMethods.LINEAR, ALPHA=0.5, BETA=0.5 ) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.load( + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.load( "@description", "@color", "@price", "@size", "@__key AS item_key" ) - posprocessing_config.limit(0, 1) + postprocessing_config.limit(0, 1) res = client.ft().hybrid_search( query=hybrid_query, combine_method=combine_method, - post_processing=posprocessing_config, + post_processing=postprocessing_config, timeout=10, ) @@ -4847,16 +4847,16 @@ def test_hybrid_search_query_with_load_and_apply(self, client): hybrid_query = HybridQuery(search_query, vsim_query) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.load("@color", "@price", "@size") - posprocessing_config.apply( + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.load("@color", "@price", "@size") + postprocessing_config.apply( price_discount="@price - (@price * 0.1)", tax_discount="@price_discount * 0.2", ) - posprocessing_config.limit(0, 3) + postprocessing_config.limit(0, 3) res = client.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, post_processing=postprocessing_config, timeout=10 ) expected_results = [ @@ -4912,16 +4912,16 @@ def test_hybrid_search_query_with_load_and_filter(self, client): hybrid_query = HybridQuery(search_query, vsim_query) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.load("@description", "@color", "@price", "@size") + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.load("@description", "@color", "@price", "@size") # for the postprocessing filter we need to filter on the loaded fields # expecting all of them to be interpreted as strings - the initial filed types # are not preserved - posprocessing_config.filter(HybridFilter('@price=="15"')) - posprocessing_config.limit(0, 3) + postprocessing_config.filter(HybridFilter('@price=="15"')) + postprocessing_config.limit(0, 3) res = client.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, post_processing=postprocessing_config, timeout=10 ) if is_resp2_connection(client): @@ -4954,10 +4954,10 @@ def test_hybrid_search_query_with_load_apply_and_params(self, client): hybrid_query = HybridQuery(search_query, vsim_query) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.load("@description", "@color", "@price") - posprocessing_config.apply(price_discount="@price - (@price * 0.1)") - posprocessing_config.limit(0, 3) + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.load("@description", "@color", "@price") + postprocessing_config.apply(price_discount="@price - (@price * 0.1)") + postprocessing_config.limit(0, 3) params_substitution = { "vector": "abcd1234abcd5678", @@ -4966,7 +4966,7 @@ def test_hybrid_search_query_with_load_apply_and_params(self, client): res = client.ft().hybrid_search( query=hybrid_query, - post_processing=posprocessing_config, + post_processing=postprocessing_config, params_substitution=params_substitution, timeout=10, ) @@ -5019,11 +5019,11 @@ def test_hybrid_search_query_with_limit(self, client): hybrid_query = HybridQuery(search_query, vsim_query) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.limit(0, 3) + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.limit(0, 3) res = client.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, post_processing=postprocessing_config, timeout=10 ) if is_resp2_connection(client): @@ -5050,16 +5050,16 @@ def test_hybrid_search_query_with_load_apply_and_sortby(self, client): hybrid_query = HybridQuery(search_query, vsim_query) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.load("@color", "@price") - posprocessing_config.apply(price_discount="@price - (@price * 0.1)") - posprocessing_config.sort_by( + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.load("@color", "@price") + postprocessing_config.apply(price_discount="@price - (@price * 0.1)") + postprocessing_config.sort_by( SortbyField("@price_discount", asc=False), SortbyField("@color", asc=True) ) - posprocessing_config.limit(0, 5) + postprocessing_config.limit(0, 5) res = client.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, post_processing=postprocessing_config, timeout=10 ) expected_results = [ @@ -5158,26 +5158,47 @@ def test_hybrid_search_query_with_load_and_groupby(self, client): hybrid_query = HybridQuery(search_query, vsim_query) - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.load("@color", "@price", "@size", "@item_type") - posprocessing_config.limit(0, 4) + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.load("@color", "@price", "@size", "@item_type") + postprocessing_config.limit(0, 4) - posprocessing_config.group_by( - ["@price"], + postprocessing_config.group_by( + ["@item_type", "@price"], reducers.count_distinct("@color").alias("colors_count"), + reducers.min("@size"), ) - posprocessing_config.sort_by(SortbyField("@price", asc=True)) + postprocessing_config.sort_by(SortbyField("@price", asc=True)) res = client.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=10 + query=hybrid_query, post_processing=postprocessing_config, timeout=10 ) expected_results = [ - {"price": b"15", "colors_count": b"2"}, - {"price": b"16", "colors_count": b"2"}, - {"price": b"17", "colors_count": b"2"}, - {"price": b"18", "colors_count": b"2"}, + { + "item_type": b"dress", + "price": b"15", + "colors_count": b"1", + "__generated_aliasminsize": b"10", + }, + { + "item_type": b"shoes", + "price": b"15", + "colors_count": b"2", + "__generated_aliasminsize": b"10", + }, + { + "item_type": b"shoes", + "price": b"16", + "colors_count": b"2", + "__generated_aliasminsize": b"10", + }, + { + "item_type": b"dress", + "price": b"16", + "colors_count": b"1", + "__generated_aliasminsize": b"11", + }, ] if is_resp2_connection(client): @@ -5189,21 +5210,21 @@ def test_hybrid_search_query_with_load_and_groupby(self, client): assert res["results"] == self._convert_dict_values_to_str(expected_results) assert res["warnings"] == [] - posprocessing_config = HybridPostProcessingConfig() - posprocessing_config.load("@color", "@price", "@size", "@item_type") - posprocessing_config.limit(0, 6) - posprocessing_config.sort_by( + postprocessing_config = HybridPostProcessingConfig() + postprocessing_config.load("@color", "@price", "@size", "@item_type") + postprocessing_config.limit(0, 6) + postprocessing_config.sort_by( SortbyField("@price", asc=True), SortbyField("@item_type", asc=True), ) - posprocessing_config.group_by( + postprocessing_config.group_by( ["@price", "@item_type"], reducers.count_distinct("@color").alias("unique_colors_count"), ) res = client.ft().hybrid_search( - query=hybrid_query, post_processing=posprocessing_config, timeout=1000 + query=hybrid_query, post_processing=postprocessing_config, timeout=1000 ) expected_results = [