In [1]:
# Schema Management Agent Example
"""
This notebook demonstrates creating an AI agent that uses the SchemaManager
to provide database schema information in an interactive way.
"""

from dotenv import load_dotenv

load_dotenv()

True

In [2]:
import json
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.ui import Console
from autogen_ext.models.openai import OpenAIChatCompletionClient
from schema_manager import SchemaManager
from typing import List, Dict, Any

In [3]:
# Initialize the SchemaManager
# Update these paths to match your environment
data_path = "../data/bird"
tables_json_path = f"{data_path}/dev_tables.json"
dataset_name = "bird"

schema_manager = SchemaManager(
    data_path=data_path,
    tables_json_path=tables_json_path,
    dataset_name=dataset_name,
    lazy=True  # Use lazy loading for performance
)

load json file from ../data/bird/dev_tables.json


In [4]:
# Define schema-related functions that the agent can use

async def list_databases() -> str:
    """List all available databases in the dataset."""
    db_ids = list(schema_manager.db2dbjsons.keys())
    return f"Available databases: {json.dumps(db_ids[:10])}..." if len(db_ids) > 10 else f"Available databases: {json.dumps(db_ids)}"

async def get_database_info(db_id: str) -> str:
    """Get basic information about a specific database."""
    if db_id not in schema_manager.db2dbjsons:
        return f"Database '{db_id}' not found."
    
    db_info = schema_manager.db2dbjsons[db_id]
    summary = {
        "db_id": db_id,
        "table_count": db_info['table_count'],
        "total_columns": db_info['total_column_count'],
        "avg_columns_per_table": db_info['avg_column_count'],
        "table_names": db_info['table_names_original']
    }
    return json.dumps(summary, indent=2)

async def get_table_schema(db_id: str, table_name: str) -> str:
    """Get detailed schema information for a specific table."""
    if db_id not in schema_manager.db2dbjsons:
        return f"Database '{db_id}' not found."
    
    # Load database info if not already loaded
    if db_id not in schema_manager.db2infos:
        schema_manager.db2infos[db_id] = schema_manager._load_single_db_info(db_id)
    
    db_info = schema_manager.db2infos[db_id]
    
    if table_name not in db_info['desc_dict']:
        return f"Table '{table_name}' not found in database '{db_id}'."
    
    # Get table information
    columns_desc = db_info['desc_dict'][table_name]
    columns_val = db_info['value_dict'][table_name]
    primary_keys = db_info['pk_dict'][table_name]
    foreign_keys = db_info['fk_dict'][table_name]
    
    result = {
        "table": table_name,
        "columns": [],
        "primary_keys": primary_keys,
        "foreign_keys": []
    }
    
    # Add column information
    for (col_name, full_col_name, _), (_, values_str) in zip(columns_desc, columns_val):
        col_info = {
            "name": col_name,
            "description": full_col_name,
            "sample_values": values_str
        }
        result["columns"].append(col_info)
    
    # Add foreign key information
    for from_col, to_table, to_col in foreign_keys:
        fk_info = {
            "from_column": from_col,
            "to_table": to_table,
            "to_column": to_col
        }
        result["foreign_keys"].append(fk_info)
    
    return json.dumps(result, indent=2)

async def generate_schema_xml(db_id: str, table_names: List[str] = None) -> str:
    """Generate XML schema description for a database or specific tables."""
    if db_id not in schema_manager.db2dbjsons:
        return f"Database '{db_id}' not found."
    
    # Create selected schema dict
    if table_names:
        selected_schema = {table: "keep_all" for table in table_names}
    else:
        # Include all tables
        db_info = schema_manager.db2dbjsons[db_id]
        selected_schema = {table: "keep_all" for table in db_info['table_names_original']}
    
    # Generate schema XML
    schema_xml, fk_infos, chosen_schema = schema_manager.generate_schema_description(
        db_id, selected_schema, use_gold_schema=False
    )
    
    return schema_xml

In [5]:
# Initialize the OpenAI model client
model_client = OpenAIChatCompletionClient(
    model="gpt-4o",
)

# Define the schema agent with the SchemaManager tools
schema_agent = AssistantAgent(
    name="schema_agent",
    model_client=model_client,
    tools=[
        list_databases,
        get_database_info,
        get_table_schema,
        generate_schema_xml
    ],
    system_message="""You are a database schema expert assistant. 
    You help users understand database schemas by providing:
    - Information about available databases
    - Details about specific tables and their relationships
    - Sample values for columns when available
    - Schema descriptions in XML format
    
    Use the provided tools to access schema information from the SchemaManager.""",
    reflect_on_tool_use=True,
    model_client_stream=True,
)

In [6]:
# Example: Ask about available databases
await Console(schema_agent.run_stream(task="What databases are available?"))

---------- TextMessage (user) ----------
What databases are available?
---------- ToolCallRequestEvent (schema_agent) ----------
[FunctionCall(id='call_l11hJKse77QOkmBKbUvwoi78', arguments='{}', name='list_databases')]
---------- ToolCallExecutionEvent (schema_agent) ----------
[FunctionExecutionResult(content='Available databases: ["debit_card_specializing", "financial", "formula_1", "california_schools", "card_games", "european_football_2", "thrombosis_prediction", "toxicology", "student_club", "superhero"]...', name='list_databases', call_id='call_l11hJKse77QOkmBKbUvwoi78', is_error=False)]
---------- ModelClientStreamingChunkEvent (schema_agent) ----------
The available databases are:

1. Debit Card Specializing
2. Financial
3. Formula 1
4. California Schools
5. Card Games
6. European Football 2
7. Thrombosis Prediction
8. Toxicology
9. Student Club
10. Superhero

Let me know if you need information about a specific database or table!


TaskResult(messages=[TextMessage(source='user', models_usage=None, metadata={}, content='What databases are available?', type='TextMessage'), ToolCallRequestEvent(source='schema_agent', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), metadata={}, content=[FunctionCall(id='call_l11hJKse77QOkmBKbUvwoi78', arguments='{}', name='list_databases')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='schema_agent', models_usage=None, metadata={}, content=[FunctionExecutionResult(content='Available databases: ["debit_card_specializing", "financial", "formula_1", "california_schools", "card_games", "european_football_2", "thrombosis_prediction", "toxicology", "student_club", "superhero"]...', name='list_databases', call_id='call_l11hJKse77QOkmBKbUvwoi78', is_error=False)], type='ToolCallExecutionEvent'), TextMessage(source='schema_agent', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), metadata={}, content='The available databases are:\n\n1. Debit Car

In [7]:
# Example: Get information about a specific database
await Console(schema_agent.run_stream(task="Tell me about the 'california_schools' database"))

---------- TextMessage (user) ----------
Tell me about the 'california_schools' database
---------- ToolCallRequestEvent (schema_agent) ----------
[FunctionCall(id='call_e8tcdsQfypHR6NgyIK9InQun', arguments='{"db_id":"california_schools"}', name='get_database_info')]
---------- ToolCallExecutionEvent (schema_agent) ----------
[FunctionExecutionResult(content='{\n  "db_id": "california_schools",\n  "table_count": 3,\n  "total_columns": 89,\n  "avg_columns_per_table": 29,\n  "table_names": [\n    "frpm",\n    "satscores",\n    "schools"\n  ]\n}', name='get_database_info', call_id='call_e8tcdsQfypHR6NgyIK9InQun', is_error=False)]
---------- ModelClientStreamingChunkEvent (schema_agent) ----------
The `california_schools` database contains the following information:

- **Total Tables:** 3
- **Total Columns:** 89
- **Average Columns per Table:** 29

**Tables:**
1. **frpm**: This table likely includes information related to the Free and Reduced Price Meals program.
2. **satscores**: This tab

TaskResult(messages=[TextMessage(source='user', models_usage=None, metadata={}, content="Tell me about the 'california_schools' database", type='TextMessage'), ToolCallRequestEvent(source='schema_agent', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), metadata={}, content=[FunctionCall(id='call_e8tcdsQfypHR6NgyIK9InQun', arguments='{"db_id":"california_schools"}', name='get_database_info')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='schema_agent', models_usage=None, metadata={}, content=[FunctionExecutionResult(content='{\n  "db_id": "california_schools",\n  "table_count": 3,\n  "total_columns": 89,\n  "avg_columns_per_table": 29,\n  "table_names": [\n    "frpm",\n    "satscores",\n    "schools"\n  ]\n}', name='get_database_info', call_id='call_e8tcdsQfypHR6NgyIK9InQun', is_error=False)], type='ToolCallExecutionEvent'), TextMessage(source='schema_agent', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), metadata={}, content='The `calif

In [8]:
# Example: Get detailed schema for a specific table
await Console(schema_agent.run_stream(
    task="Show me the schema for the 'schools' table in the 'california_schools' database"
))

---------- TextMessage (user) ----------
Show me the schema for the 'schools' table in the 'california_schools' database
---------- ToolCallRequestEvent (schema_agent) ----------
[FunctionCall(id='call_UsMR8tPfmuzw0TxUsB34W5IG', arguments='{"db_id":"california_schools","table_name":"schools"}', name='get_table_schema')]
---------- ToolCallExecutionEvent (schema_agent) ----------
[FunctionExecutionResult(content='{\n  "table": "schools",\n  "columns": [\n    {\n      "name": "CDSCode",\n      "description": "CDSCode",\n      "sample_values": ""\n    },\n    {\n      "name": "NCESDist",\n      "description": "National Center for Educational Statistics school district identification number",\n      "sample_values": "[None, \'0622710\', \'0634320\', \'0628050\', \'0634410\', \'0614550\', \'0633840\']"\n    },\n    {\n      "name": "NCESSchool",\n      "description": "National Center for Educational Statistics school identification number",\n      "sample_values": "[None, \'12271\', \'13785

TaskResult(messages=[TextMessage(source='user', models_usage=None, metadata={}, content="Show me the schema for the 'schools' table in the 'california_schools' database", type='TextMessage'), ToolCallRequestEvent(source='schema_agent', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), metadata={}, content=[FunctionCall(id='call_UsMR8tPfmuzw0TxUsB34W5IG', arguments='{"db_id":"california_schools","table_name":"schools"}', name='get_table_schema')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='schema_agent', models_usage=None, metadata={}, content=[FunctionExecutionResult(content='{\n  "table": "schools",\n  "columns": [\n    {\n      "name": "CDSCode",\n      "description": "CDSCode",\n      "sample_values": ""\n    },\n    {\n      "name": "NCESDist",\n      "description": "National Center for Educational Statistics school district identification number",\n      "sample_values": "[None, \'0622710\', \'0634320\', \'0628050\', \'0634410\', \'0614550\', \'063

In [9]:
# Example: Generate XML schema for specific tables
await Console(schema_agent.run_stream(
    task="Generate an XML schema description for the 'schools' and 'satscores' tables in the 'california_schools' database"
))

---------- TextMessage (user) ----------
Generate an XML schema description for the 'schools' and 'satscores' tables in the 'california_schools' database
---------- ToolCallRequestEvent (schema_agent) ----------
[FunctionCall(id='call_YswGWyVVYZgfl8zaLippQqbz', arguments='{"db_id":"california_schools","table_names":["schools","satscores"]}', name='generate_schema_xml')]
---------- ToolCallExecutionEvent (schema_agent) ----------
[FunctionExecutionResult(content='<database_schema>\n  <table name="frpm">\n    <column name="CDSCode">\n      <description>CDSCode</description>\n    </column>\n    <column name="Academic Year">\n      <description>Academic Year</description>\n      <values>[\'2014-2015\']</values>\n    </column>\n    <column name="County Code">\n      <description>County Code</description>\n      <values>[\'19\', \'37\', \'30\', \'36\', \'33\', \'43\']</values>\n    </column>\n    <column name="District Code">\n      <description>District Code</description>\n    </column>\n  

TaskResult(messages=[TextMessage(source='user', models_usage=None, metadata={}, content="Generate an XML schema description for the 'schools' and 'satscores' tables in the 'california_schools' database", type='TextMessage'), ToolCallRequestEvent(source='schema_agent', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), metadata={}, content=[FunctionCall(id='call_YswGWyVVYZgfl8zaLippQqbz', arguments='{"db_id":"california_schools","table_names":["schools","satscores"]}', name='generate_schema_xml')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='schema_agent', models_usage=None, metadata={}, content=[FunctionExecutionResult(content='<database_schema>\n  <table name="frpm">\n    <column name="CDSCode">\n      <description>CDSCode</description>\n    </column>\n    <column name="Academic Year">\n      <description>Academic Year</description>\n      <values>[\'2014-2015\']</values>\n    </column>\n    <column name="County Code">\n      <description>County Code</de

In [10]:
# Example: Complex query about relationships
await Console(schema_agent.run_stream(
    task="What are the relationships between tables in the 'california_schools' database? Which tables are connected?"
))

---------- TextMessage (user) ----------
What are the relationships between tables in the 'california_schools' database? Which tables are connected?
---------- ModelClientStreamingChunkEvent (schema_agent) ----------
In the `california_schools` database, the relationships between the tables are defined through foreign keys. Here are the connections:

1. **`schools` Table:**
   - The `CDSCode` column in the `schools` table acts as a primary key.

2. **`satscores` Table:**
   - The `cds` column in the `satscores` table is a foreign key that references the `CDSCode` in the `schools` table. This indicates that each record in the `satscores` table is associated with a specific school in the `schools` table.

3. **`frpm` Table:**
   - Likewise, the `CDSCode` column in the `frpm` table is a foreign key that references the `CDSCode` in the `schools` table. This establishes a relationship between school records and their corresponding free and reduced-price meal data.

In summary, both the `sat

TaskResult(messages=[TextMessage(source='user', models_usage=None, metadata={}, content="What are the relationships between tables in the 'california_schools' database? Which tables are connected?", type='TextMessage'), TextMessage(source='schema_agent', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), metadata={}, content='In the `california_schools` database, the relationships between the tables are defined through foreign keys. Here are the connections:\n\n1. **`schools` Table:**\n   - The `CDSCode` column in the `schools` table acts as a primary key.\n\n2. **`satscores` Table:**\n   - The `cds` column in the `satscores` table is a foreign key that references the `CDSCode` in the `schools` table. This indicates that each record in the `satscores` table is associated with a specific school in the `schools` table.\n\n3. **`frpm` Table:**\n   - Likewise, the `CDSCode` column in the `frpm` table is a foreign key that references the `CDSCode` in the `schools` table. This 

In [15]:
# Advanced Example: Create a schema analysis agent that can answer complex questions

class SchemaAnalysisAgent(AssistantAgent):
    """An advanced schema analysis agent with additional capabilities."""
    
    def __init__(self, schema_manager: SchemaManager, **kwargs):
        self.schema_manager = schema_manager
        
        # Define advanced analysis functions
        async def analyze_database_complexity(db_id: str) -> str:
            """Analyze the complexity of a database schema."""
            if db_id not in self.schema_manager.db2dbjsons:
                return f"Database '{db_id}' not found."
            
            db_info = self.schema_manager.db2dbjsons[db_id]
            
            # Determine complexity
            is_complex = self.schema_manager._is_complex_schema(db_id)
            
            analysis = {
                "db_id": db_id,
                "is_complex": is_complex,
                "metrics": {
                    "table_count": db_info['table_count'],
                    "avg_columns_per_table": db_info['avg_column_count'],
                    "max_columns_in_table": db_info['max_column_count'],
                    "total_columns": db_info['total_column_count']
                },
                "complexity_reasoning": "Complex" if is_complex else "Simple",
                "recommendation": "Schema pruning recommended" if is_complex else "No pruning needed"
            }
            
            return json.dumps(analysis, indent=2)
        
        async def find_tables_by_keyword(keyword: str) -> str:
            """Find tables that contain a specific keyword in their name or columns."""
            results = []
            
            for db_id, db_info in self.schema_manager.db2dbjsons.items():
                matching_tables = []
                
                # Check table names
                for table_name in db_info['table_names_original']:
                    if keyword.lower() in table_name.lower():
                        matching_tables.append({
                            "table": table_name,
                            "match_type": "table_name"
                        })
                
                # Check column names
                for idx, (tb_idx, col_name) in enumerate(db_info['column_names_original']):
                    if keyword.lower() in col_name.lower():
                        if tb_idx >= 0:
                            table_name = db_info['table_names_original'][tb_idx]
                            matching_tables.append({
                                "table": table_name,
                                "column": col_name,
                                "match_type": "column_name"
                            })
                
                if matching_tables:
                    results.append({
                        "database": db_id,
                        "matches": matching_tables
                    })
            
            if not results:
                return f"No tables or columns found containing '{keyword}'"
            
            return json.dumps(results[:5], indent=2)  # Limit to 5 databases
        
        # Initialize with enhanced tools
        super().__init__(
            tools=[
                list_databases,
                get_database_info,
                get_table_schema,
                generate_schema_xml,
                analyze_database_complexity,
                find_tables_by_keyword
            ],
            system_message="""You are an advanced database schema analysis expert. 
            You can:
            - Analyze database complexity and provide recommendations
            - Find tables and columns by keywords
            - Identify relationships between tables
            - Generate detailed schema documentation
            - Provide insights about database design patterns
            
            Use the available tools to provide comprehensive schema analysis.""",
            **kwargs
        )

# Create the advanced agent
schema_analysis_agent = SchemaAnalysisAgent(
    schema_manager=schema_manager,
    name="schema_analysis_agent",
    model_client=model_client,
    reflect_on_tool_use=True,
    model_client_stream=True,
)

In [16]:
# Test the advanced agent: Analyze database complexity
await Console(schema_analysis_agent.run_stream(
    task="Analyze the complexity of the 'california_schools' database"
))

---------- TextMessage (user) ----------
Analyze the complexity of the 'california_schools' database
---------- ToolCallRequestEvent (schema_analysis_agent) ----------
[FunctionCall(id='call_58wgmtJIMThanJgzI1ytMTCR', arguments='{"db_id":"california_schools"}', name='analyze_database_complexity')]
---------- ToolCallExecutionEvent (schema_analysis_agent) ----------
[FunctionExecutionResult(content='{\n  "db_id": "california_schools",\n  "is_complex": true,\n  "metrics": {\n    "table_count": 3,\n    "avg_columns_per_table": 29,\n    "max_columns_in_table": 49,\n    "total_columns": 89\n  },\n  "complexity_reasoning": "Complex",\n  "recommendation": "Schema pruning recommended"\n}', name='analyze_database_complexity', call_id='call_58wgmtJIMThanJgzI1ytMTCR', is_error=False)]
---------- ModelClientStreamingChunkEvent (schema_analysis_agent) ----------
The 'california_schools' database is determined to be complex. Here are the details of the complexity analysis:

- **Total Tables**: 3
- *

TaskResult(messages=[TextMessage(source='user', models_usage=None, metadata={}, content="Analyze the complexity of the 'california_schools' database", type='TextMessage'), ToolCallRequestEvent(source='schema_analysis_agent', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), metadata={}, content=[FunctionCall(id='call_58wgmtJIMThanJgzI1ytMTCR', arguments='{"db_id":"california_schools"}', name='analyze_database_complexity')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='schema_analysis_agent', models_usage=None, metadata={}, content=[FunctionExecutionResult(content='{\n  "db_id": "california_schools",\n  "is_complex": true,\n  "metrics": {\n    "table_count": 3,\n    "avg_columns_per_table": 29,\n    "max_columns_in_table": 49,\n    "total_columns": 89\n  },\n  "complexity_reasoning": "Complex",\n  "recommendation": "Schema pruning recommended"\n}', name='analyze_database_complexity', call_id='call_58wgmtJIMThanJgzI1ytMTCR', is_error=False)], type='ToolCall

In [17]:
# Test the advanced agent: Find tables by keyword
await Console(schema_analysis_agent.run_stream(
    task="Find all tables and columns that contain the word 'student' across all databases"
))

---------- TextMessage (user) ----------
Find all tables and columns that contain the word 'student' across all databases
---------- ToolCallRequestEvent (schema_analysis_agent) ----------
[FunctionCall(id='call_gNPp5FHz0emnGesc30IVD4hz', arguments='{"keyword":"student"}', name='find_tables_by_keyword')]
---------- ToolCallExecutionEvent (schema_analysis_agent) ----------
[FunctionExecutionResult(content="No tables or columns found containing 'student'", name='find_tables_by_keyword', call_id='call_gNPp5FHz0emnGesc30IVD4hz', is_error=False)]
---------- ModelClientStreamingChunkEvent (schema_analysis_agent) ----------
There are no tables or columns across the available databases that contain the word 'student'. If you have specific tables to search within or additional keywords, please let me know!


TaskResult(messages=[TextMessage(source='user', models_usage=None, metadata={}, content="Find all tables and columns that contain the word 'student' across all databases", type='TextMessage'), ToolCallRequestEvent(source='schema_analysis_agent', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), metadata={}, content=[FunctionCall(id='call_gNPp5FHz0emnGesc30IVD4hz', arguments='{"keyword":"student"}', name='find_tables_by_keyword')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='schema_analysis_agent', models_usage=None, metadata={}, content=[FunctionExecutionResult(content="No tables or columns found containing 'student'", name='find_tables_by_keyword', call_id='call_gNPp5FHz0emnGesc30IVD4hz', is_error=False)], type='ToolCallExecutionEvent'), TextMessage(source='schema_analysis_agent', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), metadata={}, content="There are no tables or columns across the available databases that contain the word 'stu

In [18]:
# Close the model client
await model_client.close()