# SQL Table Extractor using DSPy

This notebook demonstrates how to use DSPy to extract tables from SQL statements in Databricks query history.
It handles complex SQL including:
- Common Table Expressions (CTEs)
- Subqueries
- Joins
- Nested queries
- Temporary tables


## Install Dependencies


In [None]:
%pip install dspy-ai --quiet

In [None]:
# Restart Python kernel to use newly installed packages
# dbutils.library.restartPython()  # Uncomment if running in Databricks

## Setup and Imports


In [1]:
import dspy
from typing import List
import json


## Configure DSPy with Databricks LM


In [3]:
# Initialize Databricks Foundation Model
# You can choose from available models:
# - databricks-meta-llama-3-3-70b-instruct (recommended for complex tasks)
# - databricks-meta-llama-3-1-70b-instruct
# - databricks-dbrx-instruct
# - databricks-claude-opus-4-5 (most powerful, use for complex queries)

LLM_MODEL_NAME = "databricks-claude-opus-4-5"
lm = dspy.LM(model=f"databricks/{LLM_MODEL_NAME}", cache=False)
dspy.configure(lm=lm)

print(f"‚úì Configured DSPy with model: {LLM_MODEL_NAME}")
print(f"‚úì Model endpoint: databricks/{LLM_MODEL_NAME}")


‚úì Configured DSPy with model: databricks-claude-opus-4-5
‚úì Model endpoint: databricks/databricks-claude-opus-4-5


## Define DSPy Signatures for Table Extraction


In [4]:
class SQLTableExtractor(dspy.Signature):
    """Extract all table names and column names with detailed usage metadata from a SQL statement.
    
    Tables include:
    - Base tables in FROM clauses
    - Tables in JOIN clauses
    - Tables in subqueries
    - CTE (Common Table Expression) source tables
    - Tables in nested queries
    
    Columns are categorized by usage:
    - SELECT columns (output columns, aggregations)
    - JOIN columns (columns in ON/USING clauses)
    - FILTER columns (columns in WHERE/HAVING conditions)
    - GROUP BY columns
    - ORDER BY columns
    
    Important: 
    - Exclude CTE names themselves (they are temporary)
    - Include fully qualified names when present (catalog.schema.table, table.column)
    - Return unique names only per category
    - For columns, include table prefix if present (e.g., "customers.customer_id")
    - A column may appear in multiple categories
    """
    
    sql_statement: str = dspy.InputField(desc="The SQL statement to analyze")
    tables: List[str] = dspy.OutputField(desc="List of unique table names referenced in the SQL, excluding CTE names. Include full qualification (catalog.schema.table) when present.")
    columns: List[str] = dspy.OutputField(desc="List of ALL unique column names referenced anywhere in the SQL. Include table prefix when present (e.g., 'table.column').")
    join_columns: List[str] = dspy.OutputField(desc="List of columns used in JOIN conditions (ON/USING clauses). Include table prefix (e.g., 'table.column').")
    filter_columns: List[str] = dspy.OutputField(desc="List of columns used in WHERE and HAVING filter conditions. Include table prefix (e.g., 'table.column').")
    select_columns: List[str] = dspy.OutputField(desc="List of columns in SELECT clause, including those in aggregate functions. Include table prefix (e.g., 'table.column').")
    groupby_columns: List[str] = dspy.OutputField(desc="List of columns in GROUP BY clause. Include table prefix (e.g., 'table.column').")
    orderby_columns: List[str] = dspy.OutputField(desc="List of columns in ORDER BY clause. Include table prefix (e.g., 'table.column').")
    cte_names: List[str] = dspy.OutputField(desc="List of CTE (Common Table Expression) names defined in the query (these are temporary, not actual tables)")
    explanation: str = dspy.OutputField(desc="Brief explanation of how tables and columns were identified and categorized in this query")


## Create DSPy Module for Table Extraction

In [5]:
class TableExtractorModule(dspy.Module):
    """Module to extract tables from SQL statements using Chain of Thought reasoning."""
    
    def __init__(self):
        super().__init__()
        self.extract = dspy.ChainOfThought(SQLTableExtractor)
    
    def forward(self, sql_statement: str):
        result = self.extract(sql_statement=sql_statement)
        return result


## Query Databricks Query History


In [9]:
%sql
SELECT 
    statement_id,
    statement_text,
    executed_by_user_id,
    start_time,
    end_time,
    statement_type
FROM system.query.history
WHERE 
    statement_type IN ('SELECT', 'INSERT', 'UPDATE', 'DELETE', 'MERGE', 'CREATE_TABLE_AS_SELECT')
    AND statement_text IS NOT NULL
    AND LENGTH(statement_text) > 50
    AND start_time >= CURRENT_DATE - INTERVAL '1' DAYS
ORDER BY start_time DESC
LIMIT 10


HBox(children=(IntProgress(value=0, bar_style='success'), Label(value='')))

Unnamed: 0,statement_id,statement_text,executed_by_user_id,start_time,end_time,statement_type
0,a03cf69c-081e-4766-bd9a-7ddb848c766c,"rows = spark.sql(\n f""SELECT lower(column_name) AS n, data_type FROM `{cfg.catalog}`.`{cfg.schema}`.columns WHERE asset_urn = '{urn.replace(""'"",""''"")}'""\n).collect()",7056949509210167,2025-12-09 19:52:27.587,2025-12-09 19:52:27.765,SELECT
1,63de4f2e-0b0b-4de0-90b2-a085c88c95f9,"rows = spark.sql(\n f""SELECT lower(column_name) AS n, data_type FROM `{cfg.catalog}`.`{cfg.schema}`.columns WHERE asset_urn = '{urn.replace(""'"",""''"")}'""\n).collect()",7056949509210167,2025-12-09 19:52:27.023,2025-12-09 19:52:27.197,SELECT
2,deedd4c1-446f-48fa-bbf6-b20d33540c60,"rows = spark.sql(\n f""SELECT lower(column_name) AS n, data_type FROM `{cfg.catalog}`.`{cfg.schema}`.columns WHERE asset_urn = '{urn.replace(""'"",""''"")}'""\n).collect()",7056949509210167,2025-12-09 19:52:26.444,2025-12-09 19:52:26.623,SELECT
3,39a8d5f6-d593-49cd-8a28-043494d3d907,"tags_rows = (\n Context.spark.table(f""`{catalog_name}`.information_schema.table_tags"")\n .filter(F.col(""catalog_name"") == catalog_name)\n .filter(F.col(""schema_name"") == schema_name)\n .filter(F.col(""table_name"").isin(table_names))\n .select(""table_name"", ""tag_name"")\n .distinct()\n .collect()\n)",6102504452920701,2025-12-09 19:52:26.085,2025-12-09 19:52:26.439,SELECT
4,d1a797fe-e9ff-43c0-9ec2-8ebfe9ecad1e,"rows = spark.sql(\n f""SELECT lower(column_name) AS n, data_type FROM `{cfg.catalog}`.`{cfg.schema}`.columns WHERE asset_urn = '{urn.replace(""'"",""''"")}'""\n).collect()",7056949509210167,2025-12-09 19:52:25.862,2025-12-09 19:52:26.044,SELECT
5,81ab9b9d-d789-448d-a604-546caae8aff5,existing_cols = {c.name for c in Context.spark.catalog.listColumns(self.logging_table_full_name)},6102504452920701,2025-12-09 19:52:25.819,2025-12-09 19:52:25.944,SELECT
6,96b12828-899b-46a8-9c1a-e29c7dd36cb8,"tags_rows = (\n Context.spark.table(f""`{catalog_name}`.information_schema.table_tags"")\n .filter(F.col(""catalog_name"") == catalog_name)\n .filter(F.col(""schema_name"") == schema_name)\n .filter(F.col(""table_name"").isin(table_names))\n .select(""table_name"", ""tag_name"")\n .distinct()\n .collect()\n)",3021682792699784,2025-12-09 19:52:25.353,2025-12-09 19:52:25.704,SELECT
7,7e5f0659-9794-43c9-a7a6-4c5f24041637,"rows = spark.sql(\n f""SELECT lower(column_name) AS n, data_type FROM `{cfg.catalog}`.`{cfg.schema}`.columns WHERE asset_urn = '{urn.replace(""'"",""''"")}'""\n).collect()",7056949509210167,2025-12-09 19:52:25.281,2025-12-09 19:52:25.454,SELECT
8,4658bb80-db95-425c-83f7-4801b67c9b87,existing_cols = {c.name for c in Context.spark.catalog.listColumns(self.logging_table_full_name)},3021682792699784,2025-12-09 19:52:24.955,2025-12-09 19:52:25.181,SELECT
9,9ec3f6da-8a37-44ab-877c-345c7920e278,"rows = spark.sql(\n f""SELECT lower(column_name) AS n, data_type FROM `{cfg.catalog}`.`{cfg.schema}`.columns WHERE asset_urn = '{urn.replace(""'"",""''"")}'""\n).collect()",7056949509210167,2025-12-09 19:52:24.650,2025-12-09 19:52:24.856,SELECT


## Store Query Results in a DataFrame


In [10]:
# Get the query results from the previous cell
query_history_df = _sqldf  # _sqldf contains the last SQL query result
print(f"Retrieved {query_history_df.count()} queries from history")
display(query_history_df)

HBox(children=(IntProgress(value=0, bar_style='success'), Label(value='')))

Retrieved 10 queries from history


HBox(children=(IntProgress(value=0, bar_style='success'), Label(value='')))

Unnamed: 0,statement_id,statement_text,executed_by_user_id,start_time,end_time,statement_type
0,a03cf69c-081e-4766-bd9a-7ddb848c766c,"rows = spark.sql(\n f""SELECT lower(column_name) AS n, data_type FROM `{cfg.catalog}`.`{cfg.schema}`.columns WHERE asset_urn = '{urn.replace(""'"",""''"")}'""\n).collect()",7056949509210167,2025-12-09 19:52:27.587,2025-12-09 19:52:27.765,SELECT
1,63de4f2e-0b0b-4de0-90b2-a085c88c95f9,"rows = spark.sql(\n f""SELECT lower(column_name) AS n, data_type FROM `{cfg.catalog}`.`{cfg.schema}`.columns WHERE asset_urn = '{urn.replace(""'"",""''"")}'""\n).collect()",7056949509210167,2025-12-09 19:52:27.023,2025-12-09 19:52:27.197,SELECT
2,deedd4c1-446f-48fa-bbf6-b20d33540c60,"rows = spark.sql(\n f""SELECT lower(column_name) AS n, data_type FROM `{cfg.catalog}`.`{cfg.schema}`.columns WHERE asset_urn = '{urn.replace(""'"",""''"")}'""\n).collect()",7056949509210167,2025-12-09 19:52:26.444,2025-12-09 19:52:26.623,SELECT
3,39a8d5f6-d593-49cd-8a28-043494d3d907,"tags_rows = (\n Context.spark.table(f""`{catalog_name}`.information_schema.table_tags"")\n .filter(F.col(""catalog_name"") == catalog_name)\n .filter(F.col(""schema_name"") == schema_name)\n .filter(F.col(""table_name"").isin(table_names))\n .select(""table_name"", ""tag_name"")\n .distinct()\n .collect()\n)",6102504452920701,2025-12-09 19:52:26.085,2025-12-09 19:52:26.439,SELECT
4,d1a797fe-e9ff-43c0-9ec2-8ebfe9ecad1e,"rows = spark.sql(\n f""SELECT lower(column_name) AS n, data_type FROM `{cfg.catalog}`.`{cfg.schema}`.columns WHERE asset_urn = '{urn.replace(""'"",""''"")}'""\n).collect()",7056949509210167,2025-12-09 19:52:25.862,2025-12-09 19:52:26.044,SELECT
5,81ab9b9d-d789-448d-a604-546caae8aff5,existing_cols = {c.name for c in Context.spark.catalog.listColumns(self.logging_table_full_name)},6102504452920701,2025-12-09 19:52:25.819,2025-12-09 19:52:25.944,SELECT
6,96b12828-899b-46a8-9c1a-e29c7dd36cb8,"tags_rows = (\n Context.spark.table(f""`{catalog_name}`.information_schema.table_tags"")\n .filter(F.col(""catalog_name"") == catalog_name)\n .filter(F.col(""schema_name"") == schema_name)\n .filter(F.col(""table_name"").isin(table_names))\n .select(""table_name"", ""tag_name"")\n .distinct()\n .collect()\n)",3021682792699784,2025-12-09 19:52:25.353,2025-12-09 19:52:25.704,SELECT
7,7e5f0659-9794-43c9-a7a6-4c5f24041637,"rows = spark.sql(\n f""SELECT lower(column_name) AS n, data_type FROM `{cfg.catalog}`.`{cfg.schema}`.columns WHERE asset_urn = '{urn.replace(""'"",""''"")}'""\n).collect()",7056949509210167,2025-12-09 19:52:25.281,2025-12-09 19:52:25.454,SELECT
8,4658bb80-db95-425c-83f7-4801b67c9b87,existing_cols = {c.name for c in Context.spark.catalog.listColumns(self.logging_table_full_name)},3021682792699784,2025-12-09 19:52:24.955,2025-12-09 19:52:25.181,SELECT
9,9ec3f6da-8a37-44ab-877c-345c7920e278,"rows = spark.sql(\n f""SELECT lower(column_name) AS n, data_type FROM `{cfg.catalog}`.`{cfg.schema}`.columns WHERE asset_urn = '{urn.replace(""'"",""''"")}'""\n).collect()",7056949509210167,2025-12-09 19:52:24.650,2025-12-09 19:52:24.856,SELECT


In [11]:
# Create the table extractor module
table_extractor = TableExtractorModule()


## Test with Sample Complex SQL Queries


In [None]:
# Test with a complex SQL example
test_sql = """
WITH customer_orders AS (
    SELECT 
        c.customer_id,
        c.customer_name,
        o.order_id,
        o.order_date
    FROM catalog.sales.customers c
    JOIN catalog.sales.orders o ON c.customer_id = o.customer_id
    WHERE o.order_date >= '2024-01-01'
),
order_details AS (
    SELECT 
        od.order_id,
        od.product_id,
        p.product_name,
        od.quantity * od.unit_price as total_amount
    FROM catalog.sales.order_details od
    JOIN catalog.products.product_catalog p ON od.product_id = p.product_id
)
SELECT 
    co.customer_name,
    od.product_name,
    SUM(od.total_amount) as total_spent
FROM customer_orders co
JOIN order_details od ON co.order_id = od.order_id
WHERE od.product_name IN (
    SELECT product_name 
    FROM catalog.products.featured_products
    WHERE featured_date >= '2024-01-01'
)
GROUP BY co.customer_name, od.product_name
ORDER BY total_spent DESC
"""

print("Testing with complex SQL query...\\n")
result = table_extractor(sql_statement=test_sql)

print("="*80)
print("EXTRACTION RESULTS")
print("="*80)
print(f"\\nActual Tables Used ({len(result.tables)}):")
for table in result.tables:
    print(f"  - {table}")

print(f"\\nAll Columns Referenced ({len(result.columns)}):")
for column in result.columns:
    print(f"  - {column}")

print(f"\\nüîó JOIN Columns ({len(result.join_columns)}):")
for column in result.join_columns:
    print(f"  - {column}")

print(f"\\nüîç FILTER Columns (WHERE/HAVING) ({len(result.filter_columns)}):")
for column in result.filter_columns:
    print(f"  - {column}")

print(f"\\nüìä SELECT Columns ({len(result.select_columns)}):")
for column in result.select_columns:
    print(f"  - {column}")

print(f"\\nüì¶ GROUP BY Columns ({len(result.groupby_columns)}):")
for column in result.groupby_columns:
    print(f"  - {column}")

print(f"\\nüìà ORDER BY Columns ({len(result.orderby_columns)}):")
for column in result.orderby_columns:
    print(f"  - {column}")

print(f"\\nCTE Names (temporary) ({len(result.cte_names)}):")
for cte in result.cte_names:
    print(f"  - {cte}")

print(f"\\nExplanation:")
print(f"  {result.explanation}")
print("="*80)




Testing with complex SQL query...\n
EXTRACTION RESULTS
\nActual Tables Used (5):
  - catalog.sales.customers
  - catalog.sales.orders
  - catalog.sales.order_details
  - catalog.products.product_catalog
  - catalog.products.featured_products
\nAll Columns Referenced (18):
  - c.customer_id
  - c.customer_name
  - o.order_id
  - o.order_date
  - o.customer_id
  - od.order_id
  - od.product_id
  - p.product_name
  - od.quantity
  - od.unit_price
  - p.product_id
  - co.customer_name
  - od.product_name
  - od.total_amount
  - co.order_id
  - product_name
  - featured_date
  - total_spent
\nüîó JOIN Columns (6):
  - c.customer_id
  - o.customer_id
  - od.product_id
  - p.product_id
  - co.order_id
  - od.order_id
\nüîç FILTER Columns (WHERE/HAVING) (4):
  - o.order_date
  - od.product_name
  - product_name
  - featured_date
\nüìä SELECT Columns (13):
  - c.customer_id
  - c.customer_name
  - o.order_id
  - o.order_date
  - od.order_id
  - od.product_id
  - p.product_name
  - od.quantit

In [None]:
# Process each query from history
results = []

for row in query_history_df.collect():
    statement_id = row['statement_id']
    statement_text = row['statement_text']
    executed_by_user_id = row['executed_by_user_id']
    
    print(f"\\nProcessing query: {statement_id}")
    print(f"User: {executed_by_user_id}")
    print(f"SQL Preview: {statement_text[:100]}...")
    
    try:
        # Extract tables using DSPy
        extraction = table_extractor(sql_statement=statement_text)
        
        result_entry = {
            'statement_id': statement_id,
            'executed_by_user_id': executed_by_user_id,
            'tables': extraction.tables,
            'columns': extraction.columns,
            'join_columns': extraction.join_columns,
            'filter_columns': extraction.filter_columns,
            'select_columns': extraction.select_columns,
            'groupby_columns': extraction.groupby_columns,
            'orderby_columns': extraction.orderby_columns,
            'cte_names': extraction.cte_names,
            'explanation': extraction.explanation,
            'statement_preview': statement_text[:200]
        }
        results.append(result_entry)
        
        print(f"  ‚úì Found {len(extraction.tables)} tables, {len(extraction.columns)} columns")
        print(f"    - {len(extraction.join_columns)} join cols, {len(extraction.filter_columns)} filter cols")
        
    except Exception as e:
        print(f"  ‚úó Error processing query: {str(e)}")
        results.append({
            'statement_id': statement_id,
            'executed_by_user_id': executed_by_user_id,
            'tables': [],
            'columns': [],
            'join_columns': [],
            'filter_columns': [],
            'select_columns': [],
            'groupby_columns': [],
            'orderby_columns': [],
            'cte_names': [],
            'explanation': f"Error: {str(e)}",
            'statement_preview': statement_text[:200]
        })

print(f"\\n\\nProcessed {len(results)} queries successfully!")


HBox(children=(IntProgress(value=0, bar_style='success'), Label(value='')))



\nProcessing query: a03cf69c-081e-4766-bd9a-7ddb848c766c
User: 7056949509210167
SQL Preview: rows = spark.sql(
    f"SELECT lower(column_name) AS n, data_type FROM `{cfg.catalog}`.`{cfg.schema}...




  ‚úì Found 1 tables, 3 columns
    - 0 join cols, 1 filter cols
\nProcessing query: 63de4f2e-0b0b-4de0-90b2-a085c88c95f9
User: 7056949509210167
SQL Preview: rows = spark.sql(
    f"SELECT lower(column_name) AS n, data_type FROM `{cfg.catalog}`.`{cfg.schema}...




  ‚úì Found 1 tables, 3 columns
    - 0 join cols, 1 filter cols
\nProcessing query: deedd4c1-446f-48fa-bbf6-b20d33540c60
User: 7056949509210167
SQL Preview: rows = spark.sql(
    f"SELECT lower(column_name) AS n, data_type FROM `{cfg.catalog}`.`{cfg.schema}...




  ‚úì Found 1 tables, 3 columns
    - 0 join cols, 1 filter cols
\nProcessing query: 39a8d5f6-d593-49cd-8a28-043494d3d907
User: 6102504452920701
SQL Preview: tags_rows = (
    Context.spark.table(f"`{catalog_name}`.information_schema.table_tags")
    .filter...




  ‚úì Found 1 tables, 4 columns
    - 0 join cols, 3 filter cols
\nProcessing query: d1a797fe-e9ff-43c0-9ec2-8ebfe9ecad1e
User: 7056949509210167
SQL Preview: rows = spark.sql(
    f"SELECT lower(column_name) AS n, data_type FROM `{cfg.catalog}`.`{cfg.schema}...




  ‚úì Found 1 tables, 3 columns
    - 0 join cols, 1 filter cols
\nProcessing query: 81ab9b9d-d789-448d-a604-546caae8aff5
User: 6102504452920701
SQL Preview: existing_cols = {c.name for c in Context.spark.catalog.listColumns(self.logging_table_full_name)}...




  ‚úì Found 0 tables, 0 columns
    - 0 join cols, 0 filter cols
\nProcessing query: 96b12828-899b-46a8-9c1a-e29c7dd36cb8
User: 3021682792699784
SQL Preview: tags_rows = (
    Context.spark.table(f"`{catalog_name}`.information_schema.table_tags")
    .filter...




  ‚úì Found 1 tables, 4 columns
    - 0 join cols, 3 filter cols
\nProcessing query: 7e5f0659-9794-43c9-a7a6-4c5f24041637
User: 7056949509210167
SQL Preview: rows = spark.sql(
    f"SELECT lower(column_name) AS n, data_type FROM `{cfg.catalog}`.`{cfg.schema}...




  ‚úì Found 1 tables, 3 columns
    - 0 join cols, 1 filter cols
\nProcessing query: 4658bb80-db95-425c-83f7-4801b67c9b87
User: 3021682792699784
SQL Preview: existing_cols = {c.name for c in Context.spark.catalog.listColumns(self.logging_table_full_name)}...




  ‚úì Found 0 tables, 0 columns
    - 0 join cols, 0 filter cols
\nProcessing query: 9ec3f6da-8a37-44ab-877c-345c7920e278
User: 7056949509210167
SQL Preview: rows = spark.sql(
    f"SELECT lower(column_name) AS n, data_type FROM `{cfg.catalog}`.`{cfg.schema}...
  ‚úì Found 1 tables, 3 columns
    - 0 join cols, 1 filter cols
\n\nProcessed 10 queries successfully!


In [None]:
display(results)


## Aggregate Table Usage Statistics


In [None]:
# Count table and column usage across all queries
from collections import Counter

all_tables = []
all_columns = []
all_join_columns = []
all_filter_columns = []
all_select_columns = []
all_groupby_columns = []
all_orderby_columns = []

for result in results:
    all_tables.extend(result['tables'])
    all_columns.extend(result['columns'])
    all_join_columns.extend(result['join_columns'])
    all_filter_columns.extend(result['filter_columns'])
    all_select_columns.extend(result['select_columns'])
    all_groupby_columns.extend(result['groupby_columns'])
    all_orderby_columns.extend(result['orderby_columns'])

table_usage = Counter(all_tables)
column_usage = Counter(all_columns)
join_column_usage = Counter(all_join_columns)
filter_column_usage = Counter(all_filter_columns)
select_column_usage = Counter(all_select_columns)
groupby_column_usage = Counter(all_groupby_columns)
orderby_column_usage = Counter(all_orderby_columns)

print("\n" + "="*80)
print("TABLE USAGE STATISTICS")
print("="*80)
print(f"\nTotal unique tables referenced: {len(table_usage)}")
print(f"Total table references: {sum(table_usage.values())}")
print(f"\nTop 10 Most Referenced Tables:")
print("-"*80)

for table, count in table_usage.most_common(10):
    print(f"  {count:3d}x  {table}")

print("="*80)

print("\n" + "="*80)
print("COLUMN USAGE STATISTICS - ALL COLUMNS")
print("="*80)
print(f"\nTotal unique columns referenced: {len(column_usage)}")
print(f"\nTotal column references: {sum(column_usage.values())}")
print(f"\nTop 20 Most Referenced Columns:")
print("-"*80)

for column, count in column_usage.most_common(20):
    print(f"  {count:3d}x  {column}")

print("="*80)

print("\n" + "="*80)
print("üîó JOIN COLUMN USAGE STATISTICS")
print("="*80)
print(f"\nTotal unique join columns: {len(join_column_usage)}")
print(f"\nTotal join column references: {sum(join_column_usage.values())}")
print(f"\nTop 10 Most Used Join Columns:")
print("-"*80)

for column, count in join_column_usage.most_common(10):
    print(f"  {count:3d}x  {column}")

print("="*80)

print("\n" + "="*80)
print("üîç FILTER COLUMN USAGE STATISTICS (WHERE/HAVING)")
print("="*80)
print(f"\nTotal unique filter columns: {len(filter_column_usage)}")
print(f"\nTotal filter column references: {sum(filter_column_usage.values())}")
print(f"\nTop 10 Most Used Filter Columns:")
print("-"*80)

for column, count in filter_column_usage.most_common(10):
    print(f"  {count:3d}x  {column}")

print("="*80)


## Save Results to Delta Table (Optional)


In [None]:
# Convert results to Spark DataFrame and save
from pyspark.sql import Row

# Flatten results for DataFrame - separate tables and columns with usage type
flattened_table_results = []
flattened_column_results = []

for result in results:
    # Table references
    for table in result['tables']:
        flattened_table_results.append(Row(
            statement_id=result['statement_id'],
            executed_by_user_id=result['executed_by_user_id'],
            table_name=table,
            explanation=result['explanation']
        ))
    
    # All columns with usage type
    for column in result['select_columns']:
        flattened_column_results.append(Row(
            statement_id=result['statement_id'],
            executed_by_user_id=result['executed_by_user_id'],
            column_name=column,
            usage_type='SELECT',
            explanation=result['explanation']
        ))
    
    for column in result['join_columns']:
        flattened_column_results.append(Row(
            statement_id=result['statement_id'],
            executed_by_user_id=result['executed_by_user_id'],
            column_name=column,
            usage_type='JOIN',
            explanation=result['explanation']
        ))
    
    for column in result['filter_columns']:
        flattened_column_results.append(Row(
            statement_id=result['statement_id'],
            executed_by_user_id=result['executed_by_user_id'],
            column_name=column,
            usage_type='FILTER',
            explanation=result['explanation']
        ))
    
    for column in result['groupby_columns']:
        flattened_column_results.append(Row(
            statement_id=result['statement_id'],
            executed_by_user_id=result['executed_by_user_id'],
            column_name=column,
            usage_type='GROUP_BY',
            explanation=result['explanation']
        ))
    
    for column in result['orderby_columns']:
        flattened_column_results.append(Row(
            statement_id=result['statement_id'],
            executed_by_user_id=result['executed_by_user_id'],
            column_name=column,
            usage_type='ORDER_BY',
            explanation=result['explanation']
        ))

# Create and display table references DataFrame
if flattened_table_results:
    tables_df = spark.createDataFrame(flattened_table_results)
    
    print(f"\\nüìä Created Tables DataFrame with {tables_df.count()} table references")
    display(tables_df)
    
    # Optionally save to a Delta table
    # tables_df.write.mode("overwrite").saveAsTable("your_catalog.your_schema.query_table_analysis")
else:
    print("No tables extracted from queries.")

# Create and display column references DataFrame with usage type
if flattened_column_results:
    columns_df = spark.createDataFrame(flattened_column_results)
    
    print(f"\\nüìä Created Columns DataFrame with {columns_df.count()} column references (categorized by usage)")
    display(columns_df)
    
    # Show breakdown by usage type
    print("\\nüìà Column Usage Type Distribution:")
    columns_df.groupBy("usage_type").count().orderBy("count", ascending=False).show()
    
    # Optionally save to a Delta table
    # columns_df.write.mode("overwrite").saveAsTable("your_catalog.your_schema.query_column_analysis")
else:
    print("No columns extracted from queries.")


## Advanced: Use Different LM for Complex Queries


In [None]:
# For very complex queries, you can use a more powerful model
# Example: Using Claude Opus for complex SQL analysis with recursive CTEs

complex_sql = """
SELECT * FROM (
    WITH RECURSIVE org_hierarchy AS (
        SELECT employee_id, manager_id, 1 as level
        FROM hr.employees
        WHERE manager_id IS NULL
        UNION ALL
        SELECT e.employee_id, e.manager_id, oh.level + 1
        FROM hr.employees e
        JOIN org_hierarchy oh ON e.manager_id = oh.employee_id
    )
    SELECT * FROM org_hierarchy
) emp
JOIN hr.departments d ON emp.employee_id = d.manager_id
"""

print("Using advanced model (Claude Opus) for complex recursive query...\\n")

# Use a more powerful model for this specific query
advanced_lm = dspy.LM('databricks/databricks-claude-opus-4-1', cache=False)

with dspy.context(lm=advanced_lm):
    result = table_extractor(sql_statement=complex_sql)
    
    print("Tables extracted:")
    for table in result.tables:
        print(f"  - {table}")
    
    print(f"\\nüîó Join columns:")
    for column in result.join_columns:
        print(f"  - {column}")
    
    print(f"\\nüîç Filter columns:")
    for column in result.filter_columns:
        print(f"  - {column}")
    
    print(f"\\nExplanation: {result.explanation}")


## Export Results to JSON


In [None]:
# Export results to JSON for further processing
output_file = "/dbfs/tmp/sql_table_extraction_results.json"

with open(output_file, 'w') as f:
    json.dump(results, f, indent=2)

print("\\n" + "="*80)
print("üì¶ EXPORT SUMMARY")
print("="*80)
print(f"\\n‚úì Results exported to: {output_file}")
print(f"\\nüìä Analysis Summary:")
print(f"  - Queries analyzed: {len(results)}")
print(f"  - Unique tables found: {len(table_usage)}")
print(f"  - Unique columns found: {len(column_usage)}")
print(f"  - Join columns: {len(join_column_usage)}")
print(f"  - Filter columns: {len(filter_column_usage)}")
print(f"  - Select columns: {len(select_column_usage)}")
print(f"  - Group By columns: {len(groupby_column_usage)}")
print(f"  - Order By columns: {len(orderby_column_usage)}")
print("="*80)
