From 036222613b771296315470f815905c0e94cae15a Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Tue, 9 Apr 2024 10:08:19 -0400 Subject: [PATCH 1/2] address schema serialization bug --- redisvl/schema/schema.py | 36 ++++++++++++++++++++++++++---------- tests/unit/test_schema.py | 32 +++++++++----------------------- 2 files changed, 35 insertions(+), 33 deletions(-) diff --git a/redisvl/schema/schema.py b/redisvl/schema/schema.py index 36c1ab80..1048b02d 100644 --- a/redisvl/schema/schema.py +++ b/redisvl/schema/schema.py @@ -14,6 +14,28 @@ SCHEMA_VERSION = "0.1.0" +def custom_dict(model: BaseModel) -> Dict[str, Any]: + """ + Custom serialization function that converts a Pydantic model to a dict, + serializing Enum fields to their values, and handling nested models and lists. + """ + + def serialize_item(item): + if isinstance(item, Enum): + return item.value.lower() + elif isinstance(item, dict): + return {key: serialize_item(value) for key, value in item.items()} + elif isinstance(item, list): + return [serialize_item(element) for element in item] + else: + return item + + serialized_data = model.dict(exclude_none=True) + for key, value in serialized_data.items(): + serialized_data[key] = serialize_item(value) + return serialized_data + + class StorageType(Enum): """ Enumeration for the storage types supported in Redis. @@ -63,14 +85,6 @@ class IndexInfo(BaseModel): storage_type: StorageType = StorageType.HASH """The storage type used in Redis (e.g., 'hash' or 'json').""" - def dict(self, *args, **kwargs) -> Dict[str, Any]: - return { - "name": self.name, - "prefix": self.prefix, - "key_separator": self.key_separator, - "storage_type": self.storage_type.value, - } - class IndexSchema(BaseModel): """A schema definition for a search index in Redis, used in RedisVL for @@ -428,12 +442,13 @@ def generate_fields( return fields def to_dict(self) -> Dict[str, Any]: - """Convert the index schema to a dictionary. + """Serialize the index schema model to a dictionary, handling Enums + and other special cases properly. Returns: Dict[str, Any]: The index schema as a dictionary. """ - dict_schema = self.dict(exclude_none=True) + dict_schema = custom_dict(self) # cast fields back to a pure list dict_schema["fields"] = [ field for field_name, field in dict_schema["fields"].items() @@ -456,6 +471,7 @@ def to_yaml(self, file_path: str, overwrite: bool = True) -> None: with open(fp, "w") as f: yaml_data = self.to_dict() + print(yaml_data) yaml.dump(yaml_data, f, sort_keys=False) diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py index a85b0763..f6e84702 100644 --- a/tests/unit/test_schema.py +++ b/tests/unit/test_schema.py @@ -3,8 +3,8 @@ import pytest -from redisvl.schema.fields import NumericField, TagField, TextField -from redisvl.schema.schema import IndexSchema, StorageType +from redisvl.schema.fields import TagField, TextField +from redisvl.schema.schema import IndexSchema, StorageType, custom_dict def get_base_path(): @@ -16,6 +16,12 @@ def create_sample_index_schema(): sample_fields = [ {"name": "example_text", "type": "text", "attrs": {"sortable": False}}, {"name": "example_numeric", "type": "numeric", "attrs": {"sortable": True}}, + {"name": "example_tag", "type": "tag", "attrs": {"sortable": True}}, + { + "name": "example_vector", + "type": "vector", + "attrs": {"dims": 1024, "algorithm": "flat"}, + }, ] return IndexSchema.from_dict({"index": {"name": "test"}, "fields": sample_fields}) @@ -89,26 +95,6 @@ def test_remove_field(): assert "example_text" not in index_schema.field_names -def test_schema_compare(): - """Test schema comparisons.""" - schema_1 = IndexSchema.from_dict({"index": {"name": "test"}}) - # manually add the same fields as the helper method provides below - schema_1.add_fields( - [ - {"name": "example_text", "type": "text", "attrs": {"sortable": False}}, - {"name": "example_numeric", "type": "numeric", "attrs": {"sortable": True}}, - ] - ) - - assert "example_text" in schema_1.fields - assert "example_numeric" in schema_1.fields - - schema_2 = create_sample_index_schema() - assert schema_1.fields == schema_2.fields - assert schema_1.index.name == schema_2.index.name - assert schema_1.to_dict() == schema_2.to_dict() - - def test_generate_fields(): """Test field generation.""" sample = {"name": "John", "age": 30, "tags": ["test", "test2"]} @@ -126,7 +112,7 @@ def test_to_dict(): index_dict = index_schema.to_dict() assert index_dict["index"]["name"] == "test" assert isinstance(index_dict["fields"], list) - assert len(index_dict["fields"]) == 2 == len(index_schema.fields) + assert len(index_dict["fields"]) == 4 == len(index_schema.fields) def test_from_dict(): From 7ab6a2bac9b920c7d136e23e1c85585f128e6f54 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Tue, 9 Apr 2024 11:31:29 -0400 Subject: [PATCH 2/2] remove print yaml --- redisvl/schema/schema.py | 1 - 1 file changed, 1 deletion(-) diff --git a/redisvl/schema/schema.py b/redisvl/schema/schema.py index 1048b02d..d165cff4 100644 --- a/redisvl/schema/schema.py +++ b/redisvl/schema/schema.py @@ -471,7 +471,6 @@ def to_yaml(self, file_path: str, overwrite: bool = True) -> None: with open(fp, "w") as f: yaml_data = self.to_dict() - print(yaml_data) yaml.dump(yaml_data, f, sort_keys=False)