# SQL Generation Agent

Purpose: Generates a syntactically valid and semantically plausible SQL query based on the processed natural language query part and the linked schema elements.

Input:
- processed_query_part_nl: (string) The natural language for a specific query part
- linked_schema_elements: (list) The relevant schema elements from Agent 2
- proposed_join_paths: (list, optional) From Agent 2
- database_id: (string) Identifier for the target database
- sql_dialect: (string) Specific SQL dialect (e.g., "MySQL", "PostgreSQL", "SQLite")
- dependent_sub_query_results_context: (object, optional) If this query part depends on others
- refinement_instructions: (string, optional) Specific instructions if this is a retry

Output:
- generated_sql: (string) The SQL query
- generation_confidence: (float) Confidence that the SQL correctly represents the query part and schema
- brief_explanation_of_sql_logic: (string) A short natural language explanation of what the SQL is intended to do
- validation_status_self_assessed: (string) "Presumed_Valid", "Potential_Issues_Noted"

In [ ]:
# Import necessary modules
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath('./')))

import asyncio
from typing import List, Dict, Any, Optional, Tuple
import json
import re
import logging

# Import unified schemas
from schemas import (
    # SQL Generation types
    SQLGenerationInput,
    SQLGenerationOutput,
    SchemaLinkingOutput,
    SchemaElement,
    JoinPath,
    
    # Error types
    SQLGenerationError,
    
    # Database schemas
    SCHEMAS
)

logger = logging.getLogger(__name__)

In [ ]:
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_ext.models.openai import OpenAIChatCompletionClient

In [ ]:
# SQL Generation System Prompt
SQL_GENERATOR_SYSTEM_PROMPT = """You are an expert SQL generator. Given a natural language query part and linked schema elements,
generate a syntactically valid and semantically plausible SQL query.

Follow these guidelines:
1. Use only the schema elements provided in the linked_schema_elements
2. Pay attention to the SQL dialect specified (MySQL, PostgreSQL, SQLite)
3. Handle joins appropriately based on proposed_join_paths
4. If refinement_instructions are provided, incorporate them into your query
5. Consider dependent_sub_query_results_context if provided
6. Generate efficient and optimized SQL when possible
7. Use appropriate SQL constructs for the query requirements

For the SQL dialect:
- MySQL: Use DATE_SUB(CURDATE(), INTERVAL n DAY) for date arithmetic
- PostgreSQL: Use CURRENT_DATE - INTERVAL 'n days' for date arithmetic
- SQLite: Use date('now', '-n days') for date arithmetic

Always validate that your generated SQL:
- Uses correct table and column names from the schema
- Has proper JOIN conditions
- Has correct WHERE clause syntax
- Follows the specific dialect's syntax rules

Respond in JSON format with:
{
  "sql_query": "your generated SQL",
  "confidence": 0.0-1.0,
  "explanation": "brief explanation",
  "validation_status": "Presumed_Valid" or "Potential_Issues_Noted",
  "issues": ["list of issues if any"]
}"""

In [None]:
# Helper functions

def extract_xml_content(response: str, tag: str) -> str:
    """Extract content between XML tags."""
    pattern = f"<{tag}>(.*?)</{tag}>"
    match = re.search(pattern, response, re.DOTALL)
    if match:
        return match.group(1).strip()
    return ""

def format_schema_elements_for_prompt(elements: List[SchemaElement]) -> str:
    """Format schema elements for inclusion in the prompt."""
    formatted = []
    tables_seen = set()
    
    for elem in elements:
        if elem.element_type == "table" and elem.table_name not in tables_seen:
            tables_seen.add(elem.table_name)
            formatted.append(f"TABLE: {elem.table_name} (relevance: {elem.relevance_score:.2f})")
        elif elem.element_type == "column":
            col_info = f"COLUMN: {elem.table_name}.{elem.column_name}"
            if elem.data_type:
                col_info += f" [{elem.data_type}]"
            if elem.constraints:
                col_info += f" {', '.join(elem.constraints)}"
            col_info += f" (relevance: {elem.relevance_score:.2f})"
            if elem.value_examples:
                col_info += f" Examples: {', '.join(elem.value_examples[:3])}"
            formatted.append(col_info)
    
    return "\n".join(formatted)

def format_join_paths_for_prompt(join_paths: List[JoinPath]) -> str:
    """Format join paths for inclusion in the prompt."""
    if not join_paths:
        return "No join paths provided."
    
    formatted = []
    for join in join_paths:
        formatted.append(
            f"{join.join_type} JOIN: {join.from_table}.{join.from_column} => "
            f"{join.to_table}.{join.to_column} (confidence: {join.confidence:.2f})"
        )
    
    return "\n".join(formatted)

In [ ]:
class SQLGenerationAgent:
    """Agent that generates SQL queries from natural language and schema elements."""
    
    def __init__(self, config: Optional[Dict] = None):
        """Initialize the SQL Generation Agent."""
        self.config = config or {}
        self.model = self.config.get('model', 'gpt-4o')
        self.model_client = OpenAIChatCompletionClient(model=self.model)
        self.max_retries = self.config.get('max_retries', 3)
    
    async def generate_sql(self, input_data: SQLGenerationInput) -> SQLGenerationOutput:
        """
        Generate SQL query from input data.
        
        Args:
            input_data: SQLGenerationInput with query and schema information
            
        Returns:
            SQLGenerationOutput with generated SQL
            
        Raises:
            SQLGenerationError: If generation fails
        """
        try:
            # Format the prompt
            prompt = self._format_prompt(input_data)
            
            # Send to LLM
            messages = [
                {"role": "system", "content": SQL_GENERATOR_SYSTEM_PROMPT},
                {"role": "user", "content": prompt}
            ]
            
            response = await self.model_client.create(messages=messages)
            content = response.choices[0].message.content
            
            # Parse response
            sql_output = self._parse_response(content)
            
            # Validate and enhance output
            return self._create_output(sql_output, input_data)
            
        except Exception as e:
            logger.error(f"SQL generation failed: {str(e)}")
            raise SQLGenerationError(f"Failed to generate SQL: {str(e)}")
    
    def _format_prompt(self, input_data: SQLGenerationInput) -> str:
        """Format the input data into a prompt for the LLM."""
        # Format schema elements
        schema_elements_formatted = self._format_schema_elements(input_data.linked_schema.relevant_schema_elements)
        
        # Format join paths
        join_paths_formatted = self._format_join_paths(input_data.linked_schema.proposed_join_paths)
        
        prompt = f"""Generate a SQL query for the following:

Query (Natural Language): {input_data.processed_natural_language}

Database ID: {input_data.database_id}
SQL Dialect: {input_data.sql_dialect}

Relevant Schema Elements:
{schema_elements_formatted}

Join Paths:
{join_paths_formatted}

"""
        
        # Add dependent sub-query context if provided
        if input_data.dependent_sub_query_results:
            prompt += f"""
Dependent Sub-Query Results:
{json.dumps(input_data.dependent_sub_query_results, indent=2)}

Use the above sub-query results in your SQL generation.
"""
        
        # Add refinement instructions if provided
        if input_data.refinement_instructions:
            prompt += f"""
Refinement Instructions:
{input_data.refinement_instructions}

Please incorporate these refinements into your SQL query.
"""
        
        prompt += "\nGenerate the SQL query based on the above information."
        
        return prompt
    
    def _format_schema_elements(self, elements: List[SchemaElement]) -> str:
        """Format schema elements for the prompt."""
        formatted = []
        tables_seen = set()
        
        for elem in elements:
            if elem.element_type == "Table" and elem.table_name not in tables_seen:
                tables_seen.add(elem.table_name)
                formatted.append(f"TABLE: {elem.table_name} (relevance: {elem.relevance_score:.2f})")
            elif elem.element_type == "Column":
                col_info = f"COLUMN: {elem.element_name}"
                if elem.data_type:
                    col_info += f" [{elem.data_type}]"
                if elem.constraints:
                    col_info += f" {', '.join(elem.constraints)}"
                col_info += f" (relevance: {elem.relevance_score:.2f})"
                if elem.value_examples:
                    col_info += f" Examples: {', '.join(str(v) for v in elem.value_examples[:3])}"
                formatted.append(col_info)
        
        return "\n".join(formatted)
    
    def _format_join_paths(self, join_paths: List[JoinPath]) -> str:
        """Format join paths for the prompt."""
        if not join_paths:
            return "No join paths needed."
        
        formatted = []
        for join in join_paths:
            formatted.append(
                f"{join.join_type} JOIN: {join.join_condition} "
                f"(confidence: {join.confidence:.2f})"
            )
        
        return "\n".join(formatted)
    
    def _parse_response(self, response: str) -> Dict[str, Any]:
        """Parse the LLM response to extract SQL components."""
        try:
            # Try to parse as JSON first
            json_match = re.search(r'\{.*\}', response, re.DOTALL)
            if json_match:
                result = json.loads(json_match.group())
                return result
            
            # Fallback to structured parsing
            sql_match = re.search(r'<sql_query>(.*?)</sql_query>', response, re.DOTALL)
            confidence_match = re.search(r'<confidence>(.*?)</confidence>', response, re.DOTALL)
            explanation_match = re.search(r'<explanation>(.*?)</explanation>', response, re.DOTALL)
            validation_match = re.search(r'<validation_status>(.*?)</validation_status>', response, re.DOTALL)
            
            result = {
                "sql_query": sql_match.group(1).strip() if sql_match else "",
                "confidence": float(confidence_match.group(1)) if confidence_match else 0.5,
                "explanation": explanation_match.group(1).strip() if explanation_match else "",
                "validation_status": validation_match.group(1).strip() if validation_match else "Potential_Issues_Noted",
                "issues": []
            }
            
            return result
            
        except Exception as e:
            logger.warning(f"Failed to parse response: {e}")
            # Return a basic structure
            return {
                "sql_query": response.strip(),
                "confidence": 0.3,
                "explanation": "Unable to parse structured response",
                "validation_status": "Potential_Issues_Noted",
                "issues": ["Failed to parse LLM response"]
            }
    
    def _create_output(self, parsed_response: Dict, input_data: SQLGenerationInput) -> SQLGenerationOutput:
        """Create the final output object."""
        # Clean and validate SQL
        sql_query = self._clean_sql(parsed_response.get("sql_query", ""))
        
        # Determine validation status
        validation_status = parsed_response.get("validation_status", "Potential_Issues_Noted")
        if validation_status not in ["Presumed_Valid", "Potential_Issues_Noted"]:
            validation_status = "Potential_Issues_Noted"
        
        # Compile issues
        issues = parsed_response.get("issues", [])
        if not sql_query:
            issues.append("Generated SQL is empty")
            validation_status = "Potential_Issues_Noted"
        
        # Create output
        return SQLGenerationOutput(
            sql_query=sql_query,
            generation_confidence=float(parsed_response.get("confidence", 0.5)),
            brief_explanation_of_sql_logic=parsed_response.get("explanation", ""),
            validation_status_self_assessed=validation_status,
            potential_issues=issues if issues else None
        )
    
    def _clean_sql(self, sql: str) -> str:
        """Clean and format the generated SQL."""
        # Remove any markdown code blocks
        sql = re.sub(r'```sql\s*', '', sql)
        sql = re.sub(r'```\s*', '', sql)
        
        # Remove extra whitespace
        sql = ' '.join(sql.split())
        
        # Ensure it ends with semicolon
        if sql and not sql.rstrip().endswith(';'):
            sql = sql.rstrip() + ';'
        
        return sql

In [None]:
# Create SQL generation workflow

async def generate_sql(sql_input: SQLGeneratorInput) -> SQLGeneratorOutput:
    """Generate SQL query from input."""
    
    # Format the prompt
    schema_elements_formatted = format_schema_elements_for_prompt(sql_input.linked_schema_elements)
    join_paths_formatted = format_join_paths_for_prompt(sql_input.proposed_join_paths)
    
    prompt = f"""Generate a SQL query for the following:

Query Part (Natural Language): {sql_input.processed_query_part_nl}

Database ID: {sql_input.database_id}
SQL Dialect: {sql_input.sql_dialect}

Linked Schema Elements:
{schema_elements_formatted}

Proposed Join Paths:
{join_paths_formatted}

"""
    
    if sql_input.dependent_sub_query_results_context:
        prompt += f"\nDependent Sub-Query Context:\n{json.dumps(sql_input.dependent_sub_query_results_context, indent=2)}\n"
    
    if sql_input.refinement_instructions:
        prompt += f"\nRefinement Instructions:\n{sql_input.refinement_instructions}\n"
    
    prompt += "\nGenerate the SQL query based on the above information."
    
    # Create team with termination condition
    termination = MaxMessageTermination(max_messages=2)
    team = RoundRobinGroupChat([sql_generator_agent], termination_condition=termination)
    
    result = await team.run(prompt)
    
    # Parse the response
    last_message = result.messages[-1].content
    
    generated_sql = extract_xml_content(last_message, "sql_query")
    confidence_str = extract_xml_content(last_message, "confidence")
    explanation = extract_xml_content(last_message, "explanation")
    validation_status = extract_xml_content(last_message, "validation_status")
    issues = extract_xml_content(last_message, "issues_if_any")
    
    # Convert confidence to float
    try:
        confidence = float(confidence_str)
    except:
        confidence = 0.5
    
    # Ensure validation status is valid
    if validation_status not in ["Presumed_Valid", "Potential_Issues_Noted"]:
        validation_status = "Potential_Issues_Noted" if issues else "Presumed_Valid"
    
    # Add issues to explanation if noted
    if issues and validation_status == "Potential_Issues_Noted":
        explanation += f" [Issues: {issues}]"
    
    return SQLGeneratorOutput(
        generated_sql=generated_sql,
        generation_confidence=confidence,
        brief_explanation_of_sql_logic=explanation,
        validation_status_self_assessed=validation_status
    )

In [ ]:
# Test with simple example using new interface
async def test_simple_sql_generation():
    # Create mock schema linking output
    schema_elements = [
        SchemaElement(
            element_name="Customers",
            element_type="Table",
            table_name="Customers",
            relevance_score=0.9,
            mapping_rationale="Query mentions customers"
        ),
        SchemaElement(
            element_name="Customers.CustomerName",
            element_type="Column",
            table_name="Customers",
            column_name="CustomerName",
            data_type="VARCHAR(255)",
            relevance_score=0.95,
            mapping_rationale="Query asks for customer names",
            value_examples=["John Smith", "Jane Doe", "Bob Johnson"]
        ),
        SchemaElement(
            element_name="Customers.CustomerID",
            element_type="Column",
            table_name="Customers",
            column_name="CustomerID",
            data_type="INT",
            constraints=["PRIMARY KEY"],
            relevance_score=0.7,
            mapping_rationale="Primary key of Customers table"
        )
    ]
    
    linked_schema = SchemaLinkingOutput(
        relevant_schema_elements=schema_elements,
        proposed_join_paths=[],
        overall_linking_confidence=0.9,
        unresolved_elements_notes=[]
    )
    
    # Create input
    input_data = SQLGenerationInput(
        processed_natural_language="Show all customer names from the customers table",
        linked_schema=linked_schema,
        database_id="test_db",
        sql_dialect="MySQL"
    )
    
    # Initialize agent and generate SQL
    sql_generator = SQLGenerationAgent()
    
    try:
        result = await sql_generator.generate_sql(input_data)
        
        print("Generated SQL:", result.sql_query)
        print("Confidence:", result.generation_confidence)
        print("Explanation:", result.brief_explanation_of_sql_logic)
        print("Validation Status:", result.validation_status_self_assessed)
        if result.potential_issues:
            print("Issues:", result.potential_issues)
        
        return result
    except SQLGenerationError as e:
        print(f"SQL generation failed: {e}")
        return None

# Run the test
simple_result = await test_simple_sql_generation()

In [ ]:
# Test with joins using new interface
async def test_join_sql_generation():
    # Create schema elements for join scenario
    schema_elements = [
        SchemaElement(
            element_name="Orders",
            element_type="Table",
            table_name="Orders",
            relevance_score=0.95,
            mapping_rationale="Query mentions orders"
        ),
        SchemaElement(
            element_name="Customers",
            element_type="Table",
            table_name="Customers",
            relevance_score=0.9,
            mapping_rationale="Query mentions customers"
        ),
        SchemaElement(
            element_name="Orders.OrderDate",
            element_type="Column",
            table_name="Orders",
            column_name="OrderDate",
            data_type="DATE",
            relevance_score=0.85,
            mapping_rationale="Query mentions recent orders"
        ),
        SchemaElement(
            element_name="Orders.CustomerID",
            element_type="Column",
            table_name="Orders",
            column_name="CustomerID",
            data_type="INT",
            constraints=["FOREIGN KEY"],
            relevance_score=0.8
        ),
        SchemaElement(
            element_name="Customers.CustomerName",
            element_type="Column",
            table_name="Customers",
            column_name="CustomerName",
            data_type="VARCHAR(255)",
            relevance_score=0.9
        ),
        SchemaElement(
            element_name="Customers.CustomerID",
            element_type="Column",
            table_name="Customers",
            column_name="CustomerID",
            data_type="INT",
            constraints=["PRIMARY KEY"],
            relevance_score=0.8
        )
    ]
    
    # Create join paths
    join_paths = [
        JoinPath(
            from_table="Orders",
            to_table="Customers",
            from_column="CustomerID",
            to_column="CustomerID",
            join_condition="Orders.CustomerID = Customers.CustomerID",
            join_type="INNER",
            confidence=0.95
        )
    ]
    
    linked_schema = SchemaLinkingOutput(
        relevant_schema_elements=schema_elements,
        proposed_join_paths=join_paths,
        overall_linking_confidence=0.9,
        unresolved_elements_notes=[]
    )
    
    # Create input
    input_data = SQLGenerationInput(
        processed_natural_language="Find customer names who placed orders in the last 30 days",
        linked_schema=linked_schema,
        database_id="test_db",
        sql_dialect="MySQL"
    )
    
    # Generate SQL
    sql_generator = SQLGenerationAgent(config={'model': 'gpt-4o'})
    
    try:
        result = await sql_generator.generate_sql(input_data)
        
        print("Generated SQL:", result.sql_query)
        print("Confidence:", result.generation_confidence)
        print("Explanation:", result.brief_explanation_of_sql_logic)
        print("Validation Status:", result.validation_status_self_assessed)
        
        return result
    except SQLGenerationError as e:
        print(f"SQL generation failed: {e}")
        return None

# Run the test
join_result = await test_join_sql_generation()

In [ ]:
# Test with refinement instructions
async def test_refinement_sql_generation():
    # Reuse the join scenario schema
    schema_elements = [
        SchemaElement(
            element_name="Orders",
            element_type="Table",
            table_name="Orders",
            relevance_score=0.95,
            mapping_rationale="Query mentions orders"
        ),
        SchemaElement(
            element_name="Customers",
            element_type="Table",
            table_name="Customers",
            relevance_score=0.9,
            mapping_rationale="Query mentions customers"
        ),
        SchemaElement(
            element_name="Orders.OrderDate",
            element_type="Column",
            table_name="Orders",
            column_name="OrderDate",
            data_type="DATE",
            relevance_score=0.85,
            mapping_rationale="Query mentions recent orders"
        ),
        SchemaElement(
            element_name="Customers.CustomerName",
            element_type="Column",
            table_name="Customers",
            column_name="CustomerName",
            data_type="VARCHAR(255)",
            relevance_score=0.9
        )
    ]
    
    join_paths = [
        JoinPath(
            from_table="Orders",
            to_table="Customers",
            from_column="CustomerID",
            to_column="CustomerID",
            join_condition="Orders.CustomerID = Customers.CustomerID",
            join_type="INNER",
            confidence=0.95
        )
    ]
    
    linked_schema = SchemaLinkingOutput(
        relevant_schema_elements=schema_elements,
        proposed_join_paths=join_paths,
        overall_linking_confidence=0.9,
        unresolved_elements_notes=[]
    )
    
    # Create input with refinement instructions
    input_data = SQLGenerationInput(
        processed_natural_language="Find customer names who placed orders in the last 30 days",
        linked_schema=linked_schema,
        database_id="test_db",
        sql_dialect="MySQL",
        refinement_instructions="The previous query was missing DISTINCT to avoid duplicates. Also ensure proper date formatting for MySQL using DATE_SUB."
    )
    
    # Generate SQL
    sql_generator = SQLGenerationAgent()
    
    try:
        result = await sql_generator.generate_sql(input_data)
        
        print("Generated SQL with refinements:", result.sql_query)
        print("Confidence:", result.generation_confidence)
        print("Explanation:", result.brief_explanation_of_sql_logic)
        print("Validation Status:", result.validation_status_self_assessed)
        
        return result
    except SQLGenerationError as e:
        print(f"SQL generation failed: {e}")
        return None

# Run the test
refinement_result = await test_refinement_sql_generation()

In [ ]:
# Test with dependent sub-query context
async def test_dependent_sql_generation():
    # Define dependent sub-query context
    dependent_context = {
        "previous_part_result": {
            "sql": "SELECT DISTINCT CustomerID FROM Orders WHERE OrderDate >= DATE_SUB(CURDATE(), INTERVAL 30 DAY)",
            "result_summary": "Returns 42 customer IDs who ordered in last 30 days",
            "column_names": ["CustomerID"],
            "sample_rows": [[1], [2], [3]]  # First 3 rows
        }
    }
    
    # Create schema for dependent query
    schema_elements = [
        SchemaElement(
            element_name="Customers",
            element_type="Table",
            table_name="Customers",
            relevance_score=0.95
        ),
        SchemaElement(
            element_name="Customers.CustomerName",
            element_type="Column",
            table_name="Customers",
            column_name="CustomerName",
            data_type="VARCHAR(255)",
            relevance_score=0.9
        ),
        SchemaElement(
            element_name="Customers.Email",
            element_type="Column",
            table_name="Customers",
            column_name="Email",
            data_type="VARCHAR(255)",
            relevance_score=0.9
        ),
        SchemaElement(
            element_name="Customers.CustomerID",
            element_type="Column",
            table_name="Customers",
            column_name="CustomerID",
            data_type="INT",
            constraints=["PRIMARY KEY"],
            relevance_score=0.8
        )
    ]
    
    linked_schema = SchemaLinkingOutput(
        relevant_schema_elements=schema_elements,
        proposed_join_paths=[],
        overall_linking_confidence=0.9,
        unresolved_elements_notes=[]
    )
    
    # Create input with dependent context
    input_data = SQLGenerationInput(
        processed_natural_language="Get the names and email addresses of customers from the previous query results",
        linked_schema=linked_schema,
        database_id="test_db",
        sql_dialect="MySQL",
        dependent_sub_query_results=dependent_context
    )
    
    # Generate SQL
    sql_generator = SQLGenerationAgent()
    
    try:
        result = await sql_generator.generate_sql(input_data)
        
        print("Generated SQL with dependencies:", result.sql_query)
        print("Confidence:", result.generation_confidence)
        print("Explanation:", result.brief_explanation_of_sql_logic)
        print("Validation Status:", result.validation_status_self_assessed)
        
        return result
    except SQLGenerationError as e:
        print(f"SQL generation failed: {e}")
        return None

# Run the test
dependent_result = await test_dependent_sql_generation()

## Summary

The SQL Generator Agent is designed to:

1. Generate syntactically valid SQL queries based on natural language input
2. Use linked schema elements to construct appropriate queries
3. Handle joins using proposed join paths
4. Support different SQL dialects (MySQL, PostgreSQL, SQLite)
5. Incorporate refinement instructions for iterative improvement
6. Handle dependent sub-query contexts for complex queries
7. Provide confidence scores and self-assessment of validation status
8. Generate explanations of the SQL logic for transparency

The agent can handle:
- Simple SELECT queries
- Complex queries with JOINs
- Queries with WHERE clauses and date filtering
- Queries that depend on previous sub-query results
- Refinements based on feedback