# SQL Table Extractor using DSPy

This notebook demonstrates how to use DSPy to extract tables from SQL statements in Databricks query history.
It handles complex SQL including:
- Common Table Expressions (CTEs)
- Subqueries
- Joins
- Nested queries
- Temporary tables


## Install Dependencies


In [None]:
%pip install dspy-ai --quiet

In [None]:
# Restart Python kernel to use newly installed packages
# dbutils.library.restartPython()  # Uncomment if running in Databricks

## Setup and Imports


In [None]:
import dspy
from typing import List
import json


## Configure DSPy with Databricks LM


In [None]:
# Initialize Databricks Foundation Model
# You can choose from available models:
# - databricks-meta-llama-3-3-70b-instruct (recommended for complex tasks)
# - databricks-meta-llama-3-1-70b-instruct
# - databricks-dbrx-instruct
# - databricks-claude-opus-4-5 (most powerful, use for complex queries)

LLM_MODEL_NAME = "databricks-claude-opus-4-5"
lm = dspy.LM(model=f"databricks/{LLM_MODEL_NAME}", cache=False)
dspy.configure(lm=lm)

print(f"‚úì Configured DSPy with model: {LLM_MODEL_NAME}")
print(f"‚úì Model endpoint: databricks/{LLM_MODEL_NAME}")


## Define DSPy Signatures for Table Extraction


In [None]:
class SQLTableExtractor(dspy.Signature):
    """Extract all table names and column names with detailed usage metadata from a SQL statement.
    
    Tables include:
    - Base tables in FROM clauses
    - Tables in JOIN clauses
    - Tables in subqueries
    - CTE (Common Table Expression) source tables
    - Tables in nested queries
    
    Columns are categorized by usage:
    - SELECT columns (output columns, aggregations)
    - JOIN columns (columns in ON/USING clauses)
    - FILTER columns (columns in WHERE/HAVING conditions)
    - GROUP BY columns
    - ORDER BY columns
    
    Important: 
    - Exclude CTE names themselves (they are temporary)
    - Include fully qualified names when present (catalog.schema.table, table.column)
    - Return unique names only per category
    - For columns, include table prefix if present (e.g., "customers.customer_id")
    - A column may appear in multiple categories
    """
    
    sql_statement: str = dspy.InputField(desc="The SQL statement to analyze")
    tables: List[str] = dspy.OutputField(desc="List of unique table names referenced in the SQL, excluding CTE names. Include full qualification (catalog.schema.table) when present.")
    columns: List[str] = dspy.OutputField(desc="List of ALL unique column names referenced anywhere in the SQL. Include table prefix when present (e.g., 'table.column').")
    join_columns: List[str] = dspy.OutputField(desc="List of columns used in JOIN conditions (ON/USING clauses). Include table prefix (e.g., 'table.column').")
    filter_columns: List[str] = dspy.OutputField(desc="List of columns used in WHERE and HAVING filter conditions. Include table prefix (e.g., 'table.column').")
    select_columns: List[str] = dspy.OutputField(desc="List of columns in SELECT clause, including those in aggregate functions. Include table prefix (e.g., 'table.column').")
    groupby_columns: List[str] = dspy.OutputField(desc="List of columns in GROUP BY clause. Include table prefix (e.g., 'table.column').")
    orderby_columns: List[str] = dspy.OutputField(desc="List of columns in ORDER BY clause. Include table prefix (e.g., 'table.column').")
    cte_names: List[str] = dspy.OutputField(desc="List of CTE (Common Table Expression) names defined in the query (these are temporary, not actual tables)")
    explanation: str = dspy.OutputField(desc="Brief explanation of how tables and columns were identified and categorized in this query")


## Create DSPy Module for Table Extraction

In [None]:
class TableExtractorModule(dspy.Module):
    """Module to extract tables from SQL statements using Chain of Thought reasoning."""
    
    def __init__(self):
        super().__init__()
        self.extract = dspy.ChainOfThought(SQLTableExtractor)
    
    def forward(self, sql_statement: str):
        result = self.extract(sql_statement=sql_statement)
        return result


## Query Databricks Query History


In [None]:
%sql
SELECT 
    statement_id,
    statement_text,
    executed_by_user_id,
    start_time,
    end_time,
    statement_type
FROM system.query.history
WHERE 
    statement_type IN ('SELECT', 'INSERT', 'UPDATE', 'DELETE', 'MERGE', 'CREATE_TABLE_AS_SELECT')
    AND statement_text IS NOT NULL
    AND LENGTH(statement_text) > 50
    AND start_time >= CURRENT_DATE - INTERVAL '1' DAYS
ORDER BY start_time DESC
LIMIT 10


## Store Query Results in a DataFrame


In [None]:
# Get the query results from the previous cell
query_history_df = _sqldf  # _sqldf contains the last SQL query result
print(f"Retrieved {query_history_df.count()} queries from history")
display(query_history_df)

In [None]:
# Create the table extractor module
table_extractor = TableExtractorModule()

## Test with Sample Complex SQL Queries


In [21]:
# Test with a complex SQL example
test_sql = """
WITH customer_orders AS (
    SELECT 
        c.customer_id,
        c.customer_name,
        o.order_id,
        o.order_date
    FROM catalog.sales.customers c
    JOIN catalog.sales.orders o ON c.customer_id = o.customer_id
    WHERE o.order_date >= '2024-01-01'
),
order_details AS (
    SELECT 
        od.order_id,
        od.product_id,
        p.product_name,
        od.quantity * od.unit_price as total_amount
    FROM catalog.sales.order_details od
    JOIN catalog.products.product_catalog p ON od.product_id = p.product_id
)
SELECT 
    co.customer_name,
    od.product_name,
    SUM(od.total_amount) as total_spent
FROM customer_orders co
JOIN order_details od ON co.order_id = od.order_id
WHERE od.product_name IN (
    SELECT product_name 
    FROM catalog.products.featured_products
    WHERE featured_date >= '2024-01-01'
)
GROUP BY co.customer_name, od.product_name
ORDER BY total_spent DESC
"""

print("Testing with complex SQL query...\\n")
result = table_extractor(sql_statement=test_sql)

print("="*80)
print("EXTRACTION RESULTS")
print("="*80)
print(f"\\nActual Tables Used ({len(result.tables)}):")
for table in result.tables:
    print(f"  - {table}")

print(f"\\nAll Columns Referenced ({len(result.columns)}):")
for column in result.columns:
    print(f"  - {column}")

print(f"\\n JOIN Columns ({len(result.join_columns)}):")
for column in result.join_columns:
    print(f"  - {column}")

print(f"\\n FILTER Columns (WHERE/HAVING) ({len(result.filter_columns)}):")
for column in result.filter_columns:
    print(f"  - {column}")

print(f"\\n SELECT Columns ({len(result.select_columns)}):")
for column in result.select_columns:
    print(f"  - {column}")

print(f"\\n GROUP BY Columns ({len(result.groupby_columns)}):")
for column in result.groupby_columns:
    print(f"  - {column}")

print(f"\\n ORDER BY Columns ({len(result.orderby_columns)}):")
for column in result.orderby_columns:
    print(f"  - {column}")

print(f"\\nCTE Names (temporary) ({len(result.cte_names)}):")
for cte in result.cte_names:
    print(f"  - {cte}")

print(f"\\nExplanation:")
print(f"  {result.explanation}")
print("="*80)


Testing with complex SQL query...\n
EXTRACTION RESULTS
\nActual Tables Used (5):
  - catalog.sales.customers
  - catalog.sales.orders
  - catalog.sales.order_details
  - catalog.products.product_catalog
  - catalog.products.featured_products
\nAll Columns Referenced (18):
  - c.customer_id
  - c.customer_name
  - o.order_id
  - o.order_date
  - o.customer_id
  - od.order_id
  - od.product_id
  - p.product_name
  - od.quantity
  - od.unit_price
  - p.product_id
  - co.customer_name
  - od.product_name
  - od.total_amount
  - co.order_id
  - product_name
  - featured_date
  - total_spent
\n JOIN Columns (6):
  - c.customer_id
  - o.customer_id
  - od.product_id
  - p.product_id
  - co.order_id
  - od.order_id
\n FILTER Columns (WHERE/HAVING) (4):
  - o.order_date
  - od.product_name
  - product_name
  - featured_date
\n SELECT Columns (14):
  - c.customer_id
  - c.customer_name
  - o.order_id
  - o.order_date
  - od.order_id
  - od.product_id
  - p.product_name
  - od.quantity
  - od.uni

In [None]:
# Process each query from history
results = []

for row in query_history_df.collect():
    statement_id = row['statement_id']
    statement_text = row['statement_text']
    executed_by_user_id = row['executed_by_user_id']
    
    print(f"\\nProcessing query: {statement_id}")
    print(f"User: {executed_by_user_id}")
    print(f"SQL Preview: {statement_text[:100]}...")
    
    try:
        # Extract tables using DSPy
        extraction = table_extractor(sql_statement=statement_text)
        
        result_entry = {
            'statement_id': statement_id,
            'executed_by_user_id': executed_by_user_id,
            'tables': extraction.tables,
            'columns': extraction.columns,
            'join_columns': extraction.join_columns,
            'filter_columns': extraction.filter_columns,
            'select_columns': extraction.select_columns,
            'groupby_columns': extraction.groupby_columns,
            'orderby_columns': extraction.orderby_columns,
            'cte_names': extraction.cte_names,
            'explanation': extraction.explanation,
            'statement_preview': statement_text[:200]
        }
        results.append(result_entry)
        
        print(f"  ‚úì Found {len(extraction.tables)} tables, {len(extraction.columns)} columns")
        print(f"    - {len(extraction.join_columns)} join cols, {len(extraction.filter_columns)} filter cols")
        
    except Exception as e:
        print(f"  ‚úó Error processing query: {str(e)}")
        results.append({
            'statement_id': statement_id,
            'executed_by_user_id': executed_by_user_id,
            'tables': [],
            'columns': [],
            'join_columns': [],
            'filter_columns': [],
            'select_columns': [],
            'groupby_columns': [],
            'orderby_columns': [],
            'cte_names': [],
            'explanation': f"Error: {str(e)}",
            'statement_preview': statement_text[:200]
        })

print(f"\\n\\nProcessed {len(results)} queries successfully!")


In [15]:
display(results)

[{'statement_id': 'a03cf69c-081e-4766-bd9a-7ddb848c766c',
  'executed_by_user_id': '7056949509210167',
  'tables': ['`{cfg.catalog}`.`{cfg.schema}`.columns'],
  'columns': ['column_name', 'data_type', 'asset_urn'],
  'join_columns': [],
  'filter_columns': ['asset_urn'],
  'select_columns': ['column_name', 'data_type'],
  'groupby_columns': [],
  'orderby_columns': [],
  'cte_names': [],
  'explanation': "This is a Spark SQL query embedded in Python code. The query selects from a single table with a parameterized fully qualified name (`{cfg.catalog}`.`{cfg.schema}`.columns). The SELECT clause references two columns: `column_name` (wrapped in the lower() function and aliased as 'n') and `data_type`. The WHERE clause filters on the `asset_urn` column. There are no JOINs, GROUP BY, ORDER BY clauses, or CTEs in this simple query.",
  'statement_preview': 'rows = spark.sql(\n    f"SELECT lower(column_name) AS n, data_type FROM `{cfg.catalog}`.`{cfg.schema}`.columns WHERE asset_urn = \'{urn.

## Sample - Aggregate Table Usage Statistics - Export to Delta for Further Analysis


In [22]:
# Count table and column usage across all queries
from collections import Counter

all_tables = []
all_columns = []
all_join_columns = []
all_filter_columns = []
all_select_columns = []
all_groupby_columns = []
all_orderby_columns = []

for result in results:
    all_tables.extend(result['tables'])
    all_columns.extend(result['columns'])
    all_join_columns.extend(result['join_columns'])
    all_filter_columns.extend(result['filter_columns'])
    all_select_columns.extend(result['select_columns'])
    all_groupby_columns.extend(result['groupby_columns'])
    all_orderby_columns.extend(result['orderby_columns'])

table_usage = Counter(all_tables)
column_usage = Counter(all_columns)
join_column_usage = Counter(all_join_columns)
filter_column_usage = Counter(all_filter_columns)
select_column_usage = Counter(all_select_columns)
groupby_column_usage = Counter(all_groupby_columns)
orderby_column_usage = Counter(all_orderby_columns)

print("\n" + "="*80)
print("TABLE USAGE STATISTICS")
print("="*80)
print(f"\nTotal unique tables referenced: {len(table_usage)}")
print(f"Total table references: {sum(table_usage.values())}")
print(f"\nTop 10 Most Referenced Tables:")
print("-"*80)

for table, count in table_usage.most_common(10):
    print(f"  {count:3d}x  {table}")

print("="*80)

print("\n" + "="*80)
print("COLUMN USAGE STATISTICS - ALL COLUMNS")
print("="*80)
print(f"\nTotal unique columns referenced: {len(column_usage)}")
print(f"\nTotal column references: {sum(column_usage.values())}")
print(f"\nTop 20 Most Referenced Columns:")
print("-"*80)

for column, count in column_usage.most_common(20):
    print(f"  {count:3d}x  {column}")

print("="*80)

print("\n" + "="*80)
print(" JOIN COLUMN USAGE STATISTICS")
print("="*80)
print(f"\nTotal unique join columns: {len(join_column_usage)}")
print(f"\nTotal join column references: {sum(join_column_usage.values())}")
print(f"\nTop 10 Most Used Join Columns:")
print("-"*80)

for column, count in join_column_usage.most_common(10):
    print(f"  {count:3d}x  {column}")

print("="*80)

print("\n" + "="*80)
print(" FILTER COLUMN USAGE STATISTICS (WHERE/HAVING)")
print("="*80)
print(f"\nTotal unique filter columns: {len(filter_column_usage)}")
print(f"\nTotal filter column references: {sum(filter_column_usage.values())}")
print(f"\nTop 10 Most Used Filter Columns:")
print("-"*80)

for column, count in filter_column_usage.most_common(10):
    print(f"  {count:3d}x  {column}")

print("="*80)


TABLE USAGE STATISTICS

Total unique tables referenced: 2
Total table references: 8

Top 10 Most Referenced Tables:
--------------------------------------------------------------------------------
    6x  `{cfg.catalog}`.`{cfg.schema}`.columns
    2x  information_schema.table_tags

COLUMN USAGE STATISTICS - ALL COLUMNS

Total unique columns referenced: 7

Total column references: 26

Top 20 Most Referenced Columns:
--------------------------------------------------------------------------------
    6x  column_name
    6x  data_type
    6x  asset_urn
    2x  catalog_name
    2x  schema_name
    2x  table_name
    2x  tag_name

 JOIN COLUMN USAGE STATISTICS

Total unique join columns: 0

Total join column references: 0

Top 10 Most Used Join Columns:
--------------------------------------------------------------------------------

 FILTER COLUMN USAGE STATISTICS (WHERE/HAVING)

Total unique filter columns: 4

Total filter column references: 12

Top 10 Most Used Filter Columns:
---------

## Save Results to Delta Table (Optional)


In [None]:
# Convert results to Spark DataFrame and save
from pyspark.sql import Row

# Flatten results for DataFrame - separate tables and columns with usage type
flattened_table_results = []
flattened_column_results = []

for result in results:
    # Table references
    for table in result['tables']:
        flattened_table_results.append(Row(
            statement_id=result['statement_id'],
            executed_by_user_id=result['executed_by_user_id'],
            table_name=table,
            explanation=result['explanation']
        ))
    
    # All columns with usage type
    for column in result['select_columns']:
        flattened_column_results.append(Row(
            statement_id=result['statement_id'],
            executed_by_user_id=result['executed_by_user_id'],
            column_name=column,
            usage_type='SELECT',
            explanation=result['explanation']
        ))
    
    for column in result['join_columns']:
        flattened_column_results.append(Row(
            statement_id=result['statement_id'],
            executed_by_user_id=result['executed_by_user_id'],
            column_name=column,
            usage_type='JOIN',
            explanation=result['explanation']
        ))
    
    for column in result['filter_columns']:
        flattened_column_results.append(Row(
            statement_id=result['statement_id'],
            executed_by_user_id=result['executed_by_user_id'],
            column_name=column,
            usage_type='FILTER',
            explanation=result['explanation']
        ))
    
    for column in result['groupby_columns']:
        flattened_column_results.append(Row(
            statement_id=result['statement_id'],
            executed_by_user_id=result['executed_by_user_id'],
            column_name=column,
            usage_type='GROUP_BY',
            explanation=result['explanation']
        ))
    
    for column in result['orderby_columns']:
        flattened_column_results.append(Row(
            statement_id=result['statement_id'],
            executed_by_user_id=result['executed_by_user_id'],
            column_name=column,
            usage_type='ORDER_BY',
            explanation=result['explanation']
        ))

# Create and display table references DataFrame
if flattened_table_results:
    tables_df = spark.createDataFrame(flattened_table_results)
    
    print(f"\\n Created Tables DataFrame with {tables_df.count()} table references")
    display(tables_df)
    
    # Optionally save to a Delta table
    # tables_df.write.mode("overwrite").saveAsTable("your_catalog.your_schema.query_table_analysis")
else:
    print("No tables extracted from queries.")

# Create and display column references DataFrame with usage type
if flattened_column_results:
    columns_df = spark.createDataFrame(flattened_column_results)
    
    print(f"\\n Created Columns DataFrame with {columns_df.count()} column references (categorized by usage)")
    display(columns_df)
    
    # Show breakdown by usage type
    print("\\n Column Usage Type Distribution:")
    columns_df.groupBy("usage_type").count().orderBy("count", ascending=False).show()
    
    # Optionally save to a Delta table
    # columns_df.write.mode("overwrite").saveAsTable("your_catalog.your_schema.query_column_analysis")
else:
    print("No columns extracted from queries.")


HBox(children=(IntProgress(value=0, bar_style='success'), Label(value='')))

\nüìä Created Tables DataFrame with 8 table references


Unnamed: 0,statement_id,executed_by_user_id,table_name,explanation
0,a03cf69c-081e-4766-bd9a-7ddb848c766c,7056949509210167,`{cfg.catalog}`.`{cfg.schema}`.columns,"This is a Spark SQL query embedded in Python code. The query selects from a single table with a parameterized fully qualified name (`{cfg.catalog}`.`{cfg.schema}`.columns). The SELECT clause references two columns: `column_name` (wrapped in the lower() function and aliased as 'n') and `data_type`. The WHERE clause filters on the `asset_urn` column. There are no JOINs, GROUP BY, ORDER BY clauses, or CTEs in this simple query."
1,63de4f2e-0b0b-4de0-90b2-a085c88c95f9,7056949509210167,`{cfg.catalog}`.`{cfg.schema}`.columns,"This is a Spark SQL query embedded in Python code. The query selects from a table named ""columns"" with a parameterized catalog and schema (`{cfg.catalog}`.`{cfg.schema}`.columns). The SELECT clause retrieves `column_name` (wrapped in a lower() function and aliased as 'n') and `data_type`. The WHERE clause filters on `asset_urn` column. There are no JOINs, GROUP BY, ORDER BY clauses, or CTEs in this simple query."
2,deedd4c1-446f-48fa-bbf6-b20d33540c60,7056949509210167,`{cfg.catalog}`.`{cfg.schema}`.columns,"This is a Spark SQL query embedded in Python code. The query selects from a parameterized table `{cfg.catalog}`.`{cfg.schema}`.columns (a fully qualified table reference where catalog and schema are Python variables). The SELECT clause retrieves `column_name` (transformed with lower() function and aliased as 'n') and `data_type`. The WHERE clause filters on `asset_urn` column. There are no JOINs, GROUP BY, ORDER BY clauses, or CTEs in this simple query."
3,39a8d5f6-d593-49cd-8a28-043494d3d907,6102504452920701,information_schema.table_tags,"This is PySpark DataFrame code rather than raw SQL, but it performs SQL-like operations. The code reads from the `information_schema.table_tags` system table (the catalog name is parameterized). Three filter conditions are applied on columns `catalog_name`, `schema_name`, and `table_name` (equivalent to WHERE clauses). The select operation retrieves only `table_name` and `tag_name` columns. The `distinct()` call is equivalent to SELECT DISTINCT. No joins, GROUP BY, or ORDER BY operations are present in this code."
4,d1a797fe-e9ff-43c0-9ec2-8ebfe9ecad1e,7056949509210167,`{cfg.catalog}`.`{cfg.schema}`.columns,"This is a Spark SQL query embedded in Python code. The query selects from a parameterized table `{cfg.catalog}`.`{cfg.schema}`.columns (a fully qualified table reference where catalog and schema are Python variables). The SELECT clause retrieves `column_name` (transformed with lower() function and aliased as 'n') and `data_type`. The WHERE clause filters on `asset_urn` column. There are no JOINs, GROUP BY, ORDER BY clauses, or CTEs in this simple query."
5,96b12828-899b-46a8-9c1a-e29c7dd36cb8,3021682792699784,information_schema.table_tags,"This is PySpark DataFrame code rather than raw SQL, but it performs SQL-like operations. The code reads from the `information_schema.table_tags` system table (the catalog name is parameterized). Three columns are used in filter conditions (equivalent to WHERE clause): `catalog_name`, `schema_name`, and `table_name`. Two columns are selected in the output: `table_name` and `tag_name`. The `.distinct()` operation is equivalent to SELECT DISTINCT. There are no JOINs, GROUP BY, or ORDER BY operations in this code."
6,7e5f0659-9794-43c9-a7a6-4c5f24041637,7056949509210167,`{cfg.catalog}`.`{cfg.schema}`.columns,"This is a Spark SQL query embedded in Python code. The query selects from a parameterized table `{cfg.catalog}`.`{cfg.schema}`.columns (a fully qualified table reference where catalog and schema are Python variables). The SELECT clause retrieves `column_name` (transformed with lower() function and aliased as 'n') and `data_type`. The WHERE clause filters on `asset_urn` column. There are no JOINs, GROUP BY, ORDER BY clauses, or CTEs in this simple query."
7,9ec3f6da-8a37-44ab-877c-345c7920e278,7056949509210167,`{cfg.catalog}`.`{cfg.schema}`.columns,"This is a Spark SQL query embedded in Python code. The query selects from a single table with a parameterized fully qualified name (`{cfg.catalog}`.`{cfg.schema}`.columns). The SELECT clause references two columns: `column_name` (wrapped in the lower() function and aliased as 'n') and `data_type`. The WHERE clause filters on `asset_urn`. There are no JOINs, GROUP BY, ORDER BY clauses, or CTEs in this simple query."


HBox(children=(IntProgress(value=0, bar_style='success'), Label(value='')))

\nüìä Created Columns DataFrame with 28 column references (categorized by usage)


Unnamed: 0,statement_id,executed_by_user_id,column_name,usage_type,explanation
0,a03cf69c-081e-4766-bd9a-7ddb848c766c,7056949509210167,column_name,SELECT,"This is a Spark SQL query embedded in Python code. The query selects from a single table with a parameterized fully qualified name (`{cfg.catalog}`.`{cfg.schema}`.columns). The SELECT clause references two columns: `column_name` (wrapped in the lower() function and aliased as 'n') and `data_type`. The WHERE clause filters on the `asset_urn` column. There are no JOINs, GROUP BY, ORDER BY clauses, or CTEs in this simple query."
1,a03cf69c-081e-4766-bd9a-7ddb848c766c,7056949509210167,data_type,SELECT,"This is a Spark SQL query embedded in Python code. The query selects from a single table with a parameterized fully qualified name (`{cfg.catalog}`.`{cfg.schema}`.columns). The SELECT clause references two columns: `column_name` (wrapped in the lower() function and aliased as 'n') and `data_type`. The WHERE clause filters on the `asset_urn` column. There are no JOINs, GROUP BY, ORDER BY clauses, or CTEs in this simple query."
2,a03cf69c-081e-4766-bd9a-7ddb848c766c,7056949509210167,asset_urn,FILTER,"This is a Spark SQL query embedded in Python code. The query selects from a single table with a parameterized fully qualified name (`{cfg.catalog}`.`{cfg.schema}`.columns). The SELECT clause references two columns: `column_name` (wrapped in the lower() function and aliased as 'n') and `data_type`. The WHERE clause filters on the `asset_urn` column. There are no JOINs, GROUP BY, ORDER BY clauses, or CTEs in this simple query."
3,63de4f2e-0b0b-4de0-90b2-a085c88c95f9,7056949509210167,column_name,SELECT,"This is a Spark SQL query embedded in Python code. The query selects from a table named ""columns"" with a parameterized catalog and schema (`{cfg.catalog}`.`{cfg.schema}`.columns). The SELECT clause retrieves `column_name` (wrapped in a lower() function and aliased as 'n') and `data_type`. The WHERE clause filters on `asset_urn` column. There are no JOINs, GROUP BY, ORDER BY clauses, or CTEs in this simple query."
4,63de4f2e-0b0b-4de0-90b2-a085c88c95f9,7056949509210167,data_type,SELECT,"This is a Spark SQL query embedded in Python code. The query selects from a table named ""columns"" with a parameterized catalog and schema (`{cfg.catalog}`.`{cfg.schema}`.columns). The SELECT clause retrieves `column_name` (wrapped in a lower() function and aliased as 'n') and `data_type`. The WHERE clause filters on `asset_urn` column. There are no JOINs, GROUP BY, ORDER BY clauses, or CTEs in this simple query."
5,63de4f2e-0b0b-4de0-90b2-a085c88c95f9,7056949509210167,asset_urn,FILTER,"This is a Spark SQL query embedded in Python code. The query selects from a table named ""columns"" with a parameterized catalog and schema (`{cfg.catalog}`.`{cfg.schema}`.columns). The SELECT clause retrieves `column_name` (wrapped in a lower() function and aliased as 'n') and `data_type`. The WHERE clause filters on `asset_urn` column. There are no JOINs, GROUP BY, ORDER BY clauses, or CTEs in this simple query."
6,deedd4c1-446f-48fa-bbf6-b20d33540c60,7056949509210167,column_name,SELECT,"This is a Spark SQL query embedded in Python code. The query selects from a parameterized table `{cfg.catalog}`.`{cfg.schema}`.columns (a fully qualified table reference where catalog and schema are Python variables). The SELECT clause retrieves `column_name` (transformed with lower() function and aliased as 'n') and `data_type`. The WHERE clause filters on `asset_urn` column. There are no JOINs, GROUP BY, ORDER BY clauses, or CTEs in this simple query."
7,deedd4c1-446f-48fa-bbf6-b20d33540c60,7056949509210167,data_type,SELECT,"This is a Spark SQL query embedded in Python code. The query selects from a parameterized table `{cfg.catalog}`.`{cfg.schema}`.columns (a fully qualified table reference where catalog and schema are Python variables). The SELECT clause retrieves `column_name` (transformed with lower() function and aliased as 'n') and `data_type`. The WHERE clause filters on `asset_urn` column. There are no JOINs, GROUP BY, ORDER BY clauses, or CTEs in this simple query."
8,deedd4c1-446f-48fa-bbf6-b20d33540c60,7056949509210167,asset_urn,FILTER,"This is a Spark SQL query embedded in Python code. The query selects from a parameterized table `{cfg.catalog}`.`{cfg.schema}`.columns (a fully qualified table reference where catalog and schema are Python variables). The SELECT clause retrieves `column_name` (transformed with lower() function and aliased as 'n') and `data_type`. The WHERE clause filters on `asset_urn` column. There are no JOINs, GROUP BY, ORDER BY clauses, or CTEs in this simple query."
9,39a8d5f6-d593-49cd-8a28-043494d3d907,6102504452920701,table_name,SELECT,"This is PySpark DataFrame code rather than raw SQL, but it performs SQL-like operations. The code reads from the `information_schema.table_tags` system table (the catalog name is parameterized). Three filter conditions are applied on columns `catalog_name`, `schema_name`, and `table_name` (equivalent to WHERE clauses). The select operation retrieves only `table_name` and `tag_name` columns. The `distinct()` call is equivalent to SELECT DISTINCT. No joins, GROUP BY, or ORDER BY operations are present in this code."


\nüìà Column Usage Type Distribution:


HBox(children=(IntProgress(value=0, bar_style='success'), Label(value='')))

+----------+-----+
|usage_type|count|
+----------+-----+
|    SELECT|   16|
|    FILTER|   12|
+----------+-----+



## Advanced: Complex Queries

In [None]:
# For very complex queries, you can use a more powerful model
# Example: Using Claude Opus for complex SQL analysis with recursive CTEs

complex_sql = """
SELECT * FROM (
    WITH RECURSIVE org_hierarchy AS (
        SELECT employee_id, manager_id, 1 as level
        FROM hr.employees
        WHERE manager_id IS NULL
        UNION ALL
        SELECT e.employee_id, e.manager_id, oh.level + 1
        FROM hr.employees e
        JOIN org_hierarchy oh ON e.manager_id = oh.employee_id
    )
    SELECT * FROM org_hierarchy
) emp
JOIN hr.departments d ON emp.employee_id = d.manager_id
"""

print("Using advanced model (Claude Opus) for complex recursive query...\\n")

with dspy.context(lm=lm):
    result = table_extractor(sql_statement=complex_sql)
    
    print("Tables extracted:")
    for table in result.tables:
        print(f"  - {table}")
    
    print(f"\\n Join columns:")
    for column in result.join_columns:
        print(f"  - {column}")
    
    print(f"\\n Filter columns:")
    for column in result.filter_columns:
        print(f"  - {column}")
    
    print(f"\\nExplanation: {result.explanation}")


Using advanced model (Claude Opus) for complex recursive query...\n
Tables extracted:
  - hr.employees
  - hr.departments
\nüîó Join columns:
  - e.manager_id
  - oh.employee_id
  - emp.employee_id
  - d.manager_id
\nüîç Filter columns:
  - manager_id
\nExplanation: This query uses a recursive CTE `org_hierarchy` to traverse an employee hierarchy. The actual tables are `hr.employees` (referenced twice in the CTE for the base and recursive cases) and `hr.departments` (joined at the outer level). The CTE name `org_hierarchy` is excluded from tables as it's a temporary result set. Columns are categorized as follows: SELECT columns include `employee_id`, `manager_id`, and `level` (with a computed value) from both the base and recursive parts of the CTE. JOIN columns include `e.manager_id` and `oh.employee_id` for the recursive join within the CTE, and `emp.employee_id` and `d.manager_id` for the outer join with departments. The FILTER column is `manager_id` used in the WHERE clause (`man

## Export Results to JSON


In [23]:
# Export results to JSON for further processing
output_file = "sql_table_extraction_results.json"

# with open(output_file, 'w') as f:
#     json.dump(results, f, indent=2)

print("\\n" + "="*80)
print(" EXPORT SUMMARY")
print("="*80)
print(f"\\n Results exported to: {output_file}")
print(f"\\n Analysis Summary:")
print(f"  - Queries analyzed: {len(results)}")
print(f"  - Unique tables found: {len(table_usage)}")
print(f"  - Unique columns found: {len(column_usage)}")
print(f"  - Join columns: {len(join_column_usage)}")
print(f"  - Filter columns: {len(filter_column_usage)}")
print(f"  - Select columns: {len(select_column_usage)}")
print(f"  - Group By columns: {len(groupby_column_usage)}")
print(f"  - Order By columns: {len(orderby_column_usage)}")
print("="*80)


 EXPORT SUMMARY
\n Results exported to: sql_table_extraction_results.json
\n Analysis Summary:
  - Queries analyzed: 10
  - Unique tables found: 2
  - Unique columns found: 7
  - Join columns: 0
  - Filter columns: 4
  - Select columns: 4
  - Group By columns: 0
  - Order By columns: 0
