# Query Decomposer Agent Test

This notebook tests the Query Decomposer Agent that determines if queries are simple or complex and decomposes complex queries into sub-queries.

In [1]:
from dotenv import load_dotenv
import json
import re
from typing import Dict, Any, List, Optional
import xml.etree.ElementTree as ET

load_dotenv()

True

In [2]:
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.ui import Console
from autogen_ext.models.openai import OpenAIChatCompletionClient
from schema_manager import SchemaManager

## Define the Query Decomposer Agent

In [3]:
# System prompt for the query decomposer
QUERY_DECOMPOSER_SYSTEM_PROMPT = """You are an expert Query Decomposer Agent specializing in analyzing natural language questions about databases.

Your task is to:
1. Determine if a user query is simple or complex
2. For complex queries, decompose them into logical sub-queries
3. For simple queries, return the query as-is
4. Provide reasoning for your decisions

A simple query typically:
- Asks for direct information from a single table
- Has straightforward filtering conditions  
- Requires basic aggregations (COUNT, SUM, AVG)
- Can be answered with a single SQL statement

A complex query typically:
- Requires multiple steps to answer
- Involves comparisons with averages or other computed values
- Needs intermediate results to compute the final answer
- Involves complex joins or nested queries

You must respond in XML format exactly as specified in the examples."""

In [5]:
# Define the prompt template
DECOMPOSE_QUERY_PROMPT = """Given a user query and database information, analyze whether this is a simple or complex query.
If complex, decompose it into sub-queries.

Database: {database_id}  
Schema Information: {schema_info}
User Query: {query}
Evidence: {evidence}

Examples show that complex queries often need sub-queries when:
1. Calculating averages for comparison (e.g., "schools with SAT excellence rate over the average")
2. Finding extreme values first (e.g., "the lowest average salary branch")
3. Multi-step filtering (e.g., finding youngest client in specific conditions)

Response format for simple queries:
<QueryAnalysis>
    <QueryType>Simple</QueryType>
    <Reasoning>This query can be answered directly with a single SQL statement</Reasoning>
    <Result>
        <Query>{query}</Query>
    </Result>
</QueryAnalysis>

Response format for complex queries:
<QueryAnalysis>
    <QueryType>Complex</QueryType>
    <Reasoning>This query requires multiple steps because...</Reasoning>
    <Result>
        <SubQuery number="1">
            <Description>{description of what this sub-query does}</Description>
            <Query>{sub-query text}</Query>
        </SubQuery>
        <SubQuery number="2">
            <Description>{description of what this sub-query does}</Description>
            <Query>{sub-query text}</Query>
        </SubQuery>
        ...
    </Result>
</QueryAnalysis>

Examples from const.py:

Example 1 - Complex query needing average calculation:
Query: "List school names of charter schools with an SAT excellence rate over the average."
This decomposes to:
1. Get the average SAT excellence rate of charter schools
2. List school names with rate over the average

Example 2 - Complex query with extreme value:
Query: "What is the gender of the youngest client who opened account in the lowest average salary branch?"
This decomposes to:
1. Find the district_id with the lowest average salary
2. Find the youngest client in that district
3. Get the gender of that client

Example 3 - Simple query:
Query: "Show the stadium name and the number of concerts in each stadium."
This is simple - can be answered with a single JOIN and GROUP BY."""

In [6]:
# Initialize the schema manager
schema_manager = SchemaManager(
    data_path="../data/bird",
    tables_json_path="../data/bird/dev_tables.json",
    dataset_name="bird",
    lazy=True
)

load json file from ../data/bird/dev_tables.json


In [7]:
# Define query decomposer tools
async def analyze_query_complexity(query: str, database_id: str, evidence: str = "") -> str:
    """Analyze if a query is simple or complex and decompose if needed."""
    
    # Get schema information
    try:
        # Get basic database info
        db_info = schema_manager.db2dbjsons.get(database_id, {})
        schema_summary = {
            "tables": db_info.get('table_names_original', []),
            "table_count": db_info.get('table_count', 0)
        }
        schema_info = json.dumps(schema_summary)
    except:
        schema_info = "Schema information not available"
    
    # Use LLM to analyze query
    prompt = DECOMPOSE_QUERY_PROMPT.format(
        database_id=database_id,
        schema_info=schema_info,
        query=query,
        evidence=evidence
    )
    
    # In practice, this would call the LLM - for now, we'll implement some logic
    return analyze_query_complexity_logic(query, database_id, evidence, schema_info)

def analyze_query_complexity_logic(query: str, database_id: str, evidence: str, schema_info: str) -> str:
    """Simple logic to analyze query complexity."""
    query_lower = query.lower()
    
    # Indicators of complex queries
    complex_indicators = [
        "over the average",
        "above average", 
        "below average",
        "youngest",
        "oldest",
        "lowest",
        "highest",
        "minimum",
        "maximum",
        "top",
        "bottom",
        "compared to",
        "more than the average",
        "less than the average"
    ]
    
    # Check for complex indicators
    is_complex = any(indicator in query_lower for indicator in complex_indicators)
    
    # Check for multiple conditions that suggest complexity
    has_multiple_conditions = query_lower.count("and") > 1 or query_lower.count("where") > 1
    
    if is_complex or has_multiple_conditions:
        return generate_complex_analysis(query, database_id, evidence)
    else:
        return generate_simple_analysis(query)

def generate_simple_analysis(query: str) -> str:
    """Generate simple query analysis."""
    return f"""<QueryAnalysis>
    <QueryType>Simple</QueryType>
    <Reasoning>This query can be answered directly with a single SQL statement</Reasoning>
    <Result>
        <Query>{query}</Query>
    </Result>
</QueryAnalysis>"""

def generate_complex_analysis(query: str, database_id: str, evidence: str) -> str:
    """Generate complex query analysis with decomposition."""
    query_lower = query.lower()
    sub_queries = []
    
    # Example patterns for decomposition
    if "over the average" in query_lower or "above average" in query_lower:
        # Pattern: comparison with average
        sub_queries.append({
            "description": "Calculate the average value for comparison",
            "query": "Get the average value of the metric mentioned in the query"
        })
        sub_queries.append({
            "description": "Find items that exceed the average",
            "query": query
        })
    
    elif any(word in query_lower for word in ["youngest", "oldest", "lowest", "highest"]):
        # Pattern: finding extreme values
        if "youngest" in query_lower or "oldest" in query_lower:
            sub_queries.append({
                "description": "Find the person with the extreme age characteristic",
                "query": "Identify the youngest/oldest person based on birth date"
            })
        elif "lowest" in query_lower or "highest" in query_lower:
            sub_queries.append({
                "description": "Find the item with the extreme value",
                "query": "Identify the item with the lowest/highest value"
            })
        
        # Add final query to get the requested information
        sub_queries.append({
            "description": "Get the final requested information",
            "query": query
        })
    
    # Build the XML response
    sub_query_xml = ""
    for i, sq in enumerate(sub_queries, 1):
        sub_query_xml += f"""        <SubQuery number="{i}">
            <Description>{sq['description']}</Description>
            <Query>{sq['query']}</Query>
        </SubQuery>
"""
    
    return f"""<QueryAnalysis>
    <QueryType>Complex</QueryType>
    <Reasoning>This query requires multiple steps because it involves {identify_complexity_reason(query_lower)}</Reasoning>
    <Result>
{sub_query_xml}    </Result>
</QueryAnalysis>"""

def identify_complexity_reason(query_lower: str) -> str:
    """Identify the reason for query complexity."""
    if "average" in query_lower:
        return "comparison with an average value that must be calculated first"
    elif any(word in query_lower for word in ["youngest", "oldest", "lowest", "highest"]):
        return "finding extreme values that require intermediate calculations"
    else:
        return "multiple conditions that need to be evaluated in sequence"

In [8]:
# Create the Query Decomposer Agent
model_client = OpenAIChatCompletionClient(model="gpt-4o")

query_decomposer_agent = AssistantAgent(
    name="query_decomposer",
    model_client=model_client,
    tools=[analyze_query_complexity],
    system_message=QUERY_DECOMPOSER_SYSTEM_PROMPT,
    reflect_on_tool_use=True,
    model_client_stream=True,
)

## Test the Query Decomposer Agent

In [9]:
# Test with a simple query
simple_query = "Show the stadium name and the number of concerts in each stadium"
await Console(query_decomposer_agent.run_stream(
    task=f"Analyze this query for database 'concert_singer': {simple_query}"
))

---------- TextMessage (user) ----------
Analyze this query for database 'concert_singer': Show the stadium name and the number of concerts in each stadium
---------- ModelClientStreamingChunkEvent (query_decomposer) ----------
<analysis>
    <type>simple</type>
    <query>Show the stadium name and the number of concerts in each stadium</query>
    <reasoning>
        The query involves retrieving the stadium name and counting the number of concerts in each stadium. This can typically be accomplished with a single SQL statement that performs a GROUP BY operation on the stadium name to count the concerts. There are no additional aggregations or complex conditions needed. Thus, the query is simple.
    </reasoning>
</analysis>


TaskResult(messages=[TextMessage(source='user', models_usage=None, metadata={}, content="Analyze this query for database 'concert_singer': Show the stadium name and the number of concerts in each stadium", type='TextMessage'), TextMessage(source='query_decomposer', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), metadata={}, content='<analysis>\n    <type>simple</type>\n    <query>Show the stadium name and the number of concerts in each stadium</query>\n    <reasoning>\n        The query involves retrieving the stadium name and counting the number of concerts in each stadium. This can typically be accomplished with a single SQL statement that performs a GROUP BY operation on the stadium name to count the concerts. There are no additional aggregations or complex conditions needed. Thus, the query is simple.\n    </reasoning>\n</analysis>', type='TextMessage')], stop_reason=None)

In [10]:
# Test with a complex query involving average
complex_query_1 = "List school names of charter schools with an SAT excellence rate over the average"
await Console(query_decomposer_agent.run_stream(
    task=f"Analyze this query for database 'california_schools': {complex_query_1}. Evidence: Charter schools refers to `Charter School (Y/N)` = 1; Excellence rate = NumGE1500 / NumTstTakr"
))

---------- TextMessage (user) ----------
Analyze this query for database 'california_schools': List school names of charter schools with an SAT excellence rate over the average. Evidence: Charter schools refers to `Charter School (Y/N)` = 1; Excellence rate = NumGE1500 / NumTstTakr
---------- ToolCallRequestEvent (query_decomposer) ----------
[FunctionCall(id='call_nc2NKnPrtABrdfYThcLByu02', arguments='{"query":"List school names of charter schools with an SAT excellence rate over the average.","database_id":"california_schools","evidence":"Charter schools refers to `Charter School (Y/N)` = 1; Excellence rate = NumGE1500 / NumTstTakr"}', name='analyze_query_complexity')]
---------- ToolCallExecutionEvent (query_decomposer) ----------
[FunctionExecutionResult(content="'description of what this sub-query does'", name='analyze_query_complexity', call_id='call_nc2NKnPrtABrdfYThcLByu02', is_error=True)]
---------- ModelClientStreamingChunkEvent (query_decomposer) ----------
<analysis>
    <

TaskResult(messages=[TextMessage(source='user', models_usage=None, metadata={}, content="Analyze this query for database 'california_schools': List school names of charter schools with an SAT excellence rate over the average. Evidence: Charter schools refers to `Charter School (Y/N)` = 1; Excellence rate = NumGE1500 / NumTstTakr", type='TextMessage'), ToolCallRequestEvent(source='query_decomposer', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), metadata={}, content=[FunctionCall(id='call_nc2NKnPrtABrdfYThcLByu02', arguments='{"query":"List school names of charter schools with an SAT excellence rate over the average.","database_id":"california_schools","evidence":"Charter schools refers to `Charter School (Y/N)` = 1; Excellence rate = NumGE1500 / NumTstTakr"}', name='analyze_query_complexity')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='query_decomposer', models_usage=None, metadata={}, content=[FunctionExecutionResult(content="'description of what t

In [11]:
# Test with a complex query involving extreme values
complex_query_2 = "What is the gender of the youngest client who opened account in the lowest average salary branch?"
await Console(query_decomposer_agent.run_stream(
    task=f"Analyze this query for database 'financial': {complex_query_2}. Evidence: Later birthdate refers to younger age; A11 refers to average salary"
))

---------- TextMessage (user) ----------
Analyze this query for database 'financial': What is the gender of the youngest client who opened account in the lowest average salary branch?. Evidence: Later birthdate refers to younger age; A11 refers to average salary
---------- ToolCallRequestEvent (query_decomposer) ----------
[FunctionCall(id='call_1qnkA7DQuaMlZrNcEInLUJTg', arguments='{"query":"What is the gender of the youngest client who opened an account in the lowest average salary branch?","database_id":"financial","evidence":"Later birthdate refers to younger age; A11 refers to average salary"}', name='analyze_query_complexity')]
---------- ToolCallExecutionEvent (query_decomposer) ----------
[FunctionExecutionResult(content="'description of what this sub-query does'", name='analyze_query_complexity', call_id='call_1qnkA7DQuaMlZrNcEInLUJTg', is_error=True)]
---------- ModelClientStreamingChunkEvent (query_decomposer) ----------
<analysis>
    <type>complex</type>
    <subqueries>
 

TaskResult(messages=[TextMessage(source='user', models_usage=None, metadata={}, content="Analyze this query for database 'financial': What is the gender of the youngest client who opened account in the lowest average salary branch?. Evidence: Later birthdate refers to younger age; A11 refers to average salary", type='TextMessage'), ToolCallRequestEvent(source='query_decomposer', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), metadata={}, content=[FunctionCall(id='call_1qnkA7DQuaMlZrNcEInLUJTg', arguments='{"query":"What is the gender of the youngest client who opened an account in the lowest average salary branch?","database_id":"financial","evidence":"Later birthdate refers to younger age; A11 refers to average salary"}', name='analyze_query_complexity')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='query_decomposer', models_usage=None, metadata={}, content=[FunctionExecutionResult(content="'description of what this sub-query does'", name='analyze_qu

In [12]:
# Test with more queries
test_queries = [
    {
        "query": "What is the total number of students in California?",
        "database": "california_schools",
        "evidence": ""
    },
    {
        "query": "Find all transactions over 500 dollars",
        "database": "financial",
        "evidence": ""
    },
    {
        "query": "Which department has the highest average salary compared to the company average?",
        "database": "employee",
        "evidence": ""
    }
]

for test in test_queries:
    print(f"\n{'='*50}")
    print(f"Query: {test['query']}")
    print(f"Database: {test['database']}")
    print(f"{'='*50}\n")
    
    task = f"Analyze this query for database '{test['database']}': {test['query']}"
    if test['evidence']:
        task += f". Evidence: {test['evidence']}"
    
    await Console(query_decomposer_agent.run_stream(task=task))


Query: What is the total number of students in California?
Database: california_schools

---------- TextMessage (user) ----------
Analyze this query for database 'california_schools': What is the total number of students in California?
---------- ModelClientStreamingChunkEvent (query_decomposer) ----------
<analysis>
    <type>simple</type>
    <query>What is the total number of students in California?</query>
    <reasoning>
        The query asks for the total number of students, which requires a straightforward aggregation by summing up a column (likely representing student numbers) across all relevant records in the database. As this can be achieved with a single SQL query using the SUM function and does not involve any complex conditions or multiple steps, it is considered a simple query.
    </reasoning>
</analysis>

Query: Find all transactions over 500 dollars
Database: financial

---------- TextMessage (user) ----------
Analyze this query for database 'financial': Find all tr

## Create a more sophisticated version using actual LLM

In [13]:
class QueryDecomposerAgent:
    """Agent that analyzes query complexity and decomposes complex queries."""
    
    def __init__(self, model_client, schema_manager):
        self.model_client = model_client
        self.schema_manager = schema_manager
        self.agent = self._create_agent()
    
    def _create_agent(self) -> AssistantAgent:
        """Create the query decomposer agent with tools."""
        
        async def decompose_query(query: str, database_id: str, evidence: str = "") -> str:
            """Analyze query complexity and decompose if needed."""
            
            # Get schema information
            schema_info = self._get_schema_info(database_id)
            
            # Create the prompt
            prompt = DECOMPOSE_QUERY_PROMPT.format(
                database_id=database_id,
                schema_info=schema_info,
                query=query,
                evidence=evidence
            )
            
            # Get LLM response
            messages = [
                {"role": "system", "content": QUERY_DECOMPOSER_SYSTEM_PROMPT},
                {"role": "user", "content": prompt}
            ]
            
            response = await self.model_client.get_completion(messages)
            return response.choices[0].message.content
        
        async def parse_decomposition_result(xml_response: str) -> str:
            """Parse the XML response and return structured data."""
            try:
                root = ET.fromstring(xml_response)
                query_type = root.find('QueryType').text
                reasoning = root.find('Reasoning').text
                
                result = {
                    "query_type": query_type,
                    "reasoning": reasoning
                }
                
                if query_type == "Simple":
                    result["query"] = root.find('.//Query').text
                else:
                    sub_queries = []
                    for sq in root.findall('.//SubQuery'):
                        sub_queries.append({
                            "number": sq.get('number'),
                            "description": sq.find('Description').text,
                            "query": sq.find('Query').text
                        })
                    result["sub_queries"] = sub_queries
                
                return json.dumps(result, indent=2)
            except Exception as e:
                return f"Error parsing XML: {str(e)}"
        
        # Create the agent
        return AssistantAgent(
            name="query_decomposer",
            model_client=self.model_client,
            tools=[decompose_query, parse_decomposition_result],
            system_message=QUERY_DECOMPOSER_SYSTEM_PROMPT,
            reflect_on_tool_use=True,
            model_client_stream=True,
        )
    
    def _get_schema_info(self, database_id: str) -> str:
        """Get schema information for a database."""
        try:
            db_info = self.schema_manager.db2dbjsons.get(database_id, {})
            
            # Create a summary of the schema
            schema_summary = {
                "database": database_id,
                "tables": db_info.get('table_names_original', []),
                "table_count": db_info.get('table_count', 0),
                "total_columns": db_info.get('total_column_count', 0)
            }
            
            # Add some column information if available
            if 'column_names_original' in db_info:
                sample_columns = []
                for i, (table_idx, col_name) in enumerate(db_info['column_names_original'][:10]):
                    if table_idx >= 0:
                        table_name = db_info['table_names_original'][table_idx]
                        sample_columns.append(f"{table_name}.{col_name}")
                schema_summary["sample_columns"] = sample_columns
            
            return json.dumps(schema_summary, indent=2)
        except Exception as e:
            return f"Error getting schema info: {str(e)}"
    
    async def query(self, task: str) -> None:
        """Run a query against the decomposer agent."""
        await Console(self.agent.run_stream(task=task))


In [14]:
# Create the sophisticated agent
sophisticated_decomposer = QueryDecomposerAgent(model_client, schema_manager)

# Test it
await sophisticated_decomposer.query(
    "Analyze this query for database 'california_schools': List school names of charter schools with an SAT excellence rate over the average. Evidence: Charter schools refers to `Charter School (Y/N)` = 1; Excellence rate = NumGE1500 / NumTstTakr"
)

---------- TextMessage (user) ----------
Analyze this query for database 'california_schools': List school names of charter schools with an SAT excellence rate over the average. Evidence: Charter schools refers to `Charter School (Y/N)` = 1; Excellence rate = NumGE1500 / NumTstTakr
---------- ToolCallRequestEvent (query_decomposer) ----------
[FunctionCall(id='call_VK1G8waeftBVCHpFuLWPj9Zf', arguments='{"query":"List school names of charter schools with an SAT excellence rate over the average.","database_id":"california_schools","evidence":"Charter schools refers to `Charter School (Y/N)` = 1; Excellence rate = NumGE1500 / NumTstTakr"}', name='decompose_query')]
---------- ToolCallExecutionEvent (query_decomposer) ----------
[FunctionExecutionResult(content="'description of what this sub-query does'", name='decompose_query', call_id='call_VK1G8waeftBVCHpFuLWPj9Zf', is_error=True)]
---------- ModelClientStreamingChunkEvent (query_decomposer) ----------
```xml
<queryAnalysis>
    <comple

In [None]:
# Close the model client
await model_client.close()