# Memory Agent Tool Test

This notebook demonstrates how to use the `MemoryAgentTool` with our KeyValueMemory implementation to create agents that can read from and write to a shared memory store.

The `MemoryAgentTool` class provides:
1. **Memory integration** for agents by extending BaseTool
2. **Pre-processing** via reader callbacks that fetch relevant context from memory before agent execution
3. **Post-processing** via parser callbacks that extract and store information from agent outputs after execution
4. A structured way to pass information between steps in the text-to-SQL workflow

In [1]:
import asyncio
import json
import logging
import re
from typing import Dict, Any, List, Optional

# Set up logging
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# Reduce noise from autogen
logging.getLogger('autogen_core').setLevel(logging.WARNING)

In [2]:
# Import our KeyValueMemory and MemoryAgentTool
from memory import KeyValueMemory
from memory_agent_tool import MemoryAgentTool
from workflow_utils import extract_sql_from_text

# Import necessary AutoGen components
from autogen_core import CancellationToken
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.base import TaskResult
from autogen_ext.models.openai import OpenAIChatCompletionClient

## 1. Set up our memory store and model client

In [3]:
# Initialize the shared memory store
memory = KeyValueMemory(name="text_to_sql_memory")

# Set up the model client - replace with your specific model
model_client = OpenAIChatCompletionClient(
    model="gpt-4o",  # or other appropriate model
    temperature=0.1,
    timeout=120
)

## 2. Define memory callback functions for different agent roles

Each agent will have custom reader and parser functions that define how it interacts with memory.

In [4]:
# Schema Selector Agent memory callbacks
async def schema_selector_reader(memory, task, cancellation_token):
    """Read relevant info for the schema selector agent."""
    # For schema selection, we might want to read previously selected schemas
    # for similar queries to maintain consistency
    query_history = await memory.get("query_history")
    context = {}
    
    if query_history:
        context["query_history"] = query_history
        
    # We might also want to read any schema preferences
    schema_preferences = await memory.get("schema_preferences")
    if schema_preferences:
        context["schema_preferences"] = schema_preferences
    
    return context

async def schema_selector_parser(memory, task, result, cancellation_token):
    """Parse and store schema selection results."""
    if result.messages:
        last_message = result.messages[-1].content
        
        # Look for database schema in XML format
        schema_match = re.search(r'<database_schema>.*?</database_schema>', last_message, re.DOTALL)
        if schema_match:
            schema_str = schema_match.group()
            await memory.set("current_schema", schema_str)
            print("Stored schema in memory")
            
        # Try to extract database ID
        db_id_match = re.search(r'"db_id"\s*:\s*"([^"]+)"', task, re.DOTALL)
        if db_id_match:
            db_id = db_id_match.group(1)
            await memory.set("current_db_id", db_id)
            print(f"Stored current_db_id: {db_id}")
            
        # Store the task and response for history
        query_history = await memory.get("query_history") or []
        task_json = json.loads(task) if isinstance(task, str) and task.strip().startswith('{') else {"query": task}
        query = task_json.get("query", task)
        
        query_history.append({"query": query, "role": "selector"})
        await memory.set("query_history", query_history)

# SQL Generator Agent memory callbacks
async def sql_generator_reader(memory, task, cancellation_token):
    """Read relevant info for the SQL generator agent."""
    context = {}
    
    # Critical: Read the database schema
    schema = await memory.get("current_schema")
    if schema:
        context["schema"] = schema
    
    # Read database ID
    db_id = await memory.get("current_db_id")
    if db_id:
        context["db_id"] = db_id
    
    # Read previously generated SQL for similar queries
    query_history = await memory.get("query_history")
    if query_history:
        sql_history = [item for item in query_history if "sql" in item]
        if sql_history:
            context["sql_history"] = sql_history
    
    return context

async def sql_generator_parser(memory, task, result, cancellation_token):
    """Parse and store SQL generation results."""
    if result.messages:
        last_message = result.messages[-1].content
        
        # Extract SQL query from the response
        sql = extract_sql_from_text(last_message)
        if sql:
            await memory.set("current_sql", sql)
            print(f"Stored SQL in memory: {sql[:50]}...")
            
            # Update query history
            query_history = await memory.get("query_history") or []
            # Find the current query (should be the last one without SQL)
            for item in reversed(query_history):
                if "sql" not in item:
                    item["sql"] = sql
                    break
            await memory.set("query_history", query_history)

# SQL Executor Agent memory callbacks
async def sql_executor_reader(memory, task, cancellation_token):
    """Read relevant info for the SQL executor agent."""
    context = {}
    
    # Read the database schema
    schema = await memory.get("current_schema")
    if schema:
        context["schema"] = schema
    
    # Read database ID
    db_id = await memory.get("current_db_id")
    if db_id:
        context["db_id"] = db_id
    
    # Read the SQL to execute
    sql = await memory.get("current_sql")
    if sql:
        context["sql"] = sql
    
    return context

async def sql_executor_parser(memory, task, result, cancellation_token):
    """Parse and store SQL execution results."""
    if result.messages:
        last_message = result.messages[-1].content
        
        # Try to extract execution results from JSON format
        try:
            # Look for JSON in the response
            json_match = re.search(r'```json\s*([\s\S]*?)\s*```', last_message)
            if json_match:
                json_str = json_match.group(1)
                execution_result = json.loads(json_str)
                await memory.set("execution_result", execution_result)
                print("Stored execution result in memory")
                
                # Extract the status
                status = execution_result.get("status", "UNKNOWN")
                await memory.set("execution_status", status)
                
                # Update query history
                query_history = await memory.get("query_history") or []
                for item in reversed(query_history):
                    if "sql" in item and "execution_result" not in item:
                        item["execution_result"] = execution_result
                        item["execution_status"] = status
                        break
                await memory.set("query_history", query_history)
        except:
            # If JSON parsing fails, store the raw result
            await memory.set("execution_result_raw", last_message)
            print("Stored raw execution result in memory")

## 3. Create our Agents with System Messages

In [5]:
# Define system messages for each agent
SCHEMA_SELECTOR_SYSTEM_MESSAGE = """
You are a database schema selector agent. Your role is to:
1. Analyze the natural language query
2. Extract relevant schema parts from the database
3. Return the selected schema in XML format wrapped in <database_schema> tags

Be precise and focus only on tables and columns directly related to the query.
"""

SQL_GENERATOR_SYSTEM_MESSAGE = """
You are an advanced SQL query generator. Your role is to:
1. Read the database schema provided in XML format
2. Analyze the natural language query
3. Generate a valid SQL query that answers the question
4. Return your SQL query inside ```sql code blocks

Make sure your SQL is valid for SQLite and follows standard SQL best practices.
"""

SQL_EXECUTOR_SYSTEM_MESSAGE = """
You are a SQL execution agent. Your role is to:
1. Execute the provided SQL query against the database
2. Verify the results are correct and complete
3. Handle any errors or refinements needed
4. Return the execution results in JSON format inside ```json code blocks

The JSON must include: status, final_sql, and execution_result fields.
"""

# Create the agents
schema_selector = AssistantAgent(
    name="schema_selector",
    system_message=SCHEMA_SELECTOR_SYSTEM_MESSAGE,
    model_client=model_client,
    description="Selects relevant parts of the database schema for a query"
)

sql_generator = AssistantAgent(
    name="sql_generator",
    system_message=SQL_GENERATOR_SYSTEM_MESSAGE,
    model_client=model_client,
    description="Generates SQL queries from natural language"
)

sql_executor = AssistantAgent(
    name="sql_executor",
    system_message=SQL_EXECUTOR_SYSTEM_MESSAGE,
    model_client=model_client,
    description="Executes SQL queries and returns results"
)

# Wrap each agent with memory capabilities
schema_selector_tool = MemoryAgentTool(
    agent=schema_selector,
    memory=memory,
    reader_callback=schema_selector_reader,
    parser_callback=schema_selector_parser
)

sql_generator_tool = MemoryAgentTool(
    agent=sql_generator,
    memory=memory,
    reader_callback=sql_generator_reader,
    parser_callback=sql_generator_parser
)

sql_executor_tool = MemoryAgentTool(
    agent=sql_executor,
    memory=memory,
    reader_callback=sql_executor_reader,
    parser_callback=sql_executor_parser
)

## How MemoryAgentTool Works

The `MemoryAgentTool` is built on top of AutoGen's BaseTool and integrates with our KeyValueMemory implementation. Let's understand how it works:

1. **Initialization**: Each tool is created with:
   - An agent (the core component that performs the task)
   - A shared memory instance (used across multiple agents)
   - Custom reader and parser callbacks (specific to each agent's role)

2. **Memory Reader Callback**: This function is called before the agent runs and reads relevant information from memory. It returns a dictionary of context that gets added to the agent's task.

3. **Memory Parser Callback**: This function is called after the agent completes its task and extracts relevant information from the agent's response to store in memory.

4. **Memory Flow**: 
   - Selector agent selects schema and stores it in memory
   - Generator agent reads the schema from memory and creates SQL
   - Executor agent reads the SQL from memory and executes it
   - Each agent updates the memory with its results

This approach allows each agent to focus on its specific task while the memory provides a persistent state between steps.

## 4. Create a Coordinator with the Memory-Enabled Tools

In [6]:
# Create a coordinating agent with the memory-enabled tools
coordinator = AssistantAgent(
    name="text_to_sql_coordinator",
    system_message="""You are a text-to-SQL workflow coordinator. You have access to three specialized tools:
1. schema_selector: Analyzes the query and selects relevant parts of the database schema
2. sql_generator: Generates SQL queries based on the schema and natural language query
3. sql_executor: Executes the SQL and returns the results

For any natural language query about a database, you should:
1. First use the schema_selector tool to get the relevant schema
2. Then use the sql_generator tool to create a SQL query
3. Finally use the sql_executor tool to run the query and get results
4. Summarize the results in a user-friendly way

These tools share a memory system so relevant information will be passed between them automatically.
""",
    model_client=model_client,
    tools=[schema_selector_tool, sql_generator_tool, sql_executor_tool],
    description="Coordinates the text-to-SQL workflow using memory-enabled agent tools"
)

## 5. Reset Memory and Set Initial Database Information

In [7]:
# Reset memory for a fresh run
await memory.clear()

# Set preferences that will be used by the schema selector
await memory.set("schema_preferences", {
    "max_tables": 5,
    "include_foreign_keys": True,
    "include_examples": True
})

# Initialize empty query history
await memory.set("query_history", [])

2025-05-20 12:45:36,038 - root - INFO - [KeyValueMemory] Memory cleared.


TypeError: Expected dict or Image instance, got <class 'list'>

## 6. Run a Text-to-SQL Workflow

Let's test the workflow with a sample query from the BIRD dataset.

In [None]:
# Create a sample task in the expected format
sample_task = json.dumps({
    "db_id": "hospital_1",
    "query": "What are the names of patients who have been admitted more than twice?",
    "evidence": "The hospital_1 database tracks patient admissions. Each patient can have multiple admissions over time."
})

# Run the workflow
result = await coordinator.run(task=sample_task)

# Print the final response
for message in result.messages:
    if message.source != "user":
        print(f"\n{message.source}: {message.content}")

## 7. Examine Memory After Execution

In [None]:
# Check what's in memory after running the workflow
schema = await memory.get("current_schema")
print(f"Schema in memory: {'Yes, length=' + str(len(schema)) if schema else 'No'}")

sql = await memory.get("current_sql")
print(f"SQL in memory: {'Yes - ' + sql[:50] + '...' if sql else 'No'}")

result = await memory.get("execution_result")
print(f"Execution result in memory: {'Yes' if result else 'No'}")

status = await memory.get("execution_status")
print(f"Execution status: {status if status else 'Unknown'}")

# Print query history
history = await memory.get("query_history")
print(f"\nQuery history entries: {len(history) if history else 0}")
if history:
    for i, entry in enumerate(history):
        print(f"\nEntry {i+1}:")
        print(f"Query: {entry.get('query')}")
        print(f"Role: {entry.get('role', 'unknown')}")
        if 'sql' in entry:
            sql = entry['sql']
            print(f"SQL: {sql[:50]}{'...' if len(sql) > 50 else ''}")
        if 'execution_status' in entry:
            print(f"Status: {entry['execution_status']}")

## 8. Running Another Query Using Existing Memory

Let's run another query on the same database. The agents will use the existing memory to be more efficient.

In [None]:
# Create a follow-up task
follow_up_task = json.dumps({
    "db_id": "hospital_1",
    "query": "What is the average length of stay for patients over the age of 65?",
    "evidence": "Admission records include admission and discharge dates."
})

# Run the workflow
result = await coordinator.run(task=follow_up_task)

# Print the final response
for message in result.messages:
    if message.source != "user":
        print(f"\n{message.source}: {message.content}")

## 9. Check Updated Query History

## 10. Integrating with Workflow Runners

To fully leverage the MemoryAgentTool in your text-to-SQL pipeline, you can integrate it with the existing workflow runners. Here's a simplified example of how to adapt the workflow runner to use memory-enabled agents:

In [None]:
class MemoryWorkflowRunner:
    """
    A workflow runner that uses memory-enabled agent tools for text-to-SQL processing.
    
    This class manages the creation and execution of memory-enabled agent tools
    for the text-to-SQL pipeline, maintaining state between steps and across
    multiple queries.
    """
    
    def __init__(
        self,
        model_client,
        memory=None,
        timeout: int = 120
    ):
        """Initialize the memory workflow runner."""
        self.model_client = model_client
        self.memory = memory or KeyValueMemory(name="text_to_sql_memory")
        self.timeout = timeout
        
        # Create the agents and tools
        self._create_agents()
    
    def _create_agents(self):
        """Create memory-enabled agents and tools."""
        # Create the base agents with system messages
        self.selector_agent = AssistantAgent(
            name="schema_selector",
            system_message=SCHEMA_SELECTOR_SYSTEM_MESSAGE,
            model_client=self.model_client
        )
        
        self.generator_agent = AssistantAgent(
            name="sql_generator",
            system_message=SQL_GENERATOR_SYSTEM_MESSAGE,
            model_client=self.model_client
        )
        
        self.executor_agent = AssistantAgent(
            name="sql_executor",
            system_message=SQL_EXECUTOR_SYSTEM_MESSAGE,
            model_client=self.model_client
        )
        
        # Create memory-enabled tools
        self.selector_tool = MemoryAgentTool(
            agent=self.selector_agent,
            memory=self.memory,
            reader_callback=schema_selector_reader,
            parser_callback=schema_selector_parser
        )
        
        self.generator_tool = MemoryAgentTool(
            agent=self.generator_agent,
            memory=self.memory,
            reader_callback=sql_generator_reader,
            parser_callback=sql_generator_parser
        )
        
        self.executor_tool = MemoryAgentTool(
            agent=self.executor_agent,
            memory=self.memory,
            reader_callback=sql_executor_reader,
            parser_callback=sql_executor_parser
        )
        
        # Create coordinator
        self.coordinator = AssistantAgent(
            name="text_to_sql_coordinator",
            system_message="""You are a text-to-SQL workflow coordinator. Follow these steps:
            1. Use schema_selector to get the database schema
            2. Use sql_generator to create the SQL query
            3. Use sql_executor to execute and verify the query
            4. Summarize the results in a user-friendly way""",
            model_client=self.model_client,
            tools=[self.selector_tool, self.generator_tool, self.executor_tool]
        )
    
    async def initialize_memory(self):
        """Initialize the memory for a new session."""
        await self.memory.clear()
        await self.memory.set("query_history", [])
        await self.memory.set("schema_preferences", {
            "max_tables": 5,
            "include_foreign_keys": True,
            "include_examples": True
        })
    
    async def run_workflow(self, task_json):
        """Run the text-to-SQL workflow on a task."""
        # Convert dict to JSON string if needed
        if isinstance(task_json, dict):
            task_json = json.dumps(task_json)
            
        # Create cancellation token
        cancellation_token = CancellationToken()
        
        # Run the coordinator
        result = await self.coordinator.run(
            task=task_json, 
            cancellation_token=cancellation_token
        )
        
        # Collect final results from memory
        final_result = {
            "sql": await self.memory.get("current_sql"),
            "execution_result": await self.memory.get("execution_result"),
            "status": await self.memory.get("execution_status")
        }
        
        return result, final_result

## 11. Demo of the MemoryWorkflowRunner

Let's test our MemoryWorkflowRunner with a sample task:

In [None]:
# Create a memory workflow runner
runner = MemoryWorkflowRunner(model_client=model_client)

# Initialize memory
await runner.initialize_memory()

# Define a sample task
sample_task = {
    "db_id": "restaurant_1",
    "query": "What is the average price of items in the breakfast menu?",
    "evidence": "The restaurant_1 database contains menu items organized by category."
}

# Run the workflow
print("Running workflow...")
result, final_result = await runner.run_workflow(sample_task)

# Display the results
print("\nCoordinator output:")
for message in result.messages:
    if message.source != "user":
        print(f"\n{message.source}: {message.content}")
        
print("\nFinal result from memory:")
print(f"SQL: {final_result['sql']}")
print(f"Status: {final_result['status']}")
if final_result['execution_result']:
    print("Execution result available in memory")

## 12. Best Practices for Using Memory-Enabled Agent Tools

When implementing memory-enabled agent tools in your workflow, follow these best practices:

1. **Clearly define memory schemas** - Decide what information should be stored in memory and use consistent key names across your reader/parser functions

2. **Keep reader callbacks focused** - Each reader should only fetch the information relevant to its specific agent's task

3. **Write robust parser callbacks** - Handle different response formats and edge cases in your parser functions

4. **Format memory context properly** - When adding memory context to a task, format it clearly so the agent can understand it

5. **Use a shared memory instance** - Ensure all agents use the same memory instance to maintain state between steps

6. **Check memory before queries** - Look for existing information in memory to avoid redundant work

7. **Update query history** - Maintain a record of query history for future reference and learning

8. **Add metadata to memory entries** - Include timestamps and other metadata with memory entries for better tracking

9. **Implement error handling** - Add proper error handling in reader and parser callbacks to ensure workflow continuity

10. **Test with diverse queries** - Verify memory capabilities with a variety of query types and complexity levels

## 13. Conclusion

The `MemoryAgentTool` provides a powerful way to add state and memory capabilities to AutoGen agents in a text-to-SQL workflow. By using a shared memory store and custom callbacks, we can create a pipeline where:

1. Each agent has access to relevant information from previous steps
2. Agents can focus on their specific tasks without needing to manage state
3. Information is efficiently passed between workflow steps
4. Previous queries and results can inform future processing

This approach makes the text-to-SQL process more efficient and allows for more complex multi-step reasoning, while maintaining a clean separation of concerns between the different agents in the pipeline.

In [None]:
# Print updated query history
history = await memory.get("query_history")
print(f"Query history entries: {len(history) if history else 0}")
if history:
    for i, entry in enumerate(history):
        print(f"\nEntry {i+1}:")
        print(f"Query: {entry.get('query')}")
        if 'sql' in entry:
            sql = entry['sql']
            print(f"SQL: {sql[:50]}{'...' if len(sql) > 50 else ''}")
        if 'execution_status' in entry:
            print(f"Status: {entry['execution_status']}")