In [0]:
%pip install sqlparse
dbutils.library.restartPython()

In [0]:
import sqlparse

def prettify_final(query_string: str):
    # Format with sqlparse (keeps <<>> for any missing)
    prettified_value = sqlparse.format(query_string, reindent=True, keyword_case='upper')
    return prettified_value

In [0]:
def strip_comments(sql: str) -> str:
    # Remove multiline comments like /* ... */
    sql = re.sub(r'/\*.*?\*/', '', sql, flags=re.DOTALL)

    # Remove inline and full-line comments starting with --
    sql = re.sub(r'--[^\n\r]*', '', sql)

    return sql

In [0]:
import re

def extract_cte_block(sql_text: str):
    """
    Extracts the full WITH clause including all CTEs and returns:
    - the full CTE block (starting with WITH)
    - the remainder of the query (starting after the last closing paren of the CTE block)
    """
    sql_text = sql_text.strip()
    if not sql_text[:4].upper() == "WITH":
        return None, sql_text

    depth = 0
    i = 0
    while i < len(sql_text):
        if sql_text[i] == '(':
            depth += 1
        elif sql_text[i] == ')':
            depth -= 1
        elif sql_text[i:i+6].upper() == 'SELECT' and depth == 0 and i > 4:
            # Found a top-level SELECT not inside any CTE block
            break
        i += 1

    # Keep scanning until we get out of the last CTE's final parenthesis
    while i < len(sql_text) and depth > 0:
        if sql_text[i] == '(':
            depth += 1
        elif sql_text[i] == ')':
            depth -= 1
        i += 1

    # Move forward to the first SELECT after the CTE block
    remaining_sql = sql_text[i:].lstrip()
    select_match = re.match(r'(?is)^SELECT\b', remaining_sql)
    if select_match:
        return sql_text[:i].strip(), remaining_sql
    else:
        # Could not find SELECT after WITH block, return full input
        return sql_text, ""


def extract_top_level_selects_with_joins(sql_text: str):
    """
    Extracts:
    - Top-level SELECT statements
    - UNION, UNION ALL, INTERSECT, EXCEPT as separate sections
    Ignores subqueries (SELECTs inside parentheses).
    """
    result = []
    buffer = ''
    depth = 0
    i = 0
    length = len(sql_text)

    while i < length:
        # Handle strings safely
        if sql_text[i] in ["'", '"']:
            quote_char = sql_text[i]
            buffer += quote_char
            i += 1
            while i < length and sql_text[i] != quote_char:
                buffer += sql_text[i]
                i += 1
            if i < length:
                buffer += quote_char
                i += 1
            continue

        # Detect UNION keywords at depth 0
        if depth == 0:
            union_match = re.match(r'\s*(UNION\s+ALL|UNION|INTERSECT|EXCEPT)\b', sql_text[i:], re.IGNORECASE)
            if union_match:
                if buffer.strip():
                    result.append(buffer.strip())
                result.append(union_match.group(1).upper())
                i += union_match.end()
                buffer = ''
                continue

            if sql_text[i:i+6].upper() == 'SELECT':
                if buffer.strip():
                    result.append(buffer.strip())
                buffer = 'SELECT'
                i += 6
                continue

        char = sql_text[i]
        if char == '(':
            depth += 1
        elif char == ')':
            depth -= 1
        buffer += char
        i += 1

    if buffer.strip():
        result.append(buffer.strip())

    # Filter out any block that is a subquery like "(SELECT ..."
    return [s for s in result if not s.lstrip().startswith('(')]

def split_main_sections_with_classification(sql: str):
    """
    Returns a list of dicts:
    [
        { "type": 'cte' | 'select' | 'join', "id": '1', "content": str },
        ...
    ]
    """
    sql = strip_comments(sql)
    cte_block, remainder = extract_cte_block(sql)
    sections = []
    id_counter = 1

    if cte_block:
        sections.append({
            "type": "cte",
            "section_id": str(id_counter),
            "content": cte_block
        })
        id_counter += 1

    for part in extract_top_level_selects_with_joins(remainder):
        stripped = part.strip().upper()
        if stripped.startswith('SELECT'):
            section_type = "select"
        elif stripped in ('UNION', 'UNION ALL', 'INTERSECT', 'EXCEPT'):
            section_type = "join"
        else:
            section_type = "unknown"

        sections.append({
            "type": section_type,
            "section_id": str(id_counter),
            "content": part
        })
        id_counter += 1

    return sections



In [0]:
def extract_and_replace_subqueries(sql_query, section:str="default", section_id:int=0):
    def parse(sql, base_idx=1):
        subqueries = []
        result = ""
        i = 0
        n = len(sql)
        subquery_counter = base_idx

        while i < n:
            # Copy normal text
            if sql[i] != '(':
                result += sql[i]
                i += 1
                continue

            # We hit a '(', skip whitespace to see if SELECT
            j = i + 1
            while j < n and sql[j].isspace():
                j += 1

            if j + 5 <= n and sql[j:j+6].lower() == 'select':
                # We found a (SELECT...
                paren_count = 1
                k = j + 6
                buffer = '(' + sql[j:j+6]

                while k < n and paren_count > 0:
                    buffer += sql[k]
                    if sql[k] == '(':
                        paren_count += 1
                    elif sql[k] == ')':
                        paren_count -= 1
                    k += 1

                # Recursive replacement inside this subquery
                inner_sql, inner_subs, next_counter = parse(buffer[1:-1], subquery_counter)
                subqueries.extend(inner_subs)

                # Store the current subquery properly as (placeholder, content)
                placeholder = f"<<subquery_{section}[{section_id}]_{next_counter}>>"
                subqueries.append((placeholder, inner_sql))  # <-- fix here
                result += f"({placeholder})"
                subquery_counter = next_counter + 1
                i = k
            else:
                # Just a normal '('
                result += sql[i]
                i += 1

        return result, subqueries, subquery_counter

    clean_sql = strip_comments(sql_query)
    pretty_sql = prettify_final(clean_sql)
    # print(f"pretty:\n{pretty_sql}\n")
    final_sql, all_subqueries, _ = parse(pretty_sql)
    return final_sql, all_subqueries



In [0]:
def process_sections_with_subquery_extraction(sql: str):
    """
    Runs extract_and_replace_subqueries on each 'cte' and 'select' section.
    Keeps 'join' sections unchanged.

    Returns a list of dicts:
    [
        {
            "type": "cte" | "select" | "join",
            "content": "modified content",
            "subqueries": [ (placeholder, original_subquery), ... ]
        },
        ...
    ]
    """
    sections = split_main_sections_with_classification(sql)
    processed = []

    for section in sections:
        if section["type"] in ("cte", "select"):
            replaced_sql, subqueries = extract_and_replace_subqueries(
                section["content"], 
                section=section["type"],
                section_id=section["id"]
            )
            processed.append({
                "section_id": section["id"],
                "type": section["type"],
                "content": replaced_sql,
                "subqueries": subqueries
            })
        else:  # e.g., join, unknown
            processed.append({
                "section_id": section["id"],
                "type": section["type"],
                "content": section["content"],
                "subqueries": []
            })

    return processed


In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, IntegerType

# Initialize SparkSession if needed
spark = SparkSession.builder.appName("SQLSubqueries").getOrCreate()

def get_schema():
    schema = StructType([
        StructField("section", StringType(), True),
        StructField("section_id", IntegerType(), True),
        StructField("type", StringType(), True),
        StructField("identifier", StringType(), True),
        StructField("original_query", StringType(), True),
        StructField("final_query", StringType(), True),
        StructField("content", StringType(), True),
        StructField("columns", ArrayType(StringType()), True),
        StructField("converted_content", StringType(), True),
        StructField("converted_columns", ArrayType(StringType()), True),
        StructField("response_error", StringType(), True)
    ])
    return schema

def initialize_empty_subquery_delta_table(table_name):
    """
    Creates or overwrites a Delta table with the expected schema but no data.
    """

    if table_name is None or len(table_name.strip()) == 0 or table_name == "":
        return
    
    # Create empty DataFrame
    empty_df = spark.createDataFrame([], get_schema())

    # Overwrite the Delta table
    spark.sql(f"drop table if exists {table_name}")
    empty_df.write.format("delta").mode("overwrite").saveAsTable(table_name)
  
def subqueries_to_spark_dataframe(sql_query: str):
    from pyspark.sql import SparkSession
    spark = SparkSession.builder.getOrCreate()

    sections = split_main_sections_with_classification(sql_query)
    # sections = process_sections_with_subquery_extraction(sql_query)
    all_records = []

    for idx, section in enumerate(sections):
        modified_sections = []
        section_type = section["type"]
        section_id = int(section["section_id"])
        
        if section_type in ("cte", "select"):
            modified_sql, subqueries = extract_and_replace_subqueries(section["content"], section=section_type, section_id=section_id)
            modified_sections.append(modified_sql)
            print(f"{section['content']}")

            for ident, content in subqueries:
                print(f"appending {content}")
                all_records.append({
                    "section": section_type,
                    "section_id": section_id,
                    "type": "subquery",
                    "identifier": ident.replace("<<", "").replace(">>", ""),
                    "original_query": None,
                    "final_query": "",
                    "content": content,
                    "columns": [],
                    "converted_content": "",
                    "converted_columns": [],
                    "response_error": ""
                })
        else:
            # join sections
            all_records.append({
                    "section": section_type,
                    "section_id": section_id,
                    "type": "",
                    "identifier": "",
                    "original_query": None,
                    "final_query": "",
                    "content": section["content"],
                    "columns": [],
                    "converted_content": "",
                    "converted_columns": [],
                    "response_error": ""
                })
            modified_sections.append(section["content"])

        final_query = "\n".join(modified_sections)
        # Add the top-level reconstructed query
        
        all_records.append({
            "section": section_type,
            "section_id": section_id,
            "type": "section_main",
            "identifier": f"section_{section_id}_main",
            "original_query": section["content"],
            "final_query": "",
            "content": final_query,
            "columns": [],
            "converted_content": "",
            "converted_columns": [],
            "response_error": ""
        })

        # # Add the top-level reconstructed query
        # all_records.append({
        #     "section": "main",
        #     "section_id": 0,
        #     "type": "main",
        #     "identifier": "main",
        #     "original_query": sql_query,
        #     "final_query": "",
        #     "content": final_query,
        #     "columns": [],
        #     "converted_content": "",
        #     "converted_columns": [],
        #     "response_error": ""
        # })

    schema = get_schema()
    return spark.createDataFrame(all_records, schema)

def write_subqueries_to_delta(df, table_name: str):
    """
    Writes the Spark DataFrame of subqueries and main query
    to a Delta table in overwrite mode.
    """
    if table_name is None or len(table_name.strip()) == 0 or table_name == "":
        return
    df.write.mode("overwrite").format("delta").saveAsTable(table_name)


In [0]:
from pyspark.sql.types import StructType, StructField, StringType, ArrayType

def extract_columns_and_replace_select(df):
    """
    Takes a Spark DataFrame of queries, extracts columns from each SELECT,
    and replaces SELECT list in every query (subquery and main) with '*'.
    Returns a new DataFrame with an added `columns` array field.
    """
    import re

    def get_columns_and_rewrite(query):
        # crude regex to get columns between SELECT and FROM
        pattern = re.compile(r"(select)(.*?)(from)", re.IGNORECASE | re.DOTALL)
        match = pattern.search(query)
        if match:
            cols_section = match.group(2)
            parts = re.split(r",(?![^(]*\))", cols_section)
            columns = [col.strip() for col in parts if col.strip()]
            # replace SELECT columns with SELECT *
            modified_query = pattern.sub(r"\1 * \3", query)
            return columns, modified_query
        else:
            return [], query

    rows = df.collect()
    records = []

    for row in rows:
        columns, modified_content = get_columns_and_rewrite(row['content'])
        records.append({
            "section": row['section'],
            "section_id": row["section_id"],
            "type": row['type'],
            "identifier": row['identifier'],
            "original_query": row['original_query'],
            "final_query": "",
            "content": modified_content,
            "columns": columns,
            "converted_content": "",
            "converted_columns": [],
            "response_error": ""            
        })

    schema = get_schema()
    return spark.createDataFrame(records, schema)



In [0]:
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import ChatMessage, ChatMessageRole
import traceback

def convert_query_to_databricks_sql(query: str, endpoint_name: str = "databricks-claude-sonnet-4"):
    if endpoint_name is None or len(endpoint_name.strip()) == 0:
        print("\n✅ LLM endpoint disabled\n")
        print(f"converted qry = {query}")
        return query
    w = WorkspaceClient()  # Initialize without timeout parameter, set timeout if supported later
    response = w.serving_endpoints.query(
        # name="databricks-claude-3-7-sonnet",
        name=endpoint_name,
        # name="llama-70b-code-converstion",
        messages=[
            ChatMessage(
                role=ChatMessageRole.SYSTEM, content="You are a helpful assistant."
            ),
            ChatMessage(
                role=ChatMessageRole.USER, content=f"Please covert the following Oracle SQL query to Databricks SQL. Just return the query, no other content, including ```sql. If you see any sql that is wrapped in << >>, for example <<subquery_1>>, assume it is valid sql and leave it as is.  I need a complete conversion, do not skip any lines:\n{query}"
            ),
        ]
    )
    response = response.choices[0].message.content.strip()
    return response

In [0]:
def convert_sql(df, endpoint_name="databricks-claude-sonnet-4"):
    """
    Takes a Spark DataFrame of queries, loops through each row,
    calls convert_query_to_databricks_sql to convert it,
    and returns a new Spark DataFrame with an added converted_content field.
    """
    rows = df.collect()
    records = []

    for row in rows:
        if row['section'] == "join":
            # Row(section='join', section_id=3, type='', identifier='', original_query=None, final_query='', content='UNION ALL', columns=[], converted_content='', converted_columns=[], response_error='')
            records.append({
                "section": row['section'],
                "section_id": row["section_id"],
                "type": row["type"],
                "identifier": row["identifier"],
                "original_query": row["original_query"],
                "final_query": row["final_query"],
                "content": row['content'],
                "columns": [],
                "converted_content": "",
                "converted_columns": [],
                "response_error": ""
            })
            continue
        
        # call your existing function for conversion
        # llm_response = convert_query_to_databricks_sql(row['content'], endpoint_name=endpoint_name)
        # converted_sql = llm_response.choices[0].message.content.strip()
        response_error = ""
        try:
            converted_sql = convert_query_to_databricks_sql(row['content'], endpoint_name=endpoint_name)
        except Exception as e:
            converted_sql = ""
            response_error = f"{type(e).__name__}: {str(e)}\n{traceback.format_exc(limit=2)}"

        records.append({
            "section": row['section'],
            "section_id": row["section_id"],
            "type": row['type'],
            "identifier": row['identifier'],
            "original_query": row['original_query'],
            "final_query": "",
            "content": row['content'],
            "columns": row['columns'],
            "converted_content": converted_sql,
            "converted_columns": [],
            "response_error": response_error
        })

    from pyspark.sql.types import StructType, StructField, StringType, ArrayType
    schema = get_schema()

    return spark.createDataFrame(records, schema)


In [0]:
def convert_sql_on_columns(df, chunk_size=10, endpoint_name="databricks-claude-sonnet-4"):
    """
    Loops through the dataframe, chunks the columns, calls convert_query_to_databricks_sql
    on a dummy SELECT, and aggregates converted columns into `converted_columns`.
    Returns a new Spark DataFrame.
    """
    import re
    rows = df.collect()
    records = []

    for row in rows:
        if row['section'] == "join":
            # build the new record
            records.append({
                "section": row['section'],
                "section_id": row["section_id"],
                "type": row["type"],
                "identifier": row["identifier"],
                "original_query": row["original_query"],
                "final_query": row["final_query"],
                "content": row['content'],
                "columns": [],
                "converted_content": "",
                "converted_columns": [],
                "response_error": ""
            })
            continue

        converted_columns = []

        if row['columns']:
            # break columns into chunks
            col_chunks = [row['columns'][i:i+chunk_size] for i in range(0, len(row['columns']), chunk_size)]
            print(f"col_chunks:\n{col_chunks}")

            for chunk in col_chunks:
                dummy_query = f"SELECT {', '.join(chunk)} FROM dummy"
                print(f"dummy_query = {dummy_query}")

                # call your existing function exactly as you wrote it
                # response = convert_query_to_databricks_sql(dummy_query, endpoint_name)
                # converted_sql = response.choices[0].message.content.strip()
                response_error = ""
                try:
                    print(f"endpoint_name = {endpoint_name}")
                    converted_sql = convert_query_to_databricks_sql(dummy_query, endpoint_name=endpoint_name)
                except Exception as e:
                    converted_sql = ""
                    response_error = f"{type(e).__name__}: {str(e)}\n{traceback.format_exc(limit=2)}"
                    continue
                # extract columns between SELECT and FROM
                match = re.search(r"(?is)select\s+(.*?)(?:\s+from\b|$)", converted_sql)
                print(f"converted_sql = {converted_sql}")
                if match:
                    cols_section = match.group(1)
                    print(f"cols_section = {cols_section}")
                    cols = [c.strip() for c in re.split(r",(?![^(]*\))", cols_section) if c.strip()]
                    converted_columns.extend(cols)

        # build the new record
        records.append({
            "section": row['section'],
            "section_id": row["section_id"],
            "type": row['type'],
            "identifier": row['identifier'],
            "original_query": row['original_query'],
            "final_query": "",
            "content": row['content'],
            "columns": row['columns'],
            "converted_content": row['converted_content'],
            "converted_columns": converted_columns,
            "response_error": response_error
        })

    return spark.createDataFrame(records, get_schema())


In [0]:
# def assemble_final_query_string(converted_df):
#     """
#     - Replaces SELECT * with converted columns if present.
#     - Recursively replaces <<subquery_x>> placeholders across all rows.
#     - Returns final fully assembled main query string.
#     """

#     import re
#     rows = converted_df.collect()
#     query_map = {}

#     # Build initial map with SELECT * replaced by converted columns if present
#     for row in rows:
#         name = row['identifier']
#         query_text = row['converted_content'] if row['converted_content'] else row['content']

#         if row['columns'] and row['converted_columns']:
#             all_converted_cols = ", ".join(row['converted_columns'])
#             query_text = re.sub(
#                 r'(?i)(select\s+)\*',
#                 lambda m: m.group(1) + all_converted_cols,
#                 query_text,
#                 count=1
#             )

#         query_map[name] = query_text

#     # Recursive placeholder replacement until stable
#     changed = True
#     while changed:
#         changed = False
#         for name, text in query_map.items():
#             original_text = text
#             for sub_name, sub_text in query_map.items():
#                 if sub_name != name:
#                     text = text.replace(f"<<{sub_name}>>", sub_text)
#             if text != original_text:
#                 changed = True
#             query_map[name] = text

#     # Return the final resolved main query
#     return query_map.get('main', '')


In [0]:
import ast

def assemble_final_query_string(df):
    """
    - Replaces SELECT * with converted columns if available.
    - Resolves subquery placeholders within each section_main.
    - Assembles the full final query using ordered section_main parts and join parts.
    """
    rows = df.collect()
    query_map = {}

    # Step 1: Build map of identifier -> resolved content
    for row in rows:
        identifier = row['identifier']
        content = row['converted_content'] if row['converted_content'] else row['content']
        if not isinstance(content, str):
            content = ""

        if row['columns'] and row['converted_columns']:
            try:
                columns = ast.literal_eval(row['columns']) if isinstance(row['columns'], str) else row['columns']
                converted_columns = ast.literal_eval(row['converted_columns']) if isinstance(row['converted_columns'], str) else row['converted_columns']
                if columns and converted_columns:
                    converted_cols_str = ", ".join(converted_columns)
                    content = re.sub(r'(?i)(select\s+)\*', lambda m: m.group(1) + converted_cols_str, content, count=1)
            except Exception as e:
                # print(f"e = {e}")
                pass

        query_map[identifier] = content

    # print(f"query_map = {query_map}")

    # Step 2: Resolve all <<subquery_x>> placeholders recursively
    changed = True
    while changed:
        changed = False
        for name, text in query_map.items():
            original_text = text
            for sub_name, sub_text in query_map.items():
                print(f"sub_name = {sub_name}, sub_text = {sub_text}")
                if sub_name != name and isinstance(sub_text, str):
                    text = text.replace(f"<<{sub_name}>>", sub_text)
            if text != original_text:
                changed = True
            query_map[name] = text

    # print(f"query_map_relplaced = {query_map}")

    # Step 3: Assemble all section_main and join entries in order
    df_main = df.filter(df["type"].isin(["section_main", "join"]))
    df_main = df_main.withColumn("section_id", df_main["section_id"].cast("int"))
    ordered_rows = df_main.sort("section_id").collect()

    # print(f"ordered_rows = {ordered_rows}")
    
    final_parts = []
    for row in ordered_rows:
        identifier = row['identifier']
        resolved = query_map.get(identifier, row['converted_content'] or row['content'])
        final_parts.append(resolved.strip() if isinstance(resolved, str) else "")

    return "\n\n".join(final_parts)


In [0]:
def append_final_query_row(df, final_query_str):
    """
    Appends a new row with identifier 'final_query' and the provided final query string
    into the 'final_query' column. Other fields set to None or empty.
    Assumes DataFrame has the schema with all expected columns.
    """
    from pyspark.sql.types import StructType, StructField, StringType, ArrayType

    schema = StructType([
        StructField("section", StringType(), True),
        StructField("section_id", IntegerType(), True),
        StructField("type", StringType(), True),
        StructField("identifier", StringType(), True),
        StructField("original_query", StringType(), True),
        StructField("final_query", StringType(), True),
        StructField("content", StringType(), True),
        StructField("columns", ArrayType(StringType()), True),
        StructField("converted_content", StringType(), True),
        StructField("converted_columns", ArrayType(StringType()), True),
        StructField("response_error", StringType(), True)
    ])

    # build single new row
    new_row_data = [{
        "section": "final_query",
        "section_id": 0,
        "type": "final",
        "identifier": "final_query",
        "original_query": None,
        "final_query": final_query_str,
        "content": None,
        "columns": [],
        "converted_content": None,
        "converted_columns": [],
        "response_error": None
    }]
    new_row_df = spark.createDataFrame(new_row_data, schema)

    # union with existing DataFrame
    final_df = df.unionByName(new_row_df)

    return final_df


In [0]:
# import random

# def generate_complex_query():
#     base_query = """
# WITH
# DeptStats AS (
#     SELECT
#         DepartmentID,
#         SUM(Salary) AS TotalDeptSalary,
#         COUNT(*) AS NumEmployees,
#         AVG(Salary) AS AvgDeptSalary
#     FROM
#         Employees
#     GROUP BY
#         DepartmentID
# ),
# EmpProjects AS (
#     SELECT
#         EmployeeID,
#         COUNT(ProjectID) AS ProjectsCompleted,
#         MAX(CompletedDate) AS LastProjectDate,
#         YEAR(MAX(CompletedDate)) AS LastProjectYear
#     FROM
#         Projects
#     WHERE
#         Status = 'Completed'
#     GROUP BY
#         EmployeeID
# )

# SELECT
#     e.EmployeeID,
#     UPPER(e.Name) AS Name,
#     d.Name AS Department,
#     CASE
#         WHEN e.Salary > (
#             SELECT AVG(Salary)
#             FROM Employees
#             WHERE DepartmentID = e.DepartmentID
#         ) THEN CONCAT('Above Average (', CAST(e.Salary AS VARCHAR), ')')
#         ELSE 'Average or Below'
#     END AS SalaryStatus,
#     -- some remark here,
#     CASE
#         WHEN    rsm.investment_type = 'BL'
#             AND NVL (psah.acrd_cd, 'N') NOT IN ('Y', 'V') -- story 897300
#         THEN
#             NVL (
#                 (SELECT wacoupon
#                    FROM stg_wso_pos_acr_ame
#                   WHERE     portfolio_fund_id = psah.cal_dt
#                         AND asofdate = psah.cal_dt
#                         AND asset_primaryud = psah.asset_id
#                         AND rec_typ_cd = 'POS'),
#                 0)
#         ELSE
#             psah.int_rt
#     END
#         AS pos_int_it,  
#     ep.ProjectsCompleted,
#     YEAR(e.HireDate) AS HireYear,
#     MONTH(e.HireDate) AS HireMonth,
#     COALESCE(ep.LastProjectYear, 'N/A') AS LastProjectYear
# FROM
#     Employees e
#     JOIN Departments d ON e.DepartmentID = d.DepartmentID
#     LEFT JOIN EmpProjects ep ON e.EmployeeID = ep.EmployeeID
# WHERE
#     e.EmployeeID IN (
#         SELECT
#             e2.EmployeeID
#         FROM
#             Employees e2
#             JOIN Departments d2 ON e2.DepartmentID = d2.DepartmentID
#             LEFT JOIN EmpProjects ep2 ON e2.EmployeeID = ep2.EmployeeID
#         WHERE
#             e2.Salary > (
#                 SELECT AVG(Salary)
#                 FROM Employees
#                 WHERE DepartmentID = e2.DepartmentID
#             )
#     )
# UNION ALL
# SELECT
#     NULL AS EmployeeID,
#     NULL AS Name,
#     d.Name AS Department,
#     CONCAT('Department Total: ', CAST(ds.TotalDeptSalary AS VARCHAR)) AS SalaryStatus,
#     ds.NumEmployees AS ProjectsCompleted,
#     NULL AS HireYear,
#     NULL AS HireMonth,
#     NULL AS LastProjectYear,
# """

#     # Generate 599 columns, each randomly a literal or a CASE with subquery
#     cols = []
#     for i in range(1, 10):
#         rnd = random.random()
#         if i == 1:
#             col = (
#                 f"CASE WHEN ds.TotalDeptSalary > 100000 AND ds.TotalDeptSalary < 200000 AND NVL(a.emp_class, 'N') NOT IN ('A', 'B') -- story 1234\n"
#                 f"THEN NVL((SELECT COUNT(*) FROM Employees WHERE DepartmentID = ds.DepartmentID),0)\n"
#                 f"ELSE 0 END as col{i}"
#             )
#         else:
#             col = f"'col{i}' as col{i}"
#         cols.append(col)
#     cols_str = ",\n".join(cols)

#     # Assemble the final query
#     final_query = f"""
# {base_query}
# {cols_str}
# FROM
#     DeptStats ds
#     JOIN Departments d ON ds.DepartmentID = d.DepartmentID;
# """
#     return final_query

# query_string = generate_complex_query()
# print(query_string)


In [0]:
# parse_log_table="users.paul_signorelli.sql_parsing_log"
# endpoint_name=""
# initialize_empty_subquery_delta_table(table_name=parse_log_table)
# spark_df = subqueries_to_spark_dataframe(query_string)
# spark_df_with_columns = extract_columns_and_replace_select(spark_df)
# spark_df_converted = convert_sql(spark_df_with_columns, endpoint_name=endpoint_name)
# spark_df_with_converted_columns = convert_sql_on_columns(spark_df_converted, chunk_size=5, endpoint_name=endpoint_name)

# final_query = assemble_final_query_string(spark_df_with_converted_columns)
# pretty_final = prettify_final(final_query)

# df_final = append_final_query_row(spark_df_with_converted_columns, pretty_final)
# write_subqueries_to_delta(df_final, table_name=parse_log_table)
  

In [0]:
# display(
#   spark.sql(f"select * from {parse_log_table}")
# )