In [1]:
import asyncio
import json
import os
import re
from typing import Sequence, Dict, Any, Tuple, List, Optional

# Autogen imports
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import SelectorGroupChat
from autogen_agentchat.conditions import TextMentionTermination
from autogen_agentchat.ui import Console
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage

# Import from project modules
from const import (
    SELECTOR_NAME, DECOMPOSER_NAME, REFINER_NAME, SYSTEM_NAME,
    selector_template, decompose_template_bird, refiner_template
)
from schema_manager import SchemaManager
from sql_executor import SQLExecutor

# --- Constants ---
MAX_REFINEMENT_ATTEMPTS = 3  # Maximum number of refinement attempts for SQL
BIRD_DATA_PATH = "../data/bird"
BIRD_TABLES_JSON_PATH = os.path.join(BIRD_DATA_PATH, "dev_tables.json")
DATASET_NAME = "bird"

# --- Helper Functions ---
def parse_json(text: str) -> Dict:
    """
    Attempt to parse JSON from text, returning an empty dict if parsing fails.
    
    Args:
        text: Text that might contain JSON
        
    Returns:
        Dictionary of parsed JSON or empty dict if parsing fails
    """
    try:
        # Find JSON-like patterns with regex
        json_pattern = r'{.*}'
        match = re.search(json_pattern, text, re.DOTALL)
        if match:
            json_str = match.group()
            return json.loads(json_str)
        return {}
    except Exception:
        return {}

def extract_sql_from_text(text: str) -> str:
    """
    Extract SQL query from text.
    
    Args:
        text: Text that might contain SQL
        
    Returns:
        Extracted SQL query or empty string if no SQL found
    """
    try:
        # Try to extract SQL from JSON
        data = parse_json(text)
        if 'sql' in data:
            return data['sql']
        if 'final_sql' in data:
            return data['final_sql']
            
        # Try to extract SQL with regex patterns
        sql_patterns = [
            r'```sql\s*(.*?)\s*```',  # SQL in code blocks
            r'SELECT.*?(?:;|$)',      # Simple SELECT statements
            r'WITH.*?(?:;|$)',        # WITH queries
        ]
        
        for pattern in sql_patterns:
            matches = re.findall(pattern, text, re.DOTALL | re.IGNORECASE)
            if matches:
                return matches[0].strip()
                
        return ""
    except Exception:
        return ""

In [2]:
schema_manager = SchemaManager(
    data_path=BIRD_DATA_PATH,
    tables_json_path=BIRD_TABLES_JSON_PATH,
    dataset_name=DATASET_NAME,
    lazy=False  # Use lazy loading for performance
)

# Replace the original SQLExecutor with our patched version
sql_executor = SQLExecutor(
    data_path=BIRD_DATA_PATH,
    dataset_name=DATASET_NAME
)

# Tool implementation for schema selection and management
async def get_initial_database_schema(db_id: str) -> str:
    """
    Retrieves the full database schema information for a given database.
    
    Args:
        db_id: The database identifier
        
    Returns:
        JSON string with full schema information
    """
    print(f"[Tool] Loading schema for database: {db_id}")
    
    # Load database information using SchemaManager
    if db_id not in schema_manager.db2infos:
        schema_manager.db2infos[db_id] = schema_manager._load_single_db_info(db_id)
    
    # Get database information
    db_info = schema_manager.db2dbjsons.get(db_id, {})
    if not db_info:
        return json.dumps({"error": f"Database '{db_id}' not found"})
    
    # Determine if schema is complex enough to need pruning
    is_complex = schema_manager._is_complex_schema(db_id)
    
    # Generate full schema description (without pruning)
    full_schema_str, full_fk_str, _ = schema_manager.generate_schema_description(
        db_id, {}, use_gold_schema=False
    )
    
    # Return schema details
    return json.dumps({
        "db_id": db_id,
        "table_count": db_info.get('table_count', 0),
        "total_column_count": db_info.get('total_column_count', 0),
        "avg_column_count": db_info.get('avg_column_count', 0),
        "is_complex_schema": is_complex,
        "full_schema_str": full_schema_str,
        "full_fk_str": full_fk_str
    })

async def prune_database_schema(db_id: str, pruning_rules: Dict) -> str:
    """
    Applies pruning rules to a database schema.
    
    Args:
        db_id: The database identifier
        pruning_rules: Dictionary with tables and columns to keep
        
    Returns:
        JSON string with pruned schema
    """
    print(f"[Tool] Pruning schema for database {db_id}")
    
    # Generate pruned schema description
    schema_str, fk_str, chosen_schema = schema_manager.generate_schema_description(
        db_id, pruning_rules, use_gold_schema=False
    )
    
    # Return pruned schema
    return json.dumps({
        "db_id": db_id,
        "pruning_applied": True,
        "pruning_rules": pruning_rules,
        "pruned_schema_str": schema_str,
        "pruned_fk_str": fk_str,
        "tables_columns_kept": chosen_schema
    })

# Tool implementation for SQL execution
async def execute_sql(sql: str, db_id: str) -> str:
    """
    Executes a SQL query on the specified database.
    
    Args:
        sql: The SQL query to execute
        db_id: The database identifier
        
    Returns:
        JSON string with execution results
    """
    print(f"[Tool] Executing SQL on database {db_id}: {sql[:100]}...")
    
    # Execute SQL with timeout protection
    result = sql_executor.safe_execute(sql, db_id)
    
    # Add validation information
    is_valid, reason = sql_executor.is_valid_result(result)
    result["is_valid_result"] = is_valid
    result["validation_message"] = reason
    
    # Convert to JSON string
    return json.dumps(result)

load json file from ../data/bird/dev_tables.json

Loading all database info...
Found 11 databases in bird dataset


In [3]:
from autogen_agentchat.conditions import FunctionalTermination

class SQLTerminationCondition(FunctionalTermination):
    """Custom termination condition for Text-to-SQL agent workflow."""
    
    def __init__(self, max_refinement_attempts: int = MAX_REFINEMENT_ATTEMPTS):
        self.max_refinement_attempts = max_refinement_attempts
        self._refinement_attempts = 0
        self._last_sql = None
        
        # Initialize parent with our termination function
        super().__init__(func=self._should_terminate)
    
    def reset(self) -> None:
        """Reset termination condition state."""
        self._refinement_attempts = 0
        self._last_sql = None
        print("[Termination] State reset")
        
    async def _should_terminate(self, messages):
        """
        Determine if the conversation should terminate.
        
        Args:
            messages: Sequence of messages in the conversation
            
        Returns:
            True if the conversation should terminate, False otherwise
        """
        if not messages:
            return False
            
        # Get the last message
        last_message = messages[-1]
        
        # Get the message content as text
        if hasattr(last_message, 'content'):
            content = last_message.content
            if isinstance(content, dict):
                content = str(content)  # Convert dict to string
        else:
            content = str(last_message)
            
        # Get message source name
        if hasattr(last_message, 'source') and hasattr(last_message.source, 'name'):
            source_name = last_message.source.name
        else:
            source_name = ""
            
        # Try to parse the message content as JSON
        try:
            data = parse_json(content)
        except Exception as e:
            print(f"[Termination] Error parsing JSON: {str(e)}")
            data = {}
            
        # Debug the message
        print(f"[Termination] Checking message from '{source_name}'")
        
        # Handle SQL generation from the decomposer
        if source_name == DECOMPOSER_NAME:
            sql = extract_sql_from_text(content)
            if sql and sql != self._last_sql:
                # New SQL detected
                self._last_sql = sql
                self._refinement_attempts = 0
                print(f"[Termination] New SQL detected from {DECOMPOSER_NAME}. Resetting refinement attempts.")
                return False  # Continue to refiner
                
        # Handle SQL refinement
        if source_name == REFINER_NAME:
            # Check for status in the data
            status = data.get('status', '')
            print(f"[Termination] Refiner status: {status}")
            
            # Termination conditions
            if status in ['EXECUTION_SUCCESSFUL', 'NO_CHANGE_NEEDED', 'EXECUTION_CONFIRMED']:
                print(f"[Termination] Success: {status}")
                return True
                
            # Get the current SQL
            sql = extract_sql_from_text(content)
            if sql:
                # Track attempt for this SQL
                if sql != self._last_sql:
                    self._last_sql = sql
                    self._refinement_attempts = 1
                else:
                    self._refinement_attempts += 1
                    
                print(f"[Termination] Refinement attempt {self._refinement_attempts}/{self.max_refinement_attempts}")
                
                # Terminate if max attempts reached
                if self._refinement_attempts >= self.max_refinement_attempts:
                    print(f"[Termination] Max refinements ({self.max_refinement_attempts}) reached")
                    return True
                    
        # Continue if no termination condition met
        return False

In [None]:
def selector_func(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> str | None:
    """
    Determines which agent should receive the next message.
    
    This function implements the workflow logic:
    1. Initial message -> Schema Selector Agent
    2. Schema Selector output -> Decomposer Agent
    3. Decomposer output -> Refiner Agent
    4. Refiner output -> Either Refiner again (for refinement) or terminate
    
    Args:
        messages: Sequence of messages in the conversation
        
    Returns:
        Name of the next agent or None to terminate
    """
    if not messages:
        return None
        
    # Get the last message
    last_message = messages[-1]
    
    # Get content as text
    if hasattr(last_message, 'content'):
        content = last_message.content
        if isinstance(content, dict):
            content = str(content)
    else:
        content = str(last_message)
        
    # Get source name - with more robust checking
    source_name = ""
    if hasattr(last_message, 'source'):
        if hasattr(last_message.source, 'name'):
            source_name = last_message.source.name
        # If source exists but name is missing or empty, try to infer from the message
        elif len(messages) > 1 and hasattr(messages[-2], 'source') and hasattr(messages[-2].source, 'name'):
            # If we know previous speaker, infer current based on workflow
            prev_source = messages[-2].source.name
            if prev_source == SELECTOR_NAME:
                source_name = DECOMPOSER_NAME  # Assume response to selector is from decomposer
            elif prev_source == DECOMPOSER_NAME:
                source_name = REFINER_NAME     # Assume response to decomposer is from refiner
            # Don't try to infer refiner->refiner, as that depends on execution status
    
    # Parse content as JSON
    try:
        data = parse_json(content)
    except Exception as e:
        print(f"[Orchestrator] Error parsing JSON: {str(e)}")
        data = {}
        
    print(f"[Orchestrator] Last message from: '{source_name}', content type: {type(content)}")
    
    # Initial message handling - start with schema selector
    if len(messages) == 1:
        print(f"[Orchestrator] Starting new task, selecting {SELECTOR_NAME}")
        return SELECTOR_NAME
        
    # Check if we have a real message that appears to be from the user
    is_user_message = False
    if source_name == "user":
        is_user_message = True
    # If source is empty but it's a direct response to a message we sent
    elif source_name == "" and len(messages) > 1:
        # This is likely a response that didn't have a source properly set
        if hasattr(last_message, 'name') and isinstance(last_message.name, str) and last_message.name:
            # If the message has a name attribute, use that
            source_name = last_message.name
        else:
            # If we can't determine source and it's not the first message
            is_user_message = True
    
    # Handle user messages
    if is_user_message and len(messages) > 1:  # Not the first message
        print(f"[Orchestrator] User message received, redirecting to {SELECTOR_NAME}")
        return SELECTOR_NAME
        
    # Workflow transitions
    if source_name == SELECTOR_NAME:
        # Schema selector -> Decomposer
        print(f"[Orchestrator] Schema processing complete, selecting {DECOMPOSER_NAME}")
        return DECOMPOSER_NAME
        
    elif source_name == DECOMPOSER_NAME:
        # Decomposer -> Refiner
        sql = extract_sql_from_text(content)
        if sql:
            print(f"[Orchestrator] SQL generated, selecting {REFINER_NAME}")
            return REFINER_NAME
        else:
            print(f"[Orchestrator] No SQL found in Decomposer output, terminating")
            return None
            
    elif source_name == REFINER_NAME:
        # Check refiner status
        status = data.get('status', '')
        
        if status in ['EXECUTION_SUCCESSFUL', 'NO_CHANGE_NEEDED', 'EXECUTION_CONFIRMED']:
            # Successful execution or no change needed - terminate
            print(f"[Orchestrator] SQL execution successful, no change needed, or confirmed, terminating")
            return None
            
        elif status == 'REFINEMENT_NEEDED':
            # Continue refinement
            print(f"[Orchestrator] SQL needs refinement, selecting {REFINER_NAME} again")
            return REFINER_NAME
            
        else:
            # Default to continue refinement
            print(f"[Orchestrator] Continuing with {REFINER_NAME} (status: {status})")
            return REFINER_NAME
            
    # Default to termination for unexpected cases
    print(f"[Orchestrator] No matching rule for source '{source_name}', terminating")
    return None

In [10]:
# Initialize model client
model_client = OpenAIChatCompletionClient(model="gpt-4o")

# Create Schema Selector Agent
selector_agent = AssistantAgent(
    name=SELECTOR_NAME,
    model_client=model_client,
    system_message=f"""You are a Database Schema Selector specialized in analyzing database schemas for text-to-SQL tasks.

Your job is to help prune large database schemas to focus on the relevant tables and columns for a given query.

TASK OVERVIEW:
1. You will receive a task with database ID, query, and evidence
2. Use the 'get_initial_database_schema' tool to retrieve the full schema
3. Analyze the schema complexity and relevance to the query
4. For complex schemas, determine which tables and columns are relevant
5. Use the 'prune_database_schema' tool to generate a focused schema
6. Return the processed schema information for the next agent

WHEN ANALYZING THE SCHEMA:
- Study the database table structure and relationships
- Identify tables directly mentioned or implied in the query
- Consider foreign key relationships that might be needed
- Follow the pruning guidelines in the following template:
{selector_template}

FORMAT YOUR FINAL RESPONSE AS JSON:
{{
  "db_id": "<database_id>",
  "query": "<natural_language_query>",
  "evidence": "<any_evidence_provided>",
  "pruning_applied": true/false,
  "schema_str": "<schema_description>",
  "fk_str": "<foreign_key_information>"
}}

Remember that high-quality schema selection improves the accuracy of SQL generation.""",
    tools=[get_initial_database_schema, prune_database_schema],
)

# Create Decomposer Agent
decomposer_agent = AssistantAgent(
    name=DECOMPOSER_NAME,
    model_client=model_client,
    system_message=f"""You are a Query Decomposer specialized in converting natural language questions into SQL for the BIRD dataset.

Your job is to analyze a natural language query and relevant database schema, then generate the appropriate SQL query.

TASK OVERVIEW:
1. You will receive a JSON with db_id, query, evidence, and schema information
2. Study the database schema, focusing on tables, columns, and relationships
3. For complex queries, break down the problem into logical steps
4. Generate a clear, efficient SQL query that answers the question
5. Follow the specific query decomposition template for BIRD:
{decompose_template_bird}

IMPORTANT CONSIDERATIONS:
- BIRD queries often require domain knowledge and multiple steps
- Carefully use the evidence provided to understand domain-specific concepts
- Apply type conversion when comparing numeric data (CAST AS REAL/INT)
- Ensure proper handling of NULL values
- Use table aliases (T1, T2, etc.) for clarity, especially in JOINs
- Always use valid SQLite syntax

FORMAT YOUR FINAL RESPONSE AS JSON:
{{
  "db_id": "<database_id>",
  "query": "<natural_language_query>",
  "evidence": "<any_evidence_provided>",
  "sql": "<generated_sql_query>",
  "decomposition": [
    "step1_description", 
    "step2_description",
    ...
  ]
}}

Your goal is to generate SQL that will execute correctly and return the precise information requested.""",
    tools=[],  # SQL generation is the primary LLM task
)

# Create Refiner Agent
refiner_agent = AssistantAgent(
    name=REFINER_NAME,
    model_client=model_client,
    system_message=f"""You are an SQL Refiner specializing in executing and fixing SQL queries for the BIRD dataset.

Your job is to test SQL queries against the database, identify errors, and refine them until they execute successfully.

TASK OVERVIEW:
1. You will receive a JSON with db_id, query, evidence, schema, and SQL
2. Use the 'execute_sql' tool to run the SQL against the database
3. Analyze execution results or errors
4. For errors, refine the SQL using the template:
{refiner_template}
5. For successful execution, validate the results are appropriate for the original query

IMPORTANT CONSIDERATIONS:
- BIRD databases have specific requirements for valid results:
  - Results should not be empty (unless that's the expected answer)
  - Results should not contain NULL values without justification
  - Results should match the expected types and formats
- Focus on SQLite-specific syntax and behaviors
- Pay special attention to:
  - Table and column name quoting (use backticks)
  - Type conversions (CAST AS)
  - JOIN conditions
  - Subquery structure and aliases

FORMAT YOUR FINAL RESPONSE AS JSON:
{{
  "db_id": "<database_id>",
  "query": "<natural_language_query>",
  "evidence": "<any_evidence_provided>",
  "original_sql": "<original_sql>",
  "final_sql": "<refined_sql>",
  "status": "<EXECUTION_SUCCESSFUL|REFINEMENT_NEEDED|NO_CHANGE_NEEDED>",
  "execution_result": "<execution_result_summary>",
  "refinement_explanation": "<explanation_of_changes>"
}}

Your goal is to ensure the SQL query executes successfully and returns relevant results.""",
    tools=[execute_sql],
)

In [11]:
# Set up the SelectorGroupChat with the agents
termination_condition = SQLTerminationCondition(max_refinement_attempts=MAX_REFINEMENT_ATTEMPTS)

# Create the team - using participants parameter instead of agents
team = SelectorGroupChat(
    participants=[selector_agent, decomposer_agent, refiner_agent],
    model_client=model_client,
    selector_func=selector_func,
    termination_condition=termination_condition,
)

In [12]:
# Define test cases for BIRD dataset
bird_test_cases = [
    # Test case 1: Excellence rate calculation (basic join with aggregation)
    {
        "db_id": "california_schools",
        "query": "List school names of charter schools with an SAT excellence rate over the average.",
        "evidence": "Charter schools refers to `Charter School (Y/N)` = 1 in the table frpm; Excellence rate = NumGE1500 / NumTstTakr"
    },
    
    # Test case 2: Multi-table query with numeric conditions (multiple joins)
    {
        "db_id": "game_injury",
        "query": "Show the names of players who have been injured for more than 3 matches in the 2010 season.",
        "evidence": "Season info is in the game table with year 2010; injury severity is measured by the number of matches a player misses."
    },
    
    # Test case 3: Complex aggregation with grouping
    {
        "db_id": "formula_1",
        "query": "What is the name of the driver who has won the most races in rainy conditions?",
        "evidence": "Weather conditions are recorded in the races table; winner information is in the results table."
    },
    
    # Test case 4: Temporal query with date handling
    {
        "db_id": "loan_data",
        "query": "Find the customer with the highest total payment amount for loans taken in the first quarter of 2011.",
        "evidence": "First quarter means January to March (months 1-3); loan dates are stored in ISO format (YYYY-MM-DD)."
    }
]

# Select the test case to run (0-3)
test_idx = 0
current_test = bird_test_cases[test_idx]

print(f"Selected test case {test_idx + 1}:")
print(f"Database: {current_test['db_id']}")
print(f"Query: {current_test['query']}")
print(f"Evidence: {current_test['evidence']}")

Selected test case 1:
Database: california_schools
Query: List school names of charter schools with an SAT excellence rate over the average.
Evidence: Charter schools refers to `Charter School (Y/N)` = 1 in the table frpm; Excellence rate = NumGE1500 / NumTstTakr


In [None]:
# Execute the task with the agent team
# This will process the BIRD query through the schema selection, 
# SQL generation, and refinement stages
await Console(team.run_stream(task=json.dumps(current_test)))

[Termination] Checking message from ''
[Orchestrator] Last message from: '', content type: <class 'str'>
[Orchestrator] Starting new task, selecting Selector
---------- TextMessage (user) ----------
{"db_id": "california_schools", "query": "List school names of charter schools with an SAT excellence rate over the average.", "evidence": "Charter schools refers to `Charter School (Y/N)` = 1 in the table frpm; Excellence rate = NumGE1500 / NumTstTakr"}
[Tool] Loading schema for database: california_schools
[Termination] Checking message from ''
[Orchestrator] Last message from: '', content type: <class 'str'>
[Orchestrator] User message received, redirecting to Selector
---------- ToolCallRequestEvent (Selector) ----------
[FunctionCall(id='call_SVns6gmDswL7rXjF5tnoZwe2', arguments='{"db_id":"california_schools"}', name='get_initial_database_schema')]
---------- ToolCallExecutionEvent (Selector) ----------
[FunctionExecutionResult(content='{"db_id": "california_schools", "table_count": 3,

In [None]:
# Run all test cases and collect results
async def run_all_tests():
    results = []
    for i, test_case in enumerate(bird_test_cases):
        print(f"\n\n{'='*80}")
        print(f"Test {i+1}: {test_case['db_id']} - {test_case['query']}")
        print(f"{'='*80}\n")
        
        # Reset the termination condition
        if hasattr(team.termination_condition, 'reset'):
            team.termination_condition.reset()
            
        # Run the test
        try:
            await Console(team.run_stream(task=json.dumps(test_case)))
            print(f"\n{'-'*40}")
            print(f"Test {i+1} completed")
            print(f"{'-'*40}\n")
        except Exception as e:
            print(f"Error in test {i+1}: {str(e)}")
    
    print("\nAll tests completed!")

In [None]:
# Comment this out to run a single test (cell 9)
# Uncomment to run all test cases sequentially
# await run_all_tests()

In [None]:
# Test cell to verify SQLite is properly configured
import sqlite3
print("SQLite version:", sqlite3.version)
print("SQLite library version:", sqlite3.sqlite_version)

# Test connection to a sample database
try:
    test_db_path = f"{BIRD_DATA_PATH}/dev_databases/california_schools/california_schools.sqlite"
    print(f"Attempting to connect to: {test_db_path}")
    
    conn = sqlite3.connect(test_db_path)
    cursor = conn.cursor()
    
    # Get schema information
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    print(f"Tables in database: {[table[0] for table in tables]}")
    
    # Try a simple query
    cursor.execute("SELECT COUNT(*) FROM frpm;")
    count = cursor.fetchone()[0]
    print(f"Count of rows in frpm table: {count}")
    
    conn.close()
    print("SQLite test successful!")
except Exception as e:
    print(f"SQLite test failed: {str(e)}")
    print(f"Error type: {type(e).__name__}")

## Testing Multiple BIRD Queries

The notebook has been updated to support testing multiple queries from the BIRD dataset. We've defined a set of test cases with varying complexity:

1. **California Schools**: Requires understanding of charter schools and SAT excellence rate calculation.
2. **Game Injury**: Requires joining multiple tables and filtering based on numeric conditions.
3. **Formula 1**: Complex aggregation with grouping to find a driver with most race wins under specific conditions.
4. **Loan Data**: Temporal query that requires date handling and filtering by quarter.

You can change the `test_idx` variable to select which test case to run individually, or use the `run_all_tests()` function to execute all tests sequentially.

### Expected SQL Solutions

For the first test case (California Schools):

```sql
SELECT T2.`sname`
FROM frpm AS T1
JOIN satscores AS T2 ON T1.`CDSCode` = T2.`cds`
WHERE T1.`Charter School (Y/N)` = 1 
  AND T2.`sname` IS NOT NULL
  AND CAST(T2.`NumGE1500` AS REAL) / T2.`NumTstTakr` > (
    SELECT AVG(CAST(T4.`NumGE1500` AS REAL) / T4.`NumTstTakr`)
    FROM frpm AS T3 
    JOIN satscores AS T4 ON T3.`CDSCode` = T4.`cds` 
    WHERE T3.`Charter School (Y/N)` = 1
  )
```

### Key Challenges in BIRD Dataset

The BIRD dataset presents several unique challenges:

1. **Domain Knowledge**: Requires understanding specialized concepts in different domains
2. **Complex Schemas**: Many databases have 10+ tables and 50+ columns
3. **Multi-step Reasoning**: Often requires breaking a query into multiple logical steps
4. **Data Conversion**: Often needs CAST operations for proper numeric comparisons
5. **NULL Handling**: Important to account for NULL values in results
6. **Data Validation**: Results need to be verified for correctness against requirements

### Workflow Components

The Text-to-SQL workflow includes:

1. **Schema Selection**: Prunes large schemas to focus on relevant tables/columns
2. **SQL Generation**: Decomposes complex queries into logical steps
3. **SQL Refinement**: Executes and fixes errors until valid results are obtained