###Parser Driver

To run this driver use <b>dbutils.notebook.run</b> with the following parameters:
<p>

```
dbutils.notebook.run(
    "../parser_driver", 
    0, 
    {
        "sql": "<query string>", 
        "endpoint": "<name of llm endpoint>",
        "parse_log_table": "<full namespace of table to write results to>",
        "column_chunk_size": "<number of columns per chunk>",
        "total_stages": <total number of states (1,2, or 3)>
    }
)
```

##### Stages
- Stage 1 - Parse all of the sql
- Stage 2 - Parse all of the sql and convert it with the LLM
- State 3 - Parse all of the sql,  convert it with the LLM, and reassemble it to a final query

##### Parse Log Table
Leave the parse log table as an empty string if you don't want to write results to a table, although this is <b>highly</b> recommended.  To get thew final parsed sql you can read it from the variable: <b>pretty_final</b> which gets set in stage 3.  That value is written into a cell below when the driver is run (<b>if statge 3 is run</b>).  The last row of the dataframe will also have the final_query in it.

In [0]:
%run ./sql_parser

In [0]:
# Example usage
default_sql = """
    SELECT
        e.EmployeeID,
        UPPER(e.Name) AS Name,
        d.Name AS Department,
        CASE
            WHEN e.Salary > (
                SELECT AVG(Salary)
                FROM Employees
                WHERE DepartmentID = e.DepartmentID
            ) THEN CONCAT('Above Average (', CAST(e.Salary AS VARCHAR), ')')
            ELSE 'Average or Below'
        END AS SalaryStatus,
        CASE
            WHEN rsm.investment_type = 'BL'
                AND NVL (psah.acrd_cd, 'N') NOT IN ('Y', 'V')
            THEN NVL (
                (SELECT wacoupon
                   FROM stg_wso_pos_acr_ame
                  WHERE portfolio_fund_id = psah.cal_dt
                        AND asofdate = psah.cal_dt
                        AND asset_primaryud = psah.asset_id
                        AND rec_typ_cd = 'POS'),
                0)
            ELSE psah.int_rt
        END AS pos_int_it,
        ep.ProjectsCompleted,
        YEAR(e.HireDate) AS HireYear,
        MONTH(e.HireDate) AS HireMonth
    FROM
        Employees e
        JOIN Departments d ON e.DepartmentID = d.DepartmentID
        LEFT JOIN EmpProjects ep ON e.EmployeeID = ep.EmployeeID
    WHERE
        e.EmployeeID IN (
            SELECT
                e2.EmployeeID
            FROM
                Employees e2
                JOIN Departments d2 ON e2.DepartmentID = d2.DepartmentID
                LEFT JOIN EmpProjects ep2 ON e2.EmployeeID = ep2.EmployeeID
            WHERE
                e2.Salary > (
                    SELECT AVG(Salary)
                    FROM Employees
                    WHERE DepartmentID = e2.DepartmentID
                )
        )
    """

In [0]:
dbutils.widgets.text("sql", default_sql)
dbutils.widgets.text("endpoint", "databricks-claude-sonnet-4")
dbutils.widgets.text("parse_log_table", "users.paul_signorelli.sql_parsing_log")
dbutils.widgets.text("column_chunk_size", "5")
dbutils.widgets.text("total_stages", "3")

In [0]:
sql = dbutils.widgets.get("sql") 
endpoint = dbutils.widgets.get("endpoint") 
parse_log_table = dbutils.widgets.get("parse_log_table")
column_chunk_size = int(dbutils.widgets.get("column_chunk_size"))
total_stages = int(dbutils.widgets.get("total_stages"))

In [0]:
print(sql)

In [0]:
def stage1():
  initialize_empty_subquery_delta_table(table_name=parse_log_table)
  spark_df = subqueries_to_spark_dataframe(sql)
  spark_df_with_columns = extract_columns_and_replace_select(spark_df)
  write_subqueries_to_delta(spark_df_with_columns, table_name=parse_log_table)
  
  return spark_df_with_columns


In [0]:
def stage2(spark_df_with_columns):
  spark_df_converted = convert_sql(spark_df_with_columns, endpoint_name=endpoint)
  write_subqueries_to_delta(spark_df_converted, table_name=parse_log_table)

  return spark_df_converted


In [0]:
def stage3(spark_df_converted):
  spark_df_with_converted_columns = convert_sql_on_columns(spark_df_converted, chunk_size=column_chunk_size, endpoint_name=endpoint)
  write_subqueries_to_delta(spark_df_with_converted_columns, table_name=parse_log_table)
  
  return spark_df_with_converted_columns


In [0]:
if total_stages == 1:
  spark_df_with_columns = stage1()
  print("\n✅ Parsed subqueries\n")
  write_subqueries_to_delta(spark_df_with_columns, table_name=parse_log_table)

if total_stages == 2:
  spark_df_with_columns = stage1()
  spark_df_converted = stage2(spark_df_with_columns)
  print("\n✅ Coverted subqueries\n")
  write_subqueries_to_delta(spark_df_converted, table_name=parse_log_table)

if total_stages >= 3:
  spark_df_with_columns = stage1()
  spark_df_converted = stage2(spark_df_with_columns)
  spark_df_with_converted_columns = stage3(spark_df_converted)
  
  final_query = assemble_final_query_string(spark_df_with_converted_columns)
  
  pretty_final = prettify_final(final_query)
  df_final = append_final_query_row(spark_df_with_converted_columns, pretty_final)
  write_subqueries_to_delta(df_final, table_name=parse_log_table)

In [0]:
# NOTE: This might not exists if stage 3 was not run

try:
  print(f"✅ Reassembled Final Query:\n\n{pretty_final}")
except Exception:
  pass

In [0]:
# NOTE: This might fail without proper table permissions or is parse_log_table is blank

try:
  display(
    spark.sql(f"select * from {parse_log_table}")
  )
except Exception:
  pass