In [0]:
%run ../parse_sql

In [0]:
def generate_large_test_query_string() -> str:
    lines = []
    lines.append("WITH")

    # Generate 10 CTEs
    for i in range(1, 11):
        lines.append(f"    cte{i} AS (")
        lines.extend([
            f"        SELECT",
            f"            user_id,",
            f"            SUM(metric_{i}) AS metric_sum_{i}",
            f"        FROM",
            f"            source_table_{i}",
            f"        WHERE",
            f"            user_id IN (",
            f"                SELECT user_id FROM user_filter_{i} WHERE active = TRUE",
            f"            )",
            f"        GROUP BY user_id",
            f"    ),"
        ])

    lines[-1] = lines[-1].rstrip(',')  # Remove trailing comma from last CTE

    # Main SELECT with subqueries and joins
    lines.extend([
        "SELECT",
        "    u.user_id,",
        "    cte1.metric_sum_1,",
        "    cte2.metric_sum_2,",
        "    extra_data.extra_val,",
        "    (",
        "        SELECT MAX(score)",
        "        FROM user_scores us",
        "        WHERE us.user_id = u.user_id",
        "    ) AS max_score,",
        "    CASE",
        "        WHEN EXISTS (",
        "            SELECT 1 FROM audit_log al WHERE al.user_id = u.user_id AND al.status = 'flagged'",
        "        ) THEN 'FLAGGED'",
        "        ELSE 'OK'",
        "    END AS audit_status",
        "FROM users u",
        "LEFT JOIN (",
        "    SELECT user_id, COUNT(*) AS extra_val",
        "    FROM extra_events",
        "    GROUP BY user_id",
        ") AS extra_data ON extra_data.user_id = u.user_id",
        "JOIN cte1 ON cte1.user_id = u.user_id",
        "JOIN cte2 ON cte2.user_id = u.user_id",
        "WHERE",
        "    u.region IN (",
        "        SELECT region FROM allowed_regions WHERE region_type = 'premium'",
        "    )",
        "    AND EXISTS (",
        "        SELECT 1 FROM login_events le WHERE le.user_id = u.user_id AND le.success = TRUE",
        "    )"
    ])

    # Pad the rest with filler lines to reach 500
    # while len(lines) < 499:
    #    lines.append(f"-- filler line {len(lines)+1}")

    lines.append("ORDER BY u.user_id;")

    return '\n'.join(lines)


In [0]:
query_string = generate_large_test_query_string()

print(query_string[:100])
print("\n . . . \n")
print(query_string[:-100]) 
print("\n---------------------------------------------\n")
print_full_queries(query_string) # Just preview the first 1000 characters

In [0]:
query_string = generate_large_test_query_string()
converted_df = convert_and_get_dataframe(query_string, endpoint_name="databricks-claude-3-7-sonnet")
display(converted_df.select("name", "original", "converted"))

In [0]:
display(assemble_final_query(converted_df.select("name", "original", "converted")))

In [0]:
prettified_value = prettify_final(assemble_final_query(converted_df.select("name", "original", "converted")))
print(prettified_value)