# SQL Refiner Agent Tool Test

This notebook demonstrates how to use the `MemoryAgentTool` with our KeyValueMemory implementation to create a SQL refiner agent that can read from and write to memory.

The SQL refiner agent improves and optimizes generated SQL queries based on database schema knowledge, error information, and execution results.

The `MemoryAgentTool` class provides:
1. **Memory integration** for agents by extending BaseTool
2. **Pre-processing** via reader callbacks that fetch relevant context from memory before agent execution
3. **Post-processing** via parser callbacks that extract and store information from agent outputs after execution

In [1]:
import asyncio
import json
import logging
import re
from typing import Dict, Any, List, Optional

# Set up logging
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# Reduce noise from autogen
logging.getLogger('autogen_core').setLevel(logging.WARNING)

In [2]:
# Import our KeyValueMemory and MemoryAgentTool
from memory import KeyValueMemory
from memory_agent_tool import MemoryAgentTool, MemoryAgentToolArgs

# Import necessary AutoGen components
from autogen_core import CancellationToken
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.base import TaskResult
from autogen_ext.models.openai import OpenAIChatCompletionClient

## 1. Set up our memory store and model client

In [3]:
# Initialize the shared memory store
memory = KeyValueMemory(name="text_to_sql_memory")

# Set up the model client - replace with your specific model
model_client = OpenAIChatCompletionClient(
    model="gpt-4o",  # or other appropriate model
    temperature=0.1,
    timeout=120
)

## 2. Define memory callback functions for our refiner agent

The agent will have custom reader and parser functions that define how it interacts with memory.

In [4]:
# SQL Refiner Agent memory callbacks
async def sql_refiner_reader(memory, task, cancellation_token):
    """Read relevant information for the SQL refiner agent."""
    context = {}
    
    # Get database schema if available
    schema = await memory.get("current_schema") or await memory.get("full_database_schema")
    if schema:
        context["schema"] = schema
    
    # Get the original SQL if available
    original_sql = await memory.get("original_sql")
    if original_sql:
        context["original_sql"] = original_sql
    
    # Get error information if available
    error_info = await memory.get("sql_error_info")
    if error_info:
        context["error_info"] = error_info
    
    # Get execution results if available
    execution_results = await memory.get("sql_execution_results")
    if execution_results:
        context["execution_results"] = execution_results
        
    # Get refinement history if available
    refinement_history_json = await memory.get("refinement_history")
    if refinement_history_json:
        try:
            refinement_history = json.loads(refinement_history_json)
            if refinement_history:
                context["refinement_history"] = refinement_history
        except json.JSONDecodeError:
            logging.error("Failed to parse refinement history JSON")
    
    # Get database dialect/preferences if available
    db_info_json = await memory.get("db_info")
    if db_info_json:
        try:
            db_info = json.loads(db_info_json)
            context["db_info"] = db_info
        except json.JSONDecodeError:
            logging.error("Failed to parse DB info JSON")
            
    # Add the original query for context
    original_query = await memory.get("original_query")
    if original_query:
        context["original_query"] = original_query
    
    return context

async def sql_refiner_parser(memory, task, result, cancellation_token):
    """Parse and store SQL refinement results."""
    if result.messages:
        last_message = result.messages[-1].content
        print(f"Parsing refiner result: {last_message[:100]}...")
        
        # Extract the refined SQL query
        sql_matches = re.findall(r'```sql\s*(.+?)\s*```', last_message, re.DOTALL)
        if sql_matches:
            refined_sql = sql_matches[0].strip()
            await memory.set("refined_sql", refined_sql)
            print(f"Stored refined SQL in memory: {refined_sql[:100]}...")
            
            # Extract the original SQL before refinement (if not already in memory)
            original_sql = await memory.get("original_sql")
            if not original_sql:
                # Try to parse from task
                try:
                    task_obj = json.loads(task) if isinstance(task, str) else task
                    if "sql" in task_obj:
                        original_sql = task_obj["sql"]
                        await memory.set("original_sql", original_sql)
                except (json.JSONDecodeError, AttributeError):
                    logging.error("Failed to extract original SQL from task")
            
            # Store in refinement history
            refinement_entry = {
                "timestamp": str(datetime.datetime.now()),
                "original_sql": original_sql,
                "refined_sql": refined_sql,
                "refinement_explanation": last_message
            }
            
            # Update the refinement history
            refinement_history = []
            refinement_history_json = await memory.get("refinement_history")
            if refinement_history_json:
                try:
                    refinement_history = json.loads(refinement_history_json)
                except json.JSONDecodeError:
                    logging.error("Failed to parse refinement history, creating new one")
            
            refinement_history.append(refinement_entry)
            # Keep only the last 5 refinements
            if len(refinement_history) > 5:
                refinement_history = refinement_history[-5:]
            
            await memory.set("refinement_history", json.dumps(refinement_history))
            print("Updated refinement history")
            
            # Extract optimization notes if available
            optimization_match = re.search(r'<optimization_notes>(.+?)</optimization_notes>', last_message, re.DOTALL)
            if optimization_match:
                optimization_notes = optimization_match.group(1).strip()
                await memory.set("optimization_notes", optimization_notes)
                print("Stored optimization notes")
        else:
            print("No SQL query found in refiner output")

## 3. Create our SQL Refiner Agent with System Message

In [5]:
import datetime

# Define system message for the SQL refiner agent
SQL_REFINER_SYSTEM_MESSAGE = """
You are a SQL refinement and optimization expert. Your role is to:
1. Review and analyze SQL queries
2. Correct any syntax or logical errors
3. Optimize queries for better performance
4. Ensure queries correctly match the database schema

For each SQL query you review, you should:
- Check for proper table and column names based on the provided schema
- Verify join conditions and relationships
- Suggest index usage and query optimization techniques
- Improve readability with proper formatting and comments

If error information is provided, focus on fixing those specific issues.

Always return your refined SQL in a code block using ```sql ``` format.

Optionally, include optimization notes in XML format:
<optimization_notes>
Your detailed notes about the optimizations applied and why they improve the query.
</optimization_notes>
"""

# Create the agent
sql_refiner_agent = AssistantAgent(
    name="sql_refiner",
    system_message=SQL_REFINER_SYSTEM_MESSAGE,
    model_client=model_client,
    description="Refines and optimizes SQL queries based on schema and execution feedback"
)

# Wrap the agent with memory capabilities
sql_refiner_tool = MemoryAgentTool(
    agent=sql_refiner_agent,
    memory=memory,
    reader_callback=sql_refiner_reader,
    parser_callback=sql_refiner_parser
)

## 4. Set up Sample Database Schema and Information

We'll use the same sample database schema as in the schema selector test, plus additional database information for the refiner.

In [6]:
# Reset memory for a fresh run
await memory.clear()

# Set up a sample database schema
full_schema = """
<database_schema>
  <table name="customers">
    <column name="customer_id" type="INTEGER" primary_key="true" />
    <column name="name" type="TEXT" />
    <column name="email" type="TEXT" />
    <column name="join_date" type="DATE" />
  </table>
  <table name="orders">
    <column name="order_id" type="INTEGER" primary_key="true" />
    <column name="customer_id" type="INTEGER" foreign_key="customers.customer_id" />
    <column name="order_date" type="DATE" />
    <column name="total_amount" type="DECIMAL" />
  </table>
  <table name="products">
    <column name="product_id" type="INTEGER" primary_key="true" />
    <column name="name" type="TEXT" />
    <column name="price" type="DECIMAL" />
    <column name="category" type="TEXT" />
  </table>
  <table name="order_items">
    <column name="item_id" type="INTEGER" primary_key="true" />
    <column name="order_id" type="INTEGER" foreign_key="orders.order_id" />
    <column name="product_id" type="INTEGER" foreign_key="products.product_id" />
    <column name="quantity" type="INTEGER" />
    <column name="price" type="DECIMAL" />
  </table>
  <table name="inventory">
    <column name="inventory_id" type="INTEGER" primary_key="true" />
    <column name="product_id" type="INTEGER" foreign_key="products.product_id" />
    <column name="quantity" type="INTEGER" />
    <column name="warehouse" type="TEXT" />
  </table>
</database_schema>
"""

# Store the schema
await memory.set("full_database_schema", full_schema)
print("Full database schema stored in memory")

# Set up database information
db_info = {
    "dialect": "SQLite",
    "version": "3.39.0",
    "indexes": [
        {"table": "orders", "column": "customer_id", "name": "idx_orders_customer"},
        {"table": "order_items", "column": "order_id", "name": "idx_items_order"},
        {"table": "order_items", "column": "product_id", "name": "idx_items_product"}
    ],
    "row_counts": {
        "customers": 5000,
        "orders": 20000,
        "products": 1000,
        "order_items": 50000,
        "inventory": 2000
    }
}

await memory.set("db_info", json.dumps(db_info))
print("Database information stored in memory")

# Initialize empty refinement history
await memory.set("refinement_history", json.dumps([]))
print("Refinement history initialized")

2025-05-21 17:39:46,546 - root - INFO - [KeyValueMemory] Memory cleared.


Full database schema stored in memory
Database information stored in memory
Refinement history initialized


## 5. Test SQL Refiner with a Query Needing Optimization

Let's test our refiner with a SQL query that could be improved.

In [7]:
# Define the original query intent
original_query = "Find all customers who have placed orders with a total amount greater than $100"
await memory.set("original_query", original_query)

# A SQL query with room for improvement
unoptimized_sql = """
SELECT c.name, c.email 
FROM customers c
WHERE c.customer_id IN (
    SELECT customer_id 
    FROM orders 
    WHERE total_amount > 100
)
"""

await memory.set("original_sql", unoptimized_sql)

# Create a task for the refiner
task1 = json.dumps({
    "sql": unoptimized_sql,
    "query": original_query,
    "optimize_for": "readability and performance"
})

# Create a cancellation token
cancellation_token = CancellationToken()

# Create the proper arguments object
args = MemoryAgentToolArgs(task=task1)

# Run the agent
result1 = await sql_refiner_tool.run(
    args=args,
    cancellation_token=cancellation_token
)

# Display result
print(f"\nAgent Response:\n{result1.messages[-1].content}")

2025-05-21 17:39:50,033 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


Parsing refiner result: The original SQL query is mostly correct, but there are a few optimizations and improvements that ca...
Stored refined SQL in memory: SELECT c.name, c.email
FROM customers c
WHERE EXISTS (
    SELECT 1
    FROM orders o
    WHERE o.cu...
Updated refinement history
Stored optimization notes

Agent Response:
The original SQL query is mostly correct, but there are a few optimizations and improvements that can be made for better performance and readability. Here's the refined query:

```sql
SELECT c.name, c.email
FROM customers c
WHERE EXISTS (
    SELECT 1
    FROM orders o
    WHERE o.customer_id = c.customer_id
      AND o.total_amount > 100
);
```

<optimization_notes>
- **Replaced IN with EXISTS**: Using `EXISTS` is generally more efficient than `IN` when checking for the existence of rows in a subquery, especially when dealing with large datasets. This is because `EXISTS` can stop processing as soon as it finds a matching row, whereas `IN` may need to process 

## 6. Test SQL Refiner with an Erroneous Query

Let's test our refiner with a SQL query that has an error that needs to be fixed.

In [8]:
# A SQL query with an error
error_sql = """
SELECT p.name, SUM(oi.quantity) as total_ordered
FROM products p
JOIN ordered_items oi ON p.product_id = oi.product_id  -- Error: table name should be order_items
GROUP BY p.product_id
ORDER BY total_ordered DESC
LIMIT 10
"""

# Simulated error info
error_info = "Error: no such table: ordered_items"
await memory.set("sql_error_info", error_info)

# Create a task for the refiner
task2 = json.dumps({
    "sql": error_sql,
    "query": "Find the top 10 most ordered products",
    "error": error_info
})

# Create the proper arguments object
args = MemoryAgentToolArgs(task=task2)

# Run the agent
result2 = await sql_refiner_tool.run(
    args=args,
    cancellation_token=cancellation_token
)

# Display result
print(f"\nAgent Response:\n{result2.messages[-1].content}")

2025-05-21 17:39:54,026 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


Parsing refiner result: The error in the provided SQL query is due to the incorrect table name `ordered_items`. The correct ...
Stored refined SQL in memory: SELECT p.name, SUM(oi.quantity) AS total_ordered
FROM products p
JOIN order_items oi ON p.product_id...
Updated refinement history
Stored optimization notes

Agent Response:
The error in the provided SQL query is due to the incorrect table name `ordered_items`. The correct table name is `order_items`. Let's refine the query to fix this error and ensure optimal performance:

```sql
SELECT p.name, SUM(oi.quantity) AS total_ordered
FROM products p
JOIN order_items oi ON p.product_id = oi.product_id  -- Corrected table name
GROUP BY p.product_id
ORDER BY total_ordered DESC
LIMIT 10;
```

<optimization_notes>
- **Corrected Table Name**: Changed `ordered_items` to `order_items` to match the schema.
- **Index Usage**: The query benefits from the existing index on `order_items.product_id` (`idx_items_product`), which helps speed up the jo

## 7. Test SQL Refiner with Execution Feedback

Let's test our refiner with a SQL query that produces unexpected results.

In [9]:
# A SQL query with unexpected results
unexpected_sql = """
SELECT c.name, COUNT(*) as order_count
FROM customers c
JOIN orders o ON c.customer_id = o.customer_id
GROUP BY c.name
HAVING order_count > 5
"""

# Simulated execution results
execution_results = """
Query executed successfully but returned 0 rows. The database has 5000 customers and 20000 orders.
Sample data:
- customers table has customers with multiple orders
- Average orders per customer is 4
- Some customers have up to 20 orders
"""
await memory.set("sql_execution_results", execution_results)

# Create a task for the refiner
task3 = json.dumps({
    "sql": unexpected_sql,
    "query": "Find customers who have placed more than 5 orders",
    "execution_results": execution_results,
    "feedback": "Query returned no results but we expect to see customers with more than 5 orders"
})

# Create the proper arguments object
args = MemoryAgentToolArgs(task=task3)

# Run the agent
result3 = await sql_refiner_tool.run(
    args=args,
    cancellation_token=cancellation_token
)

# Display result
print(f"\nAgent Response:\n{result3.messages[-1].content}")

2025-05-21 17:39:57,713 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


Parsing refiner result: The query provided is intended to find customers who have placed more than 5 orders. However, it ret...
Stored refined SQL in memory: SELECT c.name, COUNT(o.order_id) AS order_count
FROM customers c
JOIN orders o ON c.customer_id = o....
Updated refinement history
Stored optimization notes

Agent Response:
The query provided is intended to find customers who have placed more than 5 orders. However, it returned no results, which is unexpected given the sample data. Let's refine the query to ensure it accurately reflects the intended logic and correct any potential issues:

```sql
SELECT c.name, COUNT(o.order_id) AS order_count
FROM customers c
JOIN orders o ON c.customer_id = o.customer_id
GROUP BY c.customer_id, c.name
HAVING COUNT(o.order_id) > 5;
```

<optimization_notes>
- **Corrected Grouping**: The original query grouped by `c.name`, which could lead to incorrect results if there are customers with the same name. Grouping by `c.customer_id` ensures that eac

## 8. Check Memory Contents

Let's examine what's stored in memory after our agent runs.

In [10]:
# Check the refined SQL
refined_sql = await memory.get("refined_sql")
print(f"Most recent refined SQL:\n{refined_sql}\n")

# Check optimization notes if available
optimization_notes = await memory.get("optimization_notes")
if optimization_notes:
    print(f"Optimization Notes:\n{optimization_notes}\n")
else:
    print("No optimization notes available\n")
    
# Get refinement history
refinement_history_json = await memory.get("refinement_history")
if refinement_history_json:
    refinement_history = json.loads(refinement_history_json)
    print(f"Refinement History ({len(refinement_history)} entries):")
    for i, entry in enumerate(refinement_history):
        print(f"\nEntry {i+1}:")
        print(f"Timestamp: {entry['timestamp']}")
        print(f"Original SQL: {entry['original_sql'][:50]}...")
        print(f"Refined SQL: {entry['refined_sql'][:50]}...")
else:
    print("No refinement history found in memory")

Most recent refined SQL:
SELECT c.name, COUNT(o.order_id) AS order_count
FROM customers c
JOIN orders o ON c.customer_id = o.customer_id
GROUP BY c.customer_id, c.name
HAVING COUNT(o.order_id) > 5;

Optimization Notes:
- **Corrected Grouping**: The original query grouped by `c.name`, which could lead to incorrect results if there are customers with the same name. Grouping by `c.customer_id` ensures that each customer is uniquely identified.
- **Explicit Count**: Changed `COUNT(*)` to `COUNT(o.order_id)` to explicitly count the number of orders per customer, which is more semantically clear.
- **Index Usage**: The query benefits from the existing index on `orders.customer_id` (`idx_orders_customer`), which helps speed up the join operation.
- **Readability**: The query is formatted for better readability, with clear alignment and indentation.

Refinement History (3 entries):

Entry 1:
Timestamp: 2025-05-21 17:39:50.037501
Original SQL: 
SELECT c.name, c.email 
FROM customers c
WHERE c..

## 9. Demonstrate Using Refinement Results

Let's demonstrate how we might use the refined SQL in a workflow.

In [11]:
# Get the most recent refined SQL
refined_sql = await memory.get("refined_sql")
if refined_sql:
    print("Using the refined SQL in a workflow:")
    print(f"\nRefined SQL:\n{refined_sql}")
    
    # In a real workflow, we would execute this SQL against a database
    print("\nIn a real workflow:")
    print("1. We would execute this SQL against the database")
    print("2. Process the results for the user")
    print("3. Store the execution metrics for future optimization")
    
    # Extract any optimization info for the user
    optimization_notes = await memory.get("optimization_notes")
    if optimization_notes:
        print(f"\nOptimization insights for the user:\n{optimization_notes}")
else:
    print("No refined SQL available to use")

Using the refined SQL in a workflow:

Refined SQL:
SELECT c.name, COUNT(o.order_id) AS order_count
FROM customers c
JOIN orders o ON c.customer_id = o.customer_id
GROUP BY c.customer_id, c.name
HAVING COUNT(o.order_id) > 5;

In a real workflow:
1. We would execute this SQL against the database
2. Process the results for the user
3. Store the execution metrics for future optimization

Optimization insights for the user:
- **Corrected Grouping**: The original query grouped by `c.name`, which could lead to incorrect results if there are customers with the same name. Grouping by `c.customer_id` ensures that each customer is uniquely identified.
- **Explicit Count**: Changed `COUNT(*)` to `COUNT(o.order_id)` to explicitly count the number of orders per customer, which is more semantically clear.
- **Index Usage**: The query benefits from the existing index on `orders.customer_id` (`idx_orders_customer`), which helps speed up the join operation.
- **Readability**: The query is formatted for 

## 10. Conclusion

This notebook demonstrates how the `MemoryAgentTool` allows a SQL refiner agent to improve and optimize queries while maintaining context through memory. The key features demonstrated include:

1. Reading schema, original SQL, and error information from memory before refinement
2. Using execution feedback to guide refinements
3. Storing optimized SQL and optimization notes after agent execution
4. Maintaining a history of refinements for context and learning

The SQL refiner agent demonstrated three key capabilities:
1. **Optimization**: Improving a working but suboptimal query
2. **Error correction**: Fixing syntax and schema-related errors
3. **Results-driven refinement**: Adjusting queries based on unexpected execution results

This pattern is essential in real-world text-to-SQL workflows, where initial SQL generation may not be perfect and requires iterative refinement based on database feedback.