In [1]:
import asyncio
import json
import os
from typing import Dict, Any, List

# Autogen imports
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_core import CancellationToken
from autogen_agentchat.messages import TextMessage

# Import from project modules
from const import (
    SELECTOR_NAME, DECOMPOSER_NAME, REFINER_NAME, SYSTEM_NAME,
    selector_template, decompose_template_bird, refiner_template
)
from schema_manager import SchemaManager
from sql_executor import SQLExecutor

# Import utilities from utility modules
from workflow_utils import (
    display_task_info, display_sql_result,
    select_schema, generate_sql, refine_sql, process_text_to_sql
)
from workflow_runners import (
    run_text_to_sql_step_by_step, run_complete_text_to_sql, run_all_tests,
    DEFAULT_TIMEOUT, MAX_REFINEMENT_ATTEMPTS
)

# Set constants
BIRD_DATA_PATH = "../data/bird"
BIRD_TABLES_JSON_PATH = os.path.join(BIRD_DATA_PATH, "dev_tables.json")
DATASET_NAME = "bird"

In [2]:
schema_manager = SchemaManager(
    data_path=BIRD_DATA_PATH,
    tables_json_path=BIRD_TABLES_JSON_PATH,
    dataset_name=DATASET_NAME,
    lazy=False  # Use lazy loading for performance
)

# Replace the original SQLExecutor with our patched version
sql_executor = SQLExecutor(
    data_path=BIRD_DATA_PATH,
    dataset_name=DATASET_NAME
)

load json file from ../data/bird/dev_tables.json

Loading all database info...
Found 11 databases in bird dataset


In [3]:
# Define test cases for BIRD dataset
bird_test_cases = [
    # Test case 1: Excellence rate calculation (basic join with aggregation)
    {
        "db_id": "california_schools",
        "query": "List school names of charter schools with an SAT excellence rate over the average.",
        "evidence": "Charter schools refers to `Charter School (Y/N)` = 1 in the table frpm; Excellence rate = NumGE1500 / NumTstTakr"
    },
    
    # Test case 2: Multi-table query with numeric conditions (multiple joins)
    {
        "db_id": "game_injury",
        "query": "Show the names of players who have been injured for more than 3 matches in the 2010 season.",
        "evidence": "Season info is in the game table with year 2010; injury severity is measured by the number of matches a player misses."
    },
    
    # Test case 3: Complex aggregation with grouping
    {
        "db_id": "formula_1",
        "query": "What is the name of the driver who has won the most races in rainy conditions?",
        "evidence": "Weather conditions are recorded in the races table; winner information is in the results table."
    },
    
    # Test case 4: Temporal query with date handling
    {
        "db_id": "loan_data",
        "query": "Find the customer with the highest total payment amount for loans taken in the first quarter of 2011.",
        "evidence": "First quarter means January to March (months 1-3); loan dates are stored in ISO format (YYYY-MM-DD)."
    }
]

# Select the test case to run (0-3)
test_idx = 0
current_test = bird_test_cases[test_idx]

# Display test case information
display_task_info(current_test)


DATABASE: california_schools
QUERY: List school names of charter schools with an SAT excellence rate over the average.
EVIDENCE: Charter schools refers to `Charter School (Y/N)` = 1 in the table frpm; Excellence rate = NumGE1500 / NumTstTakr



In [4]:
# Initialize SchemaManager and SQLExecutor
schema_manager = SchemaManager(
    data_path=BIRD_DATA_PATH,
    tables_json_path=BIRD_TABLES_JSON_PATH,
    dataset_name=DATASET_NAME,
    lazy=False
)

sql_executor = SQLExecutor(
    data_path=BIRD_DATA_PATH,
    dataset_name=DATASET_NAME
)

# Tool implementation for schema selection and management
async def get_initial_database_schema(db_id: str) -> str:
    """Retrieves the full database schema information for a given database."""
    print(f"[Tool] Loading schema for database: {db_id}")
    
    # Load database information using SchemaManager
    if db_id not in schema_manager.db2infos:
        schema_manager.db2infos[db_id] = schema_manager._load_single_db_info(db_id)
    
    # Get database information
    db_info = schema_manager.db2dbjsons.get(db_id, {})
    if not db_info:
        return json.dumps({"error": f"Database '{db_id}' not found"})
    
    # Generate full schema description (without pruning)
    is_complex = schema_manager._is_complex_schema(db_id)
    full_schema_str, full_fk_str, _ = schema_manager.generate_schema_description(
        db_id, {}, use_gold_schema=False
    )
    
    # Return schema details
    return json.dumps({
        "db_id": db_id,
        "table_count": db_info.get('table_count', 0),
        "total_column_count": db_info.get('total_column_count', 0),
        "avg_column_count": db_info.get('avg_column_count', 0),
        "is_complex_schema": is_complex,
        "full_schema_str": full_schema_str,
        "full_fk_str": full_fk_str
    })

async def prune_database_schema(db_id: str, pruning_rules: Dict) -> str:
    """Applies pruning rules to a database schema."""
    print(f"[Tool] Pruning schema for database {db_id}")
    
    # Generate pruned schema description
    schema_str, fk_str, chosen_schema = schema_manager.generate_schema_description(
        db_id, pruning_rules, use_gold_schema=False
    )
    
    # Return pruned schema
    return json.dumps({
        "db_id": db_id,
        "pruning_applied": True,
        "pruning_rules": pruning_rules,
        "pruned_schema_str": schema_str,
        "pruned_fk_str": fk_str,
        "tables_columns_kept": chosen_schema
    })

# Tool implementation for SQL execution
async def execute_sql(sql: str, db_id: str) -> str:
    """Executes a SQL query on the specified database."""
    print(f"[Tool] Executing SQL on database {db_id}: {sql[:100]}...")
    
    # Execute SQL with timeout protection
    result = sql_executor.safe_execute(sql, db_id)
    
    # Add validation information
    is_valid, reason = sql_executor.is_valid_result(result)
    result["is_valid_result"] = is_valid
    result["validation_message"] = reason
    
    # Convert to JSON string
    return json.dumps(result)

load json file from ../data/bird/dev_tables.json

Loading all database info...
Found 11 databases in bird dataset


In [5]:
# Initialize model client
model_client = OpenAIChatCompletionClient(model="gpt-4o")

# Create Schema Selector Agent
selector_agent = AssistantAgent(
    name=SELECTOR_NAME,
    model_client=model_client,
    system_message=f"""You are a Database Schema Selector specialized in analyzing database schemas for text-to-SQL tasks.

Your job is to help prune large database schemas to focus on the relevant tables and columns for a given query.

TASK OVERVIEW:
1. You will receive a task with database ID, query, and evidence
2. Use the 'get_initial_database_schema' tool to retrieve the full schema
3. Analyze the schema complexity and relevance to the query
4. For complex schemas, determine which tables and columns are relevant
5. Use the 'prune_database_schema' tool to generate a focused schema
6. Return the processed schema information for the next agent

WHEN ANALYZING THE SCHEMA:
- Study the database table structure and relationships
- Identify tables directly mentioned or implied in the query
- Consider foreign key relationships that might be needed
- Follow the pruning guidelines in the following template:
{selector_template}

FORMAT YOUR FINAL RESPONSE AS JSON:
{{
  "db_id": "<database_id>",
  "query": "<natural_language_query>",
  "evidence": "<any_evidence_provided>",
  "pruning_applied": true/false,
  "schema_str": "<schema_description>",
  "fk_str": "<foreign_key_information>"
}}

Remember that high-quality schema selection improves the accuracy of SQL generation.""",
    tools=[get_initial_database_schema, prune_database_schema],
)

# Create Decomposer Agent
decomposer_agent = AssistantAgent(
    name=DECOMPOSER_NAME,
    model_client=model_client,
    system_message=f"""You are a Query Decomposer specialized in converting natural language questions into SQL for the BIRD dataset.

Your job is to analyze a natural language query and relevant database schema, then generate the appropriate SQL query.

TASK OVERVIEW:
1. You will receive a JSON with db_id, query, evidence, and schema information
2. Study the database schema, focusing on tables, columns, and relationships
3. For complex queries, break down the problem into logical steps
4. Generate a clear, efficient SQL query that answers the question
5. Follow the specific query decomposition template for BIRD:
{decompose_template_bird}

IMPORTANT CONSIDERATIONS:
- BIRD queries often require domain knowledge and multiple steps
- Carefully use the evidence provided to understand domain-specific concepts
- Apply type conversion when comparing numeric data (CAST AS REAL/INT)
- Ensure proper handling of NULL values
- Use table aliases (T1, T2, etc.) for clarity, especially in JOINs
- Always use valid SQLite syntax

FORMAT YOUR FINAL RESPONSE AS JSON:
{{
  "db_id": "<database_id>",
  "query": "<natural_language_query>",
  "evidence": "<any_evidence_provided>",
  "sql": "<generated_sql_query>",
  "decomposition": [
    "step1_description", 
    "step2_description",
    ...
  ]
}}

Your goal is to generate SQL that will execute correctly and return the precise information requested.""",
    tools=[],  # SQL generation is the primary LLM task
)

# Create Refiner Agent
refiner_agent = AssistantAgent(
    name=REFINER_NAME,
    model_client=model_client,
    system_message=f"""You are an SQL Refiner specializing in executing and fixing SQL queries for the BIRD dataset.

Your job is to test SQL queries against the database, identify errors, and refine them until they execute successfully.

TASK OVERVIEW:
1. You will receive a JSON with db_id, query, evidence, schema, and SQL
2. Use the 'execute_sql' tool to run the SQL against the database
3. Analyze execution results or errors
4. For errors, refine the SQL using the template:
{refiner_template}
5. For successful execution, validate the results are appropriate for the original query

IMPORTANT CONSIDERATIONS:
- BIRD databases have specific requirements for valid results:
  - Results should not be empty (unless that's the expected answer)
  - Results should not contain NULL values without justification
  - Results should match the expected types and formats
- Focus on SQLite-specific syntax and behaviors
- Pay special attention to:
  - Table and column name quoting (use backticks)
  - Type conversions (CAST AS)
  - JOIN conditions
  - Subquery structure and aliases

FORMAT YOUR FINAL RESPONSE AS JSON:
{{
  "db_id": "<database_id>",
  "query": "<natural_language_query>",
  "evidence": "<any_evidence_provided>",
  "original_sql": "<original_sql>",
  "final_sql": "<refined_sql>",
  "status": "<EXECUTION_SUCCESSFUL|REFINEMENT_NEEDED|NO_CHANGE_NEEDED>",
  "execution_result": "<execution_result_summary>",
  "refinement_explanation": "<explanation_of_changes>"
}}

Your goal is to ensure the SQL query executes successfully and returns relevant results.""",
    tools=[execute_sql],
)