In [1]:
from sqlglot import parse_one, exp, parse
from sqlglot.schema import MappingSchema
from sqlglot.optimizer import optimize


# read from file tpc-ds.sql and parse all create table statements
with open("./tpc-ds.sql") as f:
    sql = f.read()
    # trim leading comments
    sql = "\n".join([line for line in sql.split("\n") if not line.strip().startswith("--")])

schema = {}
for statement in parse(sql):
    if isinstance(statement, exp.Create):
        table_name = statement.this.this.this.this

        columns = {}
        for col_def in statement.find_all(exp.ColumnDef):
            col_name = col_def.this.name
            col_type = col_def.args.get("kind").sql()
            columns[col_name] = col_type
        
        schema[table_name] = columns

schema_obj = MappingSchema(schema)

In [20]:
# read file content into query
with open("/Users/henry/CS230-Project/query1.sql") as f:
    query = f.read()
    # remove leading comments
    query = "\n".join([line for line in query.split("\n") if not line.strip().startswith("--")])

parsed = parse_one(query)
optimized = optimize(parsed, schema=schema_obj)
# print(optimized)

if isinstance(optimized, exp.CTE):
    main_query = optimized.this
else:
    main_query = optimized

# helper: choose a dummy udf name based on the column type
import inspect
from decimal import Decimal
from datetime import date, datetime
import sys
# ensure project root is on path when running in notebooks
sys.path.insert(0, "/Users/henry/CS230-Project")
import udf_insertion.dummy_udfs as dummy_udfs

def normalize_type_name(t: str) -> str:
    if not isinstance(t, str):
        return t
    return t.strip().upper()

def target_python_type_from_col_type(col_type: str):
    """Map SQL-type text (from the schema) to a token describing the target type.

    We return short tokens rather than Python classes so we can make
    finer-grained decisions (e.g. decimal with scale).
    """
    if not isinstance(col_type, str):
        return "unknown"
    t = normalize_type_name(col_type)

    # Decimal/numeric with precision/scale: extract scale if present
    if t.startswith("DECIMAL") or t.startswith("NUMERIC"):
        # try to parse DECIMAL(p,s)
        import re

        m = re.search(r"DECIMAL\s*\((\d+)\s*,\s*(\d+)\)", t)
        if m:
            precision = int(m.group(1))
            scale = int(m.group(2))
            return ("decimal", precision, scale)
        # fallback: unknown-scale decimal
        return ("decimal", None, None)

    # integers
    if any(k in t for k in ("INT", "INTEGER", "LONG", "BIGINT", "TINYINT", "SMALLINT")):
        return "int"

    # strings / char / varchar / text
    if any(k in t for k in ("CHAR", "STRING", "VARCHAR", "TEXT", "CHARACTER")):
        return "str"

    # dates and timestamps
    if "DATE" in t and "TIMESTAMP" not in t:
        return "date"
    if "TIMESTAMP" in t or "TIME" in t:
        return "datetime"

    # floating point
    if any(k in t for k in ("DOUBLE", "FLOAT", "REAL")):
        return "float"

    return "unknown"

def is_compatible_func(func, target_token):
    """Return True if `func` is a reasonable match for the desired target_token.

    `target_token` is the value returned by `target_python_type_from_col_type`.
    It may be a string like 'int'/'str'/'date'/'float' or a tuple for decimals
    like ("decimal", precision, scale).
    """
    try:
        sig = inspect.signature(func)
    except (ValueError, TypeError):
        return False

    params = [p for p in sig.parameters.values() if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)]
    required = [p for p in params if p.default is p.empty]
    # require at most one required positional param (the column)
    if len(required) > 1:
        return False

    # if function has a return annotation, prefer matching that
    ann = sig.return_annotation
    if ann is not inspect._empty:
        ann_name = getattr(ann, "__name__", str(ann)).lower()
        # decimal tuple handling
        if isinstance(target_token, tuple) and target_token[0] == "decimal":
            if "decimal" in ann_name or "str" in ann_name:
                return True
            return False
        # simple token cases
        if isinstance(target_token, str):
            if target_token == "int" and (ann is int or ann_name == "int"):
                return True
            if target_token == "str" and (ann is str or ann_name == "str"):
                return True
            if target_token == "float" and (ann is float or ann_name == "float"):
                return True
            if target_token == "date" and ann_name in ("date",):
                return True
            if target_token == "datetime" and ann_name in ("datetime",):
                return True
        # if annotation present but didn't match, avoid function
        return False

    # no return annotation: use name heuristics + parameter count
    name = func.__name__.lower()
    # decimal target
    if isinstance(target_token, tuple) and target_token[0] == "decimal":
        # prefer formatting helpers
        if "decimal" in name or "to_decimal" in name or "to_decimal_str" in name:
            return True
        # allow numeric helpers that accept one arg
        return len(required) <= 1

    if target_token == "int":
        if any(k in name for k in ("int", "add", "scale")):
            return True
    if target_token == "str":
        if any(k in name for k in ("str", "concat", "unescape", "regex", "replace")):
            return True
    if target_token == "date":
        if "date" in name or "timestamp" in name:
            return True
    if target_token == "datetime":
        if "timestamp" in name or "time" in name:
            return True
    if target_token == "float":
        if any(k in name for k in ("float", "double", "real")):
            return True

    # conservative fallback: allow functions that accept <=1 required arg
    return len(required) <= 1

def choose_udf_name_for_col_type(col_type: str) -> str:
    token = target_python_type_from_col_type(col_type)
    # prefer functions in the provided registry that match
    for f in dummy_udfs.all_simple_functions:
        if is_compatible_func(f, token):
            return f"dummy_{f.__name__}"
    # if we have a decimal with small scale, prefer to_decimal_str if present
    if isinstance(token, tuple) and token[0] == "decimal":
        for f in dummy_udfs.all_simple_functions:
            if f.__name__ == "to_decimal_str":
                return "dummy_to_decimal_str"
    # fallback: identity (no-op) function name
    return "dummy_identity"

for select in main_query.find_all(exp.Select):
    new_expressions = []

    for expr in select.expressions:
        if isinstance(expr, exp.Alias) and isinstance(expr.this, exp.Column):
            alias_name = expr.alias
            column_name = expr.this.name

            # try to read the resolved column type from the optimizer metadata (sqlglot stores a _type on Column nodes)
            try:
                col_type = expr.this._type.sql()
            except Exception:
                col_type = None

            udf_name = choose_udf_name_for_col_type(col_type) if col_type else 'dummy_identity'
            udf_node = exp.Anonymous(this=udf_name, expressions=[expr.this.copy()])
            alias_node = exp.Alias(this=udf_node, alias=alias_name)
            new_expressions.append(alias_node)
        else:
            new_expressions.append(expr)

    select.set("expressions", new_expressions)

print(optimized.sql(pretty=True))

WITH "customer_total_return" AS (
  SELECT
    DUMMY_KW_ONLY_SCALE("store_returns"."sr_customer_sk") AS ctr_customer_sk,
    DUMMY_KW_ONLY_SCALE("store_returns"."sr_store_sk") AS ctr_store_sk,
    SUM("store_returns"."sr_return_amt") AS "ctr_total_return"
  FROM "store_returns" AS "store_returns"
  JOIN "date_dim" AS "date_dim"
    ON "date_dim"."d_date_sk" = "store_returns"."sr_returned_date_sk"
    AND "date_dim"."d_year" = 2001
  GROUP BY
    "store_returns"."sr_customer_sk",
    "store_returns"."sr_store_sk"
), "_u_0" AS (
  SELECT
    AVG("ctr2"."ctr_total_return") * 1.2 AS "_col_0",
    DUMMY_KW_ONLY_SCALE("ctr2"."ctr_store_sk") AS _u_1
  FROM "customer_total_return" AS "ctr2"
  GROUP BY
    "ctr2"."ctr_store_sk"
)
SELECT
  DUMMY_TO_DECIMAL_STR("customer"."c_customer_id") AS c_customer_id
FROM "customer_total_return" AS "ctr1"
JOIN "store" AS "store"
  ON "ctr1"."ctr_store_sk" = "store"."s_store_sk" AND "store"."s_state" = 'TN'
JOIN "customer" AS "customer"
  ON "ctr1"."ctr_custo

### Explanation
This notebook parses a SQL schema and a query with `sqlglot`, runs the optimizer with the schema,
then walks the optimized AST and wraps aliased column expressions in an anonymous UDF call while preserving aliases.
The next cell demonstrates how to use a configurable `udf_name` and shows a before/after example on a small query.

In [13]:
# Demo: configurable udf_name and small before/after example
from sqlglot import parse_one, exp
from sqlglot.optimizer import optimize
from sqlglot.schema import MappingSchema

def wrap_columns_with_udf_in_sql(query, udf_name, schema_obj):
    parsed = parse_one(query)
    optimized = optimize(parsed, schema=schema_obj)
    if isinstance(optimized, exp.CTE):
        main_query = optimized.this
    else:
        main_query = optimized
    for select in main_query.find_all(exp.Select):
        new_expressions = []
        for expr in select.expressions:
            if isinstance(expr, exp.Alias) and isinstance(expr.this, exp.Column):
                udf_node = exp.Anonymous(this=udf_name, expressions=[expr.this.copy()])
                alias_node = exp.Alias(this=udf_node, alias=expr.alias)
                new_expressions.append(alias_node)
            else:
                new_expressions.append(expr)
        select.set("expressions", new_expressions)
    return optimized.sql(pretty=True)

# create a tiny schema for the demo (doesn't depend on tpc-ds)
demo_schema = MappingSchema({'my_table': {'a': 'INT', 'b': 'STRING'}})

sample_query = 'SELECT a AS a_alias, b FROM my_table;'
print('--- Original SQL ---')
print(sample_query)
print('\n--- Transformed SQL (udf_name=my_udf) ---')
print(wrap_columns_with_udf_in_sql(sample_query, udf_name='my_udf', schema_obj=demo_schema))
# print('--- Transformed SQL (udf_name=custom_fn) ---')
# print(wrap_columns_with_udf_in_sql(sample_query, udf_name='custom_fn', schema_obj=demo_schema))

--- Original SQL ---
SELECT a AS a_alias, b FROM my_table;

--- Transformed SQL (udf_name=my_udf) ---
SELECT
  MY_UDF("my_table"."a") AS a_alias,
  MY_UDF("my_table"."b") AS b
FROM "my_table" AS "my_table"
