In [26]:
from sqlglot import parse_one, exp, parse
from sqlglot.schema import MappingSchema
from sqlglot.optimizer import optimize


# read from file tpc-ds.sql and parse all create table statements
with open("./tpc-ds.sql") as f:
    sql = f.read()
    # trim leading comments
    sql = "\n".join([line for line in sql.split("\n") if not line.strip().startswith("--")])

schema = {}
for statement in parse(sql):
    if isinstance(statement, exp.Create):
        table_name = statement.this.this.this.this

        columns = {}
        for col_def in statement.find_all(exp.ColumnDef):
            col_name = col_def.this.name
            col_type = col_def.args.get("kind").sql()
            columns[col_name] = col_type
        
        schema[table_name] = columns

schema_obj = MappingSchema(schema)

In [27]:
# read file content into query
with open("/Users/henry/CS230-Project/query1.sql") as f:
    query = f.read()
    # remove leading comments
    query = "\n".join([line for line in query.split("\n") if not line.strip().startswith("--")])

parsed = parse_one(query)
optimized = optimize(parsed, schema=schema_obj)
# print(optimized)

if isinstance(optimized, exp.CTE):
    main_query = optimized.this
else:
    main_query = optimized

# helper: choose a dummy udf name based on the column type
import inspect
from decimal import Decimal
from datetime import date, datetime
import sys
# ensure project root is on path when running in notebooks
sys.path.insert(0, "/Users/henry/CS230-Project")
import udf_insertion.dummy_udfs as dummy_udfs

def normalize_type_name(t: str) -> str:
    if not isinstance(t, str):
        return t
    return t.strip().upper()

def target_python_type_from_col_type(col_type: str):
    """Map SQL-type text (from the schema) to a token describing the target type.

    We return short tokens rather than Python classes so we can make
    finer-grained decisions (e.g. decimal with scale).
    """
    if not isinstance(col_type, str):
        return "unknown"
    t = normalize_type_name(col_type)

    # Decimal/numeric with precision/scale: extract scale if present
    if t.startswith("DECIMAL") or t.startswith("NUMERIC"):
        # try to parse DECIMAL(p,s)
        import re

        m = re.search(r"DECIMAL\s*\((\d+)\s*,\s*(\d+)\)", t)
        if m:
            precision = int(m.group(1))
            scale = int(m.group(2))
            return ("decimal", precision, scale)
        # fallback: unknown-scale decimal
        return ("decimal", None, None)

    # integers
    if any(k in t for k in ("INT", "INTEGER", "LONG", "BIGINT", "TINYINT", "SMALLINT")):
        return "int"

    # strings / char / varchar / text
    if any(k in t for k in ("CHAR", "STRING", "VARCHAR", "TEXT", "CHARACTER")):
        return "str"

    # dates and timestamps
    if "DATE" in t and "TIMESTAMP" not in t:
        return "date"
    if "TIMESTAMP" in t or "TIME" in t:
        return "datetime"

    # floating point
    if any(k in t for k in ("DOUBLE", "FLOAT", "REAL")):
        return "float"

    return "unknown"

def is_compatible_func(func, target_token):
    """Return True if `func` is a reasonable match for the desired target_token.

    This is a permissive heuristic used for choosing a dummy UDF. It prefers
    functions with at most one required positional argument and then uses
    return-annotation and name-based signals to bias selection.
    """
    try:
        sig = inspect.signature(func)
    except (ValueError, TypeError):
        return False

    params = [
        p
        for p in sig.parameters.values()
        if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
    ]
    required = [p for p in params if p.default is p.empty]
    # prefer functions with <= 1 required positional arg
    if len(required) > 1:
        return False

    # if function has a return annotation, prefer matching that
    ann = sig.return_annotation
    if ann is not inspect._empty:
        ann_name = getattr(ann, "__name__", str(ann)).lower()
        # decimal tuple handling
        if isinstance(target_token, tuple) and target_token[0] == "decimal":
            if "decimal" in ann_name or "str" in ann_name or "float" in ann_name:
                return True
            return False
        if isinstance(target_token, str):
            if target_token == "int" and (ann is int or ann_name == "int"):
                return True
            if target_token == "str" and (ann is str or ann_name == "str"):
                return True
            if target_token == "float" and (ann is float or ann_name == "float"):
                return True
            if target_token == "date" and "date" in ann_name:
                return True
            if target_token == "datetime" and "datetime" in ann_name:
                return True
            if target_token == "bool" and (ann is bool or "bool" in ann_name):
                return True
            if target_token == "bytes" and (ann is bytes or "bytes" in ann_name):
                return True
            if target_token == "list" and "list" in ann_name:
                return True
        # if annotation present but didn't match, avoid function
        return False

    # no return annotation: use name heuristics + parameter count
    name = func.__name__.lower()

    if isinstance(target_token, tuple) and target_token[0] == "decimal":
        if "decimal" in name or "to_decimal" in name or "to_decimal_str" in name:
            return True
        return len(required) <= 1

    if target_token == "int":
        if any(k in name for k in ("int", "add", "sum", "scale")):
            return True
    if target_token == "str":
        if any(k in name for k in ("str", "concat", "unescape", "regex", "replace", "base64")):
            return True
    if target_token == "date":
        if "date" in name or "timestamp" in name:
            return True
    if target_token == "datetime":
        if "timestamp" in name or "time" in name:
            return True
    if target_token == "float":
        if any(k in name for k in ("float", "double", "real", "to_float")):
            return True
    if target_token == "bool":
        if any(k in name for k in ("bool", "flag", "is_")):
            return True
    if target_token == "bytes":
        if any(k in name for k in ("bytes", "base64")):
            return True
    if target_token == "list":
        if any(k in name for k in ("list", "split", "map", "gen_id")):
            return True

    # conservative fallback: allow functions that accept <=1 required arg
    return len(required) <= 1


def choose_udf_name_for_col_type(col_type: str) -> str:
    """Choose a dummy udf name for a given SQL column type string.

    The function returns the string name of a dummy function prefixed with
    `dummy_` (matching how functions are referenced in the registry).
    """
    token = target_python_type_from_col_type(col_type)

    # preferred name substrings per token (ordered)
    prefs = {
        "int": ["sum_vars", "add", "scale"],
        "str": ["concat", "to_decimal_str", "unescape", "regex", "replace", "bytes_to_base64"],
        "date": ["timestamp_to_date", "date"],
        "datetime": ["timestamp_to_date", "time"],
        "float": ["to_float_or_none", "to_decimal_str", "float"],
        "decimal": ["to_decimal_str", "to_float_or_none"],
        "bool": ["bool_flag", "is_"],
        "bytes": ["bytes_to_base64", "base64"],
        "list": ["split_to_list", "map_to_upper", "gen_id_list"],
    }

    # First try: find a function that is compatible according to signature+annotation
    for f in dummy_udfs.all_simple_functions:
        try:
            if is_compatible_func(f, token):
                return f"dummy_{f.__name__}"
        except Exception:
            continue

    # Second pass: name-preference matching (ignore signature)
    key = None
    if isinstance(token, tuple) and token[0] == "decimal":
        key = "decimal"
    elif isinstance(token, str):
        key = token
    if key and key in prefs:
        for pname in prefs[key]:
            for f in dummy_udfs.all_simple_functions:
                if pname in f.__name__:
                    return f"dummy_{f.__name__}"

    # final fallbacks
    for f in dummy_udfs.all_simple_functions:
        if f.__name__ == "dummy_identity":
            return "dummy_identity"

    # absolute fallback
    return "dummy_identity"

In [31]:
# Small schema-aware transplant demo using the TPC-DS `item` table (more complex, with stdlib wrappers)
from sqlglot import parse_one, exp
from sqlglot.optimizer import optimize
import sys
# ensure package importable in notebooks
sys.path.insert(0, "/Users/henry/CS230-Project")
import udf_insertion.dummy_udfs as dummy_udfs
import html
import base64
import json
import re

# richer sample query using columns and expressions from the `item` table
sample_query = '''
SELECT
  i_item_id AS item_id,
  i_current_price AS price,
  i_item_desc AS description,
  i_brand_id + 100 AS brand_id_adj,
  CONCAT(i_item_id, '-', CAST(i_brand_id AS VARCHAR)) AS id_combo,
  i_wholesale_cost * 1.10 AS cost_plus,
  i_rec_start_date AS start_date,
  REGEXP_REPLACE(i_item_desc, '\\s+', ' ') AS desc_clean,
  UPPER(i_units) AS units_up,
  (i_manager_id IS NOT NULL) AS has_manager
FROM item;
'''

print('--- Original SQL ---')
print(sample_query)

# parse + optimize with the already-built schema_obj (from earlier cell)
parsed = parse_one(sample_query)
optimized = optimize(parsed, schema=schema_obj)

# Reuse chooser if present; otherwise provide a fallback chooser
try:
    choose = choose_udf_name_for_col_type
except NameError:
    def choose(col_type):
        if not col_type:
            return 'dummy_identity'
        t = col_type.upper()
        if 'DECIMAL' in t or 'NUMERIC' in t:
            return 'dummy_to_decimal_str' if any(f.__name__ == 'to_decimal_str' for f in dummy_udfs.all_simple_functions) else 'dummy_identity'
        if any(k in t for k in ('CHAR', 'STRING', 'VARCHAR', 'TEXT')):
            return 'dummy_concat' if any(f.__name__ == 'concat' for f in dummy_udfs.all_simple_functions) else 'dummy_identity'
        if any(k in t for k in ('INT', 'BIGINT', 'SMALLINT', 'TINYINT')):
            return 'dummy_add' if any(f.__name__ == 'add' for f in dummy_udfs.all_simple_functions) else 'dummy_identity'
        if any(k in t for k in ('DOUBLE','FLOAT','REAL')):
            return 'dummy_to_decimal_str' if any(f.__name__ == 'to_decimal_str' for f in dummy_udfs.all_simple_functions) else 'dummy_identity'
        return 'dummy_identity'

# We will allow some aliases to be wrapped with stdlib-like wrappers rather than dummy_* names.
# Map alias -> (module.function) to use as wrapper instead of a dummy function name.
alias_stdlib_overrides = {
    'description': 'html.unescape',
    'desc_clean': 're.sub',
    'id_combo': 'str.join',
    'units_up': 'str.upper',
    'price': 'float',
}

# apply schema-aware transplant: wrap aliased expressions that reference columns
if isinstance(optimized, exp.CTE):
    main_query = optimized.this
else:
    main_query = optimized

chosen_map = {}
used_stdlib = set()
for select in main_query.find_all(exp.Select):
    new_expressions = []
    for expr in select.expressions:
        # We're interested in Aliases (expression AS alias). For any aliased expression that
        # references at least one Column, choose a UDF based on the first referenced column's type
        if isinstance(expr, exp.Alias):
            aliased_expr = expr.this
            # find the first Column referenced inside the expression (if any)
            cols = list(aliased_expr.find_all(exp.Column))
            if cols:
                first_col = cols[0]
                try:
                    col_type = first_col._type.sql()
                except Exception:
                    col_type = None
                chosen = choose(col_type) if col_type else 'dummy_identity'
                # if the alias is in our stdlib override map, prefer that wrapper
                alias_name = expr.alias
                if alias_name in alias_stdlib_overrides:
                    stdlib_f = alias_stdlib_overrides[alias_name]
                    # record and use a readable wrapper name in the SQL AST (dot replaced with _)
                    wrapper_name = stdlib_f.replace('.', '_')
                    udf_node = exp.Anonymous(this=wrapper_name, expressions=[aliased_expr.copy()])
                    chosen_map[alias_name] = f"stdlib:{stdlib_f}"
                    used_stdlib.add(stdlib_f)
                else:
                    # normal dummy function wrapping (keep the chosen name)
                    udf_node = exp.Anonymous(this=chosen, expressions=[aliased_expr.copy()])
                    chosen_map[alias_name] = chosen
                alias_node = exp.Alias(this=udf_node, alias=alias_name)
                new_expressions.append(alias_node)
            else:
                new_expressions.append(expr)
        else:
            new_expressions.append(expr)
    select.set('expressions', new_expressions)

# print('\n--- Chosen UDF mapping (alias -> wrapper) ---')
# for k, v in chosen_map.items():
#     print(f"{k} -> {v}")

# if used_stdlib:
#     print('\n--- Stdlib wrappers used (module.function) ---')
#     for s in sorted(used_stdlib):
#         print(s)

print('\n--- Transformed SQL ---\n')
print(optimized.sql(pretty=True))

--- Original SQL ---

SELECT
  i_item_id AS item_id,
  i_current_price AS price,
  i_item_desc AS description,
  i_brand_id + 100 AS brand_id_adj,
  CONCAT(i_item_id, '-', CAST(i_brand_id AS VARCHAR)) AS id_combo,
  i_wholesale_cost * 1.10 AS cost_plus,
  i_rec_start_date AS start_date,
  REGEXP_REPLACE(i_item_desc, '\s+', ' ') AS desc_clean,
  UPPER(i_units) AS units_up,
  (i_manager_id IS NOT NULL) AS has_manager
FROM item;


--- Transformed SQL ---

SELECT
  DUMMY_TO_DECIMAL_STR("item"."i_item_id") AS item_id,
  FLOAT("item"."i_current_price") AS price,
  HTML_UNESCAPE("item"."i_item_desc") AS description,
  DUMMY_KW_ONLY_SCALE("item"."i_brand_id" + 100) AS brand_id_adj,
  STR_JOIN(CONCAT("item"."i_item_id", '-', CAST("item"."i_brand_id" AS VARCHAR))) AS id_combo,
  DUMMY_TO_DECIMAL_STR("item"."i_wholesale_cost" * 1.10) AS cost_plus,
  DUMMY_TIMESTAMP_TO_DATE("item"."i_rec_start_date") AS start_date,
  RE_SUB(REGEXP_REPLACE("item"."i_item_desc", '\s+', ' ')) AS desc_clean,
  STR_UPP