Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions examples/builder_example.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from substrait.builders.plan import (
read_named_table,
project,
select,
filter,
sort,
fetch,
Expand Down Expand Up @@ -34,7 +34,7 @@ def basic_example():
expressions=[column("id"), literal(100, i64(nullable=False))],
),
)
table = project(table, expressions=[column("id")])
table = select(table, expressions=[column("id")])

print(table(registry))
pretty_print_plan(table(registry), use_colors=True)
Expand Down Expand Up @@ -177,7 +177,7 @@ def advanced_example():
expressions=[column("id"), literal(100, i64(nullable=False))],
),
)
table = project(table, expressions=[column("id")])
table = select(table, expressions=[column("id")])

print("Simple filtered table:")
pretty_print_plan(table(registry), use_colors=True)
Expand Down Expand Up @@ -212,7 +212,7 @@ def advanced_example():
)

# Project with calculated fields
enriched_users = project(
enriched_users = select(
adult_users,
expressions=[
column("user_id"),
Expand Down Expand Up @@ -322,7 +322,7 @@ def expression_only_example():
struct=struct(types=[fp64(nullable=False)], nullable=False),
)
dummy_table = read_named_table("dummy", dummy_schema)
dummy_plan = project(dummy_table, expressions=[complex_expr])
dummy_plan = select(dummy_table, expressions=[complex_expr])
pretty_print_plan(dummy_plan(registry), use_colors=True)

print("\n" + "=" * 50 + "\n")
Expand Down
4 changes: 2 additions & 2 deletions examples/duckdb_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


import duckdb
from substrait.builders.plan import read_named_table, project, filter
from substrait.builders.plan import read_named_table, select, filter
from substrait.builders.extended_expression import column, scalar_function, literal
from substrait.builders.type import i32
from substrait.extension_registry import ExtensionRegistry
Expand Down Expand Up @@ -46,7 +46,7 @@ def read_duckdb_named_table(name: str, conn):
expressions=[column("c_nationkey"), literal(3, i32())],
),
)
table = project(
table = select(
table, expressions=[column("c_name"), column("c_address"), column("c_nationkey")]
)
sql = "CALL from_substrait(?)"
Expand Down
4 changes: 2 additions & 2 deletions examples/pyarrow_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pyarrow.compute as pc
import pyarrow.substrait as pa_substrait
import substrait
from substrait.builders.plan import project, read_named_table
from substrait.builders.plan import select, read_named_table

arrow_schema = pa.schema([pa.field("x", pa.int32()), pa.field("y", pa.int32())])

Expand All @@ -24,5 +24,5 @@
pysubstrait_expr = substrait.proto.ExtendedExpression.FromString(bytes(substrait_expr))

table = read_named_table("example", substrait_schema)
table = project(table, expressions=[pysubstrait_expr])(None)
table = select(table, expressions=[pysubstrait_expr])(None)
print(table)
34 changes: 34 additions & 0 deletions src/substrait/builders/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,40 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
resolve_expression(e, ns, registry) for e in expressions
]

names = list(_plan.relations[-1].root.names) + [
e.output_names[0] for ee in bound_expressions for e in ee.referred_expr
]

rel = stalg.Rel(
project=stalg.ProjectRel(
input=_plan.relations[-1].root.input,
expressions=[
e.expression for ee in bound_expressions for e in ee.referred_expr
],
advanced_extension=extension,
)
)

return stp.Plan(
relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))],
**_merge_extensions(_plan, *bound_expressions),
)

return resolve


def select(
plan: PlanOrUnbound,
expressions: Iterable[ExtendedExpressionOrUnbound],
extension: Optional[AdvancedExtension] = None,
) -> UnboundPlan:
def resolve(registry: ExtensionRegistry) -> stp.Plan:
_plan = plan if isinstance(plan, stp.Plan) else plan(registry)
ns = infer_plan_schema(_plan)
bound_expressions: Iterable[stee.ExtendedExpression] = [
resolve_expression(e, ns, registry) for e in expressions
]

start_index = len(_plan.relations[-1].root.names)

names = [
Expand Down
4 changes: 2 additions & 2 deletions src/substrait/sql/sql_to_substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from substrait.builders.plan import (
read_named_table,
project,
select,
filter,
sort,
fetch,
Expand Down Expand Up @@ -309,7 +309,7 @@ def translate(ast: dict, schema_resolver: SchemaResolver, registry: ExtensionReg
if having_predicate:
relation = filter(relation, having_predicate)(registry)

return project(relation, expressions=projection)(registry)
return select(relation, expressions=projection)(registry)
elif op == "Table":
name = ast["name"][0]["Identifier"]["value"]
return read_named_table(name, schema_resolver(name))
Expand Down
37 changes: 36 additions & 1 deletion tests/builders/plan/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import substrait.gen.proto.plan_pb2 as stp
import substrait.gen.proto.algebra_pb2 as stalg
from substrait.builders.type import boolean, i64
from substrait.builders.plan import read_named_table, project
from substrait.builders.plan import read_named_table, select, project
from substrait.builders.extended_expression import column
from substrait.extension_registry import ExtensionRegistry

Expand All @@ -20,6 +20,41 @@ def test_project():

actual = project(table, [column("id")])(registry)

expected = stp.Plan(
relations=[
stp.PlanRel(
root=stalg.RelRoot(
input=stalg.Rel(
project=stalg.ProjectRel(
input=table(None).relations[-1].root.input,
expressions=[
stalg.Expression(
selection=stalg.Expression.FieldReference(
direct_reference=stalg.Expression.ReferenceSegment(
struct_field=stalg.Expression.ReferenceSegment.StructField(
field=0
)
),
root_reference=stalg.Expression.FieldReference.RootReference(),
)
)
],
)
),
names=["id", "is_applicable", "id"],
)
)
]
)

assert actual == expected


def test_select():
table = read_named_table("table", named_struct)

actual = select(table, [column("id")])(registry)

expected = stp.Plan(
relations=[
stp.PlanRel(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_uri_urn_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
aggregate_function,
)
from substrait.builders.type import i64
from substrait.builders.plan import read_named_table, aggregate, project, filter
from substrait.builders.plan import read_named_table, aggregate, select, filter
from substrait.extension_registry import ExtensionRegistry
from substrait.type_inference import infer_plan_schema

Expand Down Expand Up @@ -143,7 +143,7 @@ def test_project_outputs_both_uri_and_urn():
alias=["add"],
)

actual = project(table, [add_expr])(registry)
actual = select(table, [add_expr])(registry)

ns = infer_plan_schema(table(None))

Expand Down
Loading