# SQL Execution and Coherence Agent

Purpose: Executes the generated SQL query, retrieves the results, and then analyzes these results for coherence against the original query part's intent and known data characteristics.

Input:
- sql_to_execute: (string) The SQL query from Agent 3
- database_id: (string) Identifier for the target database
- original_query_part_context: (object) Contains the processed_natural_language, extracted_entities_and_intent
- database_schema_description: (object) For context during coherence analysis

Output:
- execution_status: (string) "Success", "Execution_Error_Syntax", "Execution_Error_Semantic", "Execution_Error_Timeout", "Execution_Error_Other"
- query_results: (object, if execution successful)
  - columns: (list of strings) Column names
  - rows: (list of lists/dicts) The data returned
  - row_count: (integer)
- execution_error_details: (string, if execution error)
- performance_metrics: (object, optional) e.g., execution_time_ms
- coherence_assessment: (object)
  - is_result_satisfactory: (boolean)
  - coherence_score: (float)
  - explanation: (string)
- refinement_suggestions_for_orchestrator: (list of objects, if not satisfactory)

In [None]:
# Import necessary modules
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath('./')))

import asyncio
from typing import List, Dict, Any, Optional, Tuple
import json
import re
import time
import logging

# Import unified schemas
from schemas import (
    # SQL Execution types
    SQLExecutionAndCoherenceInput,
    SQLExecutionAndCoherenceOutput,
    SQLExecutionResult,
    QueryResults,
    PerformanceMetrics,
    CoherenceAssessment,
    RefinementSuggestion,
    QueryPart,
    SQLGenerationOutput,
    
    # Error types
    SQLExecutionError,
    CoherenceError,
    
    # Database schemas
    SCHEMAS
)

from sql_executor import SQLExecutor  # Import the SQL executor module

logger = logging.getLogger(__name__)

In [None]:
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_ext.models.openai import OpenAIChatCompletionClient

In [None]:
# Initialize the model client and SQL executor
model_client = OpenAIChatCompletionClient(
    model=MODEL_NAME,
    api_key=API_KEY,
)

sql_executor = SQLExecutor()

In [None]:
# Helper functions

def classify_execution_error(error_message: str) -> str:
    """Classify the type of execution error."""
    error_lower = error_message.lower()
    
    if "syntax" in error_lower or "parse" in error_lower:
        return "Execution_Error_Syntax"
    elif "timeout" in error_lower or "time out" in error_lower:
        return "Execution_Error_Timeout" 
    elif "column" in error_lower or "table" in error_lower or "database" in error_lower:
        return "Execution_Error_Semantic"
    else:
        return "Execution_Error_Other"

def execute_sql_safely(sql: str, database_id: str, timeout: float = 30.0) -> Tuple[Optional[QueryResults], Optional[str], float]:
    """Execute SQL with timeout and error handling."""
    start_time = time.time()
    
    try:
        # Execute the SQL using the SQL executor
        result = sql_executor.safe_execute(sql, database_id, timeout)
        execution_time = (time.time() - start_time) * 1000  # Convert to ms
        
        if result["status"] == "success":
            query_results = QueryResults(
                columns=result["columns"],
                rows=result["rows"],
                row_count=len(result["rows"])
            )
            return query_results, None, execution_time
        else:
            return None, result["error"], execution_time
            
    except Exception as e:
        execution_time = (time.time() - start_time) * 1000
        return None, str(e), execution_time

def format_schema_for_prompt(schema_desc: DatabaseSchemaDescription) -> str:
    """Format database schema for inclusion in prompt."""
    formatted = f"Database: {schema_desc.database_id}\n\n"
    
    formatted += "Tables:\n"
    for table in schema_desc.tables[:10]:  # Limit to first 10 tables
        formatted += f"  - {table.get('name', 'Unknown')}: "
        if 'columns' in table:
            cols = [f"{c.get('name')} ({c.get('type', 'Unknown')})"
                   for c in table['columns'][:5]]  # Limit to first 5 columns
            formatted += ", ".join(cols)
            if len(table['columns']) > 5:
                formatted += ", ..."
        formatted += "\n"
    
    if len(schema_desc.tables) > 10:
        formatted += f"  ... and {len(schema_desc.tables) - 10} more tables\n"
    
    return formatted

In [None]:
# SQL Coherence Analyzer System Prompt
COHERENCE_ANALYZER_SYSTEM_PROMPT = """You are an expert at analyzing SQL query results for coherence with the original query intent.

Given SQL execution results and the original query context, assess whether the results are satisfactory.
Consider:
1. Do the results align with the query intent?
2. Are the columns and data types appropriate?
3. Is the number of rows reasonable given the query?
4. Are there any anomalies in the data?
5. Does the result match the expected data characteristics?

Respond in JSON format with:
{
  "is_satisfactory": true/false,
  "coherence_score": 0.0-1.0,
  "explanation": "detailed explanation",
  "refinement_suggestions": [
    {
      "target_agent": "agent name",
      "feedback": "specific feedback",
      "priority": 1-5
    }
  ]
}"""

In [None]:
class SQLExecutionAndCoherenceAgent:
    """Agent that executes SQL and analyzes result coherence."""
    
    def __init__(self, config: Optional[Dict] = None):
        """Initialize the SQL Execution and Coherence Agent."""
        self.config = config or {}
        self.model = self.config.get('model', 'gpt-4o')
        self.model_client = OpenAIChatCompletionClient(model=self.model)
        
        # Initialize SQL executor
        if 'sql_executor' in self.config:
            self.sql_executor = self.config['sql_executor']
        else:
            self.sql_executor = SQLExecutor(config)
        
        self.timeout = self.config.get('sql_timeout', 30.0)
    
    async def execute_and_check_coherence(
        self, 
        input_data: SQLExecutionAndCoherenceInput
    ) -> SQLExecutionAndCoherenceOutput:
        """
        Execute SQL and check coherence of results.
        
        Args:
            input_data: SQLExecutionAndCoherenceInput with SQL and context
            
        Returns:
            SQLExecutionAndCoherenceOutput with results and coherence analysis
            
        Raises:
            SQLExecutionError: If execution fails
            CoherenceError: If coherence check fails
        """
        try:
            # Execute the SQL
            execution_result = await self._execute_sql(
                input_data.sql_to_execute,
                input_data.database_id
            )
            
            # If execution failed, create appropriate output
            if execution_result.status != "Success":
                return self._create_failure_output(execution_result, input_data)
            
            # Analyze coherence
            coherence_assessment = await self._analyze_coherence(
                execution_result,
                input_data
            )
            
            # Determine if result is satisfactory
            is_satisfactory = coherence_assessment.is_result_satisfactory
            
            # Create refinement suggestions if needed
            refinement_suggestions = None
            if not is_satisfactory:
                refinement_suggestions = self._create_refinement_suggestions(
                    coherence_assessment,
                    execution_result,
                    input_data
                )
            
            # Create final output
            return SQLExecutionAndCoherenceOutput(
                execution_result=execution_result,
                coherence_assessment=coherence_assessment,
                is_final_result_satisfactory_for_part=is_satisfactory,
                refinement_proposals_for_orchestrator=refinement_suggestions
            )
            
        except Exception as e:
            logger.error(f"Execution and coherence check failed: {str(e)}")
            raise SQLExecutionError(f"Failed to execute and analyze: {str(e)}")
    
    async def _execute_sql(self, sql: str, database_id: str) -> SQLExecutionResult:
        """Execute SQL with timeout and error handling."""
        start_time = time.time()
        
        try:
            # Execute using the SQL executor
            result = self.sql_executor.safe_execute(sql, database_id, self.timeout)
            execution_time = (time.time() - start_time) * 1000  # Convert to ms
            
            if result["status"] == "success":
                query_results = QueryResults(
                    columns=result["columns"],
                    rows=result["rows"],
                    row_count=len(result["rows"]),
                    data_types=result.get("data_types", [])
                )
                
                performance_metrics = PerformanceMetrics(
                    execution_time_ms=execution_time,
                    rows_returned=query_results.row_count,
                    query_length=len(sql)
                )
                
                return SQLExecutionResult(
                    status="Success",
                    query_results=query_results,
                    performance_metrics=performance_metrics
                )
            else:
                error_type = self._classify_error(result["error"])
                
                performance_metrics = PerformanceMetrics(
                    execution_time_ms=execution_time,
                    rows_returned=0,
                    query_length=len(sql)
                )
                
                return SQLExecutionResult(
                    status="Error",
                    error_message=result["error"],
                    error_type=error_type,
                    performance_metrics=performance_metrics
                )
                
        except Exception as e:
            execution_time = (time.time() - start_time) * 1000
            
            performance_metrics = PerformanceMetrics(
                execution_time_ms=execution_time,
                rows_returned=0,
                query_length=len(sql)
            )
            
            return SQLExecutionResult(
                status="Error",
                error_message=str(e),
                error_type="Other",
                performance_metrics=performance_metrics
            )
    
    def _classify_error(self, error_message: str) -> str:
        """Classify the type of execution error."""
        error_lower = error_message.lower()
        
        if "syntax" in error_lower or "parse" in error_lower:
            return "Syntax"
        elif "timeout" in error_lower or "time out" in error_lower:
            return "Timeout"
        elif "column" in error_lower or "table" in error_lower or "database" in error_lower:
            return "Semantic"
        else:
            return "Other"
    
    async def _analyze_coherence(
        self,
        execution_result: SQLExecutionResult,
        input_data: SQLExecutionAndCoherenceInput
    ) -> CoherenceAssessment:
        """Analyze coherence of execution results."""
        # Prepare data for coherence analysis
        query_part = input_data.original_query_part
        entities = query_part.extracted_entities_and_intent
        
        # Format the prompt
        prompt = f"""Analyze the coherence of these SQL execution results:

Original Query: {query_part.processed_natural_language}
Query Intent: {json.dumps({
    'metrics': entities.metrics,
    'dimensions': entities.dimensions,
    'filters': entities.filters,
    'primary_goal': entities.primary_goal
}, indent=2)}

SQL Executed: {input_data.sql_to_execute}

Execution Results:
- Columns: {execution_result.query_results.columns}
- Row Count: {execution_result.query_results.row_count}
- Sample Rows (first 5): {execution_result.query_results.rows[:5] if execution_result.query_results.rows else 'No rows'}

Database Schema Context:
{self._format_schema_for_prompt(input_data.database_schema)}

Assess whether these results are coherent with the original query intent.
"""
        
        # Send to LLM
        messages = [
            {"role": "system", "content": COHERENCE_ANALYZER_SYSTEM_PROMPT},
            {"role": "user", "content": prompt}
        ]
        
        try:
            response = await self.model_client.create(messages=messages)
            content = response.choices[0].message.content
            
            # Parse response
            result = self._parse_coherence_response(content)
            
            return CoherenceAssessment(
                is_result_satisfactory=result["is_satisfactory"],
                coherence_score=result["coherence_score"],
                explanation=result["explanation"],
                issues_found=result.get("issues", [])
            )
            
        except Exception as e:
            logger.warning(f"Coherence analysis failed: {e}, using default")
            return CoherenceAssessment(
                is_result_satisfactory=True,
                coherence_score=0.7,
                explanation="Default assessment (analysis failed)",
                issues_found=["Coherence analysis failed"]
            )
    
    def _parse_coherence_response(self, response: str) -> Dict[str, Any]:
        """Parse the coherence analysis response."""
        try:
            # Try to parse as JSON
            json_match = re.search(r'\{.*\}', response, re.DOTALL)
            if json_match:
                result = json.loads(json_match.group())
                return {
                    "is_satisfactory": result.get("is_satisfactory", True),
                    "coherence_score": float(result.get("coherence_score", 0.7)),
                    "explanation": result.get("explanation", ""),
                    "issues": result.get("issues", [])
                }
            
            # Fallback parsing
            is_satisfactory = "satisfactory" in response.lower() and "not" not in response.lower()
            
            score_match = re.search(r'score[:\s]+([0-9.]+)', response, re.IGNORECASE)
            score = float(score_match.group(1)) if score_match else 0.7
            
            return {
                "is_satisfactory": is_satisfactory,
                "coherence_score": score,
                "explanation": response.strip(),
                "issues": []
            }
            
        except Exception as e:
            logger.warning(f"Failed to parse coherence response: {e}")
            return {
                "is_satisfactory": True,
                "coherence_score": 0.7,
                "explanation": "Default assessment",
                "issues": []
            }
    
    def _create_failure_output(
        self,
        execution_result: SQLExecutionResult,
        input_data: SQLExecutionAndCoherenceInput
    ) -> SQLExecutionAndCoherenceOutput:
        """Create output for failed execution."""
        # Create coherence assessment for failure
        coherence_assessment = CoherenceAssessment(
            is_result_satisfactory=False,
            coherence_score=0.0,
            explanation=f"Execution failed: {execution_result.error_message}",
            issues_found=[execution_result.error_message]
        )
        
        # Create refinement suggestions based on error type
        refinement_suggestions = []
        
        if execution_result.error_type == "Syntax":
            refinement_suggestions.append(RefinementSuggestion(
                target_agent="SQLGenerationAgent",
                feedback=f"SQL syntax error: {execution_result.error_message}. Please fix syntax issues.",
                priority=1
            ))
        elif execution_result.error_type == "Semantic":
            refinement_suggestions.append(RefinementSuggestion(
                target_agent="SchemaLinkingAgent",
                feedback=f"Schema-related error: {execution_result.error_message}. Please verify table/column names.",
                priority=1
            ))
        else:
            refinement_suggestions.append(RefinementSuggestion(
                target_agent="SQLGenerationAgent",
                feedback=f"Execution error: {execution_result.error_message}. Please revise the query.",
                priority=2
            ))
        
        return SQLExecutionAndCoherenceOutput(
            execution_result=execution_result,
            coherence_assessment=coherence_assessment,
            is_final_result_satisfactory_for_part=False,
            refinement_proposals_for_orchestrator=refinement_suggestions
        )
    
    def _create_refinement_suggestions(
        self,
        coherence_assessment: CoherenceAssessment,
        execution_result: SQLExecutionResult,
        input_data: SQLExecutionAndCoherenceInput
    ) -> List[RefinementSuggestion]:
        """Create refinement suggestions based on coherence analysis."""
        suggestions = []
        
        # Basic suggestion if coherence is low
        if coherence_assessment.coherence_score < 0.5:
            suggestions.append(RefinementSuggestion(
                target_agent="SQLGenerationAgent",
                feedback="The query results don't align well with the intended query. Please revise to better match the requirements.",
                priority=1
            ))
        
        # Check for specific issues
        if coherence_assessment.issues_found:
            for issue in coherence_assessment.issues_found:
                if "column" in issue.lower() or "schema" in issue.lower():
                    suggestions.append(RefinementSuggestion(
                        target_agent="SchemaLinkingAgent",
                        feedback=f"Schema issue detected: {issue}",
                        priority=1
                    ))
                else:
                    suggestions.append(RefinementSuggestion(
                        target_agent="SQLGenerationAgent",
                        feedback=f"Query issue: {issue}",
                        priority=2
                    ))
        
        return suggestions
    
    def _format_schema_for_prompt(self, schema_data: Dict) -> str:
        """Format database schema for inclusion in prompt."""
        if not schema_data or "tables" not in schema_data:
            return "Schema information not available"
        
        formatted = []
        for table_name, columns in schema_data["tables"].items():
            formatted.append(f"Table: {table_name}")
            if isinstance(columns, dict):
                for col_name, col_info in columns.items():
                    col_type = col_info.get("type", "Unknown")
                    formatted.append(f"  - {col_name}: {col_type}")
            elif isinstance(columns, list):
                for col_info in columns[:5]:  # Limit to first 5 columns
                    if isinstance(col_info, dict):
                        col_name = col_info.get("name", "Unknown")
                        col_type = col_info.get("type", "Unknown")
                        formatted.append(f"  - {col_name}: {col_type}")
        
        return "\n".join(formatted)

In [None]:
# Test with successful execution using new interface
async def test_successful_execution():
    # Create mock query part
    query_part = QueryPart(
        part_id="test_part_1",
        processed_natural_language="Show all customer names from the customers table",
        extracted_entities_and_intent=ExtractedEntitiesAndIntent(
            metrics=["list"],
            dimensions=["customer names"],
            filters=[],
            primary_goal="retrieve",
            confidence=0.9
        ),
        dependencies=[],
        complexity_level="simple"
    )
    
    # Create mock SQL generation output
    sql_generation_output = SQLGenerationOutput(
        sql_query="SELECT customer_name FROM customers",
        generation_confidence=0.95,
        brief_explanation_of_sql_logic="Retrieve all customer names",
        validation_status_self_assessed="Presumed_Valid"
    )
    
    # Create mock database schema
    database_schema = {
        "tables": {
            "customers": {
                "columns": {
                    "customer_id": {"type": "INT"},
                    "customer_name": {"type": "VARCHAR(255)"},
                    "email": {"type": "VARCHAR(255)"}
                }
            }
        }
    }
    
    # Create input
    input_data = SQLExecutionAndCoherenceInput(
        sql_to_execute="SELECT customer_name FROM customers",
        database_id="test_db",
        original_query_part=query_part,
        generated_sql_context=sql_generation_output,
        database_schema=database_schema
    )
    
    # Initialize agent and execute
    executor_agent = SQLExecutionAndCoherenceAgent()
    
    try:
        result = await executor_agent.execute_and_check_coherence(input_data)
        
        print("Execution Status:", result.execution_result.status)
        if result.execution_result.query_results:
            print("Row Count:", result.execution_result.query_results.row_count)
            print("Columns:", result.execution_result.query_results.columns)
        
        print("\nCoherence Assessment:")
        print("Is Satisfactory:", result.coherence_assessment.is_result_satisfactory)
        print("Coherence Score:", result.coherence_assessment.coherence_score)
        print("Explanation:", result.coherence_assessment.explanation)
        
        print("\nFinal Result Satisfactory:", result.is_final_result_satisfactory_for_part)
        
        if result.refinement_proposals_for_orchestrator:
            print("\nRefinement Suggestions:")
            for suggestion in result.refinement_proposals_for_orchestrator:
                print(f"- Target: {suggestion.target_agent}")
                print(f"  Feedback: {suggestion.feedback}")
                print(f"  Priority: {suggestion.priority}")
        
        return result
    except SQLExecutionError as e:
        print(f"Execution failed: {e}")
        return None

# Run the test
success_result = await test_successful_execution()

In [None]:
# Test with syntax error

error_input = ExecutorCoherenceInput(
    sql_to_execute="SELCT customer_name FROM customers",  # Intentional typo
    database_id="test_db",
    original_query_part_context=sample_query_context,
    database_schema_description=sample_schema
)

error_result = await execute_and_assess_coherence(error_input)
print("Execution Status:", error_result.execution_status)
print("Error Details:", error_result.execution_error_details)
if error_result.refinement_suggestions_for_orchestrator:
    for suggestion in error_result.refinement_suggestions_for_orchestrator:
        print(f"Suggestion: Target {suggestion.suggested_target_agent_for_retry}")
        print(f"Feedback: {suggestion.feedback_for_retry}")

In [None]:
# Test with semantic error

semantic_error_input = ExecutorCoherenceInput(
    sql_to_execute="SELECT customer_name FROM nonexistent_table",
    database_id="test_db",
    original_query_part_context=sample_query_context,
    database_schema_description=sample_schema
)

semantic_result = await execute_and_assess_coherence(semantic_error_input)
print("Execution Status:", semantic_result.execution_status)
print("Error Details:", semantic_result.execution_error_details)
if semantic_result.refinement_suggestions_for_orchestrator:
    for suggestion in semantic_result.refinement_suggestions_for_orchestrator:
        print(f"Suggestion: Target {suggestion.suggested_target_agent_for_retry}")
        print(f"Feedback: {suggestion.feedback_for_retry}")

In [None]:
# Test with potential coherence issue

coherence_test_context = QueryPartContext(
    processed_natural_language="Show total sales for each customer",
    extracted_entities_and_intent={
        "intent": "aggregate",
        "entities": {
            "columns": ["sales", "customer"],
            "aggregation": "SUM",
            "grouping": "customer"
        }
    }
)

coherence_test_input = ExecutorCoherenceInput(
    sql_to_execute="SELECT customer_name FROM customers",  # Wrong query for the intent
    database_id="test_db",
    original_query_part_context=coherence_test_context,
    database_schema_description=sample_schema
)

coherence_result = await execute_and_assess_coherence(coherence_test_input)
print("Execution Status:", coherence_result.execution_status)
if coherence_result.coherence_assessment:
    print("Coherence Score:", coherence_result.coherence_assessment.coherence_score)
    print("Is Satisfactory:", coherence_result.coherence_assessment.is_result_satisfactory)
    print("Explanation:", coherence_result.coherence_assessment.explanation)
if coherence_result.refinement_suggestions_for_orchestrator:
    for suggestion in coherence_result.refinement_suggestions_for_orchestrator:
        print(f"Suggestion: Target {suggestion.suggested_target_agent_for_retry}")
        print(f"Feedback: {suggestion.feedback_for_retry}")

## Summary

The SQL Execution and Coherence Agent is designed to:

1. Execute SQL queries against the specified database
2. Handle various error types (syntax, semantic, timeout)
3. Collect performance metrics
4. Analyze result coherence with the original query intent
5. Provide refinement suggestions when results are unsatisfactory

Key features:
- Safe SQL execution with timeout handling
- Error classification and detailed error reporting
- Coherence analysis using LLM assessment
- Structured refinement suggestions for retry attempts
- Performance metrics collection

The agent can handle:
- Successful query execution with coherence analysis
- Syntax errors with suggestions for SQL generation fixes
- Semantic errors with suggestions for schema linking fixes
- Coherence issues with targeted refinement suggestions