# TPC-DS Runner + Transplanted UDFs

## Setting up Spark

In [1]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("TPC-DS Loader").getOrCreate()

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
25/12/11 19:58:11 WARN Utils: Your hostname, LAPTOP-7ECU52TP, resolves to a loopback address: 127.0.1.1; using 10.255.255.254 instead (on interface lo)
25/12/11 19:58:11 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/12/11 19:58:13 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


## Setting Up Tables From Generated TPC-DS Data

We use sqlglot to read the DDL provided by TPC-DS and collect schema information. It is worth noting that sqlglot does not preserve information like primary/foreign key constraints or nullability of column values.

In [2]:
from sqlglot import parse_one, exp, parse
from pyspark.sql.types import (
	StructType, StructField, StringType, 
	DateType, TimestampType, IntegerType, 
	DecimalType
)

def to_spark_type(sql_type: str):
	if sql_type.startswith("CHAR"):
		return StringType()
	elif sql_type.startswith("VARCHAR"):
		return StringType()
	elif sql_type.startswith("INT"):
		return IntegerType()
	elif sql_type.startswith("DECIMAL"):
		return DecimalType(7, 2)
	elif sql_type.startswith("TIME"):
		return TimestampType()
	elif sql_type.startswith("DATE"):
		return DateType()
	else:
		return StringType()

ddl_path = "./tpcds-kit/tools/tpcds.sql"
with open(ddl_path) as f:
    sql = f.read()
    # trim leading comments
    sql = "\n".join([line for line in sql.split("\n") if not line.strip().startswith("--")])

schema = {}
for statement in parse(sql):
	if isinstance(statement, exp.Create):
		table_name = statement.this.this.this.this

		columns = []
		for col_def in statement.find_all(exp.ColumnDef):
			col_name = col_def.this.name
			col_type = col_def.args.get("kind").sql()
			spark_type = to_spark_type(col_type)
			columns.append(StructField(col_name, spark_type))

		schema[table_name] = StructType(columns)

all_column_types = set()
for table_cols in schema.values():
	for field in table_cols.fields:
		all_column_types.add(field.dataType)

print(len(schema))
print(all_column_types)

25
{TimestampType(), StringType(), DateType(), DecimalType(7,2), IntegerType()}


In [3]:
def create_table_from_data(table_name: str, data_path: str):
	df = spark.read \
		.option("delimiter", "|") \
		.option("timestampFormat", "HH:mm:ss") \
		.schema(schema[table_name]) \
		.csv(data_path)
	
	df.createOrReplaceTempView(table_name)



In [4]:
import glob
import os

data_dir = "./generated_data" # replace with your directory target used with dsdgen

for path in glob.glob(os.path.join(data_dir, "*.dat")):
	file_name = os.path.basename(path)
	table_name = file_name[:-4]

	create_table_from_data(table_name, path)

spark.sql("SHOW TABLES").count()

25/12/11 19:58:43 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                

25

## Running Generated TPC-DS Queries

In [8]:
from pathlib import Path
import time

sql_dir = Path("./processed_queries") # replace with your dir containing split queries.

runnable = []
total_queries = 0
for sql_file in sorted(sql_dir.glob("*.sql")):
	with open(sql_file, "r") as f:
		query = f.read()
	
	total_queries += 1
	if query:
		try:
			start_time = time.time()
			df = spark.sql(query)
			runnable.append({
				"query": query,
				"elapsed_time_without_udfs": time.time() - start_time,
			})
		except Exception as e:
			print(f"Error executing query from {sql_file}")

print(f"Total queries processed: {total_queries}")
print(f"Total runnable queries: {len(runnable)}")

Error executing query from processed_queries/query_102.sql
Error executing query from processed_queries/query_15.sql
Error executing query from processed_queries/query_18.sql
Error executing query from processed_queries/query_25.sql
Error executing query from processed_queries/query_26.sql
Error executing query from processed_queries/query_32.sql
Error executing query from processed_queries/query_33.sql
Error executing query from processed_queries/query_44.sql
Error executing query from processed_queries/query_45.sql
Error executing query from processed_queries/query_61.sql
Error executing query from processed_queries/query_65.sql
Error executing query from processed_queries/query_66.sql
Error executing query from processed_queries/query_68.sql
Error executing query from processed_queries/query_7.sql
Error executing query from processed_queries/query_8.sql
Error executing query from processed_queries/query_81.sql
Error executing query from processed_queries/query_89.sql
Error executing

## Transplanting UDFs

In [9]:
from sqlglot import parse_one, exp, parse
from sqlglot.schema import MappingSchema
from sqlglot.optimizer import optimize


with open("./tpcds-kit/tools/tpcds.sql") as f:
    sql = f.read()
    sql = "\n".join([line for line in sql.split("\n") if not line.strip().startswith("--")])

map_schema = {}
for statement in parse(sql):
    if isinstance(statement, exp.Create):
        table_name = statement.this.this.this.this

        columns = {}
        for col_def in statement.find_all(exp.ColumnDef):
            col_name = col_def.this.name
            col_type = col_def.args.get("kind").sql()
            columns[col_name] = col_type
        
        map_schema[table_name] = columns

schema_obj = MappingSchema(map_schema)

In [None]:
for q in runnable:
	query = q["query"]

	parsed = parse_one(query)
	optimized = optimize(parsed, schema=schema_obj)

	if isinstance(optimized, exp.CTE):
		main_query = optimized.this
	else:
		main_query = optimized

	for select in main_query.find_all(exp.Select):
	new_expressions = []

	for expr in select.expressions:
		# print(expr)
		# continue
		if isinstance(expr, exp.Alias) and isinstance(expr.this, exp.Column):
			# print(f"Column: {expr.this}, Alias: {expr.alias}")
			alias_name = expr.alias
			column_name = expr.this.name
			
			udf_node = exp.Anonymous(this="my_udf", expressions=[expr.this.copy()])
			alias_node = exp.Alias(this=udf_node, alias=alias_name)
			new_expressions.append(alias_node)
		else:
			new_expressions.append(expr)

	select.set("expressions", new_expressions)
	

KeyboardInterrupt: 