Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions redisvl/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,28 @@
SCHEMA_VERSION = "0.1.0"


def custom_dict(model: BaseModel) -> Dict[str, Any]:
"""
Custom serialization function that converts a Pydantic model to a dict,
serializing Enum fields to their values, and handling nested models and lists.
"""

def serialize_item(item):
if isinstance(item, Enum):
return item.value.lower()
elif isinstance(item, dict):
return {key: serialize_item(value) for key, value in item.items()}
elif isinstance(item, list):
return [serialize_item(element) for element in item]
else:
return item

serialized_data = model.dict(exclude_none=True)
for key, value in serialized_data.items():
serialized_data[key] = serialize_item(value)
return serialized_data


class StorageType(Enum):
"""
Enumeration for the storage types supported in Redis.
Expand Down Expand Up @@ -63,14 +85,6 @@ class IndexInfo(BaseModel):
storage_type: StorageType = StorageType.HASH
"""The storage type used in Redis (e.g., 'hash' or 'json')."""

def dict(self, *args, **kwargs) -> Dict[str, Any]:
return {
"name": self.name,
"prefix": self.prefix,
"key_separator": self.key_separator,
"storage_type": self.storage_type.value,
}


class IndexSchema(BaseModel):
"""A schema definition for a search index in Redis, used in RedisVL for
Expand Down Expand Up @@ -428,12 +442,13 @@ def generate_fields(
return fields

def to_dict(self) -> Dict[str, Any]:
"""Convert the index schema to a dictionary.
"""Serialize the index schema model to a dictionary, handling Enums
and other special cases properly.

Returns:
Dict[str, Any]: The index schema as a dictionary.
"""
dict_schema = self.dict(exclude_none=True)
dict_schema = custom_dict(self)
# cast fields back to a pure list
dict_schema["fields"] = [
field for field_name, field in dict_schema["fields"].items()
Expand Down
32 changes: 9 additions & 23 deletions tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import pytest

from redisvl.schema.fields import NumericField, TagField, TextField
from redisvl.schema.schema import IndexSchema, StorageType
from redisvl.schema.fields import TagField, TextField
from redisvl.schema.schema import IndexSchema, StorageType, custom_dict


def get_base_path():
Expand All @@ -16,6 +16,12 @@ def create_sample_index_schema():
sample_fields = [
{"name": "example_text", "type": "text", "attrs": {"sortable": False}},
{"name": "example_numeric", "type": "numeric", "attrs": {"sortable": True}},
{"name": "example_tag", "type": "tag", "attrs": {"sortable": True}},
{
"name": "example_vector",
"type": "vector",
"attrs": {"dims": 1024, "algorithm": "flat"},
},
]
return IndexSchema.from_dict({"index": {"name": "test"}, "fields": sample_fields})

Expand Down Expand Up @@ -89,26 +95,6 @@ def test_remove_field():
assert "example_text" not in index_schema.field_names


def test_schema_compare():
"""Test schema comparisons."""
schema_1 = IndexSchema.from_dict({"index": {"name": "test"}})
# manually add the same fields as the helper method provides below
schema_1.add_fields(
[
{"name": "example_text", "type": "text", "attrs": {"sortable": False}},
{"name": "example_numeric", "type": "numeric", "attrs": {"sortable": True}},
]
)

assert "example_text" in schema_1.fields
assert "example_numeric" in schema_1.fields

schema_2 = create_sample_index_schema()
assert schema_1.fields == schema_2.fields
assert schema_1.index.name == schema_2.index.name
assert schema_1.to_dict() == schema_2.to_dict()


def test_generate_fields():
"""Test field generation."""
sample = {"name": "John", "age": 30, "tags": ["test", "test2"]}
Expand All @@ -126,7 +112,7 @@ def test_to_dict():
index_dict = index_schema.to_dict()
assert index_dict["index"]["name"] == "test"
assert isinstance(index_dict["fields"], list)
assert len(index_dict["fields"]) == 2 == len(index_schema.fields)
assert len(index_dict["fields"]) == 4 == len(index_schema.fields)


def test_from_dict():
Expand Down