## Simple Query Parser

This parser will deconstuct a complex query into its sub-queries so each of those sub-queries can be separated out to be analyzed or converted.

It will look for SELECT, cte's defined by WITH, IN or EXISTS to discover sub-queries.

It will also can remove the columns if need be if a query has a long list of columns and replace them with a '*'.  This can be done to reduce query size for an LLM and analyze the columns separately.

The script will put all of this information into a collection which can be iterated through to call an LLM to analyze or migrate the code and then it can be pieced back together at the end.

So the following query:

```
with cte1 as (
        select cust_id, sum(sales) as sales_sum 
        from orders
        group by cust_id
    ),
    cte2 as (
        select cust_id, sum(expenses) as expenses
        from expenses
        group by cust_id
    )
    select cust_id, sales_sum, expenses
    from cte1
    inner join cte2 on cte1.cust_id = cte2.cust_id
    where cte1.cust_id in (select cust_id from customer_region where region = 'USA')
    and exists (select 1 from sales_person_region where region = 'NYC')
```

will be deconstructed to look like

```

Query 1 (cte1):
    select * 
        from orders
        group by cust_id
Columns:
  - cust_id
  - sum(sales) as sales_sum
----------------------------------------
Query 2 (cte2):
    select *
        from expenses
        group by cust_id
Columns:
  - cust_id
  - sum(expenses) as expenses
----------------------------------------
Query 3 (subquery):
    select * from customer_region where region = 'USA'
Columns:
  - cust_id
----------------------------------------
Query 4 (subquery):
    select * from sales_person_region where region = 'NYC'
Columns:
  - 1
----------------------------------------
Query 5 (main):
    select *
    from cte1
    inner join cte2 on cte1.cust_id = cte2.cust_id
    where cte1.cust_id in (<<query 2>>)
    and exists (<<query 3>>)
Columns:
  - cust_id
  - sales_sum
  - expenses
----------------------------------------
main query:
with
    cte1 as (
        <<query 1>>
    ),
    cte2 as (
        <<query 2>>
    )
   <<query 5>>
```

In [0]:
%pip install sqlparse
dbutils.library.restartPython()

In [0]:
import re
from typing import List, Dict, Tuple


def extract_columns(query: str) -> List[str]:
    query = query.strip()
    lower = query.lower()
    start = lower.find("select")
    if start == -1:
        return []

    depth = 0
    i = start + 6
    cols = ''
    while i < len(query):
        if query[i] == '(':
            depth += 1
        elif query[i] == ')':
            depth -= 1
        elif lower[i:i+5] == ' from' and depth == 0:
            cols = query[start + 6:i].strip()
            break
        i += 1

    columns = []
    current = ''
    depth = 0
    for c in cols:
        if c == ',' and depth == 0:
            columns.append(current.strip())
            current = ''
        else:
            current += c
            if c == '(':
                depth += 1
            elif c == ')':
                depth -= 1
    if current.strip():
        columns.append(current.strip())
    return columns


def replace_columns_with_star(query: str) -> str:
    query = query.strip()
    lower = query.lower()
    start = lower.find("select")
    if start == -1:
        return query

    depth = 0
    i = start + 6
    while i < len(query):
        if query[i] == '(':
            depth += 1
        elif query[i] == ')':
            depth -= 1
        elif lower[i:i+5] == ' from' and depth == 0:
            return query[:start + 6] + ' * ' + query[i:]
        i += 1
    return query


def extract_subqueries(query: str) -> Tuple[str, List[str]]:
    subqueries = []
    output = ""
    i = 0
    n = len(query)

    while i < n:
        if query[i] == '(':
            start = i
            depth = 1
            i += 1
            content_start = i
            while i < n and depth > 0:
                if query[i] == '(':
                    depth += 1
                elif query[i] == ')':
                    depth -= 1
                i += 1
            content_end = i - 1
            content = query[content_start:content_end].strip()

            if re.search(r'\bselect\b', content, re.IGNORECASE):
                temp_depth = 0
                has_top_level_select = False
                for m in re.finditer(r'\bselect\b', content, re.IGNORECASE):
                    pos = m.start()
                    for j in range(pos):
                        if content[j] == '(':
                            temp_depth += 1
                        elif content[j] == ')':
                            temp_depth -= 1
                    if temp_depth == 0:
                        has_top_level_select = True
                        break
                if has_top_level_select:
                    subqueries.append(content)
                    placeholder = f"<<subquery_{len(subqueries)}>>"
                    output += f"({placeholder})"
                else:
                    output += f"({content})"
            else:
                output += f"({content})"
        else:
            output += query[i]
            i += 1

    return output, subqueries

def strip_comments(sql: str) -> str:
    return re.sub(r'/\*.*?\*/', '', sql, flags=re.DOTALL)

def split_sql_view_full(sql: str, extract_columns: bool = False) -> Tuple[List[Dict[str, object]], str, str]:
    sql = strip_comments(sql.strip())
    sql = sql.strip()
    sql = re.sub(r'\s+', ' ', sql, flags=re.IGNORECASE)
    mode = 'select'

    if sql.lower().startswith("with "):
        mode = 'with'
        sql_body = sql[5:].lstrip()
    elif sql.lower().startswith("select "):
        sql_body = sql
    else:
        raise ValueError("SQL must start with WITH or SELECT")

    queries = []
    idx = 0
    n = len(sql_body)

    cte_names = []

    if mode == 'with':
        while idx < n:
            match = re.match(r'(\w+)\s+as\s+\(', sql_body[idx:], re.IGNORECASE)
            if not match:
                break
            cte_name = match.group(1)
            cte_names.append(cte_name)
            start_idx = idx + match.end() - 1
            depth = 1
            end_idx = start_idx + 1
            while end_idx < n and depth > 0:
                if sql_body[end_idx] == '(':
                    depth += 1
                elif sql_body[end_idx] == ')':
                    depth -= 1
                end_idx += 1
            cte_block = sql_body[idx:end_idx].strip().rstrip(',')
            inner_query = re.match(rf'{cte_name}\s+as\s+\((.*)\)$', cte_block, re.IGNORECASE | re.DOTALL)
            inner_query_text = inner_query.group(1).strip() if inner_query else cte_block

            columns = extract_columns(inner_query_text) if extract_columns else []
            starred = replace_columns_with_star(inner_query_text) if extract_columns else inner_query_text
            queries.append({
                'name': cte_name,
                'type': 'cte',
                'columns': columns,
                'query': starred,
            })

            idx = end_idx
            while idx < n and sql_body[idx] in " ,\n\t":
                idx += 1
        main_query = sql_body[idx:].strip()
    else:
        main_query = sql_body

    rewritten_main, subqueries = extract_subqueries(main_query)

    for i, subquery in enumerate(subqueries):
        cols = extract_columns(subquery) if extract_columns else []
        starred = replace_columns_with_star(subquery) if extract_columns else subquery
        queries.append({
            'name': f"subquery_{i + 1}",
            'type': 'subquery',
            'columns': cols,
            'query': starred,
        })

    cols = extract_columns(main_query) if extract_columns else []
    rewritten_main_starred = replace_columns_with_star(rewritten_main) if extract_columns else rewritten_main

    if mode == 'with':
        full_main_query = "WITH " + ", ".join(
            [f"{name} AS (<<{name}>>)" for name in cte_names]
        ) + f" {rewritten_main_starred}"
    else:
        full_main_query = rewritten_main_starred

    queries.append({
        'name': 'main',
        'type': 'main',
        'columns': cols,
        'query': full_main_query,
    })

    return queries, mode, full_main_query


def print_full_queries(sql: str):
    parsed, mode, full_main_query = split_sql_view_full(sql)

    for i, entry in enumerate(parsed):
        label = f"{entry['name']}"
        print(f"Query {i + 1} ({label}):")
        print(f"    {entry['query']}")
        print("Columns:")
        for col in entry['columns']:
            print(f"  - {col}")
        print("-" * 40)

    print("main query:")
    if mode == 'with':
        cte_snippets = [
            f"    {q['name']} as (\n        <<{q['name']}>>\n    )"
            for q in parsed if q['type'] == 'cte'
        ]
        print("with\n" + ",\n".join(cte_snippets))
        print("   <<main>>")
    else:
        print("   <<main>>")

    print(f"\n\nFull main query:\n{full_main_query}\n")


In [0]:
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import ChatMessage, ChatMessageRole
from pyspark.sql import SparkSession

def convert_query_to_databricks_sql(query: str, endpoint_name: str = "databricks-claude-sonnet-4"):
    w = WorkspaceClient()  # Initialize without timeout parameter, set timeout if supported later
    response = w.serving_endpoints.query(
        # name="databricks-claude-3-7-sonnet",
        name="databricks-claude-sonnet-4",
        # name="llama-70b-code-converstion",
        messages=[
            ChatMessage(
                role=ChatMessageRole.SYSTEM, content="You are a helpful assistant."
            ),
            ChatMessage(
                role=ChatMessageRole.USER, content=f"Please covert the following Oracle SQL query to Databricks SQL. Just return the query, no other content, including ```sql. I need a complete conversion, do not skip any lines:\n{query}"
            ),
        ]
    )
    return response

def convert_and_get_dataframe(query_string, endpoint_name: str = "databricks-claude-sonnet-4"):
    converted = []
    parsed, mode, main_query = split_sql_view_full(query_string)
    for i, entry in enumerate(parsed):
        converted_query = convert_query_to_databricks_sql(entry['query'], endpoint_name)
        converted_query = converted_query.choices[0].message.content
        converted.append(
            dict(
                name=entry['name'], 
                original=entry['query'], 
                converted=converted_query
            )
        )

    # Convert the list of dictionaries to a Spark DataFrame
    converted_df = spark.createDataFrame(converted)
    return converted_df


In [0]:
from pyspark.sql.functions import col, when
import sqlparse

def assemble_final_query(converted_df):
    # Filter out rows where name is 'main'
    filtered_df = converted_df.filter(col("name") != "main")
    # Filter out rows where name is 'main'
    filtered_df = converted_df.filter(col("name") != "main")

    # Collect the filtered rows
    rows = filtered_df.collect()

    # Loop through the rows and replace the string in the 'converted' column
    for row in rows:
        name_value = row['name']
        converted_value = row['converted']
        main_row = converted_df.filter(col("name") == "main").first()
        if main_row:
            main_converted = main_row['converted']
            updated_main_converted = main_converted.replace(f"<<{name_value}>>", converted_value)
            converted_df = converted_df.withColumn("converted", when(col("name") == "main", updated_main_converted).otherwise(col("converted")))

    return converted_df.select("name", "original", "converted")


def prettify_final(converted_df):
    final_query = assemble_final_query(
        converted_df.select("name", "original", "converted")
    ).filter("name = 'main'")
    value = final_query.select("converted").collect()[0][0]
    prettified_value = sqlparse.format(value, reindent=True, keyword_case='upper')
    prettified_value = prettified_value.replace("\n", "\n")
    return prettified_value