# Test QueryUnderstandingAgent

This notebook tests the QueryUnderstandingAgent functionality.

In [1]:
# Setup imports
import sys
sys.path.append("..")

from dispatcher.query_understanding_agent import query_understanding_agent
from dispatcher.tools import read_database_schema_and_records
import json

load json file from /home/norman/work/text-to-sql/MAC-SQL/data/bird/dev_tables.json
load json file from /home/norman/work/text-to-sql/MAC-SQL/data/spider/tables.json




## 1. Check Agent Configuration

In [2]:
# Verify agent is created
print(f"Agent name: {query_understanding_agent.name}")
print(f"\nSystem message (first 500 chars):\n{query_understanding_agent.system_message[:500]}...")

Agent name: QueryUnderstandingAgent

System message (first 500 chars):
You are the QueryUnderstandingAgent. Your role is to understand natural language queries in the context of database schemas, 
decompose complex queries into simpler sub-queries, and identify entities, attributes, operations, and filters.

Your tasks:
1. Receive a natural language query and any existing inter_step_data (provided as an XML string).
2. Use the 'read_database_schema_and_records' tool to fetch relevant schema information and sample data.
   You will receive the tool output as a strin...


## 2. Test read_database_schema_and_records Function

In [3]:
# First, let's see what databases are available
test_result = read_database_schema_and_records(
    db_id="test_db",  # This will likely fail and show us available databases
    dataset_name="bird"
)
print(json.dumps(test_result, indent=2))

Database ID 'test_db' not found in schema_manager.db2dbjsons
Available database IDs: ['debit_card_specializing', 'financial', 'formula_1', 'california_schools', 'card_games']...
{
  "error": "Database ID 'test_db' not found. Available databases: ['debit_card_specializing', 'financial', 'formula_1', 'california_schools', 'card_games']..."
}


In [4]:
# Test with a real database (use one from the error message above)
# Common BIRD databases: california_schools, card_games, codebase_community, debit_card_specializing
db_id = "california_schools"  # Replace with actual DB ID if needed

schema_result = read_database_schema_and_records(
    db_id=db_id,
    dataset_name="bird",
    include_sample_data=True
)

# Display the result
if "error" not in schema_result:
    print(f"Schema for database '{db_id}':")
    for table_name, table_info in schema_result.items():
        print(f"\nTable: {table_name}")
        print(f"  Columns: {list(table_info['columns'].keys())}")
        if 'sample_data' in table_info and table_info['sample_data']:
            print(f"  Sample row: {table_info['sample_data'][0]}")
else:
    print(schema_result)

Generating schema description for database 'california_schools'...
Current directory: /home/norman/work/text-to-sql/MAC-SQL/dispatcher
BIRD_DATA_PATH: /home/norman/work/text-to-sql/MAC-SQL/data/bird
BIRD_DB_DIRECTORY: /home/norman/work/text-to-sql/MAC-SQL/data/bird/dev_databases
Checking database path: /home/norman/work/text-to-sql/MAC-SQL/data/bird/dev_databases/california_schools/california_schools.sqlite, exists: True
Current directory: /home/norman/work/text-to-sql/MAC-SQL/dispatcher
BIRD_DATA_PATH: /home/norman/work/text-to-sql/MAC-SQL/data/bird
BIRD_DB_DIRECTORY: /home/norman/work/text-to-sql/MAC-SQL/data/bird/dev_databases
Checking database path: /home/norman/work/text-to-sql/MAC-SQL/data/bird/dev_databases/california_schools/california_schools.sqlite, exists: True
Current directory: /home/norman/work/text-to-sql/MAC-SQL/dispatcher
BIRD_DATA_PATH: /home/norman/work/text-to-sql/MAC-SQL/data/bird
BIRD_DB_DIRECTORY: /home/norman/work/text-to-sql/MAC-SQL/data/bird/dev_databases
Chec

In [10]:
# Create a user proxy agent and test function
import autogen
import re

def parse_agent_response_robust(response_str):
    """Parse XML response from agent using regex for partial blocks"""
    result = {
        "summary": "",
        "entities": [],
        "attributes": [],
        "operations": [],
        "filters": [],
        "plan_items": [],
        "tables": []
    }
    
    # Extract summary using regex
    summary_match = re.search(r'<summaryOfUnderstanding>(.*?)</summaryOfUnderstanding>', response_str, re.DOTALL)
    if summary_match:
        result["summary"] = summary_match.group(1).strip()
    
    # Extract entities
    entity_pattern = r'<entity>.*?<tableName>(.*?)</tableName>.*?<purpose>(.*?)</purpose>.*?</entity>'
    entities = re.findall(entity_pattern, response_str, re.DOTALL)
    for table, purpose in entities:
        result["entities"].append({"table": table.strip(), "purpose": purpose.strip()})
    
    # Extract attributes 
    attr_pattern = r'<attribute>.*?<columnName>(.*?)</columnName>.*?<tableName>(.*?)</tableName>.*?<operation>(.*?)</operation>.*?</attribute>'
    attributes = re.findall(attr_pattern, response_str, re.DOTALL)
    for col, table, op in attributes:
        result["attributes"].append({"column": col.strip(), "table": table.strip(), "operation": op.strip()})
    
    # Extract operations
    op_pattern = r'<operation>.*?<type>(.*?)</type>.*?<description>(.*?)</description>.*?</operation>'
    operations = re.findall(op_pattern, response_str, re.DOTALL)
    for op_type, desc in operations:
        result["operations"].append({"type": op_type.strip(), "description": desc.strip()})
    
    # Extract filters
    filter_pattern = r'<filter>.*?<attribute>(.*?)</attribute>.*?<condition>(.*?)</condition>.*?<value>(.*?)</value>.*?</filter>'
    filters = re.findall(filter_pattern, response_str, re.DOTALL)
    for attr, cond, val in filters:
        result["filters"].append({"attribute": attr.strip(), "condition": cond.strip(), "value": val.strip()})
    
    # Extract plan items
    plan_pattern = r'<item>.*?<itemId>(.*?)</itemId>.*?<description>(.*?)</description>'
    plans = re.findall(plan_pattern, response_str, re.DOTALL)
    for item_id, desc in plans:
        result["plan_items"].append({"id": item_id.strip(), "description": desc.strip()})
    
    # Extract tables
    table_pattern = r'<table>.*?<n>(.*?)</n>'
    tables = re.findall(table_pattern, response_str, re.DOTALL)
    for table_name in tables:
        result["tables"].append({"name": table_name.strip()})
    
    return result

# Create user proxy
user_proxy = autogen.UserProxyAgent(
    name="User",
    human_input_mode="NEVER",
    max_consecutive_auto_reply=0,
    code_execution_config=False,
)

In [16]:
db_id_to_test = "california_schools"
natural_language_query = "List all schools in California"
# Include placeholder for inter_step_data as per system message
initial_message = f"""Here is the task:
Database ID: {db_id_to_test}
Natural Language Query: "{natural_language_query}"
Existing inter_step_data: <inter_step_data></inter_step_data>
"""

print(f"--- Initiating Chat with Query: '{natural_language_query}' for DB: '{db_id_to_test}' ---")

# Initiate chat
# The UserProxyAgent sends the initial message to the QueryUnderstandingAgent.
chat_results = user_proxy.initiate_chat(
    recipient=query_understanding_agent,
    message=initial_message,
    max_turns=3 # Should be enough: 1. User->Assistant, 2. Assistant (calls tool, processes, replies)->User
)

--- Initiating Chat with Query: 'List all schools in California' for DB: 'california_schools' ---
[33mUser[0m (to QueryUnderstandingAgent):

Here is the task:
Database ID: california_schools
Natural Language Query: "List all schools in California"
Existing inter_step_data: <inter_step_data></inter_step_data>


--------------------------------------------------------------------------------
[33mQueryUnderstandingAgent[0m (to User):

[32m***** Suggested tool call (call_BELXI7xgPLUpT5K3mDNchv7G): read_database_schema_and_records *****[0m
Arguments: 
{"db_id":"california_schools","include_sample_data":true}
[32m*************************************************************************************************[0m

--------------------------------------------------------------------------------


In [14]:
# --- Analyse Results ---
print("\n--- Chat History ---")
for msg in chat_results.chat_history:
    print(f"Sender: {msg['role']}, Content:\n{msg['content']}\n--------------------")

# The last message from the assistant should be the XML response
# (Adjust index if there are more messages, e.g. if user_proxy replies with TERMINATE)
final_agent_reply = None
if chat_results.chat_history and len(chat_results.chat_history) > 1:
    # Typically, user_proxy sends first, assistant replies second (which includes tool call and final response)
    # If UserProxyAgent is very basic and doesn't reply after assistant, the last message is assistant's.
    # If Assistant's message is a tool_call, then next one is tool_response, then final assistant message.
    # AutoGen groups these: the textual response from assistant after tool use is its "turn".
    
    # We are looking for the last message from the QueryUnderstandingAgent that is NOT a tool call
    for i in range(len(chat_results.chat_history) - 1, -1, -1):
        msg = chat_results.chat_history[i]
        if msg['role'].lower() == query_understanding_agent.name.lower() and not msg.get('tool_calls'):
            final_agent_reply = msg['content']
            break
    if not final_agent_reply and chat_results.summary: # Fallback to summary if available
        final_agent_reply = chat_results.summary


--- Chat History ---
Sender: assistant, Content:
Here is the task:
Database ID: california_schools
Natural Language Query: "List all schools in California"
Existing inter_step_data: <inter_step_data></inter_step_data>

--------------------
Sender: assistant, Content:
None
--------------------


In [11]:
# Test with a simple query
simple_query_message = """Please analyze this query: 'List all schools in California'

Database ID: california_schools  
Dataset: bird

Use the read_database_schema_and_records function to get the schema information with sample data.
"""

# Send the query to the agent
user_proxy.send(simple_query_message, query_understanding_agent, request_reply=True)

# Get the last message from the agent
last_message = user_proxy.last_message(query_understanding_agent)
print("Agent Response (first 1000 chars):")
print(last_message["content"][:1000])
print("\n...")

# Parse the response
parsed_response = parse_agent_response_robust(last_message["content"])
print("\n\nParsed Response:")
print(json.dumps(parsed_response, indent=2))

[33mUser[0m (to QueryUnderstandingAgent):

Please analyze this query: 'List all schools in California'

Database ID: california_schools  
Dataset: bird

Use the read_database_schema_and_records function to get the schema information with sample data.


--------------------------------------------------------------------------------
[33mQueryUnderstandingAgent[0m (to User):

[32m***** Suggested tool call (call_5xUrH5fINQahTMzFKap1obMg): read_database_schema_and_records *****[0m
Arguments: 
{"db_id":"california_schools","dataset_name":"bird","include_sample_data":true}
[32m*************************************************************************************************[0m

--------------------------------------------------------------------------------
Agent Response (first 1000 chars):


TypeError: 'NoneType' object is not subscriptable

In [12]:
last_message

{'tool_calls': [{'id': 'call_5xUrH5fINQahTMzFKap1obMg',
   'function': {'arguments': '{"db_id":"california_schools","dataset_name":"bird","include_sample_data":true}',
    'name': 'read_database_schema_and_records'},
   'type': 'function'}],
 'content': None,
 'role': 'assistant'}

In [None]:
# Test with a complex query
complex_query_message = """Please analyze this complex query: 
'What is the average FRPM Count per school for schools with enrollment greater than 1000 students, grouped by county, and only showing counties with more than 5 such schools?'

Database ID: california_schools
Dataset: bird

Use the read_database_schema_and_records function to get the schema information.
This is a complex query that needs to be decomposed into multiple steps.
"""

# Send the query to the agent
user_proxy.send(complex_query_message, query_understanding_agent, request_reply=True)

# Get the response
complex_response = user_proxy.last_message(query_understanding_agent)
print("Agent Response for Complex Query (first 1500 chars):")
print(complex_response["content"][:1500])
print("\n...")

# Parse the complex query response
parsed_complex = parse_agent_response_robust(complex_response["content"])
print("\n\nParsed Complex Query Response:")
print(f"Summary: {parsed_complex.get('summary', 'N/A')}")
print(f"\nNumber of plan steps: {len(parsed_complex.get('plan_items', []))}")
print("\nPlan Steps:")
for item in parsed_complex.get('plan_items', []):
    print(f"  {item['id']}: {item['description']}")
print(f"\nEntities involved: {[e['table'] for e in parsed_complex.get('entities', [])]}")
print(f"Operations: {[op['type'] for op in parsed_complex.get('operations', [])]}")

In [None]:
# Test with a medium complexity query
medium_query_message = """Please analyze this query: 
'Show the total enrollment for schools in each county'

Database ID: california_schools
Dataset: bird

Use the read_database_schema_and_records function to get the schema information.
"""

# Send the query to the agent
user_proxy.send(medium_query_message, query_understanding_agent, request_reply=True)

# Get and parse the response
medium_response = user_proxy.last_message(query_understanding_agent)
parsed_medium = parse_agent_response_robust(medium_response["content"])

print("Medium Query Analysis:")
print(f"Summary: {parsed_medium.get('summary', 'N/A')}")
print(f"Number of plan steps: {len(parsed_medium.get('plan_items', []))}")
print(f"Operations: {[op['type'] for op in parsed_medium.get('operations', [])]}")
print("\nAttributes:")
for attr in parsed_medium.get('attributes', []):
    print(f"  {attr['column']} from {attr['table']} - {attr['operation']}")