In [0]:
%pip install -U -qqqq sqlparse
dbutils.library.restartPython()

In [0]:
import os
import sys

current_dir = os.getcwd()
module_path = os.path.join(current_dir, ".")  # replace with actual folder
sys.path.insert(0, module_path)
from parser_utils import extract_and_replace_subqueries, strip_comments, prettify_final, extract_all_subqueries_to_list


In [0]:
import random

def generate_complex_query():
    base_query = """
WITH
DeptStats AS (
    SELECT
        DepartmentID,
        SUM(Salary) AS TotalDeptSalary,
        COUNT(*) AS NumEmployees,
        AVG(Salary) AS AvgDeptSalary
    FROM
        Employees
    GROUP BY
        DepartmentID
),
EmpProjects AS (
    SELECT
        EmployeeID,
        COUNT(ProjectID) AS ProjectsCompleted,
        MAX(CompletedDate) AS LastProjectDate,
        YEAR(MAX(CompletedDate)) AS LastProjectYear
    FROM
        Projects
    WHERE
        Status = 'Completed'
    GROUP BY
        EmployeeID
)

SELECT
    e.EmployeeID,
    UPPER(e.Name) AS Name,
    d.Name AS Department,
    CASE
        WHEN e.Salary > (
            SELECT AVG(Salary)
            FROM Employees
            WHERE DepartmentID = e.DepartmentID
        ) THEN CONCAT('Above Average (', CAST(e.Salary AS VARCHAR), ')')
        ELSE 'Average or Below'
    END AS SalaryStatus,
    -- some remark here,
    CASE
        WHEN    rsm.investment_type = 'BL'
            AND NVL (psah.acrd_cd, 'N') NOT IN ('Y', 'V') -- story 897300
        THEN
            NVL (
                (SELECT wacoupon
                   FROM stg_wso_pos_acr_ame
                  WHERE     portfolio_fund_id = psah.cal_dt
                        AND asofdate = psah.cal_dt
                        AND asset_primaryud = psah.asset_id
                        AND rec_typ_cd = 'POS'),
                0)
        ELSE
            psah.int_rt
    END
        AS pos_int_it,  
    ep.ProjectsCompleted,
    YEAR(e.HireDate) AS HireYear,
    MONTH(e.HireDate) AS HireMonth,
    COALESCE(ep.LastProjectYear, 'N/A') AS LastProjectYear,
    SELECT region_id FROM regoins WHERE ROWNUM=1
FROM
    Employees e
    JOIN Departments d ON e.DepartmentID = d.DepartmentID
    LEFT JOIN EmpProjects ep ON e.EmployeeID = ep.EmployeeID
WHERE
    e.EmployeeID IN (
        SELECT
            e2.EmployeeID
        FROM
            Employees e2
            JOIN Departments d2 ON e2.DepartmentID = d2.DepartmentID
            LEFT JOIN EmpProjects ep2 ON e2.EmployeeID = ep2.EmployeeID
        WHERE
            e2.Salary > (
                SELECT AVG(Salary)
                FROM Employees
                WHERE DepartmentID = e2.DepartmentID
            )
    )
UNION ALL
SELECT
    NULL AS EmployeeID,
    NULL AS Name,
    d.Name AS Department,
    CONCAT('Department Total: ', CAST(ds.TotalDeptSalary AS VARCHAR)) AS SalaryStatus,
    ds.NumEmployees AS ProjectsCompleted,
    NULL AS HireYear,
    NULL AS HireMonth,
    NULL AS LastProjectYear,
    SELECT region_id FROM regoins WHERE ROWNUM=1,
"""

    # Generate 599 columns, each randomly a literal or a CASE with subquery
    cols = []
    for i in range(1, 600):
        rnd = random.random()
        if rnd < 0.05:
            col = (
                f"""
                CASE
                    WHEN
                        ds.TotalDeptSalary > 100000 AND 
                        ds.TotalDeptSalary < 200000 AND
                        NVL(a.emp_class, 'N') NOT IN ('A', 'B') -- story 1234
                    THEN
                        NVL(
                            (
                                SELECT emp_id FROM Employees WHERE DepartmentID = ds.DepartmentID AND ROWNUM=1
                            )
                        , 0)
                    ELSE 0
                END as col{i}
                """
                # f"CASE WHEN ds.TotalDeptSalary > 100000 AND ds.TotalDeptSalary < 200000 AND NVL(a.emp_class, 'N') NOT IN ('A', 'B') -- story 1234\n"
                # f"THEN NVL((SELECT COUNT(*) FROM Employees WHERE DepartmentID = ds.DepartmentID),0)\n"
                # f"ELSE 0 END as col{i}"
            )
        elif rnd >= 0.05 and rnd < 0.1:
            
            col = (
                f"""
                CASE
                    WHEN
                        ds.TotalDeptSalary > 100000 AND 
                        ds.TotalDeptSalary < 200000 AND
                        NVL(a.emp_class, 'N') NOT IN ('A', 'B') -- story 1234
                    THEN
                        NVL(
                            (
                                SELECT emp_id FROM Employees WHERE DepartmentID = ds.DepartmentID
                            )
                        , 2)
                    ELSE 0
                END as col{i}
                """
                # f"CASE WHEN ds.TotalDeptSalary > 100000 AND ds.TotalDeptSalary < 200000 AND NVL(a.emp_class, 'N') NOT IN ('A', 'B') -- story 1234\n"
                # f"THEN NVL((SELECT COUNT(*) FROM Employees WHERE DepartmentID = ds.DepartmentID),2)\n"
                # f"ELSE 0 END as col{i}"
            )
        elif rnd >= 0.1 and rnd < 0.2:
            col = (
                f"""
                CASE
                    WHEN
                        ds.TotalDeptSalary > 100000 AND 
                        ds.TotalDeptSalary < 200000 AND
                        NVL(a.emp_class, 'N') NOT IN ('A', 'B') -- story 1234
                    THEN
                        NVL(
                            (
                                SELECT emp_id FROM Employees WHERE DepartmentID = ds.DepartmentID
                            )
                        , 'x')
                    ELSE 0
                END as col{i}
                """
                # f"CASE WHEN ds.TotalDeptSalary > 100000 AND ds.TotalDeptSalary < 200000 AND NVL(a.emp_class, 'N') NOT IN ('A', 'B') -- story 1234\n"
                # f"THEN NVL((SELECT COUNT(*) FROM Employees WHERE DepartmentID = ds.DepartmentID),'x')\n"
                # f"ELSE 0 END as col{i}"
            )
        elif rnd >= 0.2 and rnd < 0.25:
            col = f"myfunc_from({rnd}) as col{i}"
        elif rnd >= 0.25 and rnd < 0.3:
            col = f"myfunc_fromchar({rnd}) as col{i}"
        else:
            col = f"'col{i}' as col{i}"
        cols.append(col)
    cols_str = ",\n".join(cols)

    # Assemble the final query
    final_query = f"""
{base_query}
{cols_str}
FROM
    DeptStats ds
    JOIN Departments d ON ds.DepartmentID = d.DepartmentID;
"""
    return final_query

query_string = generate_complex_query()
print(query_string)


In [0]:
# query_string = """
# select
#     *
# from
#     (select * from (select * from subtable1))  as subquery1
# """

In [0]:
all_subqueries = extract_all_subqueries_to_list(query_string)
df_subqueries = spark.createDataFrame(all_subqueries)

display(df_subqueries.select("id", "full_sub_query"))