diff --git a/Makefile b/Makefile index b293581..54782b1 100644 --- a/Makefile +++ b/Makefile @@ -14,5 +14,8 @@ codegen-extensions: lint: uvx ruff@0.11.11 check +lint_fix: + uvx ruff@0.11.11 check --fix + format: uvx ruff@0.11.11 format diff --git a/src/substrait/builders/plan.py b/src/substrait/builders/plan.py index eca32da..ed487c7 100644 --- a/src/substrait/builders/plan.py +++ b/src/substrait/builders/plan.py @@ -34,6 +34,11 @@ def _merge_extensions(*objs): def read_named_table( names: Union[str, Iterable[str]], named_struct: stt.NamedStruct ) -> UnboundPlan: + if named_struct.struct.nullability is stt.Type.NULLABILITY_NULLABLE: + raise Exception("NamedStruct must not contain a nullable struct") + elif named_struct.struct.nullability is stt.Type.NULLABILITY_UNSPECIFIED: + named_struct.struct.nullability = stt.Type.NULLABILITY_REQUIRED + def resolve(registry: ExtensionRegistry) -> stp.Plan: _names = [names] if isinstance(names, str) else names diff --git a/src/substrait/builders/type.py b/src/substrait/builders/type.py index 8405595..39ed5e6 100644 --- a/src/substrait/builders/type.py +++ b/src/substrait/builders/type.py @@ -257,4 +257,9 @@ def map(key: stt.Type, value: stt.Type, nullable=True) -> stt.Type: def named_struct(names: Iterable[str], struct: stt.Type) -> stt.NamedStruct: + if struct.struct.nullability is stt.Type.NULLABILITY_NULLABLE: + raise Exception("NamedStruct must not contain a nullable struct") + elif struct.struct.nullability is stt.Type.NULLABILITY_UNSPECIFIED: + struct.struct.nullability = stt.Type.NULLABILITY_REQUIRED + return stt.NamedStruct(names=names, struct=struct.struct) diff --git a/tests/builders/extended_expression/test_aggregate_function.py b/tests/builders/extended_expression/test_aggregate_function.py index 8415df4..3211c0c 100644 --- a/tests/builders/extended_expression/test_aggregate_function.py +++ b/tests/builders/extended_expression/test_aggregate_function.py @@ -12,7 +12,8 @@ stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), - ] + ], + nullability=stt.Type.Nullability.NULLABILITY_REQUIRED, ) named_struct = stt.NamedStruct( diff --git a/tests/builders/extended_expression/test_scalar_function.py b/tests/builders/extended_expression/test_scalar_function.py index 0f3e9e8..7c0bdeb 100644 --- a/tests/builders/extended_expression/test_scalar_function.py +++ b/tests/builders/extended_expression/test_scalar_function.py @@ -68,7 +68,7 @@ def test_sclar_add(): extensions=[ ste.SimpleExtensionDeclaration( extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( - extension_uri_reference=1, function_anchor=1, name="test_func:i8" + extension_uri_reference=1, function_anchor=1, name="test_func:i8" ) ) ], diff --git a/tests/builders/plan/test_aggregate.py b/tests/builders/plan/test_aggregate.py index 9b36ad6..7b15c67 100644 --- a/tests/builders/plan/test_aggregate.py +++ b/tests/builders/plan/test_aggregate.py @@ -28,7 +28,9 @@ registry = ExtensionRegistry(load_default_extensions=False) registry.register_extension_dict(yaml.safe_load(content), uri="test_uri") -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) +struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED +) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) diff --git a/tests/builders/plan/test_cross.py b/tests/builders/plan/test_cross.py index 9a47ba4..e022706 100644 --- a/tests/builders/plan/test_cross.py +++ b/tests/builders/plan/test_cross.py @@ -7,13 +7,17 @@ registry = ExtensionRegistry(load_default_extensions=False) -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) +struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED +) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) named_struct_2 = stt.NamedStruct( names=["fk_id", "name"], - struct=stt.Type.Struct(types=[i64(nullable=False), string()]), + struct=stt.Type.Struct( + types=[i64(nullable=False), string()], nullability=stt.Type.NULLABILITY_REQUIRED + ), ) diff --git a/tests/builders/plan/test_fetch.py b/tests/builders/plan/test_fetch.py index ebcd372..a9f8a4d 100644 --- a/tests/builders/plan/test_fetch.py +++ b/tests/builders/plan/test_fetch.py @@ -8,7 +8,9 @@ registry = ExtensionRegistry(load_default_extensions=False) -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) +struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED +) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) diff --git a/tests/builders/plan/test_filter.py b/tests/builders/plan/test_filter.py index 659f402..bba55ec 100644 --- a/tests/builders/plan/test_filter.py +++ b/tests/builders/plan/test_filter.py @@ -8,7 +8,9 @@ registry = ExtensionRegistry(load_default_extensions=False) -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) +struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED +) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) diff --git a/tests/builders/plan/test_join.py b/tests/builders/plan/test_join.py index 8d4998a..d6503cc 100644 --- a/tests/builders/plan/test_join.py +++ b/tests/builders/plan/test_join.py @@ -8,13 +8,17 @@ registry = ExtensionRegistry(load_default_extensions=False) -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) +struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED +) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) named_struct_2 = stt.NamedStruct( names=["fk_id", "name"], - struct=stt.Type.Struct(types=[i64(nullable=False), string()]), + struct=stt.Type.Struct( + types=[i64(nullable=False), string()], nullability=stt.Type.NULLABILITY_REQUIRED + ), ) diff --git a/tests/builders/plan/test_project.py b/tests/builders/plan/test_project.py index 2dd9ff1..e66a499 100644 --- a/tests/builders/plan/test_project.py +++ b/tests/builders/plan/test_project.py @@ -8,7 +8,9 @@ registry = ExtensionRegistry(load_default_extensions=False) -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) +struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED +) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) diff --git a/tests/builders/plan/test_read.py b/tests/builders/plan/test_read.py index 6e380b8..909216a 100644 --- a/tests/builders/plan/test_read.py +++ b/tests/builders/plan/test_read.py @@ -3,8 +3,12 @@ import substrait.gen.proto.algebra_pb2 as stalg from substrait.builders.type import boolean, i64 from substrait.builders.plan import read_named_table +import pytest -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) +struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], + nullability=stt.Type.Nullability.NULLABILITY_REQUIRED, +) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) @@ -57,3 +61,16 @@ def test_read_rel_db(): ) assert actual == expected + + +def test_read_rel_schema_nullable(): + struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], + nullability=stt.Type.Nullability.NULLABILITY_NULLABLE, + ) + + named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) + with pytest.raises( + Exception, match=r"NamedStruct must not contain a nullable struct" + ): + read_named_table("example_table", named_struct)(None) diff --git a/tests/builders/plan/test_set.py b/tests/builders/plan/test_set.py index e761707..e35570d 100644 --- a/tests/builders/plan/test_set.py +++ b/tests/builders/plan/test_set.py @@ -7,7 +7,9 @@ registry = ExtensionRegistry(load_default_extensions=False) -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) +struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED +) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) diff --git a/tests/builders/plan/test_sort.py b/tests/builders/plan/test_sort.py index 4b4f49c..7fee44f 100644 --- a/tests/builders/plan/test_sort.py +++ b/tests/builders/plan/test_sort.py @@ -9,7 +9,9 @@ registry = ExtensionRegistry(load_default_extensions=False) -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) +struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED +) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) diff --git a/tests/builders/test_types.py b/tests/builders/test_types.py new file mode 100644 index 0000000..3d0e2d3 --- /dev/null +++ b/tests/builders/test_types.py @@ -0,0 +1,37 @@ +import substrait.gen.proto.type_pb2 as stt +from substrait.builders.type import boolean, i64 +from substrait.builders.type import named_struct +import pytest + + +def test_named_struct_required(): + struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], + nullability=stt.Type.NULLABILITY_REQUIRED, + ) + + named = named_struct(names=["index", "valid"], struct=stt.Type(struct=struct)) + assert named + assert named.struct.nullability == stt.Type.NULLABILITY_REQUIRED + assert named.names == ["index", "valid"] + + +def test_named_struct_unspecified(): + struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) + + named = named_struct(names=["index", "valid"], struct=stt.Type(struct=struct)) + assert named + assert named.struct.nullability == stt.Type.NULLABILITY_REQUIRED + assert named.names == ["index", "valid"] + + +def test_named_struct_nullable(): + struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], + nullability=stt.Type.NULLABILITY_NULLABLE, + ) + + with pytest.raises( + Exception, match=r"NamedStruct must not contain a nullable struct" + ): + named_struct(names=["index", "valid"], struct=stt.Type(struct=struct))