In [21]:
from sqlglot import parse_one, exp, parse
from sqlglot.schema import MappingSchema
from sqlglot.optimizer import optimize


# read from file tpc-ds.sql and parse all create table statements
with open("./tpc-ds.sql") as f:
    sql = f.read()
    # trim leading comments
    sql = "\n".join([line for line in sql.split("\n") if not line.strip().startswith("--")])

schema = {}
for statement in parse(sql):
    if isinstance(statement, exp.Create):
        table_name = statement.this.this.this.this

        columns = {}
        for col_def in statement.find_all(exp.ColumnDef):
            col_name = col_def.this.name
            col_type = col_def.args.get("kind").sql()
            columns[col_name] = col_type
        
        schema[table_name] = columns

schema_obj = MappingSchema(schema)

In [48]:
# read file contint into query
with open("../grammar_counting/tpcds/query1.sql") as f:
	query = f.read()
	# remove leading comments
	query = "\n".join([line for line in query.split("\n") if not line.strip().startswith("--")])

parsed = parse_one(query)
optimized = optimize(parsed, schema=schema_obj)
print(optimized)

if isinstance(optimized, exp.CTE):
	main_query = optimized.this
else:
	main_query = optimized

for select in main_query.find_all(exp.Select):
	new_expressions = []

	for expr in select.expressions:
		# print(expr)
		# continue
		if isinstance(expr, exp.Alias) and isinstance(expr.this, exp.Column):
			# print(f"Column: {expr.this}, Alias: {expr.alias}")
			alias_name = expr.alias
			column_name = expr.this.name
			
			udf_node = exp.Anonymous(this="my_udf", expressions=[expr.this.copy()])
			alias_node = exp.Alias(this=udf_node, alias=alias_name)
			new_expressions.append(alias_node)
		else:
			new_expressions.append(expr)

	select.set("expressions", new_expressions)

print(optimized.sql(pretty=True))


WITH "customer_total_return" AS (SELECT "store_returns"."sr_customer_sk" AS "ctr_customer_sk", "store_returns"."sr_store_sk" AS "ctr_store_sk", SUM("store_returns"."sr_return_amt") AS "ctr_total_return" FROM "store_returns" AS "store_returns" JOIN "date_dim" AS "date_dim" ON "date_dim"."d_date_sk" = "store_returns"."sr_returned_date_sk" AND "date_dim"."d_year" = 2001 GROUP BY "store_returns"."sr_customer_sk", "store_returns"."sr_store_sk"), "_u_0" AS (SELECT AVG("ctr2"."ctr_total_return") * 1.2 AS "_col_0", "ctr2"."ctr_store_sk" AS "_u_1" FROM "customer_total_return" AS "ctr2" GROUP BY "ctr2"."ctr_store_sk") SELECT "customer"."c_customer_id" AS "c_customer_id" FROM "customer_total_return" AS "ctr1" JOIN "store" AS "store" ON "ctr1"."ctr_store_sk" = "store"."s_store_sk" AND "store"."s_state" = 'TN' JOIN "customer" AS "customer" ON "ctr1"."ctr_customer_sk" = "customer"."c_customer_sk" LEFT JOIN "_u_0" AS "_u_0" ON "_u_0"."_u_1" = "ctr1"."ctr_store_sk" WHERE "_u_0"."_col_0" < "ctr1"."ctr_