In [ ]:
import asyncio
import json
import os
from typing import Dict, List, Any

# Autogen imports
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_core import CancellationToken
from autogen_agentchat.messages import TextMessage

# Import from project modules
from const import (
    SELECTOR_NAME, DECOMPOSER_NAME, REFINER_NAME,
    selector_template, decompose_template_bird, refiner_template
)
from schema_manager import SchemaManager
from sql_executor import SQLExecutor

# Import utilities from our new module
from sql_utils import (
    parse_json, extract_sql_from_text, format_sql, format_json_result,
    select_schema, generate_sql, refine_sql, process_text_to_sql
)

# Set constants
MAX_REFINEMENT_ATTEMPTS = 3
DEFAULT_TIMEOUT = 120  # seconds
BIRD_DATA_PATH = "../data/bird"
BIRD_TABLES_JSON_PATH = os.path.join(BIRD_DATA_PATH, "dev_tables.json")
DATASET_NAME = "bird"

In [ ]:
# Initialize SchemaManager and SQLExecutor
schema_manager = SchemaManager(
    data_path=BIRD_DATA_PATH,
    tables_json_path=BIRD_TABLES_JSON_PATH,
    dataset_name=DATASET_NAME,
    lazy=False
)

sql_executor = SQLExecutor(
    data_path=BIRD_DATA_PATH,
    dataset_name=DATASET_NAME
)

# Tool implementation for schema selection and management
async def get_initial_database_schema(db_id: str) -> str:
    """Retrieves the full database schema information for a given database."""
    print(f"[Tool] Loading schema for database: {db_id}")
    
    # Load database information using SchemaManager
    if db_id not in schema_manager.db2infos:
        schema_manager.db2infos[db_id] = schema_manager._load_single_db_info(db_id)
    
    # Get database information
    db_info = schema_manager.db2dbjsons.get(db_id, {})
    if not db_info:
        return json.dumps({"error": f"Database '{db_id}' not found"})
    
    # Generate full schema description (without pruning)
    is_complex = schema_manager._is_complex_schema(db_id)
    full_schema_str, full_fk_str, _ = schema_manager.generate_schema_description(
        db_id, {}, use_gold_schema=False
    )
    
    # Return schema details
    return json.dumps({
        "db_id": db_id,
        "table_count": db_info.get('table_count', 0),
        "total_column_count": db_info.get('total_column_count', 0),
        "avg_column_count": db_info.get('avg_column_count', 0),
        "is_complex_schema": is_complex,
        "full_schema_str": full_schema_str,
        "full_fk_str": full_fk_str
    })

async def prune_database_schema(db_id: str, pruning_rules: Dict) -> str:
    """Applies pruning rules to a database schema."""
    print(f"[Tool] Pruning schema for database {db_id}")
    
    # Generate pruned schema description
    schema_str, fk_str, chosen_schema = schema_manager.generate_schema_description(
        db_id, pruning_rules, use_gold_schema=False
    )
    
    # Return pruned schema
    return json.dumps({
        "db_id": db_id,
        "pruning_applied": True,
        "pruning_rules": pruning_rules,
        "pruned_schema_str": schema_str,
        "pruned_fk_str": fk_str,
        "tables_columns_kept": chosen_schema
    })

# Tool implementation for SQL execution
async def execute_sql(sql: str, db_id: str) -> str:
    """Executes a SQL query on the specified database."""
    print(f"[Tool] Executing SQL on database {db_id}: {sql[:100]}...")
    
    # Execute SQL with timeout protection
    result = sql_executor.safe_execute(sql, db_id)
    
    # Add validation information
    is_valid, reason = sql_executor.is_valid_result(result)
    result["is_valid_result"] = is_valid
    result["validation_message"] = reason
    
    # Convert to JSON string
    return json.dumps(result)

In [ ]:
# Initialize model client
model_client = OpenAIChatCompletionClient(model="gpt-4o")

# Create Schema Selector Agent
selector_agent = AssistantAgent(
    name=SELECTOR_NAME,
    model_client=model_client,
    system_message=f"""You are a Database Schema Selector specialized in analyzing database schemas for text-to-SQL tasks.

Your job is to help prune large database schemas to focus on the relevant tables and columns for a given query.

TASK OVERVIEW:
1. You will receive a task with database ID, query, and evidence
2. Use the 'get_initial_database_schema' tool to retrieve the full schema
3. Analyze the schema complexity and relevance to the query
4. For complex schemas, determine which tables and columns are relevant
5. Use the 'prune_database_schema' tool to generate a focused schema
6. Return the processed schema information for the next agent

WHEN ANALYZING THE SCHEMA:
- Study the database table structure and relationships
- Identify tables directly mentioned or implied in the query
- Consider foreign key relationships that might be needed
- Follow the pruning guidelines in the following template:
{selector_template}

FORMAT YOUR FINAL RESPONSE AS JSON:
{{
  "db_id": "<database_id>",
  "query": "<natural_language_query>",
  "evidence": "<any_evidence_provided>",
  "pruning_applied": true/false,
  "schema_str": "<schema_description>",
  "fk_str": "<foreign_key_information>"
}}

Remember that high-quality schema selection improves the accuracy of SQL generation.""",
    tools=[get_initial_database_schema, prune_database_schema],
)

# Create Decomposer Agent
decomposer_agent = AssistantAgent(
    name=DECOMPOSER_NAME,
    model_client=model_client,
    system_message=f"""You are a Query Decomposer specialized in converting natural language questions into SQL for the BIRD dataset.

Your job is to analyze a natural language query and relevant database schema, then generate the appropriate SQL query.

TASK OVERVIEW:
1. You will receive a JSON with db_id, query, evidence, and schema information
2. Study the database schema, focusing on tables, columns, and relationships
3. For complex queries, break down the problem into logical steps
4. Generate a clear, efficient SQL query that answers the question
5. Follow the specific query decomposition template for BIRD:
{decompose_template_bird}

IMPORTANT CONSIDERATIONS:
- BIRD queries often require domain knowledge and multiple steps
- Carefully use the evidence provided to understand domain-specific concepts
- Apply type conversion when comparing numeric data (CAST AS REAL/INT)
- Ensure proper handling of NULL values
- Use table aliases (T1, T2, etc.) for clarity, especially in JOINs
- Always use valid SQLite syntax

FORMAT YOUR FINAL RESPONSE AS JSON:
{{
  "db_id": "<database_id>",
  "query": "<natural_language_query>",
  "evidence": "<any_evidence_provided>",
  "sql": "<generated_sql_query>",
  "decomposition": [
    "step1_description", 
    "step2_description",
    ...
  ]
}}

Your goal is to generate SQL that will execute correctly and return the precise information requested.""",
    tools=[],  # SQL generation is the primary LLM task
)

# Create Refiner Agent
refiner_agent = AssistantAgent(
    name=REFINER_NAME,
    model_client=model_client,
    system_message=f"""You are an SQL Refiner specializing in executing and fixing SQL queries for the BIRD dataset.

Your job is to test SQL queries against the database, identify errors, and refine them until they execute successfully.

TASK OVERVIEW:
1. You will receive a JSON with db_id, query, evidence, schema, and SQL
2. Use the 'execute_sql' tool to run the SQL against the database
3. Analyze execution results or errors
4. For errors, refine the SQL using the template:
{refiner_template}
5. For successful execution, validate the results are appropriate for the original query

IMPORTANT CONSIDERATIONS:
- BIRD databases have specific requirements for valid results:
  - Results should not be empty (unless that's the expected answer)
  - Results should not contain NULL values without justification
  - Results should match the expected types and formats
- Focus on SQLite-specific syntax and behaviors
- Pay special attention to:
  - Table and column name quoting (use backticks)
  - Type conversions (CAST AS)
  - JOIN conditions
  - Subquery structure and aliases

FORMAT YOUR FINAL RESPONSE AS JSON:
{{
  "db_id": "<database_id>",
  "query": "<natural_language_query>",
  "evidence": "<any_evidence_provided>",
  "original_sql": "<original_sql>",
  "final_sql": "<refined_sql>",
  "status": "<EXECUTION_SUCCESSFUL|REFINEMENT_NEEDED|NO_CHANGE_NEEDED>",
  "execution_result": "<execution_result_summary>",
  "refinement_explanation": "<explanation_of_changes>"
}}

Your goal is to ensure the SQL query executes successfully and returns relevant results.""",
    tools=[execute_sql],
)

In [ ]:
# Define test cases for BIRD dataset
bird_test_cases = [
    # Test case 1: Excellence rate calculation (basic join with aggregation)
    {
        "db_id": "california_schools",
        "query": "List school names of charter schools with an SAT excellence rate over the average.",
        "evidence": "Charter schools refers to `Charter School (Y/N)` = 1 in the table frpm; Excellence rate = NumGE1500 / NumTstTakr"
    },
    
    # Test case 2: Multi-table query with numeric conditions (multiple joins)
    {
        "db_id": "game_injury",
        "query": "Show the names of players who have been injured for more than 3 matches in the 2010 season.",
        "evidence": "Season info is in the game table with year 2010; injury severity is measured by the number of matches a player misses."
    },
    
    # Test case 3: Complex aggregation with grouping
    {
        "db_id": "formula_1",
        "query": "What is the name of the driver who has won the most races in rainy conditions?",
        "evidence": "Weather conditions are recorded in the races table; winner information is in the results table."
    },
    
    # Test case 4: Temporal query with date handling
    {
        "db_id": "loan_data",
        "query": "Find the customer with the highest total payment amount for loans taken in the first quarter of 2011.",
        "evidence": "First quarter means January to March (months 1-3); loan dates are stored in ISO format (YYYY-MM-DD)."
    }
]

# Select the test case to run (0-3)
test_idx = 0
current_test = bird_test_cases[test_idx]

print(f"Selected test case {test_idx + 1}:")
print(f"Database: {current_test['db_id']}")
print(f"Query: {current_test['query']}")
print(f"Evidence: {current_test['evidence']}")

In [ ]:
# Test individual components

async def run_text_to_sql_step_by_step():
    """Run the text-to-SQL process step by step for the current test case."""
    try:
        print("\n" + "="*60)
        print("STEP-BY-STEP TEXT-TO-SQL EXECUTION")
        print("="*60)
        
        # Step 1: Schema Selection
        print("\n" + "="*50)
        print("STEP 1: SCHEMA SELECTION")
        print("="*50)
        task, selector_content = await select_schema(
            selector_agent=selector_agent,
            task_json=json.dumps(current_test),
            timeout=DEFAULT_TIMEOUT
        )
        
        print("\nSchema selection result summary:")
        print("-" * 40)
        # Verify if we got schema information
        if "<database_schema>" in selector_content or "schema_str" in selector_content:
            print("✓ Schema information successfully extracted")
        else:
            print("⚠ Schema information may be missing or malformed")
        
        # Step 2: SQL Generation
        print("\n" + "="*50)
        print("STEP 2: SQL GENERATION")
        print("="*50)
        decomposer_content, sql = await generate_sql(
            decomposer_agent=decomposer_agent,
            selector_content=selector_content,
            task=task,
            timeout=DEFAULT_TIMEOUT
        )
        
        print("\nSQL generation result summary:")
        print("-" * 40)
        if sql:
            print(f"✓ SQL query generated ({len(sql)} chars)")
            print("\nSQL Query:")
            print(sql[:300] + "..." if len(sql) > 300 else sql)
        else:
            print("⚠ No SQL query was generated")
        
        # Step 3: SQL Refinement
        print("\n" + "="*50)
        print("STEP 3: SQL REFINEMENT")
        print("="*50)
        result = await refine_sql(
            refiner_agent=refiner_agent,
            decomposer_content=decomposer_content,
            sql=sql,
            task=task,
            max_refinement_attempts=MAX_REFINEMENT_ATTEMPTS,
            timeout=DEFAULT_TIMEOUT
        )
        
        # Display final result
        print("\n" + "="*50)
        print("FINAL RESULT")
        print("="*50)
        
        # Display database and query info
        print(f"Database: {result.get('db_id', '')}")
        print(f"Query: {result.get('query', '')}")
        
        # Try to parse and format the result
        try:
            final_output = parse_json(result.get("final_output", "{}"))
            
            # Display status
            status = final_output.get("status") or result.get("status", "UNKNOWN")
            print(f"\nExecution Status: {status}")
            
            # Display final SQL
            final_sql = final_output.get("final_sql") or result.get("final_sql", "")
            if final_sql:
                print(f"\nFinal SQL Query:")
                print(final_sql)
            else:
                print("\n⚠ No final SQL query available")
                
            # Display execution result if available
            if "execution_result" in final_output:
                print("\nExecution Result:")
                exec_result = final_output["execution_result"]
                if isinstance(exec_result, dict):
                    for key, value in exec_result.items():
                        print(f"  {key}: {value}")
                else:
                    print(exec_result)
                    
        except Exception as e:
            print(f"\nError parsing final result: {str(e)}")
            if "final_sql" in result:
                print(f"\nFinal SQL:\n{result['final_sql']}")
                
        return result
            
    except Exception as e:
        print(f"\nERROR IN EXECUTION: {str(e)}")
        import traceback
        print(traceback.format_exc())
        return {"error": str(e)}

In [ ]:
# Run the complete text-to-SQL process for the current test case
async def run_complete_text_to_sql():
    """Run the complete text-to-SQL process in a single function call."""
    try:
        print("\n" + "="*60)
        print("COMPLETE TEXT-TO-SQL EXECUTION")
        print("="*60)
        
        result = await process_text_to_sql(
            selector_agent=selector_agent,
            decomposer_agent=decomposer_agent,
            refiner_agent=refiner_agent,
            task_json=json.dumps(current_test),
            max_refinement_attempts=MAX_REFINEMENT_ATTEMPTS,
            timeout=DEFAULT_TIMEOUT
        )
        
        # Display final result
        print("\n" + "="*50)
        print("RESULT SUMMARY")
        print("="*50)
        
        if "error" in result:
            print(f"Error: {result['error']}")
            return result
            
        # Display database and query info
        print(f"Database: {result.get('db_id', '')}")
        print(f"Query: {result.get('query', '')}")
        
        # Display final SQL and status
        final_sql = result.get("final_sql", "")
        status = result.get("status", "UNKNOWN")
        
        print(f"\nExecution Status: {status}")
        
        if final_sql:
            print(f"\nFinal SQL Query:")
            print(final_sql)
        else:
            print("\n⚠ No final SQL query available")
            
        return result
        
    except Exception as e:
        print(f"\nERROR IN EXECUTION: {str(e)}")
        import traceback
        print(traceback.format_exc())
        return {"error": str(e)}

In [ ]:
# Function to run tests on multiple database queries
async def run_all_tests():
    """Run text-to-SQL process on all test cases and collect results."""
    results = []
    for i, test_case in enumerate(bird_test_cases):
        print(f"\n\n{'='*80}")
        print(f"Test {i+1}: {test_case['db_id']} - {test_case['query']}")
        print(f"{'='*80}\n")
        
        # Run the test
        try:
            result = await process_text_to_sql(
                selector_agent=selector_agent,
                decomposer_agent=decomposer_agent,
                refiner_agent=refiner_agent,
                task_json=json.dumps(test_case),
                max_refinement_attempts=MAX_REFINEMENT_ATTEMPTS,
                timeout=DEFAULT_TIMEOUT
            )
            results.append(result)
            
            # Display result summary
            if "error" in result:
                print(f"\nTest {i+1} completed with error: {result['error']}")
            else:
                try:
                    status = result.get("status", "UNKNOWN")
                    print(f"\nTest {i+1} completed with status: {status}")
                    
                    if "final_sql" in result:
                        print(f"\nFinal SQL Query:")
                        print(result["final_sql"])
                except Exception:
                    print(f"\nTest {i+1} completed but couldn't parse final output")
            
            print(f"\n{'-'*40}")
            print(f"Test {i+1} completed")
            print(f"{'-'*40}\n")
        except Exception as e:
            print(f"Error in test {i+1}: {str(e)}")
            results.append({"error": str(e)})
    
    print("\nAll tests completed!")
    return results

In [ ]:
# Execute the step-by-step version
await run_text_to_sql_step_by_step()

In [ ]:
# Uncomment to execute the combined version (more concise output)
# await run_complete_text_to_sql()

In [ ]:
# Uncomment to run all test cases
# results = await run_all_tests()

In [11]:
# Execute each step individually with proper error handling and timeout management
try:
    print("\n" + "="*60)
    print("TEXT-TO-SQL PIPELINE EXECUTION")
    print("="*60)
    
    # Step 1: Schema Selection
    print("\n" + "="*50)
    print("STEP 1: SCHEMA SELECTION")
    print("="*50)
    task, selector_content = await select_schema(json.dumps(current_test))
    
    print("\nSchema selection result summary:")
    print("-" * 40)
    # Verify if we got schema information
    if "<database_schema>" in selector_content or "schema_str" in selector_content:
        print("✓ Schema information successfully extracted")
    else:
        print("⚠ Schema information may be missing or malformed")
    # Print a preview of the content
    print("\nPreview:")
    preview = selector_content[:200] + "..." if len(selector_content) > 200 else selector_content
    print(preview)

    # Step 2: SQL Generation
    print("\n" + "="*50)
    print("STEP 2: SQL GENERATION")
    print("="*50)
    decomposer_content, sql = await generate_sql(selector_content, task)
    
    print("\nSQL generation result summary:")
    print("-" * 40)
    if sql:
        print(f"✓ SQL query generated ({len(sql)} chars)")
        print("\nSQL Query:")
        print(sql[:300] + "..." if len(sql) > 300 else sql)
    else:
        print("⚠ No SQL query was generated")
        print("\nDecomposer output preview:")
        print(decomposer_content[:200] + "..." if len(decomposer_content) > 200 else decomposer_content)

    # Step 3: SQL Refinement
    print("\n" + "="*50)
    print("STEP 3: SQL REFINEMENT")
    print("="*50)
    if not sql:
        print("\n⚠ Skipping refinement because no SQL was generated")
        result = {
            "db_id": task.get('db_id', ''),
            "query": task.get('query', ''),
            "final_output": decomposer_content,
            "error": "No SQL generated",
            "status": "ERROR_NO_SQL"
        }
    else:
        result = await refine_sql(decomposer_content, sql, task)
    
    # Display final result
    print("\n" + "="*50)
    print("FINAL RESULT")
    print("="*50)
    
    # Display database and query info
    print(f"Database: {result.get('db_id', '')}")
    print(f"Query: {result.get('query', '')}")
    
    # Try to parse and format the result
    try:
        final_output = parse_json(result.get("final_output", "{}"))
        
        # Display status
        status = final_output.get("status") or result.get("status", "UNKNOWN")
        print(f"\nExecution Status: {status}")
        
        # Display final SQL
        final_sql = final_output.get("final_sql") or result.get("final_sql", "")
        if final_sql:
            print(f"\nFinal SQL Query:")
            print(final_sql)
        else:
            print("\n⚠ No final SQL query available")
            
        # Display execution result if available
        if "execution_result" in final_output:
            print("\nExecution Result:")
            exec_result = final_output["execution_result"]
            if isinstance(exec_result, dict):
                for key, value in exec_result.items():
                    print(f"  {key}: {value}")
            else:
                print(exec_result)
                
        # Display any errors
        if "error" in final_output or "error" in result:
            error = final_output.get("error") or result.get("error", "")
            print(f"\nError: {error}")
            
    except Exception as e:
        print(f"\nError parsing final result: {str(e)}")
        if "final_sql" in result:
            print(f"\nFinal SQL:\n{result['final_sql']}")
        print("\nRaw output preview:")
        raw_output = result.get("final_output", "")
        print(raw_output[:500] + "..." if len(raw_output) > 500 else raw_output)
        
except Exception as e:
    print(f"\nCRITICAL ERROR IN EXECUTION: {str(e)}")
    import traceback
    print(traceback.format_exc())


TEXT-TO-SQL PIPELINE EXECUTION

STEP 1: SCHEMA SELECTION
[Step 1] Starting schema selection for database 'california_schools'
[Step 1] Query: List school names of charter schools with an SAT excellence rate over the average.
[Step 1] Evidence: Charter schools refers to `Charter School (Y/N)` = 1 in the table frpm; Excellence rate = NumGE1500 / NumTstTakr
[Step 1] Requesting schema selection (timeout: 120s)
[Tool] Loading schema for database: california_schools
[Step 1] Schema selected successfully (took 1.5s)

Schema selection result summary:
----------------------------------------
✓ Schema information successfully extracted

Preview:
{"db_id": "california_schools", "table_count": 3, "total_column_count": 89, "avg_column_count": 29, "is_complex_schema": true, "full_schema_str": "<database_schema>\n  <table name=\"frpm\">\n    <colu...

STEP 2: SQL GENERATION

[Step 2] Starting SQL generation
[Step 2] Query: List school names of charter schools with an SAT excellence rate over the ave