# SQL Generator Agent Demo

This notebook demonstrates the SQLGeneratorAgent, which generates SQL queries based on:
- Query intent from the query tree node
- Schema linking information (tables, columns, joins)
- Database schema metadata

## Key Features
1. **LLM-driven SQL generation** - Uses OpenAI GPT-4o to generate SQL
2. **No hardcoded logic** - Agent only formats context and extracts LLM output
3. **XML-based output parsing** - Robust parsing of `<generation>` tags
4. **Schema-aware generation** - Uses schema linking from previous agents
5. **Support for various SQL types** - Simple queries, joins, aggregations, subqueries

## Workflow Integration
In the full text-to-SQL workflow:
1. QueryAnalyzerAgent analyzes the user query and creates nodes
2. SchemaLinkerAgent identifies relevant tables and columns
3. **SQLGeneratorAgent generates SQL based on the schema linking** ← This notebook
4. SQLEvaluatorAgent executes and evaluates the results

In [1]:
import os
import sys
import asyncio
import sqlite3
import json
import logging
import re
from typing import Dict, Any, List, Optional
from dotenv import load_dotenv

sys.path.append('../src')
load_dotenv()

# Set up logging
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# Reduce noise from autogen
logging.getLogger('autogen_core').setLevel(logging.WARNING)

In [2]:
from pathlib import Path
from keyvalue_memory import KeyValueMemory
from task_context_manager import TaskContextManager
from query_tree_manager import QueryTreeManager
from database_schema_manager import DatabaseSchemaManager
from node_history_manager import NodeHistoryManager
from query_analyzer_agent import QueryAnalyzerAgent
from schema_reader import SchemaReader
from memory_content_types import (
    TaskContext, QueryNode, NodeStatus, TaskStatus,
    TableSchema, ColumnInfo
)
from sql_generator_agent import SQLGeneratorAgent

data_path = "/home/norman/work/text-to-sql/MAC-SQL/data/bird"
tables_json_path = Path(data_path) / "dev_tables.json"
db_name = "california_schools"

In [3]:
task_id = "experimental-test"

query = "What is the highest eligible free rate for K-12 students in schools in Alameda County?"
intent = "Find the maximum eligible free rate for K-12 students in schools located in Alameda County"
memory = KeyValueMemory()
        
# Initialize task
task_manager = TaskContextManager(memory)
await task_manager.initialize(task_id, query, db_name)

# Load schema
schema_manager = DatabaseSchemaManager(memory)
await schema_manager.initialize()

schema_reader = SchemaReader(
    data_path=data_path,
    tables_json_path=str(tables_json_path),
    dataset_name="bird",
    lazy=False
)
await schema_manager.load_from_schema_reader(schema_reader, db_name)

# Initialize query tree
tree_manager = QueryTreeManager(memory)
node_id = await tree_manager.initialize(intent)
await tree_manager.set_current_node_id(node_id)

# Create schema linking for the node (simulating what schema linker would do)
schema_linking = {
    "selected_tables": [
        {
            "name": "frpm",
            "alias": "f",
            "purpose": "To find the highest eligible free rate for K-12 students in Alameda County",
            "columns": [
                {
                    "name": "County Name",
                    "used_for": "filter",
                    "reason": "To filter the records for Alameda County"
                },
                {
                    "name": "Percent (%) Eligible Free (K-12)",
                    "used_for": "aggregate",
                    "reason": "To determine the highest eligible free rate for K-12 students"
                }
            ]
        }
    ],
    "joins": [],
    "sample_query_pattern": 'SELECT MAX(f."Percent (%) Eligible Free (K-12)") FROM frpm AS f WHERE f."County Name" = \'Alameda\''
}

# Update node with schema linking
await tree_manager.update_node(node_id, {"schema_linking": schema_linking})
print(f"Schema linking added to node {node_id}")

2025-05-30 00:42:11,453 - TaskContextManager - INFO - Initialized task context for task experimental-test
2025-05-30 00:42:11,453 - DatabaseSchemaManager - INFO - Initialized empty database schema


load json file from /home/norman/work/text-to-sql/MAC-SQL/data/bird/dev_tables.json

Loading all database info...
Found 11 databases in bird dataset


2025-05-30 00:42:23,891 - DatabaseSchemaManager - INFO - Initialized empty database schema
2025-05-30 00:42:23,892 - DatabaseSchemaManager - INFO - Added table 'frpm' to schema
2025-05-30 00:42:23,892 - DatabaseSchemaManager - INFO - Added table 'satscores' to schema
2025-05-30 00:42:23,893 - DatabaseSchemaManager - INFO - Added table 'schools' to schema
2025-05-30 00:42:23,893 - DatabaseSchemaManager - INFO - Loaded schema for database 'california_schools' with 3 tables
2025-05-30 00:42:23,893 - QueryTreeManager - INFO - Initialized query tree with root node root
2025-05-30 00:42:23,893 - QueryTreeManager - INFO - Set current node to root
2025-05-30 00:42:23,894 - QueryTreeManager - INFO - Updated node root


Schema linking added to node root


In [4]:
# Create SQL Generator Agent
agent = SQLGeneratorAgent(memory, llm_config={
    "model_name": "gpt-4o",
    "temperature": 0.1,
    "timeout": 60
}, debug=True)

print("SQLGeneratorAgent created successfully")
print(f"Agent name: {agent.agent_name}")
print(f"Managers initialized: schema_manager, tree_manager, history_manager")

2025-05-30 00:42:23,914 - SQLGeneratorAgent - DEBUG - Created AssistantAgent: sql_generator
2025-05-30 00:42:23,914 - SQLGeneratorAgent - DEBUG - Created MemoryAgentTool for sql_generator
2025-05-30 00:42:23,915 - SQLGeneratorAgent - INFO - Initialized sql_generator with model gpt-4o


SQLGeneratorAgent created successfully
Agent name: sql_generator
Managers initialized: schema_manager, tree_manager, history_manager


In [5]:
# Run SQL generation
print("Running SQL generation for node:", node_id)
print(f"Intent: {intent}")
print("\nCalling LLM to generate SQL...")

result = await agent.run("Generate SQL for the current node")

# Check the results
node = await tree_manager.get_node(node_id)
if node and hasattr(node, 'generation') and node.generation:
    print("\n✅ SQL Generation Successful!")
    print(f"Generated SQL: {node.generation.get('sql', 'No SQL')}")
    print(f"Query Type: {node.generation.get('query_type', 'Unknown')}")
else:
    print("\n❌ SQL generation failed - no results found")

2025-05-30 00:42:23,917 - SQLGeneratorAgent - INFO - SQL generator context prepared for node: root
2025-05-30 00:42:23,917 - SQLGeneratorAgent - INFO - Node detail: {'nodeId': 'root', 'status': 'created', 'childIds': [], 'intent': 'Find the maximum eligible free rate for K-12 students in schools located in Alameda County', 'schema_linking': {'selected_tables': [{'name': 'frpm', 'alias': 'f', 'purpose': 'To find the highest eligible free rate for K-12 students in Alameda County', 'columns': [{'name': 'County Name', 'used_for': 'filter', 'reason': 'To filter the records for Alameda County'}, {'name': 'Percent (%) Eligible Free (K-12)', 'used_for': 'aggregate', 'reason': 'To determine the highest eligible free rate for K-12 students'}]}], 'joins': [], 'sample_query_pattern': 'SELECT MAX(f."Percent (%) Eligible Free (K-12)") FROM frpm AS f WHERE f."County Name" = \'Alameda\''}, 'generation': {}, 'evaluation': {}}


Running SQL generation for node: root
Intent: Find the maximum eligible free rate for K-12 students in schools located in Alameda County

Calling LLM to generate SQL...


2025-05-30 00:42:27,053 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-30 00:42:27,056 - SQLGeneratorAgent - INFO - Raw LLM output: <generation>
  <query_type>simple</query_type>
  <sql>
    SELECT MAX(f."Percent (%) Eligible Free (K-12)") AS max_eligible_free_rate
    FROM frpm AS f
    WHERE f."County Name" = 'Alameda'
  </sql>
  <explanation>
    The query selects the maximum value of the "Percent (%) Eligible Free (K-12)" column from the "frpm" table for records where the "County Name" is 'Alameda'. This directly addresses the intent to find the highest eligible free rate for K-12 students in schools located in Alameda County.
  </explanation>
  <considerations>
    - Assumptions made: The "frpm" table contains the necessary data for the query.
    - Limitations: The query assumes that the "County Name" is correctly spelled and capitalized as 'Alameda'.
    - Data type formatting applied: The county name is a TEXT type and i


✅ SQL Generation Successful!
Generated SQL: SELECT MAX(f."Percent (%) Eligible Free (K-12)") AS max_eligible_free_rate FROM frpm AS f WHERE f."County Name" = 'Alameda'
Query Type: simple


In [6]:
# Inspect Memory Contents
print("\n" + "="*60)
print("MEMORY INSPECTION")
print("="*60)

# Show all memory contents
memory_contents = await memory.show_all('full')
print("\n📦 Full Memory Contents:")
print(memory_contents)

print("\n" + "-"*40)
print("Key Memory Components:")
print("-"*40)

# 1. Task Context
task_context = await memory.get("taskContext")
if task_context:
    print("\n1️⃣ Task Context:")
    print(f"   Task ID: {task_context.get('taskId')}")
    print(f"   Query: {task_context.get('query')}")
    print(f"   Database: {task_context.get('databaseId')}")
    print(f"   Status: {task_context.get('status')}")

# 2. Database Schema
db_schema = await memory.get("databaseSchema")
if db_schema:
    print("\n2️⃣ Database Schema:")
    tables = db_schema.get('tables', {})
    print(f"   Tables: {list(tables.keys())}")
    for table_name, table_info in tables.items():
        columns = table_info.get('columns', {})
        print(f"   - {table_name}: {len(columns)} columns")

# 3. Query Tree
query_tree = await memory.get("queryTree")
if query_tree:
    print("\n3️⃣ Query Tree:")
    print(f"   Root ID: {query_tree.get('rootId')}")
    print(f"   Current Node ID: {query_tree.get('currentNodeId')}")
    nodes = query_tree.get('nodes', {})
    print(f"   Total Nodes: {len(nodes)}")
    
    # Show current node details
    current_node_id = query_tree.get('currentNodeId')
    if current_node_id and current_node_id in nodes:
        current_node = nodes[current_node_id]
        print(f"\n   📍 Current Node ({current_node_id}):")
        print(f"      Intent: {current_node.get('intent', 'N/A')}")
        print(f"      Status: {current_node.get('status', 'N/A')}")
        
        # Show schema linking if present
        if 'schema_linking' in current_node:
            print(f"      Schema Linking: ✅ Present")
            schema_linking = current_node['schema_linking']
            if 'selected_tables' in schema_linking:
                tables = schema_linking['selected_tables']
                print(f"        - Tables: {[t['name'] for t in tables]}")
        
        # Show generation if present
        if 'generation' in current_node:
            print(f"      Generation: ✅ Present")
            generation = current_node['generation']
            sql = generation.get('sql', '')
            if sql:
                print(f"        - SQL: {sql[:60]}..." if len(sql) > 60 else f"        - SQL: {sql}")
                print(f"        - Query Type: {generation.get('query_type', 'N/A')}")

# 4. Node History
node_history = await memory.get("nodeHistory")
if node_history:
    print(f"\n4️⃣ Node History: {len(node_history)} operations recorded")
    # Show last 3 operations
    for op in node_history[-3:]:
        print(f"   - {op.get('operation', 'N/A')} on {op.get('nodeId', 'N/A')} at {op.get('timestamp', 'N/A')}")

print("\n" + "="*60)


MEMORY INSPECTION

📦 Full Memory Contents:
=== Memory Store Detailed Contents ===
Total items in store: 13

Key: nodeHistory
  MIME Type: MemoryMimeType.JSON
  Metadata: {'variable_name': 'nodeHistory'}
  Value: [
  {
    "timestamp": "2025-05-30T00:42:27.057802",
    "nodeId": "root",
    "operation": "generate_sql",
    "data": {
      "nodeId": "root",
      "status": "sql_generated",
      "intent": "Find the maximum eligible free rate for K-12 students in schools located in Alameda County",
      "parentId": null,
      "childIds": [],
      "schema_linking": {
        "selected_tables": [
          {
            "name": "frpm",
            "alias": "f",
            "purpose": "To find the highest eligible free rate for K-12 students in Alameda County",
            "columns": [
              {
                "name": "County Name",
                "used_for": "filter",
                "reason": "To filter the records for Alameda County"
              },
              {
          

In [7]:
# Display detailed results
if result and result.messages:
    print("\n" + "="*60)
    print("DETAILED SQL GENERATION RESULTS")
    print("="*60)
    
    # Show the last message (LLM response)
    last_message = result.messages[-1]
    print(f"\n[{getattr(last_message, 'source', 'Assistant')}]:")
    print(last_message.content[:500] + "..." if len(last_message.content) > 500 else last_message.content)
    
    # Get full generation details from node
    node = await tree_manager.get_node(node_id)
    if node and hasattr(node, 'generation') and node.generation:
        generation = node.generation
        
        print("\n" + "-"*40)
        print("EXTRACTED GENERATION DATA:")
        print("-"*40)
        
        print(f"\n📊 Query Type: {generation.get('query_type', 'Unknown')}")
        
        print(f"\n💾 Generated SQL:")
        print(generation.get('sql', 'No SQL generated'))
        
        explanation = generation.get('explanation', '')
        if explanation:
            print(f"\n📝 Explanation:")
            print(explanation)
        
        considerations = generation.get('considerations', '')
        if considerations:
            print(f"\n⚠️  Considerations:")
            print(considerations)
        
        print("\n" + "="*60)
else:
    print("No messages in result")


DETAILED SQL GENERATION RESULTS

[sql_generator]:
<generation>
  <query_type>simple</query_type>
  <sql>
    SELECT MAX(f."Percent (%) Eligible Free (K-12)") AS max_eligible_free_rate
    FROM frpm AS f
    WHERE f."County Name" = 'Alameda'
  </sql>
  <explanation>
    The query selects the maximum value of the "Percent (%) Eligible Free (K-12)" column from the "frpm" table for records where the "County Name" is 'Alameda'. This directly addresses the intent to find the highest eligible free rate for K-12 students in schools located in Alameda C...

----------------------------------------
EXTRACTED GENERATION DATA:
----------------------------------------

📊 Query Type: simple

💾 Generated SQL:
SELECT MAX(f."Percent (%) Eligible Free (K-12)") AS max_eligible_free_rate FROM frpm AS f WHERE f."County Name" = 'Alameda'

📝 Explanation:
The query selects the maximum value of the "Percent (%) Eligible Free (K-12)" column from the "frpm" table for records where the "County Name" is 'Alameda'

In [9]:
# Quick Memory Inspection for Example 2
print("\n📊 Memory2 Key Contents (Example 2 - Joins):")
print("-" * 40)

# Check the query tree for the join example
query_tree2 = await memory.get("queryTree")
if query_tree2:
    current_node_id2 = query_tree2.get('currentNodeId')
    nodes2 = query_tree2.get('nodes', {})
    
    if current_node_id2 and current_node_id2 in nodes2:
        node2_data = nodes2[current_node_id2]
        
        # Show schema linking
        if 'schema_linking' in node2_data:
            schema_linking2 = node2_data['schema_linking']
            print("Schema Linking:")
            if 'joins' in schema_linking2 and schema_linking2['joins']:
                for join in schema_linking2['joins']:
                    print(f"  - JOIN: {join['from_table']} → {join['to_table']} ON {join['on']}")
        
        # Show generated SQL
        if 'generation' in node2_data:
            gen = node2_data['generation']
            print(f"\nGenerated SQL Type: {gen.get('query_type', 'N/A')}")
            print(f"Tables Used: {', '.join(t['name'] for t in schema_linking2.get('selected_tables', []))}")


📊 Memory2 Key Contents (Example 2 - Joins):
----------------------------------------
Schema Linking:

Generated SQL Type: simple
Tables Used: frpm


In [10]:
# Example 3: SQL Generation with Previous Error (Retry Scenario)
print("\n" + "="*60)
print("EXAMPLE 3: SQL RETRY WITH ERROR CONTEXT")
print("="*60)

# Create a node that has a previous SQL attempt with an error
memory3 = KeyValueMemory()
task_manager3 = TaskContextManager(memory3)
await task_manager3.initialize("test3", "Complex aggregation query", db_name)

schema_manager3 = DatabaseSchemaManager(memory3)
await schema_manager3.initialize()
await schema_manager3.load_from_schema_reader(schema_reader, db_name)

tree_manager3 = QueryTreeManager(memory3)
node_id3 = await tree_manager3.initialize("Find schools with above-average free meal rates")
await tree_manager3.set_current_node_id(node_id3)

# Add schema linking
schema_linking3 = {
    "selected_tables": [
        {
            "name": "frpm",
            "alias": "f",
            "purpose": "To calculate free meal rates"
        }
    ]
}

# Simulate a previous failed attempt
previous_sql = 'SELECT * FROM frpm WHERE "Free Meal Rate" > AVG("Free Meal Rate")'
execution_error = {
    "error": "misuse of aggregate function AVG()",
    "error_type": "sqlite3.OperationalError"
}

# Update node with previous attempt
await tree_manager3.update_node(node_id3, {
    "schema_linking": schema_linking3,
    "sql": previous_sql,
    "executionResult": execution_error,
    "status": "executed_failed"
})

print(f"Previous SQL (failed): {previous_sql}")
print(f"Error: {execution_error['error']}")

# Generate improved SQL
agent3 = SQLGeneratorAgent(memory3, llm_config={
    "model_name": "gpt-4o",
    "temperature": 0.1,
    "timeout": 60
}, debug=False)

print("\nGenerating improved SQL based on error...")
result3 = await agent3.run("Generate improved SQL based on error")

# Show improved results
node3 = await tree_manager3.get_node(node_id3)
if node3 and hasattr(node3, 'generation') and node3.generation:
    print(f"\n✅ Improved SQL Generated:")
    print(node3.generation.get('sql', 'No SQL'))
    
    explanation = node3.generation.get('explanation', '')
    if explanation:
        print(f"\nHow it fixes the error:")
        print(explanation[:200] + "..." if len(explanation) > 200 else explanation)

2025-05-30 00:43:58,805 - TaskContextManager - INFO - Initialized task context for task test3
2025-05-30 00:43:58,806 - DatabaseSchemaManager - INFO - Initialized empty database schema
2025-05-30 00:43:58,807 - DatabaseSchemaManager - INFO - Initialized empty database schema
2025-05-30 00:43:58,807 - DatabaseSchemaManager - INFO - Added table 'frpm' to schema
2025-05-30 00:43:58,807 - DatabaseSchemaManager - INFO - Added table 'satscores' to schema
2025-05-30 00:43:58,808 - DatabaseSchemaManager - INFO - Added table 'schools' to schema
2025-05-30 00:43:58,808 - DatabaseSchemaManager - INFO - Loaded schema for database 'california_schools' with 3 tables
2025-05-30 00:43:58,808 - QueryTreeManager - INFO - Initialized query tree with root node root
2025-05-30 00:43:58,809 - QueryTreeManager - INFO - Set current node to root
2025-05-30 00:43:58,809 - QueryTreeManager - INFO - Updated node root
2025-05-30 00:43:58,820 - SQLGeneratorAgent - DEBUG - Created AssistantAgent: sql_generator
2025-


EXAMPLE 3: SQL RETRY WITH ERROR CONTEXT
Previous SQL (failed): SELECT * FROM frpm WHERE "Free Meal Rate" > AVG("Free Meal Rate")
Error: misuse of aggregate function AVG()

Generating improved SQL based on error...


2025-05-30 00:44:03,924 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-30 00:44:03,925 - SQLGeneratorAgent - INFO - Raw LLM output: <generation>
  <query_type>aggregate</query_type>
  <sql>
    SELECT f.school_name, f.free_meal_rate
    FROM frpm AS f
    WHERE f.free_meal_rate > (
      SELECT AVG(free_meal_rate) FROM frpm
    )
  </sql>
  <explanation>
    The query selects schools with a free meal rate above the average free meal rate across all schools. It uses an aggregate subquery to calculate the average free meal rate and filters the schools based on this calculated average.
  </explanation>
  <considerations>
    - Assumptions made: The column `free_meal_rate` exists in the `frpm` table and represents the rate of free meals.
    - Limitations: This query assumes that `free_meal_rate` is a numeric column that can be averaged.
    - Changes from previous attempt: Ensured the use of an aggregate subquery to calculate the a


✅ Improved SQL Generated:
SELECT f.school_name, f.free_meal_rate FROM frpm AS f WHERE f.free_meal_rate > ( SELECT AVG(free_meal_rate) FROM frpm )

How it fixes the error:
The query selects schools with a free meal rate above the average free meal rate across all schools. It uses an aggregate subquery to calculate the average free meal rate and filters the schools based...


## Summary

The SQLGeneratorAgent demonstrates:

1. **Pure LLM-driven approach** - No hardcoded SQL logic or optimization
2. **Context-aware generation** - Uses schema linking from previous agents
3. **Error handling** - Can improve SQL based on previous execution errors
4. **Flexible output** - Generates various SQL types based on requirements

### Key Architecture Points:
- Inherits from `BaseMemoryAgent`
- Uses managers: `DatabaseSchemaManager`, `QueryTreeManager`, `NodeHistoryManager`
- Parses XML output with `<generation>` tags
- Stores results in query tree nodes

### Integration with Workflow:
- Reads from current node's `schema_linking` field
- Writes to node's `generation` field with SQL, explanation, and considerations
- Updates node status and records history

In [11]:
# Memory Architecture Summary
print("\n" + "="*60)
print("MEMORY ARCHITECTURE IN SQL GENERATOR WORKFLOW")
print("="*60)

print("""
The SQLGeneratorAgent interacts with KeyValueMemory through managers:

📁 KeyValueMemory
   │
   ├── 📋 taskContext
   │     └── {taskId, query, databaseId, status, ...}
   │
   ├── 🗃️ databaseSchema  
   │     └── {tables: {tableName: {columns, indexes, ...}}}
   │
   ├── 🌳 queryTree
   │     ├── rootId: "root"
   │     ├── currentNodeId: "root" 
   │     └── nodes: {
   │           "root": {
   │               intent: "...",
   │               status: "sql_generated",
   │               schema_linking: {...},    ← Input for SQLGenerator
   │               generation: {             ← Output from SQLGenerator
   │                   sql: "SELECT ...",
   │                   query_type: "simple",
   │                   explanation: "...",
   │                   considerations: "..."
   │               }
   │           }
   │         }
   │
   └── 📜 nodeHistory
         └── [{operation: "generate_sql", nodeId: "root", ...}]

Key Points:
- SQLGeneratorAgent reads from node's 'schema_linking' field
- SQLGeneratorAgent writes to node's 'generation' field
- All access is through managers (no direct memory access)
- Node status is updated after SQL generation
""")


MEMORY ARCHITECTURE IN SQL GENERATOR WORKFLOW

The SQLGeneratorAgent interacts with KeyValueMemory through managers:

📁 KeyValueMemory
   │
   ├── 📋 taskContext
   │     └── {taskId, query, databaseId, status, ...}
   │
   ├── 🗃️ databaseSchema  
   │     └── {tables: {tableName: {columns, indexes, ...}}}
   │
   ├── 🌳 queryTree
   │     ├── rootId: "root"
   │     ├── currentNodeId: "root" 
   │     └── nodes: {
   │           "root": {
   │               intent: "...",
   │               status: "sql_generated",
   │               schema_linking: {...},    ← Input for SQLGenerator
   │               generation: {             ← Output from SQLGenerator
   │                   sql: "SELECT ...",
   │                   query_type: "simple",
   │                   explanation: "...",
   │                   considerations: "..."
   │               }
   │           }
   │         }
   │
   └── 📜 nodeHistory
         └── [{operation: "generate_sql", nodeId: "root", ...}]

Key Points:
- SQLG

In [12]:
# Example 2: SQL Generation with Joins
print("\n" + "="*60)
print("EXAMPLE 2: SQL WITH JOINS")
print("="*60)

# Create a new node with join requirements
memory2 = KeyValueMemory()
task_manager2 = TaskContextManager(memory2)
await task_manager2.initialize("test2", "Average SAT scores by county", db_name)

schema_manager2 = DatabaseSchemaManager(memory2)
await schema_manager2.initialize()
await schema_manager2.load_from_schema_reader(schema_reader, db_name)

tree_manager2 = QueryTreeManager(memory2)
node_id2 = await tree_manager2.initialize("Find average SAT scores for schools in each county")
await tree_manager2.set_current_node_id(node_id2)

# Schema linking with joins
schema_linking2 = {
    "selected_tables": [
        {
            "name": "schools",
            "alias": "s",
            "purpose": "To get school location (county)",
            "columns": [
                {"name": "CDSCode", "used_for": "join", "reason": "Primary key"},
                {"name": "County", "used_for": "group_by", "reason": "Group by county"}
            ]
        },
        {
            "name": "satscores",
            "alias": "sat",
            "purpose": "To get SAT score data",
            "columns": [
                {"name": "cds", "used_for": "join", "reason": "Foreign key to schools"},
                {"name": "AvgScrMath", "used_for": "aggregate", "reason": "Math scores to average"},
                {"name": "AvgScrRead", "used_for": "aggregate", "reason": "Reading scores to average"}
            ]
        }
    ],
    "joins": [
        {
            "from_table": "schools",
            "to_table": "satscores", 
            "on": "schools.CDSCode = satscores.cds"
        }
    ]
}

await tree_manager2.update_node(node_id2, {"schema_linking": schema_linking2})

# Generate SQL with joins
agent2 = SQLGeneratorAgent(memory2, llm_config={
    "model_name": "gpt-4o",
    "temperature": 0.1,
    "timeout": 60
}, debug=False)

result2 = await agent2.run("Generate SQL with joins")

# Show results
node2 = await tree_manager2.get_node(node_id2)
if node2 and hasattr(node2, 'generation') and node2.generation:
    print(f"\n✅ Generated SQL with Joins:")
    print(node2.generation.get('sql', 'No SQL'))
    print(f"\nQuery Type: {node2.generation.get('query_type', 'Unknown')}")

2025-05-30 00:44:03,937 - TaskContextManager - INFO - Initialized task context for task test2
2025-05-30 00:44:03,937 - DatabaseSchemaManager - INFO - Initialized empty database schema
2025-05-30 00:44:03,937 - DatabaseSchemaManager - INFO - Initialized empty database schema
2025-05-30 00:44:03,938 - DatabaseSchemaManager - INFO - Added table 'frpm' to schema
2025-05-30 00:44:03,938 - DatabaseSchemaManager - INFO - Added table 'satscores' to schema
2025-05-30 00:44:03,939 - DatabaseSchemaManager - INFO - Added table 'schools' to schema
2025-05-30 00:44:03,939 - DatabaseSchemaManager - INFO - Loaded schema for database 'california_schools' with 3 tables
2025-05-30 00:44:03,939 - QueryTreeManager - INFO - Initialized query tree with root node root
2025-05-30 00:44:03,939 - QueryTreeManager - INFO - Set current node to root
2025-05-30 00:44:03,939 - QueryTreeManager - INFO - Updated node root
2025-05-30 00:44:03,950 - SQLGeneratorAgent - DEBUG - Created AssistantAgent: sql_generator
2025-


EXAMPLE 2: SQL WITH JOINS


2025-05-30 00:44:08,532 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-30 00:44:08,533 - SQLGeneratorAgent - INFO - Raw LLM output: <generation>
  <query_type>join</query_type>
  <sql>
    SELECT 
      s.County, 
      AVG(sat.AvgScrMath) AS AvgMathScore, 
      AVG(sat.AvgScrRead) AS AvgReadScore
    FROM 
      schools s
    JOIN 
      satscores sat ON s.CDSCode = sat.cds
    GROUP BY 
      s.County
  </sql>
  <explanation>
    The query calculates the average SAT Math and Reading scores for schools grouped by county. It joins the 'schools' table with the 'satscores' table using the CDSCode and cds columns, which are the primary and foreign keys respectively. The results are grouped by the County column from the 'schools' table.
  </explanation>
  <considerations>
    - Assumptions made: The CDSCode in 'schools' and cds in 'satscores' are correctly linked as primary and foreign keys.
    - Limitations: This query assumes th


✅ Generated SQL with Joins:
SELECT s.County, AVG(sat.AvgScrMath) AS AvgMathScore, AVG(sat.AvgScrRead) AS AvgReadScore FROM schools s JOIN satscores sat ON s.CDSCode = sat.cds GROUP BY s.County

Query Type: join
