In [None]:
import asyncio
import json
import os
import re
from typing import Sequence, Dict, Any, Tuple, List, Optional

# Autogen imports
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import SelectorGroupChat
from autogen_agentchat.conditions import TextMentionTermination
from autogen_agentchat.ui import Console
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage

# Import from project modules
from const import (
    SELECTOR_NAME, DECOMPOSER_NAME, REFINER_NAME, SYSTEM_NAME,
    selector_template, decompose_template_bird, refiner_template
)
from schema_manager import SchemaManager
from sql_executor import SQLExecutor

# --- Constants ---
MAX_REFINEMENT_ATTEMPTS = 3  # Maximum number of refinement attempts for SQL
BIRD_DATA_PATH = "../data/bird"
BIRD_TABLES_JSON_PATH = os.path.join(BIRD_DATA_PATH, "dev_tables.json")
DATASET_NAME = "bird"

In [None]:
# Initialize SchemaManager and SQLExecutor for BIRD dataset
schema_manager = SchemaManager(
    data_path=BIRD_DATA_PATH,
    tables_json_path=BIRD_TABLES_JSON_PATH,
    dataset_name=DATASET_NAME,
    lazy=True  # Use lazy loading for better performance
)

sql_executor = SQLExecutor(
    data_path=BIRD_DATA_PATH,
    dataset_name=DATASET_NAME
)

# Utility functions for processing agent responses
def parse_json(s: str) -> Dict[str, Any]:
    """Extract and parse JSON from agent response text."""
    try:
        # First try parsing the string directly
        return json.loads(s)
    except json.JSONDecodeError:
        # Try to extract JSON from markdown code blocks
        if "```json" in s:
            json_str = s.split("```json")[1].split("```")[0].strip()
            try:
                return json.loads(json_str)
            except json.JSONDecodeError:
                pass
                
        # Try to extract from any code block
        if "```" in s:
            code_block = s.split("```")[1].strip()
            try:
                return json.loads(code_block)
            except json.JSONDecodeError:
                pass
                
        # Try to find JSON-like content between braces
        brace_pattern = r'\{.*\}'
        match = re.search(brace_pattern, s, re.DOTALL)
        if match:
            try:
                return json.loads(match.group())
            except json.JSONDecodeError:
                pass
        
        # Return error if all parsing attempts fail
        return {"error": "Failed to parse JSON", "original_text": s[:500]}

def extract_sql_from_text(text: str) -> str:
    """Extract SQL query from text or JSON."""
    # Extract from JSON if it's a dictionary
    if isinstance(text, dict):
        # Look for the most specific field first
        for key in ['refined_sql', 'sql', 'query', 'final_sql']:
            if key in text:
                return text[key]
    
    # Extract from SQL code block
    sql_pattern = r'```sql\s*(.*?)\s*```'
    sql_matches = re.findall(sql_pattern, text, re.DOTALL)
    if sql_matches:
        return sql_matches[0].strip()
    
    # Look for SQL statements without code blocks
    if "SELECT" in text.upper():
        # Try to extract a whole SQL statement
        select_pattern = r'SELECT\s+.*?(?:;|$)'
        select_matches = re.findall(select_pattern, text, re.DOTALL | re.IGNORECASE)
        if select_matches:
            return select_matches[0].strip()
    
    # Default return if no SQL found
    return ""

In [None]:
# Tool implementation for schema selection and management
async def get_initial_database_schema(db_id: str) -> str:
    """
    Retrieves the full database schema information for a given database.
    
    Args:
        db_id: The database identifier
        
    Returns:
        JSON string with full schema information
    """
    print(f"[Tool] Loading schema for database: {db_id}")
    
    # Load database information using SchemaManager
    if db_id not in schema_manager.db2infos:
        schema_manager.db2infos[db_id] = schema_manager._load_single_db_info(db_id)
    
    # Get database information
    db_info = schema_manager.db2dbjsons.get(db_id, {})
    if not db_info:
        return json.dumps({"error": f"Database '{db_id}' not found"})
    
    # Determine if schema is complex enough to need pruning
    is_complex = schema_manager._is_complex_schema(db_id)
    
    # Generate full schema description (without pruning)
    full_schema_str, full_fk_str, _ = schema_manager.generate_schema_description(
        db_id, {}, use_gold_schema=False
    )
    
    # Return schema details
    return json.dumps({
        "db_id": db_id,
        "table_count": db_info.get('table_count', 0),
        "total_column_count": db_info.get('total_column_count', 0),
        "avg_column_count": db_info.get('avg_column_count', 0),
        "is_complex_schema": is_complex,
        "full_schema_str": full_schema_str,
        "full_fk_str": full_fk_str
    })

async def prune_database_schema(db_id: str, pruning_rules: Dict) -> str:
    """
    Applies pruning rules to a database schema.
    
    Args:
        db_id: The database identifier
        pruning_rules: Dictionary with tables and columns to keep
        
    Returns:
        JSON string with pruned schema
    """
    print(f"[Tool] Pruning schema for database {db_id}")
    
    # Generate pruned schema description
    schema_str, fk_str, chosen_schema = schema_manager.generate_schema_description(
        db_id, pruning_rules, use_gold_schema=False
    )
    
    # Return pruned schema
    return json.dumps({
        "db_id": db_id,
        "pruning_applied": True,
        "pruning_rules": pruning_rules,
        "pruned_schema_str": schema_str,
        "pruned_fk_str": fk_str,
        "tables_columns_kept": chosen_schema
    })

# Tool implementation for SQL execution
async def execute_sql(sql: str, db_id: str) -> str:
    """
    Executes a SQL query on the specified database.
    
    Args:
        sql: The SQL query to execute
        db_id: The database identifier
        
    Returns:
        JSON string with execution results
    """
    print(f"[Tool] Executing SQL on database {db_id}: {sql[:100]}...")
    
    # Execute SQL with timeout protection
    result = sql_executor.safe_execute(sql, db_id)
    
    # Add validation information
    is_valid, reason = sql_executor.is_valid_result(result)
    result["is_valid_result"] = is_valid
    result["validation_message"] = reason
    
    # Convert to JSON string
    return json.dumps(result)

In [None]:
class MacSQLTerminationCondition(TextMentionTermination):
    """Custom termination condition for Text-to-SQL agent workflow."""
    
    def __init__(self, max_refinement_attempts: int = MAX_REFINEMENT_ATTEMPTS):
        super().__init__()
        self.max_refinement_attempts = max_refinement_attempts
        self._refinement_attempts = 0
        self._last_sql = None
        
    def reset(self) -> None:
        """Reset termination condition state."""
        self._refinement_attempts = 0
        self._last_sql = None
        print("[Termination] State reset")
        
    def __call__(self, messages: Sequence[BaseChatMessage]) -> bool:
        """
        Determine if the conversation should terminate.
        
        Args:
            messages: Sequence of messages in the conversation
            
        Returns:
            True if the conversation should terminate, False otherwise
        """
        if not messages:
            return False
            
        # Get the last message
        last_message = messages[-1]
        
        # Get the message content as text
        if hasattr(last_message, 'content'):
            content = last_message.content
            if isinstance(content, dict):
                content = str(content)  # Convert dict to string
        else:
            content = str(last_message)
            
        # Get message source name
        source_name = getattr(last_message.source, 'name', '') if hasattr(last_message, 'source') else ''
        
        # Try to parse the message content as JSON
        try:
            data = parse_json(content)
        except:
            data = {}
            
        # Handle SQL generation from the decomposer
        if source_name == DECOMPOSER_NAME:
            sql = extract_sql_from_text(content)
            if sql and sql != self._last_sql:
                # New SQL detected
                self._last_sql = sql
                self._refinement_attempts = 0
                print(f"[Termination] New SQL detected from {DECOMPOSER_NAME}. Resetting refinement attempts.")
                return False  # Continue to refiner
                
        # Handle SQL refinement
        if source_name == REFINER_NAME:
            # Check for status in the data
            status = data.get('status', '')
            
            # Termination conditions
            if status in ['EXECUTION_SUCCESSFUL', 'NO_CHANGE_NEEDED']:
                print(f"[Termination] Success: {status}")
                return True
                
            # Get the current SQL
            sql = extract_sql_from_text(content)
            if sql:
                # Track attempt for this SQL
                if sql != self._last_sql:
                    self._last_sql = sql
                    self._refinement_attempts = 1
                else:
                    self._refinement_attempts += 1
                    
                print(f"[Termination] Refinement attempt {self._refinement_attempts}/{self.max_refinement_attempts}")
                
                # Terminate if max attempts reached
                if self._refinement_attempts >= self.max_refinement_attempts:
                    print(f"[Termination] Max refinements ({self.max_refinement_attempts}) reached")
                    return True
                    
        # Continue if no termination condition met
        return False

In [None]:
def selector_func(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> str | None:
    """
    Determines which agent should receive the next message.
    
    This function implements the workflow logic:
    1. Initial message -> Schema Selector Agent
    2. Schema Selector output -> Decomposer Agent
    3. Decomposer output -> Refiner Agent
    4. Refiner output -> Either Refiner again (for refinement) or terminate
    
    Args:
        messages: Sequence of messages in the conversation
        
    Returns:
        Name of the next agent or None to terminate
    """
    if not messages:
        return None
        
    # Get the last message
    last_message = messages[-1]
    
    # Get content as text
    if hasattr(last_message, 'content'):
        content = last_message.content
        if isinstance(content, dict):
            content = str(content)
    else:
        content = str(last_message)
        
    # Get source name
    source_name = getattr(last_message.source, 'name', '') if hasattr(last_message, 'source') else ''
    
    # Parse content as JSON
    try:
        data = parse_json(content)
    except:
        data = {}
        
    print(f"[Selector] Last message from: {source_name}, content type: {type(content)}")
    
    # Initial message handling - start with schema selector
    if len(messages) == 1:
        print(f"[Selector] Starting new task, selecting {SELECTOR_NAME}")
        return SELECTOR_NAME
        
    # Workflow transitions
    if source_name == SELECTOR_NAME:
        # Schema selector -> Decomposer
        print(f"[Selector] Schema processing complete, selecting {DECOMPOSER_NAME}")
        return DECOMPOSER_NAME
        
    elif source_name == DECOMPOSER_NAME:
        # Decomposer -> Refiner
        sql = extract_sql_from_text(content)
        if sql:
            print(f"[Selector] SQL generated, selecting {REFINER_NAME}")
            return REFINER_NAME
        else:
            print(f"[Selector] No SQL found in Decomposer output, terminating")
            return None
            
    elif source_name == REFINER_NAME:
        # Check refiner status
        status = data.get('status', '')
        
        if status in ['EXECUTION_SUCCESSFUL', 'NO_CHANGE_NEEDED']:
            # Successful execution or no change needed - terminate
            print(f"[Selector] SQL execution successful or no change needed, terminating")
            return None
            
        elif status == 'REFINEMENT_NEEDED':
            # Continue refinement
            print(f"[Selector] SQL needs refinement, selecting {REFINER_NAME} again")
            return REFINER_NAME
            
        else:
            # Default to continue refinement
            print(f"[Selector] Continuing with {REFINER_NAME} (status: {status})")
            return REFINER_NAME
            
    # Default to termination for unexpected cases
    print(f"[Selector] No matching rule, terminating")
    return None

In [None]:
# Initialize model client
model_client = OpenAIChatCompletionClient(model="gpt-4o")

# Create Schema Selector Agent
selector_agent = AssistantAgent(
    name=SELECTOR_NAME,
    model_client=model_client,
    system_message=f"""You are a Database Schema Selector specialized in analyzing database schemas for text-to-SQL tasks.

Your job is to help prune large database schemas to focus on the relevant tables and columns for a given query.

TASK OVERVIEW:
1. You will receive a task with database ID, query, and evidence
2. Use the 'get_initial_database_schema' tool to retrieve the full schema
3. Analyze the schema complexity and relevance to the query
4. For complex schemas, determine which tables and columns are relevant
5. Use the 'prune_database_schema' tool to generate a focused schema
6. Return the processed schema information for the next agent

WHEN ANALYZING THE SCHEMA:
- Study the database table structure and relationships
- Identify tables directly mentioned or implied in the query
- Consider foreign key relationships that might be needed
- Follow the pruning guidelines in the following template:
{selector_template}

FORMAT YOUR FINAL RESPONSE AS JSON:
{{
  "db_id": "<database_id>",
  "query": "<natural_language_query>",
  "evidence": "<any_evidence_provided>",
  "pruning_applied": true/false,
  "schema_str": "<schema_description>",
  "fk_str": "<foreign_key_information>"
}}

Remember that high-quality schema selection improves the accuracy of SQL generation.""",
    tools=[get_initial_database_schema, prune_database_schema],
)

# Create Decomposer Agent
decomposer_agent = AssistantAgent(
    name=DECOMPOSER_NAME,
    model_client=model_client,
    system_message=f"""You are a Query Decomposer specialized in converting natural language questions into SQL for the BIRD dataset.

Your job is to analyze a natural language query and relevant database schema, then generate the appropriate SQL query.

TASK OVERVIEW:
1. You will receive a JSON with db_id, query, evidence, and schema information
2. Study the database schema, focusing on tables, columns, and relationships
3. For complex queries, break down the problem into logical steps
4. Generate a clear, efficient SQL query that answers the question
5. Follow the specific query decomposition template for BIRD:
{decompose_template_bird}

IMPORTANT CONSIDERATIONS:
- BIRD queries often require domain knowledge and multiple steps
- Carefully use the evidence provided to understand domain-specific concepts
- Apply type conversion when comparing numeric data (CAST AS REAL/INT)
- Ensure proper handling of NULL values
- Use table aliases (T1, T2, etc.) for clarity, especially in JOINs
- Always use valid SQLite syntax

FORMAT YOUR FINAL RESPONSE AS JSON:
{{
  "db_id": "<database_id>",
  "query": "<natural_language_query>",
  "evidence": "<any_evidence_provided>",
  "sql": "<generated_sql_query>",
  "decomposition": [
    "step1_description", 
    "step2_description",
    ...
  ]
}}

Your goal is to generate SQL that will execute correctly and return the precise information requested.""",
    tools=[],  # SQL generation is the primary LLM task
)

# Create Refiner Agent
refiner_agent = AssistantAgent(
    name=REFINER_NAME,
    model_client=model_client,
    system_message=f"""You are an SQL Refiner specializing in executing and fixing SQL queries for the BIRD dataset.

Your job is to test SQL queries against the database, identify errors, and refine them until they execute successfully.

TASK OVERVIEW:
1. You will receive a JSON with db_id, query, evidence, schema, and SQL
2. Use the 'execute_sql' tool to run the SQL against the database
3. Analyze execution results or errors
4. For errors, refine the SQL using the template:
{refiner_template}
5. For successful execution, validate the results are appropriate for the original query

IMPORTANT CONSIDERATIONS:
- BIRD databases have specific requirements for valid results:
  - Results should not be empty (unless that's the expected answer)
  - Results should not contain NULL values without justification
  - Results should match the expected types and formats
- Focus on SQLite-specific syntax and behaviors
- Pay special attention to:
  - Table and column name quoting (use backticks)
  - Type conversions (CAST AS)
  - JOIN conditions
  - Subquery structure and aliases

FORMAT YOUR FINAL RESPONSE AS JSON:
{{
  "db_id": "<database_id>",
  "query": "<natural_language_query>",
  "evidence": "<any_evidence_provided>",
  "original_sql": "<original_sql>",
  "final_sql": "<refined_sql>",
  "status": "<EXECUTION_SUCCESSFUL|REFINEMENT_NEEDED|NO_CHANGE_NEEDED>",
  "execution_result": "<execution_result_summary>",
  "refinement_explanation": "<explanation_of_changes>"
}}

Your goal is to ensure the SQL query executes successfully and returns relevant results.""",
    tools=[execute_sql],
)

In [None]:
# Set up the SelectorGroupChat with the agents
termination_condition = MacSQLTerminationCondition(max_refinement_attempts=MAX_REFINEMENT_ATTEMPTS)

agents = [selector_agent, decomposer_agent, refiner_agent]

# Create the team
team = SelectorGroupChat(
    agents=agents,
    selector_func=selector_func,
    termination_condition=termination_condition,
)

In [None]:
# Sample BIRD dataset query
# This is a real query from the BIRD dataset testing "excellence rate" calculation
bird_task = {
    "db_id": "california_schools",
    "query": "List school names of charter schools with an SAT excellence rate over the average.",
    "evidence": "Charter schools refers to `Charter School (Y/N)` = 1 in the table frpm; Excellence rate = NumGE1500 / NumTstTakr"
}

In [None]:
# Execute the task with the agent team
# This will process the BIRD query through the schema selection, 
# SQL generation, and refinement stages
await Console(team.run_stream(task=json.dumps(bird_task)))

## Result Analysis

The example BIRD query demonstrates the full workflow:

1. **Schema Selection**: The selector agent analyzes the database schema and identifies the relevant tables (frpm and satscores) for the query about charter schools and SAT scores.

2. **SQL Generation**: The decomposer agent uses the BIRD-specific template to handle the domain knowledge about "excellence rate" being calculated as NumGE1500/NumTstTakr. It generates SQL that first calculates the average excellence rate of charter schools and then selects schools above that average.

3. **SQL Refinement**: The refiner agent executes the SQL against the database and makes any necessary corrections to handle SQLite syntax issues, type conversions, or missing IS NOT NULL checks.

The expected SQL solution for this query is:

```sql
SELECT T2.`sname`
FROM frpm AS T1
JOIN satscores AS T2 ON T1.`CDSCode` = T2.`cds`
WHERE T1.`Charter School (Y/N)` = 1 
  AND T2.`sname` IS NOT NULL
  AND CAST(T2.`NumGE1500` AS REAL) / T2.`NumTstTakr` > (
    SELECT AVG(CAST(T4.`NumGE1500` AS REAL) / T4.`NumTstTakr`)
    FROM frpm AS T3 
    JOIN satscores AS T4 ON T3.`CDSCode` = T4.`cds` 
    WHERE T3.`Charter School (Y/N)` = 1
  )
```

This query demonstrates several key aspects of BIRD dataset queries:

1. Domain knowledge application (charter schools, excellence rate formula)
2. Multiple tables with joins
3. Type conversion (CAST AS REAL)
4. Subquery for comparison with aggregates
5. NULL handling