In [1]:
import asyncio
import json
import os
import re
import time
from typing import Sequence, Dict, Any, Tuple, List, Optional

# Autogen imports
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_core import CancellationToken
from autogen_agentchat.messages import TextMessage

# Import from project modules
from const import (
    SELECTOR_NAME, DECOMPOSER_NAME, REFINER_NAME, SYSTEM_NAME,
    selector_template, decompose_template_bird, refiner_template
)
from schema_manager import SchemaManager
from sql_executor import SQLExecutor

# Set timeout constants
DEFAULT_TIMEOUT = 120  # seconds
# --- Constants ---
MAX_REFINEMENT_ATTEMPTS = 3  # Maximum number of refinement attempts for SQL
BIRD_DATA_PATH = "../data/bird"
BIRD_TABLES_JSON_PATH = os.path.join(BIRD_DATA_PATH, "dev_tables.json")
DATASET_NAME = "bird"

In [2]:
schema_manager = SchemaManager(
    data_path=BIRD_DATA_PATH,
    tables_json_path=BIRD_TABLES_JSON_PATH,
    dataset_name=DATASET_NAME,
    lazy=False  # Use lazy loading for performance
)

# Replace the original SQLExecutor with our patched version
sql_executor = SQLExecutor(
    data_path=BIRD_DATA_PATH,
    dataset_name=DATASET_NAME
)

# Tool implementation for schema selection and management
async def get_initial_database_schema(db_id: str) -> str:
    """
    Retrieves the full database schema information for a given database.
    
    Args:
        db_id: The database identifier
        
    Returns:
        JSON string with full schema information
    """
    print(f"[Tool] Loading schema for database: {db_id}")
    
    # Load database information using SchemaManager
    if db_id not in schema_manager.db2infos:
        schema_manager.db2infos[db_id] = schema_manager._load_single_db_info(db_id)
    
    # Get database information
    db_info = schema_manager.db2dbjsons.get(db_id, {})
    if not db_info:
        return json.dumps({"error": f"Database '{db_id}' not found"})
    
    # Determine if schema is complex enough to need pruning
    is_complex = schema_manager._is_complex_schema(db_id)
    
    # Generate full schema description (without pruning)
    full_schema_str, full_fk_str, _ = schema_manager.generate_schema_description(
        db_id, {}, use_gold_schema=False
    )
    
    # Return schema details
    return json.dumps({
        "db_id": db_id,
        "table_count": db_info.get('table_count', 0),
        "total_column_count": db_info.get('total_column_count', 0),
        "avg_column_count": db_info.get('avg_column_count', 0),
        "is_complex_schema": is_complex,
        "full_schema_str": full_schema_str,
        "full_fk_str": full_fk_str
    })

async def prune_database_schema(db_id: str, pruning_rules: Dict) -> str:
    """
    Applies pruning rules to a database schema.
    
    Args:
        db_id: The database identifier
        pruning_rules: Dictionary with tables and columns to keep
        
    Returns:
        JSON string with pruned schema
    """
    print(f"[Tool] Pruning schema for database {db_id}")
    
    # Generate pruned schema description
    schema_str, fk_str, chosen_schema = schema_manager.generate_schema_description(
        db_id, pruning_rules, use_gold_schema=False
    )
    
    # Return pruned schema
    return json.dumps({
        "db_id": db_id,
        "pruning_applied": True,
        "pruning_rules": pruning_rules,
        "pruned_schema_str": schema_str,
        "pruned_fk_str": fk_str,
        "tables_columns_kept": chosen_schema
    })

# Tool implementation for SQL execution
async def execute_sql(sql: str, db_id: str) -> str:
    """
    Executes a SQL query on the specified database.
    
    Args:
        sql: The SQL query to execute
        db_id: The database identifier
        
    Returns:
        JSON string with execution results
    """
    print(f"[Tool] Executing SQL on database {db_id}: {sql[:100]}...")
    
    # Execute SQL with timeout protection
    result = sql_executor.safe_execute(sql, db_id)
    
    # Add validation information
    is_valid, reason = sql_executor.is_valid_result(result)
    result["is_valid_result"] = is_valid
    result["validation_message"] = reason
    
    # Convert to JSON string
    return json.dumps(result)

load json file from ../data/bird/dev_tables.json

Loading all database info...
Found 11 databases in bird dataset


In [3]:

# --- Helper Functions ---
def parse_json(text: str) -> Dict:
    """
    Attempt to parse JSON from text, returning an empty dict if parsing fails.
    
    Args:
        text: Text that might contain JSON
        
    Returns:
        Dictionary of parsed JSON or empty dict if parsing fails
    """
    try:
        if not text or not isinstance(text, str):
            return {}
            
        # Try direct JSON loading first
        try:
            if text.strip().startswith('{') and text.strip().endswith('}'):
                return json.loads(text)
        except json.JSONDecodeError:
            pass
            
        # Find JSON-like patterns with regex
        json_pattern = r'{.*}'
        match = re.search(json_pattern, text, re.DOTALL)
        if match:
            json_str = match.group()
            return json.loads(json_str)
            
        # Try finding JSON in code blocks
        code_block_pattern = r'```(?:json)?\s*(.*?)\s*```'
        blocks = re.findall(code_block_pattern, text, re.DOTALL)
        for block in blocks:
            if block.strip().startswith('{') and block.strip().endswith('}'):
                try:
                    return json.loads(block)
                except:
                    continue
                    
        return {}
    except Exception as e:
        print(f"Error parsing JSON: {str(e)}")
        return {}

In [4]:
# Function to extract SQL from text using regex and JSON parsing
def extract_sql_from_text(text: str) -> str:
    """
    Extract SQL query from text.
    
    Args:
        text: Text that might contain SQL
        
    Returns:
        Extracted SQL query or empty string if no SQL found
    """
    try:
        # Try to extract SQL from JSON
        data = parse_json(text)
        if 'sql' in data:
            return data['sql']
        if 'final_sql' in data:
            return data['final_sql']
            
        # Try to extract SQL with regex patterns
        sql_patterns = [
            r'```sql\s*(.*?)\s*```',  # SQL in code blocks
            r'```\s*SELECT.*?```',    # SELECT in generic code blocks
            r'SELECT.*?(?:;|$)',      # Simple SELECT statements
            r'WITH.*?(?:;|$)',        # WITH queries
        ]
        
        for pattern in sql_patterns:
            matches = re.findall(pattern, text, re.DOTALL | re.IGNORECASE)
            if matches:
                # Clean up the matched SQL
                sql = matches[0].strip()
                # Remove any trailing backticks or spaces
                if sql.endswith('```'):
                    sql = sql[:sql.rfind('```')].strip()
                return sql
        
        # If no clear SQL pattern, look for any content between backticks
        code_block_pattern = r'```(.*?)```'
        code_blocks = re.findall(code_block_pattern, text, re.DOTALL)
        for block in code_blocks:
            if 'SELECT' in block.upper() or 'WITH' in block.upper():
                return block.strip()
                
        return ""
    except Exception as e:
        print(f"Error extracting SQL: {str(e)}")
        return ""

In [5]:
# Initialize model client
model_client = OpenAIChatCompletionClient(model="gpt-4o")

# Create Schema Selector Agent
selector_agent = AssistantAgent(
    name=SELECTOR_NAME,
    model_client=model_client,
    system_message=f"""You are a Database Schema Selector specialized in analyzing database schemas for text-to-SQL tasks.

Your job is to help prune large database schemas to focus on the relevant tables and columns for a given query.

TASK OVERVIEW:
1. You will receive a task with database ID, query, and evidence
2. Use the 'get_initial_database_schema' tool to retrieve the full schema
3. Analyze the schema complexity and relevance to the query
4. For complex schemas, determine which tables and columns are relevant
5. Use the 'prune_database_schema' tool to generate a focused schema
6. Return the processed schema information for the next agent

WHEN ANALYZING THE SCHEMA:
- Study the database table structure and relationships
- Identify tables directly mentioned or implied in the query
- Consider foreign key relationships that might be needed
- Follow the pruning guidelines in the following template:
{selector_template}

FORMAT YOUR FINAL RESPONSE AS JSON:
{{
  "db_id": "<database_id>",
  "query": "<natural_language_query>",
  "evidence": "<any_evidence_provided>",
  "pruning_applied": true/false,
  "schema_str": "<schema_description>",
  "fk_str": "<foreign_key_information>"
}}

Remember that high-quality schema selection improves the accuracy of SQL generation.""",
    tools=[get_initial_database_schema, prune_database_schema],
)

# Create Decomposer Agent
decomposer_agent = AssistantAgent(
    name=DECOMPOSER_NAME,
    model_client=model_client,
    system_message=f"""You are a Query Decomposer specialized in converting natural language questions into SQL for the BIRD dataset.

Your job is to analyze a natural language query and relevant database schema, then generate the appropriate SQL query.

TASK OVERVIEW:
1. You will receive a JSON with db_id, query, evidence, and schema information
2. Study the database schema, focusing on tables, columns, and relationships
3. For complex queries, break down the problem into logical steps
4. Generate a clear, efficient SQL query that answers the question
5. Follow the specific query decomposition template for BIRD:
{decompose_template_bird}

IMPORTANT CONSIDERATIONS:
- BIRD queries often require domain knowledge and multiple steps
- Carefully use the evidence provided to understand domain-specific concepts
- Apply type conversion when comparing numeric data (CAST AS REAL/INT)
- Ensure proper handling of NULL values
- Use table aliases (T1, T2, etc.) for clarity, especially in JOINs
- Always use valid SQLite syntax

FORMAT YOUR FINAL RESPONSE AS JSON:
{{
  "db_id": "<database_id>",
  "query": "<natural_language_query>",
  "evidence": "<any_evidence_provided>",
  "sql": "<generated_sql_query>",
  "decomposition": [
    "step1_description", 
    "step2_description",
    ...
  ]
}}

Your goal is to generate SQL that will execute correctly and return the precise information requested.""",
    tools=[],  # SQL generation is the primary LLM task
)

# Create Refiner Agent
refiner_agent = AssistantAgent(
    name=REFINER_NAME,
    model_client=model_client,
    system_message=f"""You are an SQL Refiner specializing in executing and fixing SQL queries for the BIRD dataset.

Your job is to test SQL queries against the database, identify errors, and refine them until they execute successfully.

TASK OVERVIEW:
1. You will receive a JSON with db_id, query, evidence, schema, and SQL
2. Use the 'execute_sql' tool to run the SQL against the database
3. Analyze execution results or errors
4. For errors, refine the SQL using the template:
{refiner_template}
5. For successful execution, validate the results are appropriate for the original query

IMPORTANT CONSIDERATIONS:
- BIRD databases have specific requirements for valid results:
  - Results should not be empty (unless that's the expected answer)
  - Results should not contain NULL values without justification
  - Results should match the expected types and formats
- Focus on SQLite-specific syntax and behaviors
- Pay special attention to:
  - Table and column name quoting (use backticks)
  - Type conversions (CAST AS)
  - JOIN conditions
  - Subquery structure and aliases

FORMAT YOUR FINAL RESPONSE AS JSON:
{{
  "db_id": "<database_id>",
  "query": "<natural_language_query>",
  "evidence": "<any_evidence_provided>",
  "original_sql": "<original_sql>",
  "final_sql": "<refined_sql>",
  "status": "<EXECUTION_SUCCESSFUL|REFINEMENT_NEEDED|NO_CHANGE_NEEDED>",
  "execution_result": "<execution_result_summary>",
  "refinement_explanation": "<explanation_of_changes>"
}}

Your goal is to ensure the SQL query executes successfully and returns relevant results.""",
    tools=[execute_sql],
)

In [6]:
# Step 1: Schema Selection
async def select_schema(task_json: str, timeout: int = DEFAULT_TIMEOUT):
    """
    First workflow step: Schema selection with selector_agent
    
    Args:
        task_json: JSON string containing the task information
        timeout: Maximum time to wait for response in seconds
        
    Returns:
        Tuple of (task dict, selector response content)
    """
    # Create cancellation token with timeout
    cancellation_token = CancellationToken()
    
    # Parse task
    try:
        task = json.loads(task_json)
    except json.JSONDecodeError as e:
        print(f"[Step 1] Error: Invalid JSON input: {str(e)}")
        raise ValueError(f"Invalid task JSON: {str(e)}")
        
    db_id = task.get('db_id', '')
    query = task.get('query', '')
    evidence = task.get('evidence', '')
    
    if not db_id or not query:
        raise ValueError("Task must contain db_id and query")
        
    print(f"[Step 1] Starting schema selection for database '{db_id}'")
    print(f"[Step 1] Query: {query}")
    print(f"[Step 1] Evidence: {evidence}")
    
    # Ensure task has the expected format for selector_agent
    task_content = task_json
    # If task_json isn't properly formatted, create a well-structured message
    if not '"db_id"' in task_json or not task_json.strip().startswith('{'):
        task_content = json.dumps({
            "db_id": db_id,
            "query": query,
            "evidence": evidence
        })
    
    # Create proper message object
    user_message = TextMessage(content=task_content, source="user")
    
    # Execute schema selection with timeout
    start_time = time.time()
    try:
        print(f"[Step 1] Requesting schema selection (timeout: {timeout}s)")
        selector_response = await selector_agent.on_messages([user_message], cancellation_token)
        selector_content = selector_response.chat_message.content.strip()
        
        # Verify selector output contains schema information
        if "<database_schema>" not in selector_content and "schema_str" not in selector_content:
            print(f"[Step 1] Warning: Selector response may not contain valid schema information")
            
        elapsed = time.time() - start_time
        print(f"[Step 1] Schema selected successfully (took {elapsed:.1f}s)")
        return task, selector_content
    except asyncio.TimeoutError:
        print(f"[Step 1] Error: Schema selection timed out after {timeout}s")
        raise
    except Exception as e:
        print(f"[Step 1] Error in schema selection: {str(e)}")
        raise

In [7]:
# Step 2: SQL Generation
async def generate_sql(selector_content: str, task: dict, timeout: int = DEFAULT_TIMEOUT):
    """
    Second workflow step: SQL generation with decomposer_agent
    
    Args:
        selector_content: Output from the schema selection step
        task: Task dictionary containing query information
        timeout: Maximum time to wait for response in seconds
        
    Returns:
        Tuple of (decomposer content, extracted SQL)
    """
    # Create cancellation token with timeout
    cancellation_token = CancellationToken()
    
    print(f"\n[Step 2] Starting SQL generation")
    print(f"[Step 2] Query: {task.get('query', '')}")
    
    # Extract schema from selector content if possible
    schema_str = ""
    try:
        # Try to extract schema from JSON format
        data = parse_json(selector_content)
        if 'schema_str' in data:
            schema_str = data['schema_str']
            print(f"[Step 2] Found schema information in JSON format")
    except Exception:
        # If JSON parsing fails, try to extract schema directly
        schema_match = re.search(r'<database_schema>.*?</database_schema>', selector_content, re.DOTALL)
        if schema_match:
            schema_str = schema_match.group()
            print(f"[Step 2] Found schema information in XML format")
    
    if not schema_str:
        print(f"[Step 2] Warning: Could not extract schema from selector output")
    
    # Ensure selector_content is properly formatted for the decomposer agent
    if not schema_str and "<database_schema>" not in selector_content:
        print(f"[Step 2] Warning: Reformatting selector content to ensure schema is included")
        # Try to reformat it as a proper JSON
        try:
            data = parse_json(selector_content)
            if not 'schema_str' in data:
                # If we parsed JSON but it doesn't have schema_str, try to extract it directly
                schema_match = re.search(r'<database_schema>.*?</database_schema>', selector_content, re.DOTALL)
                if schema_match:
                    data['schema_str'] = schema_match.group()
                    selector_content = json.dumps(data)
        except Exception:
            # If all fails, just use the original content
            pass
    
    # Create proper message object
    selector_message = TextMessage(content=selector_content, source=SELECTOR_NAME)
    
    # Execute SQL generation with timeout
    start_time = time.time()
    try:
        print(f"[Step 2] Requesting SQL generation (timeout: {timeout}s)")
        decomposer_response = await decomposer_agent.on_messages([selector_message], cancellation_token)
        decomposer_content = decomposer_response.chat_message.content.strip()
        
        # Extract SQL from decomposer output
        sql = extract_sql_from_text(decomposer_content)
        
        elapsed = time.time() - start_time
        print(f"[Step 2] SQL generation completed (took {elapsed:.1f}s)")
        
        if not sql:
            print(f"[Step 2] Warning: No SQL found in decomposer output")
        else:
            print(f"[Step 2] SQL generated: {sql[:100]}...")
            
        return decomposer_content, sql
    except asyncio.TimeoutError:
        print(f"[Step 2] Error: SQL generation timed out after {timeout}s")
        raise
    except Exception as e:
        print(f"[Step 2] Error in SQL generation: {str(e)}")
        raise

In [8]:
# Step 3: SQL Refinement
async def refine_sql(decomposer_content: str, sql: str, task: dict, 
                    max_refinement_attempts: int = MAX_REFINEMENT_ATTEMPTS,
                    timeout: int = DEFAULT_TIMEOUT):
    """
    Third workflow step: SQL refinement with refiner_agent
    
    Args:
        decomposer_content: Output from the SQL generation step
        sql: Initial SQL query extracted from decomposer_content
        task: Task dictionary containing query information
        max_refinement_attempts: Maximum number of refinement iterations
        timeout: Maximum time to wait for each response in seconds
        
    Returns:
        Refinement results dictionary
    """
    # Handle case where no SQL was generated
    if not sql:
        print(f"[Step 3] No SQL to refine, skipping refinement step")
        return {
            "db_id": task.get('db_id', ''),
            "query": task.get('query', ''),
            "evidence": task.get('evidence', ''),
            "final_output": decomposer_content,
            "error": "No SQL generated",
            "refinement_attempts": 0,
            "status": "ERROR_NO_SQL"
        }
    
    print(f"\n[Step 3] Starting SQL refinement")
    print(f"[Step 3] Initial SQL: {sql[:100]}...")
    
    # Extract schema information from decomposer content or selector content
    schema_info = ""
    try:
        # Look for schema in decomposer content
        schema_match = re.search(r'<database_schema>.*?</database_schema>', decomposer_content, re.DOTALL)
        if schema_match:
            schema_info = schema_match.group()
        
        # If not found, try parsing JSON
        if not schema_info:
            data = parse_json(decomposer_content)
            if 'schema_str' in data:
                schema_info = data['schema_str']
    except Exception as e:
        print(f"[Step 3] Warning: Could not extract schema information: {str(e)}")
    
    # Create a structured input for the refiner with all necessary context
    refiner_input = {
        "db_id": task.get('db_id', ''),
        "query": task.get('query', ''),
        "evidence": task.get('evidence', ''),
        "sql": sql,
        "schema_info": schema_info,  # Include schema information when available
        "instructions": "Please execute this SQL against the database and return the results in a structured JSON format with the fields: status, final_sql, and execution_result."
    }
    
    refiner_content = json.dumps(refiner_input)
    last_sql = sql
    refinement_attempts = 0
    
    while refinement_attempts < max_refinement_attempts:
        attempt_number = refinement_attempts + 1
        print(f"[Step 3] Starting refinement attempt {attempt_number}/{max_refinement_attempts}")
        
        # Create cancellation token with timeout
        cancellation_token = CancellationToken()
        
        # Create proper message object with the appropriate source
        message = TextMessage(
            content=refiner_content, 
            source=DECOMPOSER_NAME if refinement_attempts == 0 else REFINER_NAME
        )
        
        # Execute refinement with timeout
        start_time = time.time()
        try:
            print(f"[Step 3] Requesting refinement (timeout: {timeout}s)")
            refiner_response = await refiner_agent.on_messages([message], cancellation_token)
            refiner_content = refiner_response.chat_message.content.strip()
            
            elapsed = time.time() - start_time
            print(f"[Step 3] Refinement response received (took {elapsed:.1f}s)")
            
            # Parse refiner output (with fallback)
            data = {}
            try:
                data = parse_json(refiner_content)
            except Exception as e:
                print(f"[Step 3] Warning: Could not parse refiner response as JSON: {str(e)}")
                # Create a simple data structure if parsing failed
                data = {"error": str(e)}
            
            status = data.get('status', '')
            print(f"[Step 3] Refinement attempt {attempt_number}, status: {status}")
            
            # Check for termination conditions
            if status in ['EXECUTION_SUCCESSFUL', 'NO_CHANGE_NEEDED', 'EXECUTION_CONFIRMED']:
                print(f"[Step 3] SQL execution successful: {status}")
                break
                
            # Extract the refined SQL
            new_sql = extract_sql_from_text(refiner_content)
            
            if new_sql:
                if new_sql == last_sql:
                    # SQL didn't change
                    print(f"[Step 3] SQL unchanged in attempt {attempt_number}")
                else:
                    # SQL changed
                    print(f"[Step 3] New SQL detected: {new_sql[:100]}...")
                    last_sql = new_sql
            else:
                # No SQL found in response, force a better input format for next attempt
                print(f"[Step 3] No SQL found in refinement output, attempt {attempt_number}")
                
                refiner_input = {
                    "db_id": task.get('db_id', ''),
                    "query": task.get('query', ''),
                    "evidence": task.get('evidence', ''),
                    "sql": last_sql,
                    "schema_info": schema_info,
                    "refiner_instructions": "Please execute this SQL query against the database and provide the following in your response as a JSON object:\n - status: 'EXECUTION_SUCCESSFUL', 'REFINEMENT_NEEDED', or 'NO_CHANGE_NEEDED'\n - final_sql: The final SQL query (same as input if no changes needed)\n - execution_result: The result of executing the query",
                    "attempt": attempt_number
                }
                refiner_content = json.dumps(refiner_input)
                
            # Increment refinement attempts and check if max reached
            refinement_attempts += 1
            if refinement_attempts >= max_refinement_attempts:
                print(f"[Step 3] Max refinements ({max_refinement_attempts}) reached")
                break
                
        except asyncio.TimeoutError:
            print(f"[Step 3] Error: Refinement attempt {attempt_number} timed out after {timeout}s")
            refinement_attempts += 1
            
            # Try a simplified input for next attempt
            refiner_input = {
                "db_id": task.get('db_id', ''),
                "query": task.get('query', ''),
                "sql": last_sql,
                "timeout_error": f"Previous attempt timed out after {timeout}s",
                "instructions": "Please execute this SQL against the database and return the results."
            }
            refiner_content = json.dumps(refiner_input)
            
        except Exception as e:
            print(f"[Step 3] Error in refinement attempt {attempt_number}: {str(e)}")
            refinement_attempts += 1
            
            # Try a simplified input for next attempt
            refiner_input = {
                "db_id": task.get('db_id', ''),
                "query": task.get('query', ''),
                "sql": last_sql,
                "error": str(e),
                "instructions": "Please execute this SQL against the database and return the results."
            }
            refiner_content = json.dumps(refiner_input)
    
    # Extract final SQL and status from refiner output 
    final_sql = extract_sql_from_text(refiner_content) or last_sql
    
    # Prepare final result with fallback
    try:
        data = parse_json(refiner_content)
        status = data.get('status', 'UNKNOWN')
        
        # If data doesn't have a final_sql field but we extracted one, add it
        if final_sql and not data.get('final_sql'):
            data['final_sql'] = final_sql
            refiner_content = json.dumps(data)
    except Exception as e:
        print(f"[Step 3] Error parsing final result: {str(e)}")
        # Create a minimal result if parsing failed
        data = {
            "status": "ERROR_PARSING_RESULT",
            "final_sql": final_sql,
            "error": str(e)
        }
        refiner_content = json.dumps(data)
        status = "ERROR_PARSING_RESULT"
    
    print(f"[Step 3] Refinement complete - Final status: {status}")
    if final_sql:
        print(f"[Step 3] Final SQL: {final_sql[:100]}...")
    
    return {
        "db_id": task.get('db_id', ''),
        "query": task.get('query', ''),
        "evidence": task.get('evidence', ''),
        "final_output": refiner_content,
        "refinement_attempts": refinement_attempts,
        "status": status,
        "final_sql": final_sql
    }

In [9]:
# Step 3: SQL Refinement
async def refine_sql(decomposer_content: str, sql: str, task: dict, max_refinement_attempts: int = MAX_REFINEMENT_ATTEMPTS):
    """
    Third workflow step: SQL refinement with refiner_agent
    
    Args:
        decomposer_content: Output from the SQL generation step
        sql: Initial SQL query extracted from decomposer_content
        task: Task dictionary containing query information
        max_refinement_attempts: Maximum number of refinement iterations
        
    Returns:
        Refinement results dictionary
    """
    # Create cancellation token
    cancellation_token = CancellationToken()
    
    if not sql:
        print(f"[Step 3] No SQL to refine, skipping refinement step")
        return {
            "db_id": task.get('db_id', ''),
            "query": task.get('query', ''),
            "evidence": task.get('evidence', ''),
            "final_output": decomposer_content,
            "error": "No SQL generated",
            "refinement_attempts": 0
        }
    
    print(f"\n[Step 3] Starting SQL refinement")
    print(f"[Step 3] Initial SQL: {sql[:100]}...")
    
    # Extract schema information from decomposer content or selector content
    schema_info = ""
    try:
        # Look for schema in decomposer content
        schema_match = re.search(r'<database_schema>.*?</database_schema>', decomposer_content, re.DOTALL)
        if schema_match:
            schema_info = schema_match.group()
        
        # If not found, try parsing JSON
        if not schema_info:
            data = parse_json(decomposer_content)
            if 'schema_str' in data:
                schema_info = data['schema_str']
    except Exception as e:
        print(f"[Step 3] Warning: Could not extract schema information: {str(e)}")
    
    # Create a structured input for the refiner with all necessary context
    refiner_input = {
        "db_id": task.get('db_id', ''),
        "query": task.get('query', ''),
        "evidence": task.get('evidence', ''),
        "sql": sql,
        "schema_info": schema_info  # Include schema information when available
    }
    
    refiner_content = json.dumps(refiner_input)
    last_sql = sql
    refinement_attempts = 0
    
    while refinement_attempts < max_refinement_attempts:
        print(f"[Step 3] Starting refinement attempt {refinement_attempts + 1}/{max_refinement_attempts}")
        
        # Create proper message object with the appropriate source
        message = TextMessage(
            content=refiner_content, 
            source=DECOMPOSER_NAME if refinement_attempts == 0 else REFINER_NAME
        )
        
        try:
            # Execute refinement
            refiner_response = await refiner_agent.on_messages([message], cancellation_token)
            refiner_content = refiner_response.chat_message.content.strip()
            
            # Parse refiner output
            data = parse_json(refiner_content)
            status = data.get('status', '')
            print(f"[Step 3] Refinement attempt {refinement_attempts + 1}, status: {status}")
            
            # Check for termination conditions
            if status in ['EXECUTION_SUCCESSFUL', 'NO_CHANGE_NEEDED', 'EXECUTION_CONFIRMED']:
                print(f"[Step 3] SQL execution successful: {status}")
                break
                
            # Extract the refined SQL
            new_sql = extract_sql_from_text(refiner_content)
            
            if new_sql:
                if new_sql == last_sql:
                    # SQL didn't change
                    print(f"[Step 3] SQL unchanged in attempt {refinement_attempts + 1}")
                else:
                    # SQL changed
                    print(f"[Step 3] New SQL detected: {new_sql[:100]}...")
                    last_sql = new_sql
            else:
                # No SQL found in response
                print(f"[Step 3] No SQL found in refinement output, attempt {refinement_attempts + 1}")
                
                # Create a more structured input for next attempt
                refiner_input = {
                    "db_id": task.get('db_id', ''),
                    "query": task.get('query', ''),
                    "evidence": task.get('evidence', ''),
                    "sql": last_sql,
                    "schema_info": schema_info,
                    "refiner_instructions": "Please execute this SQL and provide a status and final_sql in your response JSON."
                }
                refiner_content = json.dumps(refiner_input)
                
            # Increment refinement attempts and check if max reached
            refinement_attempts += 1
            if refinement_attempts >= max_refinement_attempts:
                print(f"[Step 3] Max refinements ({max_refinement_attempts}) reached")
                break
                
        except Exception as e:
            print(f"[Step 3] Error in refinement attempt {refinement_attempts + 1}: {str(e)}")
            refinement_attempts += 1
            # In case of error, simplify the input for the next attempt
            refiner_input = {
                "db_id": task.get('db_id', ''),
                "query": task.get('query', ''),
                "sql": last_sql,
                "error": str(e),
                "instructions": "Please execute this SQL against the database and return the results in JSON format."
            }
            refiner_content = json.dumps(refiner_input)
    
    # Extract final SQL and status from refiner output 
    final_sql = extract_sql_from_text(refiner_content) or last_sql
    
    try:
        data = parse_json(refiner_content)
        status = data.get('status', 'UNKNOWN')
        
        # If data doesn't have a final_sql field but we extracted one, add it to create a proper result
        if final_sql and not data.get('final_sql'):
            data['final_sql'] = final_sql
            refiner_content = json.dumps(data)
    except Exception as e:
        print(f"[Step 3] Error parsing final result: {str(e)}")
        # Create a minimal result if parsing failed
        data = {
            "status": "ERROR_PARSING_RESULT",
            "final_sql": final_sql,
            "error": str(e)
        }
        refiner_content = json.dumps(data)
        status = "ERROR_PARSING_RESULT"
    
    print(f"[Step 3] Refinement complete - Final status: {status}")
    if final_sql:
        print(f"[Step 3] Final SQL: {final_sql[:100]}...")
    
    return {
        "db_id": task.get('db_id', ''),
        "query": task.get('query', ''),
        "evidence": task.get('evidence', ''),
        "final_output": refiner_content,
        "refinement_attempts": refinement_attempts,
        "status": status,
        "final_sql": final_sql
    }

In [10]:
# Define test cases for BIRD dataset
bird_test_cases = [
    # Test case 1: Excellence rate calculation (basic join with aggregation)
    {
        "db_id": "california_schools",
        "query": "List school names of charter schools with an SAT excellence rate over the average.",
        "evidence": "Charter schools refers to `Charter School (Y/N)` = 1 in the table frpm; Excellence rate = NumGE1500 / NumTstTakr"
    },
    
    # Test case 2: Multi-table query with numeric conditions (multiple joins)
    {
        "db_id": "game_injury",
        "query": "Show the names of players who have been injured for more than 3 matches in the 2010 season.",
        "evidence": "Season info is in the game table with year 2010; injury severity is measured by the number of matches a player misses."
    },
    
    # Test case 3: Complex aggregation with grouping
    {
        "db_id": "formula_1",
        "query": "What is the name of the driver who has won the most races in rainy conditions?",
        "evidence": "Weather conditions are recorded in the races table; winner information is in the results table."
    },
    
    # Test case 4: Temporal query with date handling
    {
        "db_id": "loan_data",
        "query": "Find the customer with the highest total payment amount for loans taken in the first quarter of 2011.",
        "evidence": "First quarter means January to March (months 1-3); loan dates are stored in ISO format (YYYY-MM-DD)."
    }
]

# Select the test case to run (0-3)
test_idx = 0
current_test = bird_test_cases[test_idx]

print(f"Selected test case {test_idx + 1}:")
print(f"Database: {current_test['db_id']}")
print(f"Query: {current_test['query']}")
print(f"Evidence: {current_test['evidence']}")

Selected test case 1:
Database: california_schools
Query: List school names of charter schools with an SAT excellence rate over the average.
Evidence: Charter schools refers to `Charter School (Y/N)` = 1 in the table frpm; Excellence rate = NumGE1500 / NumTstTakr


In [11]:
# Execute each step individually with proper error handling and timeout management
try:
    print("\n" + "="*60)
    print("TEXT-TO-SQL PIPELINE EXECUTION")
    print("="*60)
    
    # Step 1: Schema Selection
    print("\n" + "="*50)
    print("STEP 1: SCHEMA SELECTION")
    print("="*50)
    task, selector_content = await select_schema(json.dumps(current_test))
    
    print("\nSchema selection result summary:")
    print("-" * 40)
    # Verify if we got schema information
    if "<database_schema>" in selector_content or "schema_str" in selector_content:
        print("✓ Schema information successfully extracted")
    else:
        print("⚠ Schema information may be missing or malformed")
    # Print a preview of the content
    print("\nPreview:")
    preview = selector_content[:200] + "..." if len(selector_content) > 200 else selector_content
    print(preview)

    # Step 2: SQL Generation
    print("\n" + "="*50)
    print("STEP 2: SQL GENERATION")
    print("="*50)
    decomposer_content, sql = await generate_sql(selector_content, task)
    
    print("\nSQL generation result summary:")
    print("-" * 40)
    if sql:
        print(f"✓ SQL query generated ({len(sql)} chars)")
        print("\nSQL Query:")
        print(sql[:300] + "..." if len(sql) > 300 else sql)
    else:
        print("⚠ No SQL query was generated")
        print("\nDecomposer output preview:")
        print(decomposer_content[:200] + "..." if len(decomposer_content) > 200 else decomposer_content)

    # Step 3: SQL Refinement
    print("\n" + "="*50)
    print("STEP 3: SQL REFINEMENT")
    print("="*50)
    if not sql:
        print("\n⚠ Skipping refinement because no SQL was generated")
        result = {
            "db_id": task.get('db_id', ''),
            "query": task.get('query', ''),
            "final_output": decomposer_content,
            "error": "No SQL generated",
            "status": "ERROR_NO_SQL"
        }
    else:
        result = await refine_sql(decomposer_content, sql, task)
    
    # Display final result
    print("\n" + "="*50)
    print("FINAL RESULT")
    print("="*50)
    
    # Display database and query info
    print(f"Database: {result.get('db_id', '')}")
    print(f"Query: {result.get('query', '')}")
    
    # Try to parse and format the result
    try:
        final_output = parse_json(result.get("final_output", "{}"))
        
        # Display status
        status = final_output.get("status") or result.get("status", "UNKNOWN")
        print(f"\nExecution Status: {status}")
        
        # Display final SQL
        final_sql = final_output.get("final_sql") or result.get("final_sql", "")
        if final_sql:
            print(f"\nFinal SQL Query:")
            print(final_sql)
        else:
            print("\n⚠ No final SQL query available")
            
        # Display execution result if available
        if "execution_result" in final_output:
            print("\nExecution Result:")
            exec_result = final_output["execution_result"]
            if isinstance(exec_result, dict):
                for key, value in exec_result.items():
                    print(f"  {key}: {value}")
            else:
                print(exec_result)
                
        # Display any errors
        if "error" in final_output or "error" in result:
            error = final_output.get("error") or result.get("error", "")
            print(f"\nError: {error}")
            
    except Exception as e:
        print(f"\nError parsing final result: {str(e)}")
        if "final_sql" in result:
            print(f"\nFinal SQL:\n{result['final_sql']}")
        print("\nRaw output preview:")
        raw_output = result.get("final_output", "")
        print(raw_output[:500] + "..." if len(raw_output) > 500 else raw_output)
        
except Exception as e:
    print(f"\nCRITICAL ERROR IN EXECUTION: {str(e)}")
    import traceback
    print(traceback.format_exc())


TEXT-TO-SQL PIPELINE EXECUTION

STEP 1: SCHEMA SELECTION
[Step 1] Starting schema selection for database 'california_schools'
[Step 1] Query: List school names of charter schools with an SAT excellence rate over the average.
[Step 1] Evidence: Charter schools refers to `Charter School (Y/N)` = 1 in the table frpm; Excellence rate = NumGE1500 / NumTstTakr
[Step 1] Requesting schema selection (timeout: 120s)
[Tool] Loading schema for database: california_schools
[Step 1] Schema selected successfully (took 1.5s)

Schema selection result summary:
----------------------------------------
✓ Schema information successfully extracted

Preview:
{"db_id": "california_schools", "table_count": 3, "total_column_count": 89, "avg_column_count": 29, "is_complex_schema": true, "full_schema_str": "<database_schema>\n  <table name=\"frpm\">\n    <colu...

STEP 2: SQL GENERATION

[Step 2] Starting SQL generation
[Step 2] Query: List school names of charter schools with an SAT excellence rate over the ave