Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 89 additions & 6 deletions redisvl/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ class TextQuery(BaseQuery):
def __init__(
self,
text: str,
text_field_name: str,
text_field_name: Union[str, Dict[str, float]],
text_scorer: str = "BM25STD",
filter_expression: Optional[Union[str, FilterExpression]] = None,
return_fields: Optional[List[str]] = None,
Expand All @@ -817,7 +817,8 @@ def __init__(

Args:
text (str): The text string to perform the text search with.
text_field_name (str): The name of the document field to perform text search on.
text_field_name (Union[str, Dict[str, float]]): The name of the document field to perform
text search on, or a dictionary mapping field names to their weights.
text_scorer (str, optional): The text scoring algorithm to use.
Defaults to BM25STD. Options are {TFIDF, BM25STD, BM25, TFIDF.DOCNORM, DISMAX, DOCSCORE}.
See https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/scoring/
Expand Down Expand Up @@ -849,7 +850,7 @@ def __init__(
TypeError: If stopwords is not a valid iterable set of strings.
"""
self._text = text
self._text_field_name = text_field_name
self._field_weights = self._parse_field_weights(text_field_name)
self._num_results = num_results

self._set_stopwords(stopwords)
Expand Down Expand Up @@ -934,15 +935,97 @@ def _tokenize_and_escape_query(self, user_query: str) -> str:
[token for token in tokens if token and token not in self._stopwords]
)

def _parse_field_weights(
self, field_spec: Union[str, Dict[str, float]]
) -> Dict[str, float]:
"""Parse the field specification into a weights dictionary.

Args:
field_spec: Either a single field name or dictionary of field:weight mappings

Returns:
Dictionary mapping field names to their weights
"""
if isinstance(field_spec, str):
return {field_spec: 1.0}
elif isinstance(field_spec, dict):
# Validate all weights are numeric and positive
for field, weight in field_spec.items():
if not isinstance(field, str):
raise TypeError(f"Field name must be a string, got {type(field)}")
if not isinstance(weight, (int, float)):
raise TypeError(
f"Weight for field '{field}' must be numeric, got {type(weight)}"
)
if weight <= 0:
raise ValueError(
f"Weight for field '{field}' must be positive, got {weight}"
)
return field_spec
else:
raise TypeError(
"text_field_name must be a string or dictionary of field:weight mappings"
)

def set_field_weights(self, field_weights: Union[str, Dict[str, float]]):
"""Set or update the field weights for the query.

Args:
field_weights: Either a single field name or dictionary of field:weight mappings
"""
self._field_weights = self._parse_field_weights(field_weights)
# Invalidate the query string
self._built_query_string = None

@property
def field_weights(self) -> Dict[str, float]:
"""Get the field weights for the query.

Returns:
Dictionary mapping field names to their weights
"""
return self._field_weights.copy()

@property
def text_field_name(self) -> Union[str, Dict[str, float]]:
"""Get the text field name(s) - for backward compatibility.

Returns:
Either a single field name string (if only one field with weight 1.0)
or a dictionary of field:weight mappings.
"""
if len(self._field_weights) == 1:
field, weight = next(iter(self._field_weights.items()))
if weight == 1.0:
return field
return self._field_weights.copy()

def _build_query_string(self) -> str:
"""Build the full query string for text search with optional filtering."""
filter_expression = self._filter_expression
if isinstance(filter_expression, FilterExpression):
filter_expression = str(filter_expression)

text = (
f"@{self._text_field_name}:({self._tokenize_and_escape_query(self._text)})"
)
escaped_query = self._tokenize_and_escape_query(self._text)

# Build query parts for each field with its weight
field_queries = []
for field, weight in self._field_weights.items():
if weight == 1.0:
# Default weight doesn't need explicit weight syntax
field_queries.append(f"@{field}:({escaped_query})")
else:
# Use Redis weight syntax for non-default weights
field_queries.append(
f"@{field}:({escaped_query}) => {{ $weight: {weight} }}"
)

# Join multiple field queries with OR operator
if len(field_queries) == 1:
text = field_queries[0]
else:
text = "(" + " | ".join(field_queries) + ")"

if filter_expression and filter_expression != "*":
text += f" AND {filter_expression}"
return text
46 changes: 46 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,26 @@ async def get_redis_version_async(client):
return info["redis_version"]


def has_redisearch_module(client):
"""Check if RediSearch module is available."""
try:
# Try to list indices - this is a RediSearch command
client.execute_command("FT._LIST")
return True
except Exception:
return False


async def has_redisearch_module_async(client):
"""Check if RediSearch module is available (async)."""
try:
# Try to list indices - this is a RediSearch command
await client.execute_command("FT._LIST")
return True
except Exception:
return False


def skip_if_redis_version_below(client, min_version: str, message: str = None):
"""
Skip test if Redis version is below minimum required.
Expand Down Expand Up @@ -609,3 +629,29 @@ async def skip_if_redis_version_below_async(
if not compare_versions(redis_version, min_version):
skip_msg = message or f"Redis version {redis_version} < {min_version} required"
pytest.skip(skip_msg)


def skip_if_no_redisearch(client, message: str = None):
"""
Skip test if RediSearch module is not available.

Args:
client: Redis client instance
message: Custom skip message
"""
if not has_redisearch_module(client):
skip_msg = message or "RediSearch module not available"
pytest.skip(skip_msg)


async def skip_if_no_redisearch_async(client, message: str = None):
"""
Skip test if RediSearch module is not available (async version).

Args:
client: Async Redis client instance
message: Custom skip message
"""
if not await has_redisearch_module_async(client):
skip_msg = message or "RediSearch module not available"
pytest.skip(skip_msg)
30 changes: 22 additions & 8 deletions tests/integration/test_llmcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from redisvl.index.index import AsyncSearchIndex, SearchIndex
from redisvl.query.filter import Num, Tag, Text
from redisvl.utils.vectorize import HFTextVectorizer
from tests.conftest import skip_if_no_redisearch, skip_if_no_redisearch_async


@pytest.fixture(scope="session")
Expand All @@ -19,7 +20,8 @@ def vectorizer():


@pytest.fixture
def cache(vectorizer, redis_url, worker_id):
def cache(client, vectorizer, redis_url, worker_id):
skip_if_no_redisearch(client)
cache_instance = SemanticCache(
name=f"llmcache_{worker_id}",
vectorizer=vectorizer,
Expand All @@ -31,7 +33,8 @@ def cache(vectorizer, redis_url, worker_id):


@pytest.fixture
def cache_with_filters(vectorizer, redis_url, worker_id):
def cache_with_filters(client, vectorizer, redis_url, worker_id):
skip_if_no_redisearch(client)
cache_instance = SemanticCache(
name=f"llmcache_filters_{worker_id}",
vectorizer=vectorizer,
Expand All @@ -44,7 +47,8 @@ def cache_with_filters(vectorizer, redis_url, worker_id):


@pytest.fixture
def cache_no_cleanup(vectorizer, redis_url, worker_id):
def cache_no_cleanup(client, vectorizer, redis_url, worker_id):
skip_if_no_redisearch(client)
cache_instance = SemanticCache(
name=f"llmcache_no_cleanup_{worker_id}",
vectorizer=vectorizer,
Expand All @@ -55,7 +59,8 @@ def cache_no_cleanup(vectorizer, redis_url, worker_id):


@pytest.fixture
def cache_with_ttl(vectorizer, redis_url, worker_id):
def cache_with_ttl(client, vectorizer, redis_url, worker_id):
skip_if_no_redisearch(client)
cache_instance = SemanticCache(
name=f"llmcache_ttl_{worker_id}",
vectorizer=vectorizer,
Expand All @@ -69,6 +74,7 @@ def cache_with_ttl(vectorizer, redis_url, worker_id):

@pytest.fixture
def cache_with_redis_client(vectorizer, client, worker_id):
skip_if_no_redisearch(client)
cache_instance = SemanticCache(
name=f"llmcache_client_{worker_id}",
vectorizer=vectorizer,
Expand Down Expand Up @@ -750,7 +756,8 @@ def test_cache_filtering(cache_with_filters):
assert len(results) == 0


def test_cache_bad_filters(vectorizer, redis_url, worker_id):
def test_cache_bad_filters(client, vectorizer, redis_url, worker_id):
skip_if_no_redisearch(client)
with pytest.raises(ValueError):
cache_instance = SemanticCache(
name=f"test_bad_filters_1_{worker_id}",
Expand Down Expand Up @@ -819,6 +826,7 @@ def test_complex_filters(cache_with_filters):


def test_cache_index_overwrite(client, redis_url, worker_id, hf_vectorizer):
skip_if_no_redisearch(client)
# Skip this test for Redis 6.2.x as FT.INFO doesn't return dims properly
redis_version = client.info()["redis_version"]
if redis_version.startswith("6.2"):
Expand Down Expand Up @@ -921,7 +929,8 @@ def test_no_key_collision_on_identical_prompts(redis_url, worker_id, hf_vectoriz
assert len(filtered_results) == 2


def test_create_cache_with_different_vector_types(worker_id, redis_url):
def test_create_cache_with_different_vector_types(client, worker_id, redis_url):
skip_if_no_redisearch(client)
try:
bfloat_cache = SemanticCache(
name=f"bfloat_cache_{worker_id}", dtype="bfloat16", redis_url=redis_url
Expand Down Expand Up @@ -951,6 +960,7 @@ def test_create_cache_with_different_vector_types(worker_id, redis_url):


def test_bad_dtype_connecting_to_existing_cache(client, redis_url, worker_id):
skip_if_no_redisearch(client)
# Skip this test for Redis 6.2.x as FT.INFO doesn't return dims properly
redis_version = client.info()["redis_version"]
if redis_version.startswith("6.2"):
Expand Down Expand Up @@ -1021,7 +1031,10 @@ def test_deprecated_dtype_argument(redis_url, worker_id):


@pytest.mark.asyncio
async def test_cache_async_context_manager(redis_url, worker_id, hf_vectorizer):
async def test_cache_async_context_manager(
async_client, redis_url, worker_id, hf_vectorizer
):
await skip_if_no_redisearch_async(async_client)
async with SemanticCache(
name=f"test_cache_async_context_manager_{worker_id}",
redis_url=redis_url,
Expand All @@ -1034,8 +1047,9 @@ async def test_cache_async_context_manager(redis_url, worker_id, hf_vectorizer):

@pytest.mark.asyncio
async def test_cache_async_context_manager_with_exception(
redis_url, worker_id, hf_vectorizer
async_client, redis_url, worker_id, hf_vectorizer
):
await skip_if_no_redisearch_async(async_client)
try:
async with SemanticCache(
name=f"test_cache_async_context_manager_with_exception_{worker_id}",
Expand Down
19 changes: 14 additions & 5 deletions tests/integration/test_message_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from redisvl.extensions.constants import ID_FIELD_NAME
from redisvl.extensions.message_history import MessageHistory, SemanticMessageHistory
from tests.conftest import skip_if_no_redisearch


@pytest.fixture
Expand All @@ -21,6 +22,7 @@ def standard_history(app_name, client):

@pytest.fixture
def semantic_history(app_name, client, hf_vectorizer):
skip_if_no_redisearch(client)
history = SemanticMessageHistory(
app_name, redis_client=client, overwrite=True, vectorizer=hf_vectorizer
)
Expand Down Expand Up @@ -326,6 +328,7 @@ def test_standard_clear(standard_history):

# test semantic message history
def test_semantic_specify_client(client, hf_vectorizer):
skip_if_no_redisearch(client)
history = SemanticMessageHistory(
name="test_app",
session_tag="abc",
Expand Down Expand Up @@ -616,7 +619,8 @@ def test_semantic_drop(semantic_history):
]


def test_different_vector_dtypes(redis_url):
def test_different_vector_dtypes(client, redis_url):
skip_if_no_redisearch(client)
try:
bfloat_sess = SemanticMessageHistory(
name="bfloat_history", dtype="bfloat16", redis_url=redis_url
Expand Down Expand Up @@ -647,6 +651,7 @@ def test_different_vector_dtypes(redis_url):


def test_bad_dtype_connecting_to_exiting_history(client, redis_url):
skip_if_no_redisearch(client)
# Skip this test for Redis 6.2.x as FT.INFO doesn't return dims properly
redis_version = client.info()["redis_version"]
if redis_version.startswith("6.2"):
Expand Down Expand Up @@ -674,7 +679,8 @@ def create_same_type():
)


def test_vectorizer_dtype_mismatch(redis_url, hf_vectorizer_float16):
def test_vectorizer_dtype_mismatch(client, redis_url, hf_vectorizer_float16):
skip_if_no_redisearch(client)
with pytest.raises(ValueError):
SemanticMessageHistory(
name="test_dtype_mismatch",
Expand All @@ -685,7 +691,8 @@ def test_vectorizer_dtype_mismatch(redis_url, hf_vectorizer_float16):
)


def test_invalid_vectorizer(redis_url):
def test_invalid_vectorizer(client, redis_url):
skip_if_no_redisearch(client)
with pytest.raises(TypeError):
SemanticMessageHistory(
name="test_invalid_vectorizer",
Expand All @@ -695,7 +702,8 @@ def test_invalid_vectorizer(redis_url):
)


def test_passes_through_dtype_to_default_vectorizer(redis_url):
def test_passes_through_dtype_to_default_vectorizer(client, redis_url):
skip_if_no_redisearch(client)
# The default is float32, so we should see float64 if we pass it in.
cache = SemanticMessageHistory(
name="test_pass_through_dtype",
Expand All @@ -706,7 +714,8 @@ def test_passes_through_dtype_to_default_vectorizer(redis_url):
assert cache._vectorizer.dtype == "float64"


def test_deprecated_dtype_argument(redis_url):
def test_deprecated_dtype_argument(client, redis_url):
skip_if_no_redisearch(client)
with pytest.warns(DeprecationWarning):
SemanticMessageHistory(
name="float64 history", dtype="float64", redis_url=redis_url, overwrite=True
Expand Down
Loading
Loading