From f49727266beba790c3ea9b29923118e8355efbbd Mon Sep 17 00:00:00 2001 From: tokoko Date: Sat, 25 Oct 2025 07:36:07 +0000 Subject: [PATCH] feat: configurable registry for sql conversion --- src/substrait/sql/sql_to_substrait.py | 10 ++++++++-- tests/sql/test_sql_to_substrait.py | 7 +++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/substrait/sql/sql_to_substrait.py b/src/substrait/sql/sql_to_substrait.py index b11e385..40c14ef 100644 --- a/src/substrait/sql/sql_to_substrait.py +++ b/src/substrait/sql/sql_to_substrait.py @@ -333,7 +333,13 @@ def translate(ast: dict, schema_resolver: SchemaResolver, registry: ExtensionReg raise Exception(f"Unknown op {op}") -def convert(query: str, dialect: str, schema_resolver: SchemaResolver): +def convert( + query: str, + dialect: str, + schema_resolver: SchemaResolver, + registry: ExtensionRegistry = None, +): ast = parse_sql(sql=query, dialect=dialect)[0] - registry = ExtensionRegistry(load_default_extensions=True) + if not registry: + registry = ExtensionRegistry(load_default_extensions=True) return translate(ast, schema_resolver=schema_resolver, registry=registry) diff --git a/tests/sql/test_sql_to_substrait.py b/tests/sql/test_sql_to_substrait.py index 4b36431..9b4590e 100644 --- a/tests/sql/test_sql_to_substrait.py +++ b/tests/sql/test_sql_to_substrait.py @@ -1,3 +1,4 @@ +from substrait.extension_registry import ExtensionRegistry from substrait.sql.sql_to_substrait import convert import pyarrow from google.protobuf import json_format @@ -30,6 +31,8 @@ ] ) +registry = ExtensionRegistry(load_default_extensions=True) + def sort_arrow(table: pyarrow.Table): import pyarrow.compute as pc @@ -52,7 +55,7 @@ def df_schema_resolver(name: str): pa_schema = ctx.sql(f"SELECT * FROM {name} LIMIT 0").schema() return pa_substrait.serialize_schema(pa_schema).to_pysubstrait().base_schema - plan = convert(query, "generic", df_schema_resolver) + plan = convert(query, "generic", df_schema_resolver, registry) sql_arrow = ctx.sql(query).to_arrow_table() @@ -86,7 +89,7 @@ def duckdb_schema_resolver(name: str): conn.register("stores", data) conn.register("sales", sales_data) - plan = convert(query, "duckdb", duckdb_schema_resolver) + plan = convert(query, "duckdb", duckdb_schema_resolver, registry) conn.install_extension("substrait", repository="community") conn.load_extension("substrait")