# SQL Generator/Refiner Agent Test (Updated)

This notebook implements and tests the SQL Generator/Refiner Agent that focuses only on generating new SQL queries or refining existing ones based on errors.

In [None]:
from dotenv import load_dotenv
import json
import re
from typing import Dict, Any, List, Optional
from dataclasses import dataclass

load_dotenv()

In [None]:
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.ui import Console
from autogen_ext.models.openai import OpenAIChatCompletionClient
from schema_manager import SchemaManager
from sql_executor import SQLExecutor

## Define SQL Generation and Refinement Prompts

In [None]:
# SQL Generation template (from const.py)
SQL_GENERATION_TEMPLATE = """Given a 【Database schema】 description, a knowledge 【Evidence】 and the 【Question】, you need to use valid SQLite and understand the database and knowledge, and then generate SQL.
You can write answer in script blocks, and indicate script type in it, like this:
```sql
SELECT column_a
FROM table_b
```
When generating SQL, we should always consider constraints:
【Constraints】
- In `SELECT <column>`, just select needed columns in the 【Question】 without any unnecessary column or value
- In `FROM <table>` or `JOIN <table>`, do not include unnecessary table
- If use max or min func, `JOIN <table>` FIRST, THEN use `SELECT MAX(<column>)` or `SELECT MIN(<column>)`
- If [Value examples] of <column> has 'None' or None, use `JOIN <table>` or `WHERE <column> is NOT NULL` is better
- If use `ORDER BY <column> ASC|DESC`, add `GROUP BY <column>` before to select distinct values

Now let's start!

【Database schema】
{desc_str}
【Foreign keys】
{fk_str}
【Question】
{query}
【Evidence】
{evidence}
【Answer】"""

# SQL Refinement template (from const.py)
SQL_REFINER_TEMPLATE = """【Instruction】
When executing SQL below, some errors occurred, please fix up SQL based on query and database info.
Solve the task step by step if you need to. Using SQL format in the code block, and indicate script type in the code block.
When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.
【Constraints】
- In `SELECT <column>`, just select needed columns in the 【Question】 without any unnecessary column or value
- In `FROM <table>` or `JOIN <table>`, do not include unnecessary table
- If use max or min func, `JOIN <table>` FIRST, THEN use `SELECT MAX(<column>)` or `SELECT MIN(<column>)`
- If [Value examples] of <column> has 'None' or None, use `JOIN <table>` or `WHERE <column> is NOT NULL` is better
- If use `ORDER BY <column> ASC|DESC`, add `GROUP BY <column>` before to select distinct values
【Query】
-- {query}
【Evidence】
{evidence}
【Database info】
{desc_str}
【Foreign keys】
{fk_str}
【old SQL】
```sql
{sql}
```
【SQLite error】 
{sqlite_error}
【Exception class】
{exception_class}

Now please fixup old SQL and generate new SQL again.
【correct SQL】"""

In [None]:
# System prompt for the SQL Generator/Refiner Agent
SQL_GENERATOR_SYSTEM_PROMPT = """You are an expert SQL Generator and Refiner Agent specializing in creating and fixing SQL queries.

Your capabilities:
1. Generate SQL queries from natural language questions
2. Refine existing SQL queries to fix errors
3. Work with various SQL dialects (focusing on SQLite)
4. Handle complex joins and aggregations

Guidelines:
- Always follow the constraints specified in the prompts
- Generate clean, readable SQL with proper formatting
- Use table aliases for clarity in joins
- Avoid selecting unnecessary columns
- Handle NULL values appropriately
- Write comments for complex queries

When refining SQL:
- Carefully analyze the error message
- Fix syntax errors
- Correct table/column references
- Ensure proper join conditions
- Validate aggregate functions"""

## Implement the SQL Generator/Refiner Agent

In [None]:
class SQLGeneratorAgent:
    """Agent that generates and refines SQL queries."""
    
    def __init__(self, schema_manager: SchemaManager, sql_executor: SQLExecutor, model: str = "gpt-4o"):
        self.schema_manager = schema_manager
        self.sql_executor = sql_executor
        self.model_client = OpenAIChatCompletionClient(model=model)
        self.agent = self._create_agent()
    
    def _create_agent(self) -> AssistantAgent:
        """Create the SQL generator/refiner agent with tools."""
        
        async def generate_sql(query: str, database_id: str, evidence: str = "", selected_tables: Optional[List[str]] = None) -> str:
            """Generate SQL from a natural language query."""
            
            # Get schema information
            if selected_tables:
                selected_schema = {table: "keep_all" for table in selected_tables}
            else:
                # Get all tables if not specified
                db_info = self.schema_manager.db2dbjsons.get(database_id, {})
                selected_schema = {table: "keep_all" for table in db_info.get('table_names_original', [])}
            
            # Generate schema description
            desc_str, fk_infos, chosen_schema = self.schema_manager.generate_schema_description(
                database_id, selected_schema, use_gold_schema=False
            )
            
            # Format foreign keys
            fk_str = "\n".join(fk_infos) if fk_infos else "No foreign keys"
            
            # Generate SQL using the template
            prompt = SQL_GENERATION_TEMPLATE.format(
                desc_str=desc_str,
                fk_str=fk_str,
                query=query,
                evidence=evidence
            )
            
            # Get LLM response
            messages = [
                {"role": "system", "content": SQL_GENERATOR_SYSTEM_PROMPT},
                {"role": "user", "content": prompt}
            ]
            
            response = await self.model_client.create(messages=messages)
            sql_response = response.choices[0].message.content
            
            # Extract SQL from response
            sql = self._extract_sql_from_response(sql_response)
            
            return json.dumps({
                "sql": sql,
                "database": database_id,
                "tables_used": list(selected_schema.keys())
            }, indent=2)
        
        async def refine_sql(sql: str, query: str, database_id: str, error_message: str, exception_class: str = "", evidence: str = "") -> str:
            """Refine an existing SQL query based on errors."""
            
            # Get schema information
            db_info = self.schema_manager.db2dbjsons.get(database_id, {})
            selected_schema = {table: "keep_all" for table in db_info.get('table_names_original', [])}
            
            # Generate schema description
            desc_str, fk_infos, chosen_schema = self.schema_manager.generate_schema_description(
                database_id, selected_schema, use_gold_schema=False
            )
            
            # Format foreign keys
            fk_str = "\n".join(fk_infos) if fk_infos else "No foreign keys"
            
            # Use refinement template
            prompt = SQL_REFINER_TEMPLATE.format(
                query=query,
                evidence=evidence,
                desc_str=desc_str,
                fk_str=fk_str,
                sql=sql,
                sqlite_error=error_message,
                exception_class=exception_class
            )
            
            # Get LLM response
            messages = [
                {"role": "system", "content": SQL_GENERATOR_SYSTEM_PROMPT},
                {"role": "user", "content": prompt}
            ]
            
            response = await self.model_client.create(messages=messages)
            refined_response = response.choices[0].message.content
            
            # Extract refined SQL
            refined_sql = self._extract_sql_from_response(refined_response)
            
            return json.dumps({
                "original_sql": sql,
                "refined_sql": refined_sql,
                "error_fixed": error_message,
                "database": database_id
            }, indent=2)
        
        # Create the agent - only generation and refinement
        return AssistantAgent(
            name="sql_generator",
            model_client=self.model_client,
            tools=[
                generate_sql,
                refine_sql
            ],
            system_message=SQL_GENERATOR_SYSTEM_PROMPT,
            reflect_on_tool_use=True,
            model_client_stream=True,
        )
    
    def _extract_sql_from_response(self, response: str) -> str:
        """Extract SQL from LLM response."""
        # Look for SQL in code blocks
        sql_pattern = r'```sql\n(.*?)\n```'
        matches = re.findall(sql_pattern, response, re.DOTALL)
        
        if matches:
            return matches[0].strip()
        
        # Look for SQL after certain markers
        for marker in ['【correct SQL】', '【Answer】', 'SQL:', 'Query:']:
            if marker in response:
                sql_part = response.split(marker)[-1].strip()
                # Extract until the next marker or end
                for end_marker in ['【', '```', '\n\n']:
                    if end_marker in sql_part:
                        sql_part = sql_part.split(end_marker)[0]
                return sql_part.strip()
        
        # If no specific format found, try to extract SELECT statement
        lines = response.split('\n')
        sql_lines = []
        in_sql = False
        
        for line in lines:
            if 'SELECT' in line.upper():
                in_sql = True
            if in_sql:
                sql_lines.append(line)
                if ';' in line:
                    break
        
        return '\n'.join(sql_lines).strip()
    
    async def query(self, task: str) -> None:
        """Run a query against the SQL generator agent."""
        await Console(self.agent.run_stream(task=task))
    
    async def close(self) -> None:
        """Close the model client connection."""
        await self.model_client.close()

## Initialize Components and Test

In [None]:
# Initialize components
data_path = "../data/bird/dev_databases"
tables_json_path = "../data/bird/dev_tables.json"
dataset_name = "bird"

# Create schema manager
schema_manager = SchemaManager(
    data_path=data_path,
    tables_json_path=tables_json_path,
    dataset_name=dataset_name,
    lazy=True
)

# Create SQL executor
sql_executor = SQLExecutor(
    data_path=data_path,
    dataset_name=dataset_name
)

# Create SQL generator agent
sql_generator = SQLGeneratorAgent(schema_manager, sql_executor)

### Test SQL Generation

In [None]:
# Test generating SQL from a simple query
await sql_generator.query("""
Generate SQL for this query on database 'california_schools':
Query: Show all schools in Alameda county
""")

In [None]:
# Test generating SQL for a complex query
await sql_generator.query("""
Generate SQL for this query on database 'california_schools':
Query: List school names of charter schools with an SAT excellence rate over the average
Evidence: Charter schools refers to `Charter School (Y/N)` = 1; Excellence rate = NumGE1500 / NumTstTakr
""")

### Test SQL Refinement

In [None]:
# Test refining SQL with an error
error_sql = "SELECT School FROM schools WHERE County = 'Alameda'"
error_message = "no such column: School"

await sql_generator.query(f"""
Refine this SQL that has an error:
SQL: {error_sql}
Query: Show all schools in Alameda county
Database: california_schools
Error: {error_message}
Exception: OperationalError
""")

In [None]:
# Test refining complex SQL with error
complex_error_sql = """
SELECT s.School, AVG(sat.NumGE1500 / sat.NumTstTakr) as excellence_rate
FROM schools s
JOIN satscores sat ON s.CDSCode = sat.cds
WHERE s.`Charter School (Y/N)` = 1
GROUP BY s.School
HAVING excellence_rate > AVG(excellence_rate)
"""

await sql_generator.query(f"""
Refine this SQL that has an error:
SQL: {complex_error_sql}
Query: List charter schools with SAT excellence rate over the average
Database: california_schools
Error: misuse of aggregate function AVG()
Evidence: Charter schools refers to `Charter School (Y/N)` = 1; Excellence rate = NumGE1500 / NumTstTakr
""")

In [None]:
# Close connections
await sql_generator.close()