Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
// "forwardPorts": [],

// Use 'postCreateCommand' to run commands after the container is created.
"postCreateCommand": "uv venv --clear && uv sync --extra test",
"postCreateCommand": "uv venv --clear && uv sync --extra test --extra gen_proto",

// Configure tool-specific properties.
"customizations": {
Expand Down
44 changes: 44 additions & 0 deletions .github/workflows/codegen-check.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
name: Code Generation Check

on:
pull_request:

permissions:
contents: read

jobs:
codegen-check:
name: Verify Code Generation
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v5
with:
submodules: recursive

- name: Run code generation in devcontainer
uses: devcontainers/ci@v0.3
with:
runCmd: |
# Ensure dependencies are installed
uv sync --extra test --extra gen_proto
# Run all code generation steps
make antlr
./gen_proto.sh
make codegen-extensions

- name: Check for uncommitted changes
run: |
# Check for diffs, ignoring timestamp lines
if ! git diff --quiet --exit-code src/substrait/gen/; then
echo "Code generation produced changes. Generated code is out of sync!"
echo ""
git diff src/substrait/gen/
echo ""
echo "To fix this, run:"
echo " make antlr"
echo " ./gen_proto.sh"
echo " make codegen-extensions"
echo "Then commit the changes."
exit 1
fi
36 changes: 36 additions & 0 deletions .github/workflows/example.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
name: Run examples

on:
pull_request:
push:
branches: [ main ]

permissions:
contents: read

jobs:
example:
name: Run ${{ matrix.example }}
runs-on: ubuntu-latest
strategy:
matrix:
example:
- builder_example.py
- duckdb_example.py
- adbc_example.py
- pyarrow_example.py
steps:
- name: Checkout code
uses: actions/checkout@v5
with:
submodules: recursive
- name: Install uv with python
uses: astral-sh/setup-uv@v7
with:
python-version: "3.10"
- name: Install package dependencies
run: |
uv sync --frozen --extra extensions
- name: Run ${{ matrix.example }}
run: |
uv run examples/${{ matrix.example }}
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ codegen-extensions:
--input-file-type jsonschema \
--input third_party/substrait/text/simple_extensions_schema.yaml \
--output src/substrait/gen/json/simple_extensions.py \
--output-model-type dataclasses.dataclass
--output-model-type dataclasses.dataclass \
--disable-timestamp

lint:
uvx ruff@0.11.11 check
Expand Down
2 changes: 1 addition & 1 deletion examples/adbc_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def read_adbc_named_table(name: str, conn):
table = filter(
table,
expression=scalar_function(
"functions_comparison.yaml",
"extension:io.substrait:functions_comparison",
"gte",
expressions=[column("ints"), literal(3, i64())],
),
Expand Down
59 changes: 33 additions & 26 deletions examples/builder_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@
def basic_example():
ns = named_struct(
names=["id", "is_applicable"],
struct=struct(types=[i64(nullable=False), boolean()]),
struct=struct(types=[i64(nullable=False), boolean()], nullable=False),
)

table = read_named_table("example_table", ns)
table = filter(table, expression=column("is_applicable"))
table = filter(
table,
expression=scalar_function(
"functions_comparison.yaml",
"extension:io.substrait:functions_comparison",
"lt",
expressions=[column("id"), literal(100, i64())],
expressions=[column("id"), literal(100, i64(nullable=False))],
),
)
table = project(table, expressions=[column("id")])
Expand All @@ -41,14 +41,15 @@ def basic_example():

"""
extension_uris {
extension_uri_anchor: 13
uri: "functions_comparison.yaml"
extension_uri_anchor: 2
uri: "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml"
}
extensions {
extension_function {
extension_uri_reference: 13
function_anchor: 495
name: "lt"
extension_uri_reference: 2
function_anchor: 124
name: "lt:any_any"
extension_urn_reference: 2
}
}
relations {
Expand Down Expand Up @@ -84,7 +85,7 @@ def basic_example():
nullability: NULLABILITY_NULLABLE
}
}
nullability: NULLABILITY_NULLABLE
nullability: NULLABILITY_REQUIRED
}
}
named_table {
Expand All @@ -107,10 +108,10 @@ def basic_example():
}
condition {
scalar_function {
function_reference: 495
function_reference: 124
output_type {
bool {
nullability: NULLABILITY_NULLABLE
nullability: NULLABILITY_REQUIRED
}
}
arguments {
Expand All @@ -129,7 +130,6 @@ def basic_example():
value {
literal {
i64: 100
nullable: true
}
}
}
Expand All @@ -152,25 +152,29 @@ def basic_example():
names: "id"
}
}
"""
extension_urns {
extension_urn_anchor: 2
urn: "extension:io.substrait:functions_comparison"
}
"""


def advanced_example():
print("=== Simple Example ===")
# Simple example (original)
ns = named_struct(
names=["id", "is_applicable"],
struct=struct(types=[i64(nullable=False), boolean()]),
struct=struct(types=[i64(nullable=False), boolean()], nullable=False),
)

table = read_named_table("example_table", ns)
table = filter(table, expression=column("is_applicable"))
table = filter(
table,
expression=scalar_function(
"functions_comparison.yaml",
"extension:io.substrait:functions_comparison",
"lt",
expressions=[column("id"), literal(100, i64())],
expressions=[column("id"), literal(100, i64(nullable=False))],
),
)
table = project(table, expressions=[column("id")])
Expand All @@ -190,7 +194,8 @@ def advanced_example():
string(nullable=False), # name
i64(nullable=False), # age
fp64(nullable=False), # salary
]
],
nullable=False,
),
)

Expand All @@ -200,7 +205,7 @@ def advanced_example():
adult_users = filter(
users,
expression=scalar_function(
"functions_comparison.yaml",
"extension:io.substrait:functions_comparison",
"gt",
expressions=[column("age"), literal(25, i64())],
),
Expand All @@ -216,7 +221,7 @@ def advanced_example():
column("salary"),
# Add a calculated field (this would show function options if available)
scalar_function(
"functions_arithmetic.yaml",
"extension:io.substrait:functions_arithmetic",
"multiply",
expressions=[column("salary"), literal(1.1, fp64())],
alias="salary_with_bonus",
Expand All @@ -238,7 +243,8 @@ def advanced_example():
i64(nullable=False), # order_id
fp64(nullable=False), # amount
string(nullable=False), # status
]
],
nullable=False,
),
)

Expand All @@ -248,7 +254,7 @@ def advanced_example():
high_value_orders = filter(
orders,
expression=scalar_function(
"functions_comparison.yaml",
"extension:io.substrait:functions_comparison",
"gt",
expressions=[column("amount"), literal(50.0, fp64())],
),
Expand Down Expand Up @@ -280,16 +286,16 @@ def expression_only_example():
print("=== Expression-Only Example ===")
# Show complex expression structure
complex_expr = scalar_function(
"functions_arithmetic.yaml",
"extension:io.substrait:functions_arithmetic",
"multiply",
expressions=[
scalar_function(
"functions_arithmetic.yaml",
"extension:io.substrait:functions_arithmetic",
"add",
expressions=[
column("base_salary"),
scalar_function(
"functions_arithmetic.yaml",
"extension:io.substrait:functions_arithmetic",
"multiply",
expressions=[
column("base_salary"),
Expand All @@ -299,7 +305,7 @@ def expression_only_example():
],
),
scalar_function(
"functions_arithmetic.yaml",
"extension:io.substrait:functions_arithmetic",
"subtract",
expressions=[
literal(1.0, fp64()),
Expand All @@ -312,7 +318,8 @@ def expression_only_example():
print("Complex salary calculation expression:")
# Create a simple plan to wrap the expression
dummy_schema = named_struct(
names=["base_salary"], struct=struct(types=[fp64(nullable=False)])
names=["base_salary"],
struct=struct(types=[fp64(nullable=False)], nullable=False),
)
dummy_table = read_named_table("dummy", dummy_schema)
dummy_plan = project(dummy_table, expressions=[complex_expr])
Expand Down
8 changes: 3 additions & 5 deletions examples/duckdb_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from substrait.builders.extended_expression import column, scalar_function, literal
from substrait.builders.type import i32
from substrait.extension_registry import ExtensionRegistry
from substrait.json import dump_json
import pyarrow.substrait as pa_substrait

try:
Expand Down Expand Up @@ -42,14 +41,13 @@ def read_duckdb_named_table(name: str, conn):
table = filter(
table,
expression=scalar_function(
"functions_comparison.yaml",
"extension:io.substrait:functions_comparison",
"equal",
expressions=[column("c_nationkey"), literal(3, i32())],
),
)
table = project(
table, expressions=[column("c_name"), column("c_address"), column("c_nationkey")]
)

sql = f"CALL from_substrait_json('{dump_json(table(registry))}')"
print(duckdb.sql(sql))
sql = "CALL from_substrait(?)"
print(duckdb.sql(sql, params=[table(registry).SerializeToString()]))
2 changes: 1 addition & 1 deletion gen_proto.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ rm -rf "$dest_dir/proto"

# Generate the new python protobuf files
buf generate
protol --in-place --create-package --python-out "$dest_dir" buf
uv run protol --in-place --create-package --python-out "$dest_dir" buf
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was to make CI/CD not complain about missing protol


# Remove the old extension files
rm -rf "$extension_dir"
Expand Down
4 changes: 2 additions & 2 deletions src/substrait/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
except ImportError:
pass

__substrait_version__ = "0.74.0"
__substrait_hash__ = "793c64b"
__substrait_version__ = "0.77.0"
__substrait_hash__ = "3c25b1b"
__minimum_substrait_version__ = "0.30.0"
Loading