From 26f4de03903cc155b7cb707edd83950d60bc5a1d Mon Sep 17 00:00:00 2001 From: MBWhite Date: Thu, 16 Oct 2025 14:36:54 +0100 Subject: [PATCH 1/3] fix: add defensive check for nullable structs in builder - if the struct is marked as NULLABLE raise an error - if the struct is defaulting to UNSPECIFIED, then make it REQUIRED Signed-off-by: MBWhite --- src/substrait/builders/plan.py | 6 ++++++ .../test_aggregate_function.py | 3 ++- tests/builders/plan/test_aggregate.py | 2 +- tests/builders/plan/test_cross.py | 4 ++-- tests/builders/plan/test_fetch.py | 2 +- tests/builders/plan/test_filter.py | 2 +- tests/builders/plan/test_join.py | 4 ++-- tests/builders/plan/test_project.py | 2 +- tests/builders/plan/test_read.py | 21 ++++++++++++++++++- tests/builders/plan/test_set.py | 2 +- tests/builders/plan/test_sort.py | 2 +- 11 files changed, 38 insertions(+), 12 deletions(-) diff --git a/src/substrait/builders/plan.py b/src/substrait/builders/plan.py index eca32da..77362ba 100644 --- a/src/substrait/builders/plan.py +++ b/src/substrait/builders/plan.py @@ -34,6 +34,12 @@ def _merge_extensions(*objs): def read_named_table( names: Union[str, Iterable[str]], named_struct: stt.NamedStruct ) -> UnboundPlan: + + if named_struct.struct.nullability is stt.Type.NULLABILITY_NULLABLE: + raise Exception("NamedStruct must not contain a nullable struct") + elif named_struct.struct.nullability is stt.Type.NULLABILITY_UNSPECIFIED: + named_struct.struct.nullability = stt.Type.NULLABILITY_REQUIRED + def resolve(registry: ExtensionRegistry) -> stp.Plan: _names = [names] if isinstance(names, str) else names diff --git a/tests/builders/extended_expression/test_aggregate_function.py b/tests/builders/extended_expression/test_aggregate_function.py index 8415df4..3211c0c 100644 --- a/tests/builders/extended_expression/test_aggregate_function.py +++ b/tests/builders/extended_expression/test_aggregate_function.py @@ -12,7 +12,8 @@ stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), - ] + ], + nullability=stt.Type.Nullability.NULLABILITY_REQUIRED, ) named_struct = stt.NamedStruct( diff --git a/tests/builders/plan/test_aggregate.py b/tests/builders/plan/test_aggregate.py index 9b36ad6..7f51f2e 100644 --- a/tests/builders/plan/test_aggregate.py +++ b/tests/builders/plan/test_aggregate.py @@ -28,7 +28,7 @@ registry = ExtensionRegistry(load_default_extensions=False) registry.register_extension_dict(yaml.safe_load(content), uri="test_uri") -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()],nullability=stt.Type.NULLABILITY_REQUIRED) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) diff --git a/tests/builders/plan/test_cross.py b/tests/builders/plan/test_cross.py index 9a47ba4..2677086 100644 --- a/tests/builders/plan/test_cross.py +++ b/tests/builders/plan/test_cross.py @@ -7,13 +7,13 @@ registry = ExtensionRegistry(load_default_extensions=False) -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()],nullability=stt.Type.NULLABILITY_REQUIRED) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) named_struct_2 = stt.NamedStruct( names=["fk_id", "name"], - struct=stt.Type.Struct(types=[i64(nullable=False), string()]), + struct=stt.Type.Struct(types=[i64(nullable=False), string()],nullability=stt.Type.NULLABILITY_REQUIRED), ) diff --git a/tests/builders/plan/test_fetch.py b/tests/builders/plan/test_fetch.py index ebcd372..96427f6 100644 --- a/tests/builders/plan/test_fetch.py +++ b/tests/builders/plan/test_fetch.py @@ -8,7 +8,7 @@ registry = ExtensionRegistry(load_default_extensions=False) -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()],nullability=stt.Type.NULLABILITY_REQUIRED) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) diff --git a/tests/builders/plan/test_filter.py b/tests/builders/plan/test_filter.py index 659f402..6a35873 100644 --- a/tests/builders/plan/test_filter.py +++ b/tests/builders/plan/test_filter.py @@ -8,7 +8,7 @@ registry = ExtensionRegistry(load_default_extensions=False) -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()],nullability=stt.Type.NULLABILITY_REQUIRED) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) diff --git a/tests/builders/plan/test_join.py b/tests/builders/plan/test_join.py index 8d4998a..8bd88f2 100644 --- a/tests/builders/plan/test_join.py +++ b/tests/builders/plan/test_join.py @@ -8,13 +8,13 @@ registry = ExtensionRegistry(load_default_extensions=False) -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()],nullability=stt.Type.NULLABILITY_REQUIRED) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) named_struct_2 = stt.NamedStruct( names=["fk_id", "name"], - struct=stt.Type.Struct(types=[i64(nullable=False), string()]), + struct=stt.Type.Struct(types=[i64(nullable=False), string()],nullability=stt.Type.NULLABILITY_REQUIRED), ) diff --git a/tests/builders/plan/test_project.py b/tests/builders/plan/test_project.py index 2dd9ff1..5bdfab1 100644 --- a/tests/builders/plan/test_project.py +++ b/tests/builders/plan/test_project.py @@ -8,7 +8,7 @@ registry = ExtensionRegistry(load_default_extensions=False) -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()],nullability=stt.Type.NULLABILITY_REQUIRED) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) diff --git a/tests/builders/plan/test_read.py b/tests/builders/plan/test_read.py index 6e380b8..9ad803e 100644 --- a/tests/builders/plan/test_read.py +++ b/tests/builders/plan/test_read.py @@ -3,8 +3,12 @@ import substrait.gen.proto.algebra_pb2 as stalg from substrait.builders.type import boolean, i64 from substrait.builders.plan import read_named_table +import pytest -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) +struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], + nullability=stt.Type.Nullability.NULLABILITY_REQUIRED, +) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) @@ -57,3 +61,18 @@ def test_read_rel_db(): ) assert actual == expected + + +def test_read_rel_schema_nullable(): + + struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], + nullability=stt.Type.Nullability.NULLABILITY_NULLABLE, + ) + + named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) + with pytest.raises(Exception,match=r"NamedStruct must not contain a nullable struct"): + read_named_table("example_table", named_struct)(None) + + + diff --git a/tests/builders/plan/test_set.py b/tests/builders/plan/test_set.py index e761707..789d6a2 100644 --- a/tests/builders/plan/test_set.py +++ b/tests/builders/plan/test_set.py @@ -7,7 +7,7 @@ registry = ExtensionRegistry(load_default_extensions=False) -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()],nullability=stt.Type.NULLABILITY_REQUIRED) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) diff --git a/tests/builders/plan/test_sort.py b/tests/builders/plan/test_sort.py index 4b4f49c..e0c4a58 100644 --- a/tests/builders/plan/test_sort.py +++ b/tests/builders/plan/test_sort.py @@ -9,7 +9,7 @@ registry = ExtensionRegistry(load_default_extensions=False) -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()],nullability=stt.Type.NULLABILITY_REQUIRED) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) From d1d4fed277a17f658865fa1914b40dcd7352dfb8 Mon Sep 17 00:00:00 2001 From: MBWhite Date: Fri, 17 Oct 2025 10:11:33 +0100 Subject: [PATCH 2/3] Replicate fix into the namedstuct builder Signed-off-by: MBWhite --- src/substrait/builders/plan.py | 1 - src/substrait/builders/type.py | 5 +++ .../test_scalar_function.py | 2 +- tests/builders/plan/test_aggregate.py | 4 +- tests/builders/plan/test_cross.py | 8 +++- tests/builders/plan/test_fetch.py | 4 +- tests/builders/plan/test_filter.py | 4 +- tests/builders/plan/test_join.py | 8 +++- tests/builders/plan/test_project.py | 4 +- tests/builders/plan/test_read.py | 8 ++-- tests/builders/plan/test_set.py | 4 +- tests/builders/plan/test_sort.py | 4 +- tests/builders/test_types.py | 42 +++++++++++++++++++ 13 files changed, 81 insertions(+), 17 deletions(-) create mode 100644 tests/builders/test_types.py diff --git a/src/substrait/builders/plan.py b/src/substrait/builders/plan.py index 77362ba..ed487c7 100644 --- a/src/substrait/builders/plan.py +++ b/src/substrait/builders/plan.py @@ -34,7 +34,6 @@ def _merge_extensions(*objs): def read_named_table( names: Union[str, Iterable[str]], named_struct: stt.NamedStruct ) -> UnboundPlan: - if named_struct.struct.nullability is stt.Type.NULLABILITY_NULLABLE: raise Exception("NamedStruct must not contain a nullable struct") elif named_struct.struct.nullability is stt.Type.NULLABILITY_UNSPECIFIED: diff --git a/src/substrait/builders/type.py b/src/substrait/builders/type.py index 8405595..39ed5e6 100644 --- a/src/substrait/builders/type.py +++ b/src/substrait/builders/type.py @@ -257,4 +257,9 @@ def map(key: stt.Type, value: stt.Type, nullable=True) -> stt.Type: def named_struct(names: Iterable[str], struct: stt.Type) -> stt.NamedStruct: + if struct.struct.nullability is stt.Type.NULLABILITY_NULLABLE: + raise Exception("NamedStruct must not contain a nullable struct") + elif struct.struct.nullability is stt.Type.NULLABILITY_UNSPECIFIED: + struct.struct.nullability = stt.Type.NULLABILITY_REQUIRED + return stt.NamedStruct(names=names, struct=struct.struct) diff --git a/tests/builders/extended_expression/test_scalar_function.py b/tests/builders/extended_expression/test_scalar_function.py index 0f3e9e8..7c0bdeb 100644 --- a/tests/builders/extended_expression/test_scalar_function.py +++ b/tests/builders/extended_expression/test_scalar_function.py @@ -68,7 +68,7 @@ def test_sclar_add(): extensions=[ ste.SimpleExtensionDeclaration( extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( - extension_uri_reference=1, function_anchor=1, name="test_func:i8" + extension_uri_reference=1, function_anchor=1, name="test_func:i8" ) ) ], diff --git a/tests/builders/plan/test_aggregate.py b/tests/builders/plan/test_aggregate.py index 7f51f2e..7b15c67 100644 --- a/tests/builders/plan/test_aggregate.py +++ b/tests/builders/plan/test_aggregate.py @@ -28,7 +28,9 @@ registry = ExtensionRegistry(load_default_extensions=False) registry.register_extension_dict(yaml.safe_load(content), uri="test_uri") -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()],nullability=stt.Type.NULLABILITY_REQUIRED) +struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED +) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) diff --git a/tests/builders/plan/test_cross.py b/tests/builders/plan/test_cross.py index 2677086..e022706 100644 --- a/tests/builders/plan/test_cross.py +++ b/tests/builders/plan/test_cross.py @@ -7,13 +7,17 @@ registry = ExtensionRegistry(load_default_extensions=False) -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()],nullability=stt.Type.NULLABILITY_REQUIRED) +struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED +) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) named_struct_2 = stt.NamedStruct( names=["fk_id", "name"], - struct=stt.Type.Struct(types=[i64(nullable=False), string()],nullability=stt.Type.NULLABILITY_REQUIRED), + struct=stt.Type.Struct( + types=[i64(nullable=False), string()], nullability=stt.Type.NULLABILITY_REQUIRED + ), ) diff --git a/tests/builders/plan/test_fetch.py b/tests/builders/plan/test_fetch.py index 96427f6..a9f8a4d 100644 --- a/tests/builders/plan/test_fetch.py +++ b/tests/builders/plan/test_fetch.py @@ -8,7 +8,9 @@ registry = ExtensionRegistry(load_default_extensions=False) -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()],nullability=stt.Type.NULLABILITY_REQUIRED) +struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED +) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) diff --git a/tests/builders/plan/test_filter.py b/tests/builders/plan/test_filter.py index 6a35873..bba55ec 100644 --- a/tests/builders/plan/test_filter.py +++ b/tests/builders/plan/test_filter.py @@ -8,7 +8,9 @@ registry = ExtensionRegistry(load_default_extensions=False) -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()],nullability=stt.Type.NULLABILITY_REQUIRED) +struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED +) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) diff --git a/tests/builders/plan/test_join.py b/tests/builders/plan/test_join.py index 8bd88f2..d6503cc 100644 --- a/tests/builders/plan/test_join.py +++ b/tests/builders/plan/test_join.py @@ -8,13 +8,17 @@ registry = ExtensionRegistry(load_default_extensions=False) -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()],nullability=stt.Type.NULLABILITY_REQUIRED) +struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED +) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) named_struct_2 = stt.NamedStruct( names=["fk_id", "name"], - struct=stt.Type.Struct(types=[i64(nullable=False), string()],nullability=stt.Type.NULLABILITY_REQUIRED), + struct=stt.Type.Struct( + types=[i64(nullable=False), string()], nullability=stt.Type.NULLABILITY_REQUIRED + ), ) diff --git a/tests/builders/plan/test_project.py b/tests/builders/plan/test_project.py index 5bdfab1..e66a499 100644 --- a/tests/builders/plan/test_project.py +++ b/tests/builders/plan/test_project.py @@ -8,7 +8,9 @@ registry = ExtensionRegistry(load_default_extensions=False) -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()],nullability=stt.Type.NULLABILITY_REQUIRED) +struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED +) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) diff --git a/tests/builders/plan/test_read.py b/tests/builders/plan/test_read.py index 9ad803e..909216a 100644 --- a/tests/builders/plan/test_read.py +++ b/tests/builders/plan/test_read.py @@ -64,15 +64,13 @@ def test_read_rel_db(): def test_read_rel_schema_nullable(): - struct = stt.Type.Struct( types=[i64(nullable=False), boolean()], nullability=stt.Type.Nullability.NULLABILITY_NULLABLE, ) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) - with pytest.raises(Exception,match=r"NamedStruct must not contain a nullable struct"): + with pytest.raises( + Exception, match=r"NamedStruct must not contain a nullable struct" + ): read_named_table("example_table", named_struct)(None) - - - diff --git a/tests/builders/plan/test_set.py b/tests/builders/plan/test_set.py index 789d6a2..e35570d 100644 --- a/tests/builders/plan/test_set.py +++ b/tests/builders/plan/test_set.py @@ -7,7 +7,9 @@ registry = ExtensionRegistry(load_default_extensions=False) -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()],nullability=stt.Type.NULLABILITY_REQUIRED) +struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED +) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) diff --git a/tests/builders/plan/test_sort.py b/tests/builders/plan/test_sort.py index e0c4a58..7fee44f 100644 --- a/tests/builders/plan/test_sort.py +++ b/tests/builders/plan/test_sort.py @@ -9,7 +9,9 @@ registry = ExtensionRegistry(load_default_extensions=False) -struct = stt.Type.Struct(types=[i64(nullable=False), boolean()],nullability=stt.Type.NULLABILITY_REQUIRED) +struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED +) named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) diff --git a/tests/builders/test_types.py b/tests/builders/test_types.py new file mode 100644 index 0000000..1afc2d3 --- /dev/null +++ b/tests/builders/test_types.py @@ -0,0 +1,42 @@ +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.algebra_pb2 as stalg +from substrait.builders.type import boolean, i64 +from substrait.builders.plan import read_named_table, sort +from substrait.builders.extended_expression import column +from substrait.type_inference import infer_plan_schema +from substrait.builders.type import named_struct +import pytest + + +def test_named_struct_required(): + struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], + nullability=stt.Type.NULLABILITY_REQUIRED, + ) + + named = named_struct(names=["index", "valid"], struct=stt.Type(struct=struct)) + assert named + assert named.struct.nullability == stt.Type.NULLABILITY_REQUIRED + assert named.names == ["index", "valid"] + + +def test_named_struct_unspecified(): + struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) + + named = named_struct(names=["index", "valid"], struct=stt.Type(struct=struct)) + assert named + assert named.struct.nullability == stt.Type.NULLABILITY_REQUIRED + assert named.names == ["index", "valid"] + + +def test_named_struct_nullable(): + struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], + nullability=stt.Type.NULLABILITY_NULLABLE, + ) + + with pytest.raises( + Exception, match=r"NamedStruct must not contain a nullable struct" + ): + named_struct(names=["index", "valid"], struct=stt.Type(struct=struct)) From 25cb058b22a9b8daa5bca5bf8ac0f2d12459c737 Mon Sep 17 00:00:00 2001 From: MBWhite Date: Fri, 17 Oct 2025 10:16:44 +0100 Subject: [PATCH 3/3] lint fixes Signed-off-by: MBWhite --- Makefile | 3 +++ tests/builders/test_types.py | 5 ----- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/Makefile b/Makefile index b293581..54782b1 100644 --- a/Makefile +++ b/Makefile @@ -14,5 +14,8 @@ codegen-extensions: lint: uvx ruff@0.11.11 check +lint_fix: + uvx ruff@0.11.11 check --fix + format: uvx ruff@0.11.11 format diff --git a/tests/builders/test_types.py b/tests/builders/test_types.py index 1afc2d3..3d0e2d3 100644 --- a/tests/builders/test_types.py +++ b/tests/builders/test_types.py @@ -1,10 +1,5 @@ import substrait.gen.proto.type_pb2 as stt -import substrait.gen.proto.plan_pb2 as stp -import substrait.gen.proto.algebra_pb2 as stalg from substrait.builders.type import boolean, i64 -from substrait.builders.plan import read_named_table, sort -from substrait.builders.extended_expression import column -from substrait.type_inference import infer_plan_schema from substrait.builders.type import named_struct import pytest