## Simple Query Parser

This parser will deconstuct a complex query into its sub-queries so each of those sub-queries can be separated out to be analyzed or converted.

It will look for SELECT, cte's defined by WITH, IN or EXISTS to discover sub-queries.

It will also can remove the columns if need be if a query has a long list of columns and replace them with a '*'.  This can be done to reduce query size for an LLM and analyze the columns separately.

The script will put all of this information into a collection which can be iterated through to call an LLM to analyze or migrate the code and then it can be pieced back together at the end.

So the following query:

```
with cte1 as (
        select cust_id, sum(sales) as sales_sum 
        from orders
        group by cust_id
    ),
    cte2 as (
        select cust_id, sum(expenses) as expenses
        from expenses
        group by cust_id
    )
    select cust_id, sales_sum, expenses
    from cte1
    inner join cte2 on cte1.cust_id = cte2.cust_id
    where cte1.cust_id in (select cust_id from customer_region where region = 'USA')
    and exists (select 1 from sales_person_region where region = 'NYC')
```

will be deconstructed to look like

```

Query 1 (cte1):
    select * 
        from orders
        group by cust_id
Columns:
  - cust_id
  - sum(sales) as sales_sum
----------------------------------------
Query 2 (cte2):
    select *
        from expenses
        group by cust_id
Columns:
  - cust_id
  - sum(expenses) as expenses
----------------------------------------
Query 3 (subquery):
    select * from customer_region where region = 'USA'
Columns:
  - cust_id
----------------------------------------
Query 4 (subquery):
    select * from sales_person_region where region = 'NYC'
Columns:
  - 1
----------------------------------------
Query 5 (main):
    select *
    from cte1
    inner join cte2 on cte1.cust_id = cte2.cust_id
    where cte1.cust_id in (<<query 2>>)
    and exists (<<query 3>>)
Columns:
  - cust_id
  - sales_sum
  - expenses
----------------------------------------
main query:
with
    cte1 as (
        <<query 1>>
    ),
    cte2 as (
        <<query 2>>
    )
   <<query 5>>
```

In [0]:
import re
from typing import List, Dict

def extract_columns(select_body: str) -> List[str]:
    match = re.search(r'select\s+(.*?)\s+from', select_body, flags=re.IGNORECASE | re.DOTALL)
    if not match:
        return []
    return [col.strip() for col in match.group(1).split(',')]

def replace_columns_with_star(query: str) -> str:
    return re.sub(r'(select\s+)(.*?)(\s+from)', r'\1*\3', query, flags=re.IGNORECASE | re.DOTALL)

def extract_subqueries(query: str) -> (str, List[str]):
    subqueries = []
    output = ""
    i = 0
    n = len(query)

    while i < n:
        if query[i] == '(':
            start = i
            depth = 1
            i += 1
            content_start = i
            while i < n and depth > 0:
                if query[i] == '(':
                    depth += 1
                elif query[i] == ')':
                    depth -= 1
                i += 1
            content = query[content_start:i-1].strip()

            # If content is a subquery
            if re.match(r'^\s*select\b', content, flags=re.IGNORECASE):
                subqueries.append(content)
                output += f"(<<query {len(subqueries)+1}>>)"
            else:
                output += f"({content})"
        else:
            output += query[i]
            i += 1

    return output, subqueries

def split_sql_view_full(sql: str) -> List[Dict[str, object]]:
    sql = sql.strip()
    if not sql.lower().startswith("with"):
        raise ValueError("SQL must start with WITH")

    sql_body = sql[5:].lstrip()
    queries = []
    idx = 0
    n = len(sql_body)

    while idx < n:
        match = re.match(r'(\w+)\s+as\s+\(', sql_body[idx:], re.IGNORECASE)
        if not match:
            break  # Main query reached
        cte_name = match.group(1)
        start_idx = idx + match.end() - 1
        paren_count = 1
        end_idx = start_idx + 1
        while end_idx < n and paren_count > 0:
            if sql_body[end_idx] == '(':
                paren_count += 1
            elif sql_body[end_idx] == ')':
                paren_count -= 1
            end_idx += 1
        if paren_count != 0:
            raise ValueError(f"Unmatched parentheses in CTE '{cte_name}'")
        cte_block = sql_body[idx:end_idx].strip().rstrip(',')
        inner_match = re.match(rf'{cte_name}\s+as\s+\((.*)\)$', cte_block, re.IGNORECASE | re.DOTALL)
        inner_query = inner_match.group(1).strip() if inner_match else cte_block

        columns = extract_columns(inner_query)
        starred = replace_columns_with_star(inner_query)
        queries.append({
            'name': cte_name,
            'columns': columns,
            'query': starred,
        })

        idx = end_idx
        while idx < n and sql_body[idx] in " \n\t,":
            idx += 1

    main_query = sql_body[idx:].strip()
    modified_main, subqueries = extract_subqueries(main_query)

    # Process subqueries extracted from main
    for subquery in subqueries:
        columns = extract_columns(subquery)
        starred = replace_columns_with_star(subquery)
        queries.append({
            'name': 'subquery',
            'columns': columns,
            'query': starred,
        })

    # Main query goes last
    main_columns = extract_columns(main_query)
    main_starred = replace_columns_with_star(modified_main)
    queries.append({
        'name': 'main',
        'columns': main_columns,
        'query': main_starred,
    })

    return queries

def print_full_queries(sql: str):
    parsed = split_sql_view_full(sql)
    for i, entry in enumerate(parsed, 1):
        print(f"Query {i} ({entry['name']}):")
        print(f"    {entry['query']}")
        print("Columns:")
        for col in entry['columns']:
            print(f"  - {col}")
        print("-" * 40)

    print("main query:")
    ctes = [f"    {q['name']} as (\n        <<query {i+1}>>\n    )"
            for i, q in enumerate(parsed[:-2]) if q['name'] != 'subquery']
    print("with\n" + ",\n".join(ctes))
    print(f"   <<query {len(parsed)}>>")


In [0]:
# Demo
if __name__ == "__main__":
    sql_query = """
    with cte1 as (
        select cust_id, sum(sales) as sales_sum 
        from orders
        group by cust_id
    ),
    cte2 as (
        select cust_id, sum(expenses) as expenses
        from expenses
        group by cust_id
    )
    select cust_id, sales_sum, expenses
    from cte1
    inner join cte2 on cte1.cust_id = cte2.cust_id
    where cte1.cust_id in (select cust_id from customer_region where region = 'USA')
    and exists (select 1 from sales_person_region where region = 'NYC')
    """

    print_full_queries(sql_query)


In [0]:
def generate_large_test_query_string() -> str:
    lines = []
    lines.append("WITH")

    # Generate 10 CTEs
    for i in range(1, 11):
        lines.append(f"    cte{i} AS (")
        lines.extend([
            f"        SELECT",
            f"            user_id,",
            f"            SUM(metric_{i}) AS metric_sum_{i}",
            f"        FROM",
            f"            source_table_{i}",
            f"        WHERE",
            f"            user_id IN (",
            f"                SELECT user_id FROM user_filter_{i} WHERE active = TRUE",
            f"            )",
            f"        GROUP BY user_id",
            f"    ),"
        ])

    lines[-1] = lines[-1].rstrip(',')  # Remove trailing comma from last CTE

    # Main SELECT with subqueries and joins
    lines.extend([
        "SELECT",
        "    u.user_id,",
        "    cte1.metric_sum_1,",
        "    cte2.metric_sum_2,",
        "    extra_data.extra_val,",
        "    (",
        "        SELECT MAX(score)",
        "        FROM user_scores us",
        "        WHERE us.user_id = u.user_id",
        "    ) AS max_score,",
        "    CASE",
        "        WHEN EXISTS (",
        "            SELECT 1 FROM audit_log al WHERE al.user_id = u.user_id AND al.status = 'flagged'",
        "        ) THEN 'FLAGGED'",
        "        ELSE 'OK'",
        "    END AS audit_status",
        "FROM users u",
        "LEFT JOIN (",
        "    SELECT user_id, COUNT(*) AS extra_val",
        "    FROM extra_events",
        "    GROUP BY user_id",
        ") AS extra_data ON extra_data.user_id = u.user_id",
        "JOIN cte1 ON cte1.user_id = u.user_id",
        "JOIN cte2 ON cte2.user_id = u.user_id",
        "WHERE",
        "    u.region IN (",
        "        SELECT region FROM allowed_regions WHERE region_type = 'premium'",
        "    )",
        "    AND EXISTS (",
        "        SELECT 1 FROM login_events le WHERE le.user_id = u.user_id AND le.success = TRUE",
        "    )"
    ])

    # Pad the rest with filler lines to reach 500
    # while len(lines) < 499:
    #    lines.append(f"-- filler line {len(lines)+1}")

    lines.append("ORDER BY u.user_id;")

    return '\n'.join(lines)


In [0]:
query_string = generate_large_test_query_string()

print(query_string[:100])
print("\n . . . \n")
print(query_string[:-100]) 
print("\n---------------------------------------------\n")
print_full_queries(query_string) # Just preview the first 1000 characters