## Simple Query Parser

This parser will deconstuct a complex query into its sub-queries so each of those sub-queries can be separated out to be analyzed or converted.

It will look for SELECT, cte's defined by WITH, IN or EXISTS to discover sub-queries.

It will also can remove the columns if need be if a query has a long list of columns and replace them with a '*'.  This can be done to reduce query size for an LLM and analyze the columns separately.

The script will put all of this information into a collection which can be iterated through to call an LLM to analyze or migrate the code and then it can be pieced back together at the end.

So the following query:

```
with cte1 as (
        select cust_id, sum(sales) as sales_sum 
        from orders
        group by cust_id
    ),
    cte2 as (
        select cust_id, sum(expenses) as expenses
        from expenses
        group by cust_id
    )
    select cust_id, sales_sum, expenses
    from cte1
    inner join cte2 on cte1.cust_id = cte2.cust_id
    where cte1.cust_id in (select cust_id from customer_region where region = 'USA')
    and exists (select 1 from sales_person_region where region = 'NYC')
```

will be deconstructed to look like

```

Query 1 (cte1):
    select * 
        from orders
        group by cust_id
Columns:
  - cust_id
  - sum(sales) as sales_sum
----------------------------------------
Query 2 (cte2):
    select *
        from expenses
        group by cust_id
Columns:
  - cust_id
  - sum(expenses) as expenses
----------------------------------------
Query 3 (subquery):
    select * from customer_region where region = 'USA'
Columns:
  - cust_id
----------------------------------------
Query 4 (subquery):
    select * from sales_person_region where region = 'NYC'
Columns:
  - 1
----------------------------------------
Query 5 (main):
    select *
    from cte1
    inner join cte2 on cte1.cust_id = cte2.cust_id
    where cte1.cust_id in (<<query 2>>)
    and exists (<<query 3>>)
Columns:
  - cust_id
  - sales_sum
  - expenses
----------------------------------------
main query:
with
    cte1 as (
        <<query 1>>
    ),
    cte2 as (
        <<query 2>>
    )
   <<query 5>>
```

In [0]:
%pip install -U -qqqq mlflow-skinny[databricks] langgraph==0.3.4 databricks-langchain databricks-agents uv sqlparse
dbutils.library.restartPython()

In [0]:
import os
import sys

current_dir = os.getcwd()
module_path = os.path.join(current_dir, ".")  # replace with actual folder
sys.path.insert(0, module_path)
from agent import AGENT


In [0]:
import sqlparse

def prettify_final(query_string: str):
    # Format with sqlparse (keeps <<>> for any missing)
    prettified_value = sqlparse.format(query_string, reindent=True, keyword_case='upper')
    return prettified_value

In [0]:
def strip_comments(sql: str) -> str:
    # Remove multiline comments like /* ... */
    sql = re.sub(r'/\*.*?\*/', '', sql, flags=re.DOTALL)

    # Remove inline and full-line comments starting with --
    sql = re.sub(r'--[^\n\r]*', '', sql)

    return sql

In [0]:
def extract_tables(content: str):

  # with open(sql_file, 'r') as f:
  #   content = f.read()
  
  # Remove comments
  content = re.sub(r'--.*$', '', content, flags=re.MULTILINE)
  content = re.sub(r'/\*.*?\*/', '', content, flags=re.DOTALL)
  # Find tables after FROM/JOIN
  pattern = r'\b(?:FROM|JOIN)\s+([a-zA-Z_][a-zA-Z0-9_.]*)'
  tables = re.findall(pattern, content, re.IGNORECASE)
  # Clean up common SQL keywords
  keywords = {'ON', 'WHERE', 'GROUP', 'ORDER', 'HAVING', 'SELECT', 'INNER', 'LEFT', 'RIGHT', 'OUTER'}
  tables = [t for t in tables if t.upper() not in keywords]
  return sorted(set(tables))

# if __name__ == '__main__':
#   if len(sys.argv) < 2:
#     print("Usage: python script.py <sql_file>")
#     sys.exit(1)
#   sql_file = sys.argv[1] # This is like $1 in bash
#   try:
#     tables = extract_tables(sql_file)
#     for table in tables:
#       print(table)
#   except FileNotFoundError:
#     print(f"Error: File '{sql_file}' not found")
#     sys.exit(1)

In [0]:
import re

def extract_cte_block(sql_text: str):
    """
    Extracts the full WITH clause including all CTEs and returns:
    - the full CTE block (starting with WITH)
    - the remainder of the query (starting after the last closing paren of the CTE block)
    """
    sql_text = sql_text.strip()
    if not sql_text[:4].upper() == "WITH":
        return None, sql_text

    depth = 0
    i = 0
    while i < len(sql_text):
        if sql_text[i] == '(':
            depth += 1
        elif sql_text[i] == ')':
            depth -= 1
        elif sql_text[i:i+6].upper() == 'SELECT' and depth == 0 and i > 4:
            # Found a top-level SELECT not inside any CTE block
            break
        i += 1

    # Keep scanning until we get out of the last CTE's final parenthesis
    while i < len(sql_text) and depth > 0:
        if sql_text[i] == '(':
            depth += 1
        elif sql_text[i] == ')':
            depth -= 1
        i += 1

    # Move forward to the first SELECT after the CTE block
    remaining_sql = sql_text[i:].lstrip()
    select_match = re.match(r'(?is)^SELECT\b', remaining_sql)
    if select_match:
        return sql_text[:i].strip(), remaining_sql
    else:
        # Could not find SELECT after WITH block, return full input
        return sql_text, ""


def extract_top_level_selects_with_joins(sql_text: str):
    """
    Extracts:
    - Top-level SELECT statements
    - UNION, UNION ALL, INTERSECT, EXCEPT as separate sections
    Ignores subqueries (SELECTs inside parentheses).
    """
    result = []
    buffer = ''
    depth = 0
    i = 0
    length = len(sql_text)

    while i < length:
        # Handle strings safely
        if sql_text[i] in ["'", '"']:
            quote_char = sql_text[i]
            buffer += quote_char
            i += 1
            while i < length and sql_text[i] != quote_char:
                buffer += sql_text[i]
                i += 1
            if i < length:
                buffer += quote_char
                i += 1
            continue

        # Detect UNION keywords at depth 0
        if depth == 0:
            union_match = re.match(r'\s*(UNION\s+ALL|UNION|INTERSECT|EXCEPT)\b', sql_text[i:], re.IGNORECASE)
            if union_match:
                if buffer.strip():
                    result.append(buffer.strip())
                result.append(union_match.group(1).upper())
                i += union_match.end()
                buffer = ''
                continue

            if sql_text[i:i+6].upper() == 'SELECT':
                if buffer.strip():
                    result.append(buffer.strip())
                buffer = 'SELECT'
                i += 6
                continue

        char = sql_text[i]
        if char == '(':
            depth += 1
        elif char == ')':
            depth -= 1
        buffer += char
        i += 1

    if buffer.strip():
        result.append(buffer.strip())

    # Filter out any block that is a subquery like "(SELECT ..."
    return [s for s in result if not s.lstrip().startswith('(')]

def split_main_sections_with_classification(sql: str):
    """
    Returns a list of dicts:
    [
        { "type": 'cte' | 'select' | 'join', "id": '1', "content": str },
        ...
    ]
    """
    sql = strip_comments(sql)
    cte_block, remainder = extract_cte_block(sql)
    sections = []
    id_counter = 1

    if cte_block:
        sections.append({
            "type": "cte",
            "section_id": str(id_counter),
            "content": cte_block
        })
        id_counter += 1

    for part in extract_top_level_selects_with_joins(remainder):
        stripped = part.strip().upper()
        if stripped.startswith('SELECT'):
            section_type = "select"
        elif stripped in ('UNION', 'UNION ALL', 'INTERSECT', 'EXCEPT'):
            section_type = "join"
        else:
            section_type = "unknown"

        sections.append({
            "type": section_type,
            "section_id": str(id_counter),
            "content": part
        })
        id_counter += 1

    return sections



In [0]:
# def extract_and_replace_subqueries(sql_query, section:str="default", section_id:int=0):
#     def parse(sql, base_idx=1):
#         subqueries = []
#         result = ""
#         i = 0
#         n = len(sql)
#         subquery_counter = base_idx

#         while i < n:
#             # Copy normal text
#             if sql[i] != '(':
#                 result += sql[i]
#                 i += 1
#                 continue

#             # We hit a '(', skip whitespace to see if SELECT
#             j = i + 1
#             while j < n and sql[j].isspace():
#                 j += 1

#             if j + 5 <= n and sql[j:j+6].lower() == 'select':
#                 # We found a (SELECT...
#                 paren_count = 1
#                 k = j + 6
#                 buffer = '(' + sql[j:j+6]

#                 while k < n and paren_count > 0:
#                     buffer += sql[k]
#                     if sql[k] == '(':
#                         paren_count += 1
#                     elif sql[k] == ')':
#                         paren_count -= 1
#                     k += 1

#                 # Recursive replacement inside this subquery
#                 inner_sql, inner_subs, next_counter = parse(buffer[1:-1], subquery_counter)
#                 subqueries.extend(inner_subs)

#                 # Store the current subquery properly as (placeholder, content)
#                 placeholder = f"<<subquery_{section}[{section_id}]_{next_counter}>>"
#                 subqueries.append((placeholder, inner_sql))  # <-- fix here
#                 result += f"({placeholder})"
#                 subquery_counter = next_counter + 1
#                 i = k
#             else:
#                 # Just a normal '('
#                 result += sql[i]
#                 i += 1

#         return result, subqueries, subquery_counter

#     clean_sql = strip_comments(sql_query)
#     pretty_sql = prettify_final(clean_sql)
#     # print(f"pretty:\n{pretty_sql}\n")
#     final_sql, all_subqueries, _ = parse(pretty_sql)
#     return final_sql, all_subqueries

def extract_and_replace_subqueries(
    sql_query,
    section: str = "default",
    section_id: int = 0,
    scalar_funcs=None,
    scalar_operators=None
):
    if scalar_funcs is None:
        scalar_funcs = {"NVL", "COALESCE", "ROUND", "ABS", "LTRIM", "RTRIM", "UPPER", "LOWER"}

    if scalar_operators is None:
        scalar_operators = {"=", ">", "<", ">=", "<=", "<>", "!=", "+", "-", "*", "/"}

    def parse(sql, base_idx=1):
        subqueries = []
        result = ""
        i = 0
        n = len(sql)
        subquery_counter = base_idx

        while i < n:
            if sql[i] != '(':
                result += sql[i]
                i += 1
                continue

            j = i + 1
            while j < n and sql[j].isspace():
                j += 1

            if j + 5 <= n and sql[j:j+6].lower() == 'select':
                # ----- Improved scalar detection -----
                preceding_text = sql[:i].upper()
                # Find last 20 characters before '(' for function name
                lookback = preceding_text[max(0, i-200):i]
                
                # Match functions like NVL(...) before the subquery
                func_match = re.search(r'([A-Z_]+)\s*\(\s*$', lookback)
                found_func = func_match.group(1) if func_match else ""
                is_scalar = found_func in scalar_funcs or \
                            any(op in lookback for op in scalar_operators)

                # -------------------------------------

                paren_count = 1
                k = j + 6
                buffer = '(' + sql[j:j+6]
                while k < n and paren_count > 0:
                    buffer += sql[k]
                    if sql[k] == '(':
                        paren_count += 1
                    elif sql[k] == ')':
                        paren_count -= 1
                    k += 1

                inner_sql, inner_subs, next_counter = parse(buffer[1:-1], subquery_counter)
                subqueries.extend(inner_subs)

                placeholder = f"<<subquery_{section}[{section_id}]_{next_counter}>>"
                subqueries.append((placeholder, inner_sql, is_scalar))
                result += f"({placeholder})"
                subquery_counter = next_counter + 1
                i = k
            else:
                result += sql[i]
                i += 1

        return result, subqueries, subquery_counter

    clean_sql = strip_comments(sql_query)
    pretty_sql = prettify_final(clean_sql)
    final_sql, all_subqueries, _ = parse(pretty_sql)
    print(all_subqueries)
    return final_sql, all_subqueries


In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, IntegerType, BooleanType

# Initialize SparkSession if needed
spark = SparkSession.builder.appName("SQLSubqueries").getOrCreate()

def get_schema():
    schema = StructType([
        StructField("section", StringType(), True),
        StructField("section_id", IntegerType(), True),
        StructField("type", StringType(), True),
        StructField("identifier", StringType(), True),
        StructField("is_scalar", BooleanType(), True),
        StructField("original_query", StringType(), True),
        StructField("final_query", StringType(), True),
        StructField("content", StringType(), True),
        StructField("columns", ArrayType(StringType()), True),
        StructField("tables", ArrayType(StringType()), True),
        StructField("functions", ArrayType(StringType()), True),
        StructField("converted_content", StringType(), True),
        StructField("converted_columns", ArrayType(StringType()), True),
        StructField("response_error", StringType(), True)
    ])
    return schema

def initialize_empty_subquery_delta_table(table_name):
    """
    Creates or overwrites a Delta table with the expected schema but no data.
    """

    if table_name is None or len(table_name.strip()) == 0 or table_name == "":
        return
    
    # Create empty DataFrame
    empty_df = spark.createDataFrame([], get_schema())

    # Overwrite the Delta table
    spark.sql(f"drop table if exists {table_name}")
    empty_df.write.format("delta").mode("overwrite").saveAsTable(table_name)
  
def subqueries_to_spark_dataframe(sql_query: str):
    from pyspark.sql import SparkSession
    spark = SparkSession.builder.getOrCreate()

    sections = split_main_sections_with_classification(sql_query)
    # sections = process_sections_with_subquery_extraction(sql_query)
    all_records = []

    for idx, section in enumerate(sections):
        modified_sections = []
        section_type = section["type"]
        section_id = int(section["section_id"])
        
        if section_type in ("cte", "select"):
            modified_sql, subqueries = extract_and_replace_subqueries(section["content"], section=section_type, section_id=section_id)
            modified_sections.append(modified_sql)

            for ident, content, is_scalar in subqueries:
                all_records.append({
                    "section": section_type,
                    "section_id": section_id,
                    "type": "subquery",
                    "identifier": ident.replace("<<", "").replace(">>", ""),
                    "is_scalar": is_scalar,
                    "original_query": None,
                    "final_query": "",
                    "content": content,
                    "columns": [],
                    "tables": [],
                    "functions": [],
                    "converted_content": "",
                    "converted_columns": [],
                    "response_error": ""
                })
        else:
            # join sections
            all_records.append({
                    "section": section_type,
                    "section_id": section_id,
                    "type": "",
                    "identifier": "",
                    "is_scalar": False,
                    "original_query": None,
                    "final_query": "",
                    "content": section["content"],
                    "columns": [],
                    "tables": [],
                    "functions": [],
                    "converted_content": "",
                    "converted_columns": [],
                    "response_error": ""
                })
            modified_sections.append(section["content"])

        final_query = "\n".join(modified_sections)
        # final_query_tables = extract_tables(final_query)

        # Add the top-level reconstructed query
        
        all_records.append({
            "section": section_type,
            "section_id": section_id,
            "type": "section_main",
            "identifier": f"section_{section_id}_main",
            "is_scalar": False,
            "original_query": section["content"],
            "final_query": "",
            "content": final_query,
            "columns": [],
            "tables": [],
            "functions": [],
            "converted_content": "",
            "converted_columns": [],
            "response_error": ""
        })

    schema = get_schema()
    return spark.createDataFrame(all_records, schema)

def write_subqueries_to_delta(df, table_name: str):
    """
    Writes the Spark DataFrame of subqueries and main query
    to a Delta table in overwrite mode.
    """
    if table_name is None or len(table_name.strip()) == 0 or table_name == "":
        return
    df.write.mode("overwrite").format("delta").saveAsTable(table_name)


In [0]:
from pyspark.sql.types import StructType, StructField, StringType, ArrayType

def extract_columns_and_replace_select(df):
    """
    Takes a Spark DataFrame of queries, extracts top-level SELECT columns (parentheses-aware),
    replaces the SELECT list with '*', and returns a new Spark DataFrame.
    """
    import re

    def get_columns_and_rewrite(query):
        query = query.strip()
        upper_query = query.upper()

        select_pos = upper_query.find("SELECT")
        if select_pos == -1:
            return [], query

        i = select_pos + 6
        depth = 0
        buffer = ""
        columns = []
        n = len(query)

        while i < n:
            if query[i] in ("'", '"'):
                quote_char = query[i]
                buffer += quote_char
                i += 1
                while i < n and query[i] != quote_char:
                    buffer += query[i]
                    i += 1
                if i < n:
                    buffer += quote_char
                i += 1
                continue

            if query[i] == '(':
                depth += 1
            elif query[i] == ')':
                depth -= 1
            elif depth == 0 and query[i:i+5].upper() == 'FROM ':
                columns.append(buffer.strip())
                break
            elif depth == 0 and query[i] == ',':
                columns.append(buffer.strip())
                buffer = ""
                i += 1
                continue

            buffer += query[i]
            i += 1

        # Reconstruct query with SELECT * instead
        prefix = query[:select_pos]
        suffix = query[i:] if i < n else ""
        modified_query = f"{prefix}SELECT * {suffix}".strip()

        return [c for c in columns if c], modified_query

    # Process rows
    rows = df.collect()
    records = []

    for row in rows:
        columns, modified_content = get_columns_and_rewrite(row['content'])

        records.append({
            "section": row['section'],
            "section_id": row["section_id"],
            "type": row['type'],
            "identifier": row['identifier'],
            "is_scalar": row['is_scalar'],
            "original_query": row['original_query'],
            "final_query": "",
            "content": modified_content,
            "columns": columns,
            "tables": row['tables'],
            "functions": row['functions'],
            "converted_content": "",
            "converted_columns": [],
            "response_error": ""            
        })

    schema = get_schema()
    return spark.createDataFrame(records, schema)


In [0]:
from pyspark.sql.functions import col, collect_list, size, concat_ws

def extract_tables_full(sql):
    if not sql:
        return []
    from_tables = re.findall(r'FROM\s+([^\s;()]+)', sql, re.IGNORECASE)
    join_tables = re.findall(r'JOIN\s+([^\s;()]+)', sql, re.IGNORECASE)
    all_tables = from_tables + join_tables
    return sorted(set(t.strip().lower() for t in all_tables))

extract_tables_udf = udf(extract_tables_full, ArrayType(StringType()))

# Main function
def find_duplicate_subqueries_by_table_spark(df, content_col: str):
    subqueries_df = df.filter(col("type") == "subquery")

    with_tables = subqueries_df.withColumn(
        "tables", extract_tables_udf(col(content_col))
    )

    with_keys = with_tables.withColumn("table_key", concat_ws(",", col("tables")))

    grouped = with_keys.groupBy("table_key").agg(
        collect_list("identifier").alias("identifiers")
    ).filter(size("identifiers") > 1)

    return grouped

In [0]:
def query_to_databricks_sql_agent(query: str, is_scalar:bool = False):
  query_message = f"convert the following query: {query}" if not is_scalar else f"convert the following so that it only returns one record: {query}"
  response = AGENT.predict({"messages": [{"role": "user", "content": query_message}]})
  for msg in response.messages:
      if msg.role == "assistant":
          return msg.content

In [0]:
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import ChatMessage, ChatMessageRole
import traceback

def convert_query_to_databricks_sql(query: str, endpoint_name: str = "databricks-claude-sonnet-4"):
    if endpoint_name is None or len(endpoint_name.strip()) == 0:
        return query
    w = WorkspaceClient()
    response = w.serving_endpoints.query(
        name=endpoint_name,
        messages=[
            ChatMessage(
                role=ChatMessageRole.SYSTEM, content="You are a helpful assistant."
            ),
            ChatMessage(
                role=ChatMessageRole.USER, content=f"Please covert the following Oracle SQL query to Databricks SQL. Just return the query, no other content, including sql. If you see any sql that is wrapped in << >>, for example <<subquery_1>>, assume it is valid sql and leave it as is.  I need a complete conversion, do not skip any lines:\n{query}"
            ),
        ],
        temperature=0,
        top_p=0.95,
        top_k=1,
        max_tokens=None,
        thinking=False
    )
    response = response.choices[0].message.content.strip()
    return response

In [0]:
def load_parse_log(table_name: str):
    """
    Loads a Spark DataFrame from a given table name.
    
    Parameters:
        table_name (str): The fully qualified table name (e.g., 'db.schema.table' or 'default.my_table').

    Returns:
        DataFrame: The loaded Spark DataFrame.
    """
    return spark.read.table(table_name)


In [0]:
def convert_sql(df, endpoint_name="databricks-claude-sonnet-4", full_refresh=True):
    """
    Takes a Spark DataFrame of queries, loops through each row,
    calls convert_query_to_databricks_sql to convert it,
    and returns a new Spark DataFrame with an added converted_content field.
    """
    rows = df.collect()
    records = []

    for row in rows:
        if row['section'] == "join":
            # Row(section='join', section_id=3, type='', identifier='', original_query=None, final_query='', content='UNION ALL', columns=[], converted_content='', converted_columns=[], response_error='')
            records.append({
                "section": row['section'],
                "section_id": row["section_id"],
                "type": row["type"],
                "identifier": row["identifier"],
                "is_scalar": row["is_scalar"],
                "original_query": row["original_query"],
                "final_query": row["final_query"],
                "content": row['content'],
                "columns": [],
                "tables": row['tables'],
                "functions": row['functions'],
                "converted_content": "",
                "converted_columns": [],
                "response_error": ""
            })
            continue
        
        # Skip conversion if converted_content is a non-empty string
        # print(f"{row['converted_content']}")
        if not full_refresh and len(str(row['converted_content']).strip()) > 0:
            records.append(row.asDict())
            # print("skip")
            continue

        # print("convert")

        # call your existing function for conversion
        # llm_response = convert_query_to_databricks_sql(row['content'], endpoint_name=endpoint_name)
        # converted_sql = llm_response.choices[0].message.content.strip()
        response_error = ""
        try:
            # converted_sql = convert_query_to_databricks_sql(row['content'], endpoint_name=endpoint_name)
            converted_sql = query_to_databricks_sql_agent(row['content'], is_scalar=row['is_scalar'])
        except Exception as e:
            converted_sql = ""
            response_error = f"{type(e).__name__}: {str(e)}\n{traceback.format_exc(limit=2)}"

        records.append({
            "section": row['section'],
            "section_id": row["section_id"],
            "type": row['type'],
            "identifier": row['identifier'],
            "is_scalar": row['is_scalar'],
            "original_query": row['original_query'],
            "final_query": "",
            "content": row['content'],
            "columns": row['columns'],
            "tables": row['tables'],
            "functions": row['functions'],
            "converted_content": converted_sql,
            "converted_columns": [],
            "response_error": response_error
        })

    from pyspark.sql.types import StructType, StructField, StringType, ArrayType
    schema = get_schema()

    return spark.createDataFrame(records, schema)


In [0]:
def convert_sql_on_columns(df, chunk_size=10, endpoint_name="databricks-claude-sonnet-4", full_refresh=True):
    """
    Loops through the dataframe, chunks the columns, calls convert_query_to_databricks_sql
    on a dummy SELECT, and aggregates converted columns into `converted_columns`.
    Returns a new Spark DataFrame.
    """
    import re
    rows = df.collect()
    records = []

    for row in rows:
        if row['section'] == "join":
            # build the new record
            records.append({
                "section": row['section'],
                "section_id": row["section_id"],
                "type": row["type"],
                "identifier": row["identifier"],
                "is_scalar": row["is_scalar"],
                "original_query": row["original_query"],
                "final_query": row["final_query"],
                "content": row['content'],
                "columns": [],
                "tables": row['tables'],
                "functions": row['functions'],
                "converted_content": "",
                "converted_columns": [],
                "response_error": ""
            })
            continue

        # Skip conversion if converted_columns is a non-empty list
        if not full_refresh and len(row['converted_columns']) > 0:
            records.append(row.asDict())
            # print("skip")
            continue

        # print("convert")

        converted_columns = []
        response_error = ""

        if row['columns']:
            # break columns into chunks
            col_chunks = [row['columns'][i:i+chunk_size] for i in range(0, len(row['columns']), chunk_size)]
            # print(f"col_chunks:\n{col_chunks}")

            for chunk in col_chunks:
                dummy_query = f"SELECT {', '.join(chunk)} FROM dummy"
                # print(f"dummy_query = {dummy_query}")

                # call your existing function exactly as you wrote it
                # response = convert_query_to_databricks_sql(dummy_query, endpoint_name)
                # converted_sql = response.choices[0].message.content.strip()
                response_error = ""
                try:
                    # print(f"endpoint_name = {endpoint_name}")
                    # converted_sql = convert_query_to_databricks_sql(dummy_query, endpoint_name=endpoint_name)
                    converted_sql = query_to_databricks_sql_agent(dummy_query, is_scalar=False)
                except Exception as e:
                    converted_sql = ""
                    response_error = f"{type(e).__name__}: {str(e)}\n{traceback.format_exc(limit=2)}"
                    continue
                # extract columns between SELECT and FROM
                match = re.search(r"(?is)select\s+(.*?)(?:\s+from\b|$)", converted_sql)
                # print(f"converted_sql = {converted_sql}")
                if match:
                    cols_section = match.group(1)
                    # print(f"cols_section = {cols_section}")
                    cols = [c.strip() for c in re.split(r",(?![^(]*\))", cols_section) if c.strip()]
                    converted_columns.extend(cols)

        # build the new record
        records.append({
            "section": row['section'],
            "section_id": row["section_id"],
            "type": row['type'],
            "identifier": row['identifier'],
            "is_scalar": row['is_scalar'],
            "original_query": row['original_query'],
            "final_query": "",
            "content": row['content'],
            "columns": row['columns'],
            "tables": row['tables'],
            "functions": row['functions'],
            "converted_content": row['converted_content'],
            "converted_columns": converted_columns,
            "response_error": response_error
        })

    return spark.createDataFrame(records, get_schema())


In [0]:
import ast

def assemble_final_query_string(df):
    """
    - Replaces SELECT * with converted columns if available.
    - Resolves subquery placeholders within each section_main.
    - Assembles the full final query using ordered section_main parts and join parts.
    """
    rows = df.collect()
    query_map = {}

    # Step 1: Build map of identifier -> resolved content
    for row in rows:
        identifier = row['identifier']
        content = row['converted_content'] if row['converted_content'] else row['content']
        if not isinstance(content, str):
            content = ""

        if row['columns'] and row['converted_columns']:
            try:
                columns = ast.literal_eval(row['columns']) if isinstance(row['columns'], str) else row['columns']
                converted_columns = ast.literal_eval(row['converted_columns']) if isinstance(row['converted_columns'], str) else row['converted_columns']
                if columns and converted_columns:
                    converted_cols_str = ", ".join(converted_columns)
                    content = re.sub(r'(?i)(select\s+)\*', lambda m: m.group(1) + converted_cols_str, content, count=1)
            except Exception as e:
                # print(f"e = {e}")
                pass

        query_map[identifier] = content

    # print(f"query_map = {query_map}")

    # Step 2: Resolve all <<subquery_x>> placeholders recursively
    changed = True
    while changed:
        changed = False
        for name, text in query_map.items():
            original_text = text
            for sub_name, sub_text in query_map.items():
                # print(f"sub_name = {sub_name}, sub_text = {sub_text}")
                if sub_name != name and isinstance(sub_text, str):
                    text = text.replace(f"<<{sub_name}>>", sub_text)
            if text != original_text:
                changed = True
            query_map[name] = text

    # print(f"query_map_replaced = {query_map}")

    # Step 3: Assemble all section_main and join entries in order
    df_main = df.filter(df["type"].isin(["section_main", "join"]))
    df_main = df_main.withColumn("section_id", df_main["section_id"].cast("int"))
    ordered_rows = df_main.sort("section_id").collect()

    # print(f"ordered_rows = {ordered_rows}")
    
    final_parts = []
    for row in ordered_rows:
        identifier = row['identifier']
        resolved = query_map.get(identifier, row['converted_content'] or row['content'])
        final_parts.append(resolved.strip() if isinstance(resolved, str) else "")

    return "\n\n".join(final_parts)


In [0]:
def append_final_query_row(df, final_query_str):
    """
    Appends a new row with identifier 'final_query' and the provided final query string
    into the 'final_query' column. Other fields set to None or empty.
    Assumes DataFrame has the schema with all expected columns.
    """
    from pyspark.sql.types import StructType, StructField, StringType, ArrayType

    schema = StructType([
        StructField("section", StringType(), True),
        StructField("section_id", IntegerType(), True),
        StructField("type", StringType(), True),
        StructField("identifier", StringType(), True),
        StructField("is_scalar", BooleanType(), True),
        StructField("original_query", StringType(), True),
        StructField("final_query", StringType(), True),
        StructField("content", StringType(), True),
        StructField("columns", ArrayType(StringType()), True),
        StructField("tables", ArrayType(StringType()), True),
        StructField("functions", ArrayType(StringType()), True),
        StructField("converted_content", StringType(), True),
        StructField("converted_columns", ArrayType(StringType()), True),
        StructField("response_error", StringType(), True)
    ])

    final_query_tables = extract_tables(final_query_str)
    final_query_functions = extract_sql_functions(final_query_str)

    # build single new row
    new_row_data = [{
        "section": "final_query",
        "section_id": 0,
        "type": "final",
        "identifier": "final_query",
        "is_scalar": False,
        "original_query": None,
        "final_query": final_query_str,
        "content": None,
        "columns": [],
        "tables": final_query_tables,
        "tables": final_query_functions,
        "converted_content": None,
        "converted_columns": [],
        "response_error": None
    }]
    new_row_df = spark.createDataFrame(new_row_data, schema)

    # union with existing DataFrame
    final_df = df.unionByName(new_row_df)

    return final_df


In [0]:
from pyspark.sql.functions import explode

def populate_existing_tables_column(df, column="content"):
    """
    Updates the existing 'tables' column for each row using extract_tables 
    on the 'content' column.
    """
    extract_tables_udf = udf(lambda query: extract_tables(query) if query else [], ArrayType(StringType()))

    return df.withColumn("tables", extract_tables_udf(col(column)))


def get_unique_table_list(df):
    """
    Returns a sorted list of unique tables from the 'tables' column of the DataFrame.
    """
    tables_df = df.select(explode("tables").alias("table")).distinct()
    unique_tables = [row["table"] for row in tables_df.collect()]
    return sorted(unique_tables)



In [0]:
def extract_sql_functions(sql: str):
    """
    Extracts unique SQL function names from a SQL string.
    Matches patterns like FUNC_NAME(...).
    """
    if not sql:
        return []
    
    # Regex matches function calls like NVL(col, 0), UPPER(name), COUNT(*)
    pattern = r"\b([A-Z_][A-Z0-9_]*)\s*\("
    matches = re.findall(pattern, sql.upper())
    
    # Filter out common SQL keywords that are not functions
    keywords = {
        "SELECT", "FROM", "JOIN", "WHERE", "CASE", "WHEN", "THEN", "ELSE",
        "END", "GROUP", "ORDER", "HAVING", "ON", "UNION", "EXISTS", "IN",
        "AS", "WITH"
    }
    
    return sorted(set(fn for fn in matches if fn not in keywords))

def populate_existing_functions_columns(df, content_column="content", columns_column="columns"):
    """
    Updates the DataFrame with a 'functions' column containing unique SQL functions
    found in both the 'content' column and the 'columns' array column.
    """
    from pyspark.sql.functions import udf, col, array_union, flatten
    from pyspark.sql.types import ArrayType, StringType

    def extract_functions_from_row(content, columns):
        funcs = set(extract_sql_functions(content) if content else [])
        if columns:
            for col_expr in columns:
                funcs.update(extract_sql_functions(col_expr) if col_expr else [])
        return sorted(funcs)

    extract_functions_udf = udf(extract_functions_from_row, ArrayType(StringType()))

    return df.withColumn("functions", extract_functions_udf(col(content_column), col(columns_column)))

def get_unique_function_list(df):
    """
    Returns a sorted list of unique functions from the 'functions' column of the DataFrame.
    """
    tables_df = df.select(explode("functions").alias("function")).distinct()
    unique_functions = [row["function"] for row in tables_df.collect()]
    return sorted(unique_functions)

In [0]:
# import random

# def generate_complex_query():
#     base_query = """
# WITH
# DeptStats AS (
#     SELECT
#         DepartmentID,
#         SUM(Salary) AS TotalDeptSalary,
#         COUNT(*) AS NumEmployees,
#         AVG(Salary) AS AvgDeptSalary
#     FROM
#         Employees
#     GROUP BY
#         DepartmentID
# ),
# EmpProjects AS (
#     SELECT
#         EmployeeID,
#         COUNT(ProjectID) AS ProjectsCompleted,
#         MAX(CompletedDate) AS LastProjectDate,
#         YEAR(MAX(CompletedDate)) AS LastProjectYear
#     FROM
#         Projects
#     WHERE
#         Status = 'Completed'
#     GROUP BY
#         EmployeeID
# )

# SELECT
#     e.EmployeeID,
#     UPPER(e.Name) AS Name,
#     d.Name AS Department,
#     CASE
#         WHEN e.Salary > (
#             SELECT AVG(Salary)
#             FROM Employees
#             WHERE DepartmentID = e.DepartmentID
#         ) THEN CONCAT('Above Average (', CAST(e.Salary AS VARCHAR), ')')
#         ELSE 'Average or Below'
#     END AS SalaryStatus,
#     -- some remark here,
#     CASE
#         WHEN    rsm.investment_type = 'BL'
#             AND NVL (psah.acrd_cd, 'N') NOT IN ('Y', 'V') -- story 897300
#         THEN
#             NVL (
#                 (SELECT wacoupon
#                    FROM stg_wso_pos_acr_ame
#                   WHERE     portfolio_fund_id = psah.cal_dt
#                         AND asofdate = psah.cal_dt
#                         AND asset_primaryud = psah.asset_id
#                         AND rec_typ_cd = 'POS'),
#                 0)
#         ELSE
#             psah.int_rt
#     END
#         AS pos_int_it,  
#     ep.ProjectsCompleted,
#     YEAR(e.HireDate) AS HireYear,
#     MONTH(e.HireDate) AS HireMonth,
#     COALESCE(ep.LastProjectYear, 'N/A') AS LastProjectYear
# FROM
#     Employees e
#     JOIN Departments d ON e.DepartmentID = d.DepartmentID
#     LEFT JOIN EmpProjects ep ON e.EmployeeID = ep.EmployeeID
# WHERE
#     e.EmployeeID IN (
#         SELECT
#             e2.EmployeeID
#         FROM
#             Employees e2
#             JOIN Departments d2 ON e2.DepartmentID = d2.DepartmentID
#             LEFT JOIN EmpProjects ep2 ON e2.EmployeeID = ep2.EmployeeID
#         WHERE
#             e2.Salary > (
#                 SELECT AVG(Salary)
#                 FROM Employees
#                 WHERE DepartmentID = e2.DepartmentID
#             )
#     )
# UNION ALL
# SELECT
#     NULL AS EmployeeID,
#     NULL AS Name,
#     d.Name AS Department,
#     CONCAT('Department Total: ', CAST(ds.TotalDeptSalary AS VARCHAR)) AS SalaryStatus,
#     ds.NumEmployees AS ProjectsCompleted,
#     NULL AS HireYear,
#     NULL AS HireMonth,
#     NULL AS LastProjectYear,
# """

#     # Generate 599 columns, each randomly a literal or a CASE with subquery
#     cols = []
#     for i in range(1, 600):
#         rnd = random.random()
#         if rnd < 0.05:
#             col = (
#                 f"CASE WHEN ds.TotalDeptSalary > 100000 AND ds.TotalDeptSalary < 200000 AND NVL(a.emp_class, 'N') NOT IN ('A', 'B') -- story 1234\n"
#                 f"THEN NVL((SELECT id FROM Employees JOIN dept ON Employees.dept_id = dept.dept_id\n"
#                 f"WHERE DepartmentID = ds.DepartmentID),0)\n"
#                 f"ELSE 0 END as col{i}"
#             )
#         elif rnd >= 0.05 and rnd < 0.1:
#             col = (
#                 f"CASE WHEN ds.TotalDeptSalary > 100000 AND ds.TotalDeptSalary < 200000 AND NVL(a.emp_class, 'N') NOT IN ('A', 'B') -- story 1234\n"
#                 f"THEN NVL((SELECT id FROM Employees WHERE DepartmentID = ds.DepartmentID),2)\n"
#                 f"ELSE 0 END as col{i}"
#             )
#         elif rnd >= 0.1 and rnd < 0.2:
#             col = (
#                 f"CASE WHEN ds.TotalDeptSalary > 100000 AND ds.TotalDeptSalary < 200000 AND NVL(a.emp_class, 'N') NOT IN ('A', 'B') -- story 1234\n"
#                 f"THEN NVL((SELECT id FROM Employees WHERE DepartmentID = ds.DepartmentID),'x')\n"
#                 f"ELSE 0 END as col{i}"
#             )
#         elif rnd >= 0.2 and rnd < 0.25:
#             col = f"myfunc_from({rnd}) as col{i}"
#         elif rnd >= 0.25 and rnd < 0.3:
#             col = f"myfunc_fromchar({rnd}) as col{i}"
#         else:
#             col = f"'col{i}' as col{i}"
#         cols.append(col)
#     cols_str = ",\n".join(cols)

#     # Assemble the final query
#     final_query = f"""
# {base_query}
# {cols_str}
# FROM
#     DeptStats ds
#     JOIN Departments d ON ds.DepartmentID = d.DepartmentID;
# """
#     return final_query

# query_string = generate_complex_query()
# print(query_string)


In [0]:
# parse_log_table="users.paul_signorelli.sql_parsing_log"
# endpoint_name=""
# initialize_empty_subquery_delta_table(table_name=parse_log_table)
# spark_df = subqueries_to_spark_dataframe(query_string)
# spark_df_with_columns = extract_columns_and_replace_select(spark_df)
# spark_df_with_columns_and_tables = populate_existing_tables_column(spark_df_with_columns)
# spark_df_with_columns_and_functions = populate_existing_functions_columns(spark_df_with_columns_and_tables)

# spark_df_converted = convert_sql(spark_df_with_columns_and_functions, endpoint_name=endpoint_name)
# spark_df_with_converted_columns = convert_sql_on_columns(spark_df_converted, chunk_size=5, endpoint_name=endpoint_name)

# final_query = assemble_final_query_string(spark_df_with_converted_columns)
# pretty_final = prettify_final(final_query)

# df_final = append_final_query_row(spark_df_with_converted_columns, pretty_final)

# write_subqueries_to_delta(df_final, table_name=parse_log_table)
  

In [0]:
# unique_table_list = get_unique_table_list(spark_df_with_columns_and_functions)
# print(unique_table_list)

In [0]:
# unique_function_list = get_unique_function_list(spark_df_with_columns_and_functions)
# print(unique_function_list)

In [0]:
# %sql
# UPDATE users.paul_signorelli.sql_parsing_log 
# SET converted_content = '' 
# WHERE identifier = 'subquery_select[4]_1';

# UPDATE users.paul_signorelli.sql_parsing_log 
# SET converted_columns = array() 
# WHERE identifier = 'subquery_select[4]_2';

# select * from users.paul_signorelli.sql_parsing_log where identifier in ('subquery_select[4]_1', 'subquery_select[4]_2')

In [0]:
# df = load_parse_log(parse_log_table)

# spark_df_converted = convert_sql(df, endpoint_name=endpoint_name, full_refresh=False)
# spark_df_with_converted_columns = convert_sql_on_columns(spark_df_converted, chunk_size=5, endpoint_name=endpoint_name, full_refresh=False)

# final_query = assemble_final_query_string(spark_df_with_converted_columns)
# pretty_final = prettify_final(final_query)

# df_final = append_final_query_row(spark_df_with_converted_columns, pretty_final)
# write_subqueries_to_delta(df_final, table_name=parse_log_table)


In [0]:
# display(
#   spark.sql(f"select * from {parse_log_table}")
# )

In [0]:
# duplicates_df = find_duplicate_subqueries_by_table_spark(df_final, "content")
# display(duplicates_df)

In [0]:
# def process_sections_with_subquery_extraction(sql: str):
#     """
#     Runs extract_and_replace_subqueries on each 'cte' and 'select' section.
#     Keeps 'join' sections unchanged.

#     Returns a list of dicts:
#     [
#         {
#             "type": "cte" | "select" | "join",
#             "content": "modified content",
#             "subqueries": [ (placeholder, original_subquery), ... ]
#         },
#         ...
#     ]
#     """
#     sections = split_main_sections_with_classification(sql)
#     processed = []

#     for section in sections:
#         if section["type"] in ("cte", "select"):
#             replaced_sql, subqueries = extract_and_replace_subqueries(
#                 section["content"], 
#                 section=section["type"],
#                 section_id=section["id"]
#             )
#             processed.append({
#                 "section_id": section["id"],
#                 "type": section["type"],
#                 "content": replaced_sql,
#                 "subqueries": subqueries
#             })
#         else:  # e.g., join, unknown
#             processed.append({
#                 "section_id": section["id"],
#                 "type": section["type"],
#                 "content": section["content"],
#                 "subqueries": []
#             })

#     return processed
