Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,8 @@ codegen-extensions:
lint:
uvx ruff@0.11.11 check

lint_fix:
uvx ruff@0.11.11 check --fix

format:
uvx ruff@0.11.11 format
5 changes: 5 additions & 0 deletions src/substrait/builders/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def _merge_extensions(*objs):
def read_named_table(
names: Union[str, Iterable[str]], named_struct: stt.NamedStruct
) -> UnboundPlan:
if named_struct.struct.nullability is stt.Type.NULLABILITY_NULLABLE:
raise Exception("NamedStruct must not contain a nullable struct")
elif named_struct.struct.nullability is stt.Type.NULLABILITY_UNSPECIFIED:
named_struct.struct.nullability = stt.Type.NULLABILITY_REQUIRED

def resolve(registry: ExtensionRegistry) -> stp.Plan:
_names = [names] if isinstance(names, str) else names

Expand Down
5 changes: 5 additions & 0 deletions src/substrait/builders/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,4 +257,9 @@ def map(key: stt.Type, value: stt.Type, nullable=True) -> stt.Type:


def named_struct(names: Iterable[str], struct: stt.Type) -> stt.NamedStruct:
if struct.struct.nullability is stt.Type.NULLABILITY_NULLABLE:
raise Exception("NamedStruct must not contain a nullable struct")
elif struct.struct.nullability is stt.Type.NULLABILITY_UNSPECIFIED:
struct.struct.nullability = stt.Type.NULLABILITY_REQUIRED

return stt.NamedStruct(names=names, struct=struct.struct)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)),
stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)),
stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)),
]
],
nullability=stt.Type.Nullability.NULLABILITY_REQUIRED,
)

named_struct = stt.NamedStruct(
Expand Down
2 changes: 1 addition & 1 deletion tests/builders/extended_expression/test_scalar_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_sclar_add():
extensions=[
ste.SimpleExtensionDeclaration(
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
extension_uri_reference=1, function_anchor=1, name="test_func:i8"
extension_uri_reference=1, function_anchor=1, name="test_func:i8"
)
)
],
Expand Down
4 changes: 3 additions & 1 deletion tests/builders/plan/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
registry = ExtensionRegistry(load_default_extensions=False)
registry.register_extension_dict(yaml.safe_load(content), uri="test_uri")

struct = stt.Type.Struct(types=[i64(nullable=False), boolean()])
struct = stt.Type.Struct(
types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED
)

named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct)

Expand Down
8 changes: 6 additions & 2 deletions tests/builders/plan/test_cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@

registry = ExtensionRegistry(load_default_extensions=False)

struct = stt.Type.Struct(types=[i64(nullable=False), boolean()])
struct = stt.Type.Struct(
types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED
)

named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct)

named_struct_2 = stt.NamedStruct(
names=["fk_id", "name"],
struct=stt.Type.Struct(types=[i64(nullable=False), string()]),
struct=stt.Type.Struct(
types=[i64(nullable=False), string()], nullability=stt.Type.NULLABILITY_REQUIRED
),
)


Expand Down
4 changes: 3 additions & 1 deletion tests/builders/plan/test_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

registry = ExtensionRegistry(load_default_extensions=False)

struct = stt.Type.Struct(types=[i64(nullable=False), boolean()])
struct = stt.Type.Struct(
types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED
)

named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct)

Expand Down
4 changes: 3 additions & 1 deletion tests/builders/plan/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

registry = ExtensionRegistry(load_default_extensions=False)

struct = stt.Type.Struct(types=[i64(nullable=False), boolean()])
struct = stt.Type.Struct(
types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED
)

named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct)

Expand Down
8 changes: 6 additions & 2 deletions tests/builders/plan/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@

registry = ExtensionRegistry(load_default_extensions=False)

struct = stt.Type.Struct(types=[i64(nullable=False), boolean()])
struct = stt.Type.Struct(
types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED
)

named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct)

named_struct_2 = stt.NamedStruct(
names=["fk_id", "name"],
struct=stt.Type.Struct(types=[i64(nullable=False), string()]),
struct=stt.Type.Struct(
types=[i64(nullable=False), string()], nullability=stt.Type.NULLABILITY_REQUIRED
),
)


Expand Down
4 changes: 3 additions & 1 deletion tests/builders/plan/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

registry = ExtensionRegistry(load_default_extensions=False)

struct = stt.Type.Struct(types=[i64(nullable=False), boolean()])
struct = stt.Type.Struct(
types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED
)

named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct)

Expand Down
19 changes: 18 additions & 1 deletion tests/builders/plan/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
import substrait.gen.proto.algebra_pb2 as stalg
from substrait.builders.type import boolean, i64
from substrait.builders.plan import read_named_table
import pytest

struct = stt.Type.Struct(types=[i64(nullable=False), boolean()])
struct = stt.Type.Struct(
types=[i64(nullable=False), boolean()],
nullability=stt.Type.Nullability.NULLABILITY_REQUIRED,
)

named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct)

Expand Down Expand Up @@ -57,3 +61,16 @@ def test_read_rel_db():
)

assert actual == expected


def test_read_rel_schema_nullable():
struct = stt.Type.Struct(
types=[i64(nullable=False), boolean()],
nullability=stt.Type.Nullability.NULLABILITY_NULLABLE,
)

named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct)
with pytest.raises(
Exception, match=r"NamedStruct must not contain a nullable struct"
):
read_named_table("example_table", named_struct)(None)
4 changes: 3 additions & 1 deletion tests/builders/plan/test_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

registry = ExtensionRegistry(load_default_extensions=False)

struct = stt.Type.Struct(types=[i64(nullable=False), boolean()])
struct = stt.Type.Struct(
types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED
)

named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct)

Expand Down
4 changes: 3 additions & 1 deletion tests/builders/plan/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

registry = ExtensionRegistry(load_default_extensions=False)

struct = stt.Type.Struct(types=[i64(nullable=False), boolean()])
struct = stt.Type.Struct(
types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED
)

named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct)

Expand Down
37 changes: 37 additions & 0 deletions tests/builders/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import substrait.gen.proto.type_pb2 as stt
from substrait.builders.type import boolean, i64
from substrait.builders.type import named_struct
import pytest


def test_named_struct_required():
struct = stt.Type.Struct(
types=[i64(nullable=False), boolean()],
nullability=stt.Type.NULLABILITY_REQUIRED,
)

named = named_struct(names=["index", "valid"], struct=stt.Type(struct=struct))
assert named
assert named.struct.nullability == stt.Type.NULLABILITY_REQUIRED
assert named.names == ["index", "valid"]


def test_named_struct_unspecified():
struct = stt.Type.Struct(types=[i64(nullable=False), boolean()])

named = named_struct(names=["index", "valid"], struct=stt.Type(struct=struct))
assert named
assert named.struct.nullability == stt.Type.NULLABILITY_REQUIRED
assert named.names == ["index", "valid"]


def test_named_struct_nullable():
struct = stt.Type.Struct(
types=[i64(nullable=False), boolean()],
nullability=stt.Type.NULLABILITY_NULLABLE,
)

with pytest.raises(
Exception, match=r"NamedStruct must not contain a nullable struct"
):
named_struct(names=["index", "valid"], struct=stt.Type(struct=struct))