diff --git a/redisvl/query/aggregate.py b/redisvl/query/aggregate.py index a3a31e05..109664ee 100644 --- a/redisvl/query/aggregate.py +++ b/redisvl/query/aggregate.py @@ -1,7 +1,8 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, field_validator, model_validator from redis.commands.search.aggregation import AggregateRequest, Desc +from typing_extensions import Self from redisvl.query.filter import FilterExpression from redisvl.redis.utils import array_to_buffer @@ -32,9 +33,16 @@ def validate_dtype(cls, dtype: str) -> str: raise ValueError( f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}" ) - return dtype + @model_validator(mode="after") + def validate_vector(self) -> Self: + """If the vector passed in is an array of float convert it to a byte string.""" + if isinstance(self.vector, bytes): + return self + self.vector = array_to_buffer(self.vector, self.dtype) + return self + class AggregationQuery(AggregateRequest): """ @@ -364,12 +372,8 @@ def params(self) -> Dict[str, Any]: Dict[str, Any]: The parameters for the aggregation. """ params = {} - for i, (vector, dtype) in enumerate( - [(v.vector, v.dtype) for v in self._vectors] - ): - if isinstance(vector, list): - vector = array_to_buffer(vector, dtype=dtype) # type: ignore - params[f"vector_{i}"] = vector + for i, v in enumerate(self._vectors): + params[f"vector_{i}"] = v.vector return params def _build_query_string(self) -> str: diff --git a/tests/integration/test_aggregation.py b/tests/integration/test_aggregation.py index f08815a6..7f98cf54 100644 --- a/tests/integration/test_aggregation.py +++ b/tests/integration/test_aggregation.py @@ -365,6 +365,31 @@ def test_multivector_query(index): ) +def test_multivector_query_accepts_bytes(index): + skip_if_redis_version_below(index.client, "7.2.0") + + vector_bytes = [ + array_to_buffer([0.1, 0.1, 0.5], "float32"), + array_to_buffer([0.3, 0.4, 0.7, 0.2, -0.3, 0.25], "float64"), + ] + vector_fields = ["user_embedding", "audio_embedding"] + dtypes = ["float32", "float64"] + vectors = [] + for vector, field, dtype in zip(vector_bytes, vector_fields, dtypes): + vectors.append(Vector(vector=vector, field_name=field, dtype=dtype)) + + return_fields = ["user", "credit_score", "age", "job", "location", "description"] + + multi_query = MultiVectorQuery( + vectors=vectors, + return_fields=return_fields, + ) + + results = index.query(multi_query) + assert isinstance(results, list) + assert len(results) == 7 + + def test_multivector_query_with_filter(index): skip_if_redis_version_below(index.client, "7.2.0") diff --git a/tests/unit/test_aggregation_types.py b/tests/unit/test_aggregation_types.py index f2b6be86..8694ff0d 100644 --- a/tests/unit/test_aggregation_types.py +++ b/tests/unit/test_aggregation_types.py @@ -6,6 +6,7 @@ from redisvl.index.index import process_results from redisvl.query.aggregate import HybridQuery, MultiVectorQuery, Vector from redisvl.query.filter import Tag +from redisvl.redis.utils import array_to_buffer # Sample data for testing sample_vector = [0.1, 0.2, 0.3, 0.4] @@ -314,3 +315,20 @@ def test_vector_object_validation(): for dtype in ["bfloat16", "float16", "float32", "float64", "int8", "uint8"]: vec = Vector(vector=sample_vector, field_name="text embedding", dtype=dtype) assert isinstance(vec, Vector) + + +def test_vector_object_handles_byte_conversion(): + # test that passing an array of floats gets converted to bytes + vec = Vector(vector=sample_vector, field_name="field 1", dtype="float16") + assert vec.vector == array_to_buffer(sample_vector, dtype="float16") + + # test we can pass an array of floats and convert to all supported dtypes + for datatype in ["bfloat16", "float16", "float32", "float64"]: + vec = Vector(vector=sample_vector, field_name="field 1", dtype=datatype) + assert vec.vector == array_to_buffer(sample_vector, dtype=datatype) + + # test that passing in a byte string it is stored unchanged + for datatype in ["bfloat16", "float16", "float32", "float64"]: + byte_string = array_to_buffer(sample_vector, datatype) + vec = Vector(vector=byte_string, field_name="field 1") + assert vec.vector == byte_string