# Text-to-SQL Workflow with Agent Tools and KeyValueMemory

This notebook demonstrates a text-to-SQL workflow using:
- **KeyValueMemory** to store the database schema and intermediate results
- **Three Agent Tools**: schema_selector, sql_generator, sql_executor
- **California Schools Database** as our example database

The workflow processes a simple query through all three agents, with each agent reading from and writing to the shared memory.

In [1]:
import os
import sys
import asyncio
import sqlite3
import json
import logging
import re
from typing import Dict, Any, List, Optional
from dotenv import load_dotenv

sys.path.append('../src')
load_dotenv()

# Set up logging
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# Reduce noise from autogen
logging.getLogger('autogen_core').setLevel(logging.WARNING)

In [2]:
# Import our modules
from keyvalue_memory import KeyValueMemory
from schema_reader import SchemaReader
from memory_agent_tool import MemoryAgentTool
from workflow_utils import extract_sql_from_text
from sql_executor import SQLExecutor

# Import AutoGen components
from autogen_core import CancellationToken
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.base import TaskResult
from autogen_ext.models.openai import OpenAIChatCompletionClient

## 1. Load California Schools Database Schema

First, we'll load the database schema using SchemaReader and store it in KeyValueMemory.

In [3]:
# Initialize memory and schema reader
memory = KeyValueMemory(name="text_to_sql_memory")

# Clear memory for fresh start
await memory.clear()

# Load California Schools schema
data_path = "/home/norman/work/text-to-sql/MAC-SQL/data/bird"
tables_json_path = os.path.join(data_path, "dev_tables.json")

schema_reader = SchemaReader(
    data_path=data_path,
    tables_json_path=tables_json_path,
    dataset_name="bird",
    lazy=False
)

# Load and store the full database schema
db_id = "california_schools"

# Generate schema description (using all tables)
selected_schema = {}  # Empty means select all
schema_xml, fk_infos, chosen_schema = schema_reader.generate_schema_description(
    db_id=db_id,
    selected_schema=selected_schema,
    use_gold_schema=False
)

# Store the full schema in memory
await memory.set("full_database_schema", schema_xml)
await memory.set("database_id", db_id)
await memory.set("foreign_keys", fk_infos)

print(f"Loaded schema for database: {db_id}")
print(f"Schema length: {len(schema_xml)} characters")
print(f"Foreign keys: {len(fk_infos)}")

# Also store the database path for execution
db_path = os.path.join(data_path, "dev_databases", db_id, f"{db_id}.sqlite")
await memory.set("database_path", db_path)
print(f"Database path: {db_path}")

2025-05-24 16:35:59,070 - root - INFO - [KeyValueMemory] Memory cleared.


load json file from /home/norman/work/text-to-sql/MAC-SQL/data/bird/dev_tables.json

Loading all database info...
Found 11 databases in bird dataset
Loaded schema for database: california_schools
Schema length: 14313 characters
Foreign keys: 2
Database path: /home/norman/work/text-to-sql/MAC-SQL/data/bird/dev_databases/california_schools/california_schools.sqlite


## 2. Set up Agent Memory Callbacks

Define how each agent reads from and writes to memory.

In [4]:
# Schema Selector callbacks
async def schema_selector_reader(memory, task, cancellation_token):
    """Read the full database schema from memory."""
    print("\n[SCHEMA SELECTOR READER] Starting...")
    context = {}
    
    # Read the full schema that we loaded
    full_schema = await memory.get("full_database_schema")
    if full_schema:
        context["full_database_schema"] = full_schema
        print(f"[SCHEMA SELECTOR READER] Loaded full schema ({len(full_schema)} chars)")
    else:
        print("[SCHEMA SELECTOR READER] No full schema found in memory")
    
    db_id = await memory.get("database_id")
    if db_id:
        context["database_id"] = db_id
        print(f"[SCHEMA SELECTOR READER] Database ID: {db_id}")
    
    print("[SCHEMA SELECTOR READER] Complete\n")
    return context

async def schema_selector_parser(memory, task, result, cancellation_token):
    """Store the selected schema subset."""
    print("\n[SCHEMA SELECTOR PARSER] Starting...")
    if result.messages:
        last_message = result.messages[-1].content
        
        # Extract selected schema
        schema_match = re.search(r'<selected_schema>.*?</selected_schema>', last_message, re.DOTALL)
        if schema_match:
            selected_schema = schema_match.group()
            await memory.set("selected_schema", selected_schema)
            print(f"[SCHEMA SELECTOR PARSER] Stored selected schema ({len(selected_schema)} chars)")
        else:
            # If no specific selection, use the full schema
            full_schema = await memory.get("full_database_schema")
            await memory.set("selected_schema", full_schema)
            print("[SCHEMA SELECTOR PARSER] No selection found, using full schema")
    print("[SCHEMA SELECTOR PARSER] Complete\n")

# SQL Generator callbacks
async def sql_generator_reader(memory, task, cancellation_token):
    """Read the selected schema from memory."""
    print("\n[SQL GENERATOR READER] Starting...")
    context = {}
    
    # Read the selected schema
    selected_schema = await memory.get("selected_schema")
    if selected_schema:
        context["database_schema"] = selected_schema
        print(f"[SQL GENERATOR READER] Loaded selected schema ({len(selected_schema)} chars)")
    else:
        print("[SQL GENERATOR READER] No selected schema found")
    
    db_id = await memory.get("database_id")
    if db_id:
        context["database_id"] = db_id
        print(f"[SQL GENERATOR READER] Database ID: {db_id}")
    
    print("[SQL GENERATOR READER] Complete\n")
    return context

async def sql_generator_parser(memory, task, result, cancellation_token):
    """Store the generated SQL and execute it."""
    print("\n[SQL GENERATOR PARSER] Starting...")
    if result.messages:
        last_message = result.messages[-1].content
        
        # Extract SQL
        sql = extract_sql_from_text(last_message)
        if sql:
            await memory.set("generated_sql", sql)
            print(f"[SQL GENERATOR PARSER] Extracted SQL: {sql}")
            
            # Execute the SQL immediately
            db_path = await memory.get("database_path")
            if db_path:
                print(f"[SQL GENERATOR PARSER] Executing SQL on database: {db_path}")
                try:
                    conn = sqlite3.connect(db_path)
                    cursor = conn.cursor()
                    cursor.execute(sql)
                    results = cursor.fetchall()
                    columns = [desc[0] for desc in cursor.description] if cursor.description else []
                    
                    execution_result = {
                        "status": "success",
                        "columns": columns,
                        "data": results,
                        "row_count": len(results),
                        "sql": sql
                    }
                    conn.close()
                    
                    await memory.set("execution_result", execution_result)
                    print(f"[SQL GENERATOR PARSER] Execution successful - {len(results)} rows returned")
                    
                except Exception as e:
                    execution_result = {
                        "status": "error",
                        "error": str(e),
                        "sql": sql
                    }
                    await memory.set("execution_result", execution_result)
                    print(f"[SQL GENERATOR PARSER] Execution error: {e}")
            else:
                print("[SQL GENERATOR PARSER] No database path found")
        else:
            print("[SQL GENERATOR PARSER] No SQL found in response")
    print("[SQL GENERATOR PARSER] Complete\n")

# SQL Evaluator callbacks  
async def sql_evaluator_reader(memory, task, cancellation_token):
    """Read SQL and execution results from memory. Execute SQL if not already executed."""
    print("\n[SQL EVALUATOR READER] Starting...")
    context = {}
    
    sql = await memory.get("generated_sql")
    if sql:
        context["sql_query"] = sql
        print(f"[SQL EVALUATOR READER] Found SQL: {sql}")
    
    execution_result = await memory.get("execution_result")
    if execution_result:
        context["execution_result"] = execution_result
        print(f"[SQL EVALUATOR READER] Found existing execution result: {execution_result['status']}")
    else:
        # Execute SQL if not already executed
        if sql:
            db_path = await memory.get("database_path")
            if db_path:
                print(f"[SQL EVALUATOR READER] No execution result found - executing SQL now")
                executor = SQLExecutor(db_path)
                
                try:
                    results, columns = executor.execute_sql(sql, get_columns=True)
                    
                    execution_result = {
                        "status": "success",
                        "columns": columns,
                        "data": results,
                        "row_count": len(results),
                        "sql": sql
                    }
                    
                    await memory.set("execution_result", execution_result)
                    context["execution_result"] = execution_result
                    print(f"[SQL EVALUATOR READER] Execution successful - {len(results)} rows returned")
                    
                except Exception as e:
                    execution_result = {
                        "status": "error",
                        "error": str(e),
                        "sql": sql
                    }
                    await memory.set("execution_result", execution_result)
                    context["execution_result"] = execution_result
                    print(f"[SQL EVALUATOR READER] Execution error: {e}")
        
    # Also provide the original query for context
    selected_schema = await memory.get("selected_schema")
    if selected_schema:
        context["database_schema"] = selected_schema
        print(f"[SQL EVALUATOR READER] Added schema context ({len(selected_schema)} chars)")
    
    print("[SQL EVALUATOR READER] Complete\n")
    return context

async def sql_evaluator_parser(memory, task, result, cancellation_token):
    """Store evaluation results and suggestions."""
    print("\n[SQL EVALUATOR PARSER] Starting...")
    if result.messages:
        last_message = result.messages[-1].content
        
        # Store evaluation
        await memory.set("evaluation_result", last_message)
        print("[SQL EVALUATOR PARSER] Stored evaluation text")
        
        # Try to extract structured evaluation
        try:
            json_match = re.search(r'```json\s*([\s\S]*?)\s*```', last_message)
            if json_match:
                json_str = json_match.group(1)
                structured_eval = json.loads(json_str)
                await memory.set("evaluation_structured", structured_eval)
                print("[SQL EVALUATOR PARSER] Extracted and stored structured evaluation")
        except Exception as e:
            print(f"[SQL EVALUATOR PARSER] Could not extract structured evaluation: {e}")
    print("[SQL EVALUATOR PARSER] Complete\n")

## 3. Create Agents with Simple System Messages

In [5]:
# Initialize model client
model_client = OpenAIChatCompletionClient(
    model="gpt-4o",
    temperature=0.1,
    timeout=120
)

# Create agents with simple, focused prompts
schema_selector = AssistantAgent(
    name="schema_selector",
    system_message="""You select relevant tables and columns from a database schema for a given query.

When given a full database schema and a user query:
1. Identify which tables are needed to answer the query
2. Select only the relevant columns from those tables
3. Keep all foreign key relationships between selected tables
4. Return your selection wrapped in <selected_schema> tags

Be concise and only include what's necessary for the query.""",
    model_client=model_client
)

sql_generator = AssistantAgent(
    name="sql_generator", 
    system_message="""You generate SQL queries for SQLite databases.

When given a database schema and a user query:
1. Write a valid SQLite query that answers the question
2. Use proper JOIN syntax when multiple tables are needed
3. Return the SQL wrapped in ```sql code blocks

Keep queries simple and correct.""",
    model_client=model_client
)

sql_evaluator = AssistantAgent(
    name="sql_evaluator",
    system_message="""You evaluate SQL execution results and provide insights.

When given a SQL query and its execution results:
1. Analyze if the results correctly answer the original question
2. Check if the result set makes sense (not too many/few rows)
3. Identify any potential issues or improvements
4. Provide a brief summary of the findings

Return your evaluation in ```json format with:
{
  "answers_question": true/false,
  "result_quality": "good/acceptable/poor", 
  "summary": "Brief description of the results",
  "suggestions": ["Any improvements or issues found"]
}""",
    model_client=model_client
)

# Wrap agents with memory tools
schema_selector_tool = MemoryAgentTool(
    agent=schema_selector,
    memory=memory,
    reader_callback=schema_selector_reader,
    parser_callback=schema_selector_parser
)

sql_generator_tool = MemoryAgentTool(
    agent=sql_generator,
    memory=memory,
    reader_callback=sql_generator_reader,
    parser_callback=sql_generator_parser
)

sql_evaluator_tool = MemoryAgentTool(
    agent=sql_evaluator,
    memory=memory,
    reader_callback=sql_evaluator_reader,
    parser_callback=sql_evaluator_parser
)

## 4. Create Coordinator Agent

In [6]:
# Import team components
from autogen_agentchat.conditions import TextMentionTermination
from autogen_agentchat.teams import RoundRobinGroupChat

# Create coordinator that orchestrates the workflow
coordinator = AssistantAgent(
    name="coordinator",
    system_message="""You coordinate a text-to-SQL workflow to generate correct SQL queries.

Your tools are:
1. schema_selector - Selects relevant schema parts for a query
2. sql_generator - Generates SQL from the selected schema (also executes it)
3. sql_evaluator - Evaluates the execution results and provides insights

Your goal is to generate a CORRECT SQL query that properly answers the user's question.

Workflow:
1. Start by calling schema_selector to identify relevant tables/columns
2. Call sql_generator to create and execute SQL
3. Call sql_evaluator to check if the results are correct
4. If the evaluator indicates issues or the SQL is incorrect:
   - You may call schema_selector again to refine the schema selection
   - Call sql_generator again with better guidance
   - Continue iterating until you have correct SQL
5. Once you have correct SQL with good results, summarize the final answer and say "TERMINATE"

IMPORTANT: You must say "TERMINATE" only after you have successfully generated correct SQL and summarized the results.""",
    model_client=model_client,
    tools=[schema_selector_tool, sql_generator_tool, sql_evaluator_tool]
)

## 5. Create a Team with Termination Condition

We'll use a team-based approach where the coordinator continues working until it says "TERMINATE".

In [7]:
# Create a team with termination condition
termination_condition = TextMentionTermination("TERMINATE")

# Create a team with just the coordinator
team = RoundRobinGroupChat(
    participants=[coordinator],
    termination_condition=termination_condition
)

## 6. Test with a Simple Query

Let's test the workflow with a simple query about California schools.

In [8]:
# Simple query about California schools
simple_query = "What are the top 5 schools with the highest average math SAT scores?"

# Store the query in memory for reference
await memory.set("user_query", simple_query)

print(f"Processing query: {simple_query}")
print("-" * 60)

# Run the workflow using the team
stream = team.run_stream(task=simple_query)

# Stream and display messages as they're generated
async for message in stream:
    print(f"\n[{getattr(message, 'source', 'Unknown')}]:")
    print(message)
    print("-" * 40)

print("\n" + "="*60)
print("WORKFLOW COMPLETE")
print("="*60)

Processing query: What are the top 5 schools with the highest average math SAT scores?
------------------------------------------------------------

[user]:
source='user' models_usage=None metadata={} content='What are the top 5 schools with the highest average math SAT scores?' type='TextMessage'
----------------------------------------


2025-05-24 16:36:13,173 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"



[SCHEMA SELECTOR READER] Starting...
[SCHEMA SELECTOR READER] Loaded full schema (14313 chars)
[SCHEMA SELECTOR READER] Database ID: california_schools
[SCHEMA SELECTOR READER] Complete


[coordinator]:
source='coordinator' models_usage=RequestUsage(prompt_tokens=338, completion_tokens=27) metadata={} content=[FunctionCall(id='call_GTtwEiiM8Tavu6wuwVC9NwK5', arguments='{"task":"Find the top 5 schools with the highest average math SAT scores."}', name='schema_selector')] type='ToolCallRequestEvent'
----------------------------------------


2025-05-24 16:36:14,811 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"



[SCHEMA SELECTOR PARSER] Starting...
[SCHEMA SELECTOR PARSER] Stored selected schema (473 chars)
[SCHEMA SELECTOR PARSER] Complete


[coordinator]:
source='coordinator' models_usage=None metadata={} content=[FunctionExecutionResult(content='{"messages": [{"source": "user", "models_usage": null, "metadata": {}}, {"source": "schema_selector", "models_usage": {"prompt_tokens": 4045, "completion_tokens": 137}, "metadata": {}}], "stop_reason": null}', name='schema_selector', call_id='call_GTtwEiiM8Tavu6wuwVC9NwK5', is_error=False)] type='ToolCallExecutionEvent'
----------------------------------------

[coordinator]:
source='coordinator' models_usage=None metadata={} content='{"messages": [{"source": "user", "models_usage": null, "metadata": {}}, {"source": "schema_selector", "models_usage": {"prompt_tokens": 4045, "completion_tokens": 137}, "metadata": {}}], "stop_reason": null}' type='ToolCallSummaryMessage'
----------------------------------------


2025-05-24 16:36:16,042 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"



[SQL GENERATOR READER] Starting...
[SQL GENERATOR READER] Loaded selected schema (473 chars)
[SQL GENERATOR READER] Database ID: california_schools
[SQL GENERATOR READER] Complete


[coordinator]:
source='coordinator' models_usage=RequestUsage(prompt_tokens=431, completion_tokens=33) metadata={} content=[FunctionCall(id='call_AWiWbfafXGzdFrjffiJpdjHP', arguments='{"task":"Find the top 5 schools with the highest average math SAT scores using the relevant schema parts identified."}', name='sql_generator')] type='ToolCallRequestEvent'
----------------------------------------


2025-05-24 16:36:17,612 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"



[SQL GENERATOR PARSER] Starting...
[SQL GENERATOR PARSER] Extracted SQL: SELECT schools.School, satscores.AvgScrMath
FROM satscores
JOIN schools ON satscores.cds = schools.CDSCode
ORDER BY satscores.AvgScrMath DESC
LIMIT 5;
[SQL GENERATOR PARSER] Executing SQL on database: /home/norman/work/text-to-sql/MAC-SQL/data/bird/dev_databases/california_schools/california_schools.sqlite
[SQL GENERATOR PARSER] Execution successful - 5 rows returned
[SQL GENERATOR PARSER] Complete


[coordinator]:
source='coordinator' models_usage=None metadata={} content=[FunctionExecutionResult(content='{"messages": [{"source": "user", "models_usage": null, "metadata": {}}, {"source": "sql_generator", "models_usage": {"prompt_tokens": 252, "completion_tokens": 126}, "metadata": {}}], "stop_reason": null}', name='sql_generator', call_id='call_AWiWbfafXGzdFrjffiJpdjHP', is_error=False)] type='ToolCallExecutionEvent'
----------------------------------------

[coordinator]:
source='coordinator' models_usage=None m

2025-05-24 16:36:18,702 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"



[SQL EVALUATOR READER] Starting...
[SQL EVALUATOR READER] Found SQL: SELECT schools.School, satscores.AvgScrMath
FROM satscores
JOIN schools ON satscores.cds = schools.CDSCode
ORDER BY satscores.AvgScrMath DESC
LIMIT 5;
[SQL EVALUATOR READER] Found existing execution result: success
[SQL EVALUATOR READER] Added schema context (473 chars)
[SQL EVALUATOR READER] Complete


[coordinator]:
source='coordinator' models_usage=RequestUsage(prompt_tokens=529, completion_tokens=36) metadata={} content=[FunctionCall(id='call_CtwItOydPu941hr0Nmyae79i', arguments='{"task":"Evaluate the results of the SQL query that finds the top 5 schools with the highest average math SAT scores."}', name='sql_evaluator')] type='ToolCallRequestEvent'
----------------------------------------


2025-05-24 16:36:19,726 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"



[SQL EVALUATOR PARSER] Starting...
[SQL EVALUATOR PARSER] Stored evaluation text
[SQL EVALUATOR PARSER] Extracted and stored structured evaluation
[SQL EVALUATOR PARSER] Complete


[coordinator]:
source='coordinator' models_usage=None metadata={} content=[FunctionExecutionResult(content='{"messages": [{"source": "user", "models_usage": null, "metadata": {}}, {"source": "sql_evaluator", "models_usage": {"prompt_tokens": 480, "completion_tokens": 55}, "metadata": {}}], "stop_reason": null}', name='sql_evaluator', call_id='call_CtwItOydPu941hr0Nmyae79i', is_error=False)] type='ToolCallExecutionEvent'
----------------------------------------

[coordinator]:
source='coordinator' models_usage=None metadata={} content='{"messages": [{"source": "user", "models_usage": null, "metadata": {}}, {"source": "sql_evaluator", "models_usage": {"prompt_tokens": 480, "completion_tokens": 55}, "metadata": {}}], "stop_reason": null}' type='ToolCallSummaryMessage'
----------------------------------------


2025-05-24 16:36:20,955 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"



[coordinator]:
source='coordinator' models_usage=RequestUsage(prompt_tokens=632, completion_tokens=59) metadata={} content='The top 5 schools with the highest average math SAT scores are:\n\n1. School A\n2. School B\n3. School C\n4. School D\n5. School E\n\nThese schools have the highest average scores in the math section of the SAT. \n\nTERMINATE' type='TextMessage'
----------------------------------------

[Unknown]:
messages=[TextMessage(source='user', models_usage=None, metadata={}, content='What are the top 5 schools with the highest average math SAT scores?', type='TextMessage'), ToolCallRequestEvent(source='coordinator', models_usage=RequestUsage(prompt_tokens=338, completion_tokens=27), metadata={}, content=[FunctionCall(id='call_GTtwEiiM8Tavu6wuwVC9NwK5', arguments='{"task":"Find the top 5 schools with the highest average math SAT scores."}', name='schema_selector')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='coordinator', models_usage=None, metadata={}, co

In [9]:
# Check memory contents after workflow
print("Memory Contents After Workflow:")
print("-" * 60)

# Check all the keys we expect to be set
memory_keys = [
    "user_query",
    "full_database_schema", 
    "selected_schema",
    "generated_sql",
    "execution_result",
    "evaluation_result",
    "evaluation_structured"
]

for key in memory_keys:
    value = await memory.get(key)
    if value:
        print(f"\n{key.upper()}:")
        if isinstance(value, str):
            # For strings, show first 300 chars
            preview = value[:300] + "..." if len(value) > 300 else value
            print(preview)
        elif isinstance(value, dict):
            # For dicts, show structure
            print(f"  Type: dict with keys: {list(value.keys())}")
            if 'status' in value:
                print(f"  Status: {value['status']}")
            if 'row_count' in value:
                print(f"  Rows: {value['row_count']}")
            if 'sql' in value:
                print(f"  SQL: {value['sql']}")
            if 'data' in value and value['data']:
                print(f"  Sample data (first row): {value['data'][0] if value['data'] else 'No data'}")
        else:
            print(f"  Type: {type(value).__name__}")
            print(f"  Value: {str(value)[:200]}...")
    else:
        print(f"\n{key.upper()}: Not found in memory")

Memory Contents After Workflow:
------------------------------------------------------------

USER_QUERY:
What are the top 5 schools with the highest average math SAT scores?

FULL_DATABASE_SCHEMA:
<database_schema>
  <table name="frpm">
    <column name="CDSCode">
      <description>CDSCode</description>
    </column>
    <column name="Academic Year">
      <description>Academic Year</description>
      <values>['2014-2015']</values>
    </column>
    <column name="County Code">
      <descri...

SELECTED_SCHEMA:
<selected_schema>
  <table name="satscores">
    <column name="cds"/>
    <column name="sname"/>
    <column name="AvgScrMath"/>
  </table>
  <table name="schools">
    <column name="CDSCode"/>
    <column name="School"/>
  </table>
  <foreign_keys>
    <foreign_key>
      <from_table>satscores</fro...

GENERATED_SQL:
SELECT schools.School, satscores.AvgScrMath
FROM satscores
JOIN schools ON satscores.cds = schools.CDSCode
ORDER BY satscores.AvgScrMath DESC
LIMIT 5;

EXECUTIO