# Orchestrator Agent Test

This notebook implements and tests the Orchestrator Agent that plans and monitors the text-to-sql workflow.

In [None]:
from dotenv import load_dotenv
import json
import re
from typing import Dict, Any, List, Optional
import xml.etree.ElementTree as ET
from dataclasses import dataclass

load_dotenv()

In [None]:
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.ui import Console
from autogen_ext.models.openai import OpenAIChatCompletionClient

# Import our agents
from schema_manager import SchemaManager
from schema_agent import SchemaAgent
from sql_executor_agent import SQLExecutorAgent
from sql_executor import SQLExecutor

## Define Orchestrator System Prompt and Templates

In [None]:
ORCHESTRATOR_SYSTEM_PROMPT = """You are an Orchestrator Agent that coordinates the text-to-SQL workflow.

Your responsibilities:
1. Analyze user queries to understand what information is needed
2. Break down complex queries into steps
3. Coordinate with other agents:
   - Schema Agent: for database structure information
   - Query Decomposer: for breaking complex queries
   - SQL Generator: for creating SQL queries
   - SQL Executor: for running and validating queries
4. Monitor progress and ensure all information is collected
5. Generate the final SQL query

Workflow:
1. Understand the user query
2. Identify required information pieces
3. For each piece:
   - Get schema information
   - Generate SQL
   - Execute and verify
4. Combine results for final SQL

Always maintain a clear plan and track progress."""

In [None]:
ORCHESTRATOR_PLAN_TEMPLATE = """Analyze the following query and create an execution plan:

Database: {database_id}
Query: {query}
Evidence: {evidence}

Create a plan that identifies:
1. What information is needed to answer the query
2. Which tables and columns are likely involved
3. What relationships exist between the information
4. The order in which to retrieve information

Format your response as:
<ExecutionPlan>
    <QueryAnalysis>
        <Intent>{what the user wants to find}</Intent>
        <Complexity>{simple/complex}</Complexity>
    </QueryAnalysis>
    <InformationRequired>
        <InfoPiece number="1">
            <Description>{what information is needed}</Description>
            <PotentialTables>{tables that might contain this}</PotentialTables>
            <Dependency>{if depends on other info pieces}</Dependency>
        </InfoPiece>
        ...
    </InformationRequired>
    <ExecutionOrder>{sequence of steps}</ExecutionOrder>
</ExecutionPlan>"""

## Define Workflow State Management

In [None]:
@dataclass
class WorkflowState:
    """Tracks the state of the text-to-SQL workflow."""
    query: str
    database_id: str
    evidence: str
    plan: Optional[Dict] = None
    information_pieces: List[Dict] = None
    schema_info: Dict = None
    intermediate_results: List[Dict] = None
    final_sql: Optional[str] = None
    
    def __post_init__(self):
        if self.information_pieces is None:
            self.information_pieces = []
        if self.intermediate_results is None:
            self.intermediate_results = []
        if self.schema_info is None:
            self.schema_info = {}
    
    def add_information_piece(self, piece: Dict):
        """Add an information piece to track."""
        self.information_pieces.append(piece)
    
    def add_intermediate_result(self, result: Dict):
        """Add an intermediate SQL result."""
        self.intermediate_results.append(result)
    
    def update_schema_info(self, table: str, info: Dict):
        """Update schema information for a table."""
        self.schema_info[table] = info
    
    def is_complete(self) -> bool:
        """Check if all information pieces have been processed."""
        if not self.information_pieces:
            return False
        return all(piece.get('status') == 'complete' for piece in self.information_pieces)

## Implement the Orchestrator Agent

In [None]:
class OrchestratorAgent:
    """Orchestrator agent that coordinates the text-to-SQL workflow."""
    
    def __init__(self, 
                 schema_agent: SchemaAgent,
                 sql_executor_agent: SQLExecutorAgent,
                 model: str = "gpt-4o"):
        self.schema_agent = schema_agent
        self.sql_executor_agent = sql_executor_agent
        self.model_client = OpenAIChatCompletionClient(model=model)
        self.agent = self._create_agent()
        self.workflow_states: Dict[str, WorkflowState] = {}
    
    def _create_agent(self) -> AssistantAgent:
        """Create the orchestrator agent with tools."""
        
        async def create_execution_plan(query: str, database_id: str, evidence: str = "") -> str:
            """Create an execution plan for the query."""
            # Create workflow state
            state = WorkflowState(query=query, database_id=database_id, evidence=evidence)
            self.workflow_states[query] = state
            
            # Generate plan using LLM
            prompt = ORCHESTRATOR_PLAN_TEMPLATE.format(
                database_id=database_id,
                query=query,
                evidence=evidence
            )
            
            # For demo, return a simple plan
            return self._create_sample_plan(query, database_id, evidence)
        
        async def get_schema_information(database_id: str, table_names: List[str]) -> str:
            """Get schema information for specified tables."""
            # Delegate to schema agent
            results = {}
            for table in table_names:
                # In practice, this would call the schema agent
                results[table] = f"Schema info for {table}"
            
            return json.dumps(results, indent=2)
        
        async def generate_sql_for_info(info_piece: Dict, schema_info: Dict) -> str:
            """Generate SQL for a specific information piece."""
            # In practice, delegate to SQL generation agent
            description = info_piece.get('description', '')
            
            # Simple SQL generation logic for demo
            if "average" in description.lower():
                return "SELECT AVG(column) FROM table"
            elif "count" in description.lower():
                return "SELECT COUNT(*) FROM table"
            else:
                return "SELECT * FROM table WHERE condition"
        
        async def execute_and_verify_sql(sql: str, database_id: str) -> str:
            """Execute SQL and verify results."""
            # Delegate to SQL executor agent
            # For demo, return success
            return json.dumps({
                "success": True,
                "sql": sql,
                "row_count": 10,
                "sample_data": [["data1"], ["data2"]]
            })
        
        async def combine_results_for_final_sql(query: str, intermediate_results: List[Dict]) -> str:
            """Combine intermediate results to generate final SQL."""
            state = self.workflow_states.get(query)
            if not state:
                return "Error: No workflow state found"
            
            # In practice, use all intermediate results to build final SQL
            final_sql = "SELECT /* final SQL based on intermediate results */"
            state.final_sql = final_sql
            
            return json.dumps({
                "final_sql": final_sql,
                "based_on_results": len(intermediate_results)
            })
        
        async def monitor_workflow_progress(query: str) -> str:
            """Check the progress of the workflow."""
            state = self.workflow_states.get(query)
            if not state:
                return "No workflow found for this query"
            
            progress = {
                "query": query,
                "database": state.database_id,
                "total_pieces": len(state.information_pieces),
                "completed_pieces": sum(1 for p in state.information_pieces if p.get('status') == 'complete'),
                "has_final_sql": state.final_sql is not None
            }
            
            return json.dumps(progress, indent=2)
        
        # Create the agent
        return AssistantAgent(
            name="orchestrator",
            model_client=self.model_client,
            tools=[
                create_execution_plan,
                get_schema_information,
                generate_sql_for_info,
                execute_and_verify_sql,
                combine_results_for_final_sql,
                monitor_workflow_progress
            ],
            system_message=ORCHESTRATOR_SYSTEM_PROMPT,
            reflect_on_tool_use=True,
            model_client_stream=True,
        )
    
    def _create_sample_plan(self, query: str, database_id: str, evidence: str) -> str:
        """Create a sample execution plan for demonstration."""
        query_lower = query.lower()
        
        # Analyze query complexity
        is_complex = "average" in query_lower or "youngest" in query_lower or "highest" in query_lower
        
        info_pieces = []
        
        if "average" in query_lower and "over" in query_lower:
            # Pattern: comparison with average
            info_pieces.append({
                "number": "1",
                "description": "Calculate the average value for comparison",
                "potential_tables": "To be determined by schema analysis",
                "dependency": "None"
            })
            info_pieces.append({
                "number": "2",
                "description": "Find items that exceed the average",
                "potential_tables": "Same as piece 1",
                "dependency": "Piece 1"
            })
        elif any(word in query_lower for word in ["youngest", "oldest", "highest", "lowest"]):
            info_pieces.append({
                "number": "1",
                "description": "Find the extreme value (min/max)",
                "potential_tables": "To be determined",
                "dependency": "None"
            })
            info_pieces.append({
                "number": "2",
                "description": "Get detailed information for the extreme case",
                "potential_tables": "Same as piece 1",
                "dependency": "Piece 1"
            })
        else:
            # Simple query
            info_pieces.append({
                "number": "1",
                "description": "Direct query execution",
                "potential_tables": "To be determined",
                "dependency": "None"
            })
        
        # Format as XML
        pieces_xml = ""
        for piece in info_pieces:
            pieces_xml += f"""
        <InfoPiece number="{piece['number']}">
            <Description>{piece['description']}</Description>
            <PotentialTables>{piece['potential_tables']}</PotentialTables>
            <Dependency>{piece['dependency']}</Dependency>
        </InfoPiece>"""
        
        plan_xml = f"""<ExecutionPlan>
    <QueryAnalysis>
        <Intent>Find {query}</Intent>
        <Complexity>{'complex' if is_complex else 'simple'}</Complexity>
    </QueryAnalysis>
    <InformationRequired>{pieces_xml}
    </InformationRequired>
    <ExecutionOrder>Sequential execution from piece 1 to {len(info_pieces)}</ExecutionOrder>
</ExecutionPlan>"""
        
        # Update workflow state
        state = self.workflow_states[query]
        state.plan = {"xml": plan_xml, "info_pieces": info_pieces}
        for piece in info_pieces:
            state.add_information_piece({**piece, "status": "pending"})
        
        return plan_xml
    
    async def query(self, task: str) -> None:
        """Run a query against the orchestrator agent."""
        await Console(self.agent.run_stream(task=task))
    
    async def close(self) -> None:
        """Close the model client connection."""
        await self.model_client.close()

## Initialize Components

In [None]:
# Initialize components
data_path = "../data/bird/dev_databases"
tables_json_path = "../data/bird/dev_tables.json"
dataset_name = "bird"

# Create schema manager
schema_manager = SchemaManager(
    data_path=data_path,
    tables_json_path=tables_json_path,
    dataset_name=dataset_name,
    lazy=True
)

# Create SQL executor
sql_executor = SQLExecutor(
    data_path=data_path,
    dataset_name=dataset_name
)

# Create agents
schema_agent = SchemaAgent(schema_manager)
sql_executor_agent = SQLExecutorAgent(sql_executor)

# Create orchestrator
orchestrator = OrchestratorAgent(schema_agent, sql_executor_agent)

## Test the Orchestrator Agent

In [None]:
# Test with a simple query
simple_task = """Process this query for database 'california_schools':
Query: Show all schools in Alameda county
"""

await orchestrator.query(simple_task)

In [None]:
# Test with a complex query
complex_task = """Process this query for database 'california_schools':
Query: List school names of charter schools with an SAT excellence rate over the average
Evidence: Charter schools refers to `Charter School (Y/N)` = 1; Excellence rate = NumGE1500 / NumTstTakr
"""

await orchestrator.query(complex_task)

In [None]:
# Test the full workflow
workflow_task = """Execute the complete workflow for this query:
Database: california_schools
Query: Find the average SAT score for charter schools in counties with more than 100 schools
Evidence: Charter schools have `Charter School (Y/N)` = 1

Please:
1. Create an execution plan
2. Get schema information for relevant tables
3. Generate SQL for each information piece
4. Execute and verify the SQL
5. Combine results for final SQL
6. Monitor progress throughout
"""

await orchestrator.query(workflow_task)

## Enhanced Orchestrator with Real Agent Integration

In [None]:
class EnhancedOrchestratorAgent(OrchestratorAgent):
    """Enhanced orchestrator with real agent integration."""
    
    def _create_agent(self) -> AssistantAgent:
        """Create the enhanced orchestrator agent."""
        
        async def create_execution_plan(query: str, database_id: str, evidence: str = "") -> str:
            """Create an execution plan for the query."""
            state = WorkflowState(query=query, database_id=database_id, evidence=evidence)
            self.workflow_states[query] = state
            
            # Use the actual model to create the plan
            prompt = ORCHESTRATOR_PLAN_TEMPLATE.format(
                database_id=database_id,
                query=query,
                evidence=evidence
            )
            
            messages = [
                {"role": "system", "content": ORCHESTRATOR_SYSTEM_PROMPT},
                {"role": "user", "content": prompt}
            ]
            
            response = await self.model_client.create(messages=messages)
            plan_xml = response.choices[0].message.content
            
            # Parse and store the plan
            self._parse_and_store_plan(query, plan_xml)
            
            return plan_xml
        
        async def execute_workflow_step(query: str, step_number: int) -> str:
            """Execute a specific step in the workflow."""
            state = self.workflow_states.get(query)
            if not state:
                return "Error: No workflow state found"
            
            if step_number > len(state.information_pieces):
                return "Error: Invalid step number"
            
            info_piece = state.information_pieces[step_number - 1]
            
            # 1. Get schema information
            tables = info_piece.get('potential_tables', [])
            if isinstance(tables, str):
                tables = [tables]
            
            schema_result = await self.schema_agent.agent.run(
                f"Get schema information for tables: {tables} in database {state.database_id}"
            )
            
            # 2. Generate SQL
            sql_prompt = f"""
            Generate SQL for: {info_piece['description']}
            Database: {state.database_id}
            Schema: {schema_result}
            Evidence: {state.evidence}
            """
            
            sql_result = await self.model_client.create(
                messages=[
                    {"role": "system", "content": "Generate SQL based on the requirements"},
                    {"role": "user", "content": sql_prompt}
                ]
            )
            sql = sql_result.choices[0].message.content
            
            # 3. Execute SQL
            exec_result = await self.sql_executor_agent.agent.run(
                f"Execute this SQL on database {state.database_id}: {sql}"
            )
            
            # Update state
            info_piece['status'] = 'complete'
            info_piece['sql'] = sql
            info_piece['result'] = exec_result
            
            state.add_intermediate_result({
                "step": step_number,
                "sql": sql,
                "result": exec_result
            })
            
            return json.dumps({
                "step": step_number,
                "status": "complete",
                "sql": sql,
                "result_summary": "Execution successful"
            }, indent=2)
        
        # Create the enhanced agent
        return AssistantAgent(
            name="enhanced_orchestrator",
            model_client=self.model_client,
            tools=[
                create_execution_plan,
                execute_workflow_step,
                self.monitor_workflow_progress,
                self.generate_final_sql
            ],
            system_message=ORCHESTRATOR_SYSTEM_PROMPT,
            reflect_on_tool_use=True,
            model_client_stream=True,
        )
    
    def _parse_and_store_plan(self, query: str, plan_xml: str):
        """Parse the execution plan XML and store in workflow state."""
        try:
            root = ET.fromstring(plan_xml)
            state = self.workflow_states[query]
            
            # Parse information pieces
            for info_elem in root.findall('.//InfoPiece'):
                piece = {
                    "number": info_elem.get('number'),
                    "description": info_elem.find('Description').text,
                    "potential_tables": info_elem.find('PotentialTables').text,
                    "dependency": info_elem.find('Dependency').text,
                    "status": "pending"
                }
                state.add_information_piece(piece)
            
            state.plan = {"xml": plan_xml, "parsed": True}
        except Exception as e:
            state.plan = {"xml": plan_xml, "parsed": False, "error": str(e)}
    
    async def monitor_workflow_progress(self, query: str) -> str:
        """Monitor the progress of the workflow."""
        state = self.workflow_states.get(query)
        if not state:
            return "No workflow found for this query"
        
        progress = {
            "query": query,
            "database": state.database_id,
            "plan_created": state.plan is not None,
            "total_steps": len(state.information_pieces),
            "completed_steps": sum(1 for p in state.information_pieces if p.get('status') == 'complete'),
            "intermediate_results": len(state.intermediate_results),
            "has_final_sql": state.final_sql is not None
        }
        
        # Add details for each step
        step_details = []
        for i, piece in enumerate(state.information_pieces):
            step_details.append({
                "step": i + 1,
                "description": piece['description'],
                "status": piece['status']
            })
        progress["step_details"] = step_details
        
        return json.dumps(progress, indent=2)
    
    async def generate_final_sql(self, query: str) -> str:
        """Generate the final SQL based on all intermediate results."""
        state = self.workflow_states.get(query)
        if not state:
            return "Error: No workflow state found"
        
        if not state.is_complete():
            return "Error: Not all steps are complete"
        
        # Combine all intermediate results
        context = {
            "original_query": state.query,
            "database": state.database_id,
            "evidence": state.evidence,
            "steps": []
        }
        
        for i, result in enumerate(state.intermediate_results):
            context["steps"].append({
                "step": i + 1,
                "description": state.information_pieces[i]['description'],
                "sql": result['sql'],
                "result": result['result']
            })
        
        # Generate final SQL
        final_prompt = f"""
        Based on the following intermediate results, generate the final SQL query:
        
        Original Query: {context['original_query']}
        Database: {context['database']}
        Evidence: {context['evidence']}
        
        Intermediate Steps:
        {json.dumps(context['steps'], indent=2)}
        
        Generate the final SQL that answers the original query.
        """
        
        response = await self.model_client.create(
            messages=[
                {"role": "system", "content": "You are a SQL expert. Generate the final SQL based on intermediate results."},
                {"role": "user", "content": final_prompt}
            ]
        )
        
        final_sql = response.choices[0].message.content
        state.final_sql = final_sql
        
        return json.dumps({
            "final_sql": final_sql,
            "based_on_steps": len(state.intermediate_results)
        }, indent=2)

In [None]:
# Create enhanced orchestrator
enhanced_orchestrator = EnhancedOrchestratorAgent(schema_agent, sql_executor_agent)

# Test with a complete workflow
await enhanced_orchestrator.query("""
Complete workflow for this query:
Database: california_schools
Query: List school names of charter schools with an SAT excellence rate over the average
Evidence: Charter schools refers to `Charter School (Y/N)` = 1; Excellence rate = NumGE1500 / NumTstTakr

Execute all steps:
1. Create execution plan
2. Execute each workflow step
3. Monitor progress
4. Generate final SQL
""")

In [None]:
# Close connections
await orchestrator.close()
await enhanced_orchestrator.close()
await schema_agent.close()
await sql_executor_agent.close()