# SQL Evaluation Agent Test

This notebook implements and tests the SQL Evaluation Agent that validates and optimizes SQL queries.

In [None]:
from dotenv import load_dotenv
import json
import re
from typing import Dict, Any, List, Optional
from dataclasses import dataclass

load_dotenv()

In [None]:
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.ui import Console
from autogen_ext.models.openai import OpenAIChatCompletionClient
from schema_manager import SchemaManager
from sql_executor import SQLExecutor

## Define SQL Evaluation System Prompt

In [None]:
SQL_EVALUATOR_SYSTEM_PROMPT = """You are an expert SQL Evaluation Agent specializing in validating and optimizing SQL queries.

Your capabilities:
1. Validate SQL syntax without execution
2. Check SQL queries for common errors and issues
3. Optimize SQL queries for better performance
4. Provide recommendations for query improvements
5. Analyze query complexity and efficiency

Guidelines for validation:
- Check for syntax errors
- Verify table and column references
- Ensure proper JOIN conditions
- Validate aggregate function usage
- Check for SQL injection vulnerabilities
- Verify data type compatibility

Guidelines for optimization:
- Identify inefficient query patterns
- Suggest better indexing strategies
- Recommend query rewrites for performance
- Identify unnecessary operations
- Suggest appropriate use of indexes
- Recommend optimal JOIN orders

Always provide actionable recommendations with explanations."""

## Implement the SQL Evaluation Agent

In [None]:
class SQLEvaluationAgent:
    """Agent that validates and optimizes SQL queries."""
    
    def __init__(self, schema_manager: SchemaManager, model: str = "gpt-4o"):
        self.schema_manager = schema_manager
        self.model_client = OpenAIChatCompletionClient(model=model)
        self.agent = self._create_agent()
    
    def _create_agent(self) -> AssistantAgent:
        """Create the SQL evaluation agent with tools."""
        
        async def validate_sql_syntax(sql: str, database_id: str) -> str:
            """Validate SQL syntax without executing."""
            # Basic syntax validation
            sql_upper = sql.upper()
            
            issues = []
            warnings = []
            
            # Check for basic SQL structure
            if not "SELECT" in sql_upper:
                issues.append("Missing SELECT clause")
            
            if not "FROM" in sql_upper:
                issues.append("Missing FROM clause")
            
            # Check for common syntax errors
            if sql.count('(') != sql.count(')'):
                issues.append("Unmatched parentheses")
            
            if sql.count("'") % 2 != 0:
                issues.append("Unmatched quotes")
            
            # Check for potential SQL injection patterns
            dangerous_patterns = ['--', '/*', '*/', 'xp_', 'sp_', 'exec', 'execute']
            for pattern in dangerous_patterns:
                if pattern in sql.lower():
                    warnings.append(f"Potential SQL injection pattern detected: {pattern}")
            
            # Check table references against schema
            try:
                db_info = self.schema_manager.db2dbjsons.get(database_id, {})
                available_tables = [t.lower() for t in db_info.get('table_names_original', [])]
                
                # Extract table names from SQL
                # Simple pattern to find tables after FROM and JOIN
                table_pattern = r'(?:FROM|JOIN)\s+([\w\.]+)'
                found_tables = re.findall(table_pattern, sql, re.IGNORECASE)
                
                for table in found_tables:
                    table_name = table.split('.')[-1].lower()  # Handle db.table format
                    if table_name not in available_tables and not table_name.startswith('('):
                        issues.append(f"Table '{table}' not found in database schema")
                
                # Check column references
                if db_info and not issues:  # Only check columns if tables are valid
                    column_pattern = r'SELECT\s+(.+?)\s+FROM'
                    select_match = re.search(column_pattern, sql, re.IGNORECASE | re.DOTALL)
                    if select_match:
                        select_clause = select_match.group(1)
                        # Basic column validation (simplified)
                        if '*' not in select_clause:
                            columns = [col.strip() for col in select_clause.split(',')]
                            available_columns = db_info.get('column_names_original', [])
                            # This is a simplified check - in practice would need more sophisticated parsing
            except Exception as e:
                warnings.append(f"Unable to validate against schema: {str(e)}")
            
            # Advanced validation using LLM
            validation_prompt = f"""
            Validate this SQL query for syntax and best practices:
            
            SQL:
            {sql}
            
            Database: {database_id}
            
            Check for:
            1. Syntax errors
            2. Logic errors
            3. Performance issues
            4. Best practice violations
            
            Provide a detailed analysis.
            """
            
            messages = [
                {"role": "system", "content": SQL_EVALUATOR_SYSTEM_PROMPT},
                {"role": "user", "content": validation_prompt}
            ]
            
            response = await self.model_client.create(messages=messages)
            llm_analysis = response.choices[0].message.content
            
            return json.dumps({
                "sql": sql,
                "is_valid": len(issues) == 0,
                "issues": issues,
                "warnings": warnings,
                "llm_analysis": llm_analysis
            }, indent=2)
        
        async def analyze_query_complexity(sql: str, database_id: str) -> str:
            """Analyze the complexity of a SQL query."""
            sql_upper = sql.upper()
            
            complexity_factors = {
                "joins": sql_upper.count('JOIN'),
                "subqueries": sql.count('(SELECT'),
                "aggregations": sum(1 for func in ['COUNT', 'SUM', 'AVG', 'MAX', 'MIN'] if func in sql_upper),
                "group_by": 1 if 'GROUP BY' in sql_upper else 0,
                "having": 1 if 'HAVING' in sql_upper else 0,
                "distinct": 1 if 'DISTINCT' in sql_upper else 0,
                "union": sql_upper.count('UNION'),
                "case_when": sql_upper.count('CASE WHEN'),
                "cte": sql_upper.count('WITH')
            }
            
            # Calculate complexity score
            complexity_score = (
                complexity_factors['joins'] * 2 +
                complexity_factors['subqueries'] * 3 +
                complexity_factors['aggregations'] +
                complexity_factors['group_by'] * 2 +
                complexity_factors['having'] * 2 +
                complexity_factors['distinct'] +
                complexity_factors['union'] * 3 +
                complexity_factors['case_when'] * 2 +
                complexity_factors['cte'] * 2
            )
            
            # Determine complexity level
            if complexity_score <= 2:
                complexity_level = "Simple"
            elif complexity_score <= 6:
                complexity_level = "Moderate"
            elif complexity_score <= 12:
                complexity_level = "Complex"
            else:
                complexity_level = "Very Complex"
            
            # Get detailed analysis from LLM
            analysis_prompt = f"""
            Analyze the complexity of this SQL query:
            
            SQL:
            {sql}
            
            Database: {database_id}
            
            Provide:
            1. Complexity assessment
            2. Potential performance bottlenecks
            3. Optimization opportunities
            4. Readability concerns
            """
            
            messages = [
                {"role": "system", "content": SQL_EVALUATOR_SYSTEM_PROMPT},
                {"role": "user", "content": analysis_prompt}
            ]
            
            response = await self.model_client.create(messages=messages)
            detailed_analysis = response.choices[0].message.content
            
            return json.dumps({
                "sql": sql,
                "complexity_level": complexity_level,
                "complexity_score": complexity_score,
                "complexity_factors": complexity_factors,
                "detailed_analysis": detailed_analysis
            }, indent=2)
        
        async def optimize_sql(sql: str, database_id: str, optimization_goal: str = "performance") -> str:
            """Optimize SQL query for better performance or readability."""
            
            # Quick checks for common optimization opportunities
            quick_optimizations = []
            sql_upper = sql.upper()
            
            # Check for SELECT *
            if "SELECT *" in sql_upper:
                quick_optimizations.append("Replace SELECT * with specific columns to reduce data transfer")
            
            # Check for missing WHERE clause in JOINs
            if "JOIN" in sql_upper and "WHERE" not in sql_upper and "ON" in sql_upper:
                quick_optimizations.append("Consider adding WHERE clause to filter results and improve performance")
            
            # Check for DISTINCT with GROUP BY
            if "DISTINCT" in sql_upper and "GROUP BY" in sql_upper:
                quick_optimizations.append("DISTINCT might be redundant with GROUP BY")
            
            # Check for OR in WHERE clause
            if "WHERE" in sql_upper and " OR " in sql_upper:
                quick_optimizations.append("Consider using IN clause or UNION instead of OR for better index usage")
            
            # Use LLM for sophisticated optimization
            optimization_prompt = f"""
            Optimize this SQL query for {optimization_goal}:
            
            Original SQL:
            {sql}
            
            Database: {database_id}
            
            Provide:
            1. Optimized SQL query
            2. Explanation of changes
            3. Expected performance improvements
            4. Any trade-offs to consider
            
            Focus on:
            - Query performance
            - Index usage
            - Join optimization
            - Reducing data scans
            - Simplifying complex logic
            """
            
            messages = [
                {"role": "system", "content": SQL_EVALUATOR_SYSTEM_PROMPT},
                {"role": "user", "content": optimization_prompt}
            ]
            
            response = await self.model_client.create(messages=messages)
            optimization_response = response.choices[0].message.content
            
            # Extract optimized SQL from response
            optimized_sql = self._extract_sql_from_response(optimization_response)
            
            return json.dumps({
                "original_sql": sql,
                "optimized_sql": optimized_sql,
                "quick_optimizations": quick_optimizations,
                "optimization_goal": optimization_goal,
                "detailed_explanation": optimization_response
            }, indent=2)
        
        async def suggest_indexes(sql: str, database_id: str) -> str:
            """Suggest indexes to improve query performance."""
            
            # Extract columns used in WHERE, JOIN, and ORDER BY clauses
            where_pattern = r'WHERE\s+(.+?)(?:GROUP|ORDER|HAVING|$)'
            join_pattern = r'ON\s+(.+?)(?:WHERE|GROUP|ORDER|JOIN|$)'
            order_pattern = r'ORDER\s+BY\s+(.+?)(?:LIMIT|$)'
            
            where_columns = []
            join_columns = []
            order_columns = []
            
            # Extract columns from different clauses
            where_match = re.search(where_pattern, sql, re.IGNORECASE | re.DOTALL)
            if where_match:
                where_clause = where_match.group(1)
                # Simple extraction - in practice would need more sophisticated parsing
                column_pattern = r'([\w\.]+)\s*[=<>]'
                where_columns = re.findall(column_pattern, where_clause)
            
            join_match = re.findall(join_pattern, sql, re.IGNORECASE | re.DOTALL)
            for join_clause in join_match:
                column_pattern = r'([\w\.]+)\s*=\s*([\w\.]+)'
                matches = re.findall(column_pattern, join_clause)
                join_columns.extend([col for pair in matches for col in pair])
            
            order_match = re.search(order_pattern, sql, re.IGNORECASE | re.DOTALL)
            if order_match:
                order_clause = order_match.group(1)
                order_columns = [col.strip() for col in order_clause.split(',')]
            
            # Use LLM for comprehensive index suggestions
            index_prompt = f"""
            Suggest indexes for this SQL query:
            
            SQL:
            {sql}
            
            Database: {database_id}
            
            Columns used in:
            - WHERE: {where_columns}
            - JOIN: {join_columns}
            - ORDER BY: {order_columns}
            
            Provide:
            1. Recommended indexes with CREATE INDEX statements
            2. Explanation for each index
            3. Expected performance improvements
            4. Any existing indexes that might already help
            """
            
            messages = [
                {"role": "system", "content": SQL_EVALUATOR_SYSTEM_PROMPT},
                {"role": "user", "content": index_prompt}
            ]
            
            response = await self.model_client.create(messages=messages)
            index_suggestions = response.choices[0].message.content
            
            return json.dumps({
                "sql": sql,
                "where_columns": where_columns,
                "join_columns": join_columns,
                "order_columns": order_columns,
                "index_suggestions": index_suggestions
            }, indent=2)
        
        async def evaluate_sql_performance(sql: str, database_id: str) -> str:
            """Evaluate potential performance of SQL query."""
            
            performance_factors = {
                "has_select_star": "SELECT *" in sql.upper(),
                "has_where_clause": "WHERE" in sql.upper(),
                "has_joins": "JOIN" in sql.upper(),
                "has_subqueries": "(SELECT" in sql,
                "has_grouping": "GROUP BY" in sql.upper(),
                "has_ordering": "ORDER BY" in sql.upper(),
                "has_distinct": "DISTINCT" in sql.upper(),
                "has_functions": any(func in sql.upper() for func in ['COUNT', 'SUM', 'AVG', 'MAX', 'MIN'])
            }
            
            # Performance score (simplified)
            performance_score = 100
            
            if performance_factors['has_select_star']:
                performance_score -= 10
            if not performance_factors['has_where_clause'] and performance_factors['has_joins']:
                performance_score -= 20
            if performance_factors['has_subqueries']:
                performance_score -= 15
            if performance_factors['has_distinct']:
                performance_score -= 5
            
            # Detailed performance analysis from LLM
            performance_prompt = f"""
            Analyze the performance characteristics of this SQL query:
            
            SQL:
            {sql}
            
            Database: {database_id}
            
            Evaluate:
            1. Potential bottlenecks
            2. Resource usage (CPU, memory, I/O)
            3. Scalability concerns
            4. Execution plan considerations
            5. Data volume impact
            
            Provide specific recommendations for improvement.
            """
            
            messages = [
                {"role": "system", "content": SQL_EVALUATOR_SYSTEM_PROMPT},
                {"role": "user", "content": performance_prompt}
            ]
            
            response = await self.model_client.create(messages=messages)
            performance_analysis = response.choices[0].message.content
            
            return json.dumps({
                "sql": sql,
                "performance_score": performance_score,
                "performance_factors": performance_factors,
                "detailed_analysis": performance_analysis
            }, indent=2)
        
        # Create the agent
        return AssistantAgent(
            name="sql_evaluator",
            model_client=self.model_client,
            tools=[
                validate_sql_syntax,
                analyze_query_complexity,
                optimize_sql,
                suggest_indexes,
                evaluate_sql_performance
            ],
            system_message=SQL_EVALUATOR_SYSTEM_PROMPT,
            reflect_on_tool_use=True,
            model_client_stream=True,
        )
    
    def _extract_sql_from_response(self, response: str) -> str:
        """Extract SQL from LLM response."""
        # Look for SQL in code blocks
        sql_pattern = r'```sql\n(.*?)\n```'
        matches = re.findall(sql_pattern, response, re.DOTALL)
        
        if matches:
            return matches[0].strip()
        
        # Look for SQL after certain markers
        for marker in ['Optimized SQL:', 'SQL:', 'Query:']:
            if marker in response:
                sql_part = response.split(marker)[-1].strip()
                # Extract until the next marker or end
                for end_marker in ['\n\n', '---', '###']:
                    if end_marker in sql_part:
                        sql_part = sql_part.split(end_marker)[0]
                return sql_part.strip()
        
        # If no specific format found, try to extract SELECT statement
        lines = response.split('\n')
        sql_lines = []
        in_sql = False
        
        for line in lines:
            if 'SELECT' in line.upper():
                in_sql = True
            if in_sql:
                sql_lines.append(line)
                if ';' in line:
                    break
        
        return '\n'.join(sql_lines).strip()
    
    async def query(self, task: str) -> None:
        """Run a query against the SQL evaluator agent."""
        await Console(self.agent.run_stream(task=task))
    
    async def close(self) -> None:
        """Close the model client connection."""
        await self.model_client.close()

## Initialize Components and Test

In [None]:
# Initialize components
data_path = "../data/bird/dev_databases"
tables_json_path = "../data/bird/dev_tables.json"
dataset_name = "bird"

# Create schema manager
schema_manager = SchemaManager(
    data_path=data_path,
    tables_json_path=tables_json_path,
    dataset_name=dataset_name,
    lazy=True
)

# Create SQL evaluator agent
sql_evaluator = SQLEvaluationAgent(schema_manager)

### Test SQL Validation

In [None]:
# Test validating a correct SQL query
valid_sql = "SELECT School, County FROM schools WHERE County = 'Alameda'"

await sql_evaluator.query(f"""
Validate this SQL syntax for database 'california_schools':
{valid_sql}
""")

In [None]:
# Test validating SQL with errors
invalid_sql = "SELECT School FROM invalid_table WHERE "

await sql_evaluator.query(f"""
Validate this SQL syntax for database 'california_schools':
{invalid_sql}
""")

### Test Query Complexity Analysis

In [None]:
# Test analyzing a complex query
complex_sql = """
SELECT s.School, AVG(sat.AvgScrMath) as avg_math_score
FROM schools s
JOIN satscores sat ON s.CDSCode = sat.cds
WHERE s.County = 'Alameda'
  AND sat.AvgScrMath IS NOT NULL
GROUP BY s.School
HAVING AVG(sat.AvgScrMath) > (
  SELECT AVG(AvgScrMath) 
  FROM satscores 
  WHERE AvgScrMath IS NOT NULL
)
ORDER BY avg_math_score DESC
"""

await sql_evaluator.query(f"""
Analyze the complexity of this SQL query for database 'california_schools':
{complex_sql}
""")

### Test SQL Optimization

In [None]:
# Test optimizing an inefficient query
inefficient_sql = """
SELECT *
FROM schools s
JOIN satscores sat ON s.CDSCode = sat.cds
ORDER BY sat.AvgScrMath DESC
"""

await sql_evaluator.query(f"""
Optimize this SQL query for database 'california_schools' for better performance:
{inefficient_sql}
""")

### Test Index Suggestions

In [None]:
# Test getting index suggestions
query_needing_indexes = """
SELECT s.School, s.County, sat.AvgScrMath
FROM schools s
JOIN satscores sat ON s.CDSCode = sat.cds
WHERE s.County = 'Alameda'
  AND sat.AvgScrMath > 600
ORDER BY sat.AvgScrMath DESC
"""

await sql_evaluator.query(f"""
Suggest indexes for this SQL query on database 'california_schools':
{query_needing_indexes}
""")

### Test Performance Evaluation

In [None]:
# Test evaluating query performance
performance_test_sql = """
SELECT DISTINCT s.County, COUNT(*) as school_count, AVG(sat.AvgScrMath) as avg_score
FROM schools s
LEFT JOIN satscores sat ON s.CDSCode = sat.cds
WHERE s.County IN (SELECT County FROM schools GROUP BY County HAVING COUNT(*) > 50)
GROUP BY s.County
HAVING AVG(sat.AvgScrMath) > 500
ORDER BY avg_score DESC
"""

await sql_evaluator.query(f"""
Evaluate the performance characteristics of this SQL query for database 'california_schools':
{performance_test_sql}
""")

### Complete SQL Evaluation Workflow

In [None]:
# Test complete evaluation workflow
test_sql = """
SELECT *
FROM schools s, satscores sat
WHERE s.CDSCode = sat.cds
  AND s.County = 'Los Angeles'
"""

await sql_evaluator.query(f"""
Perform a complete evaluation of this SQL query for database 'california_schools':
{test_sql}

Please:
1. Validate the syntax
2. Analyze the complexity
3. Optimize for performance
4. Suggest indexes
5. Evaluate performance characteristics
""")

In [None]:
# Close connections
await sql_evaluator.close()