# Schema Manager Test Notebook

This notebook tests the functionality of the `SchemaManager` class in the `dispatcher/schema_manager.py` file.

In [1]:
import os
import sys
import json
from typing import Dict, List, Any, Tuple

# Add the project root to the path to import the code
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

from dispatcher.schema_manager import SchemaManager

## 1. Initialize the Schema Manager

We'll initialize the SchemaManager with the Bird dataset.

In [2]:
# Define paths for the BIRD dataset
data_path = "../data/bird"
tables_json_path = "../data/bird/dev_tables.json"
dataset_name = "bird"

# Initialize the Schema Manager with lazy loading
schema_manager = SchemaManager(
    data_path=data_path,
    tables_json_path=tables_json_path,
    dataset_name=dataset_name,
    lazy=True
)

print(f"Schema Manager initialized with dataset: {dataset_name}")

load json file from ../data/bird/dev_tables.json
Schema Manager initialized with dataset: bird


## 2. Load and Examine the Database Mappings

Let's look at the database information loaded from the tables.json file.

In [3]:
# Get the number of databases
num_dbs = len(schema_manager.db2dbjsons)
print(f"Number of databases: {num_dbs}")

# Show available database IDs
db_ids = list(schema_manager.db2dbjsons.keys())
print(f"First 5 database IDs: {db_ids[:5]}")

# Example: Get information for one database
if db_ids:
    example_db_id = db_ids[0]
    example_db_info = schema_manager.db2dbjsons[example_db_id]
    print(f"\nExample database: {example_db_id}")
    print(f"Tables: {example_db_info['table_names']}")
    print(f"Number of tables: {example_db_info['table_count']}")
    print(f"Total columns: {example_db_info['total_column_count']}")
    print(f"Avg columns per table: {example_db_info['avg_column_count']}")

Number of databases: 11
First 5 database IDs: ['debit_card_specializing', 'financial', 'formula_1', 'california_schools', 'card_games']

Example database: debit_card_specializing
Tables: ['customers', 'gasstations', 'products', 'transactions_1k', 'yearmonth']
Number of tables: 5
Total columns: 21
Avg columns per table: 4


## 3. Load Database Information for a Specific DB

Now let's load detailed information for a specific database.

In [4]:
# Choose a database to examine (using the first one from the list)
if db_ids:
    test_db_id = db_ids[0]
    
    # Load database info
    db_info = schema_manager._load_single_db_info(test_db_id)
    
    print(f"Loaded database info for: {test_db_id}")
    
    # Show tables in this database
    tables = list(db_info["desc_dict"].keys())
    print(f"Tables in {test_db_id}: {tables}")
    
    # Show column info for the first table
    if tables:
        first_table = tables[-1]
        columns = db_info["desc_dict"][first_table]
        print(f"\nColumns in {first_table}:")
        for col_name, full_name, desc in columns:
            print(f"  - {col_name} ({full_name})")
            
        # Show primary keys for this table
        pk_columns = db_info["pk_dict"][first_table]
        print(f"\nPrimary keys in {first_table}: {pk_columns}")
        
        # Show foreign keys for this table
        fk_info = db_info["fk_dict"][first_table]
        print(f"\nForeign keys in {first_table}:")
        for from_col, to_table, to_col in fk_info:
            print(f"  - {from_col} -> {to_table}.{to_col}")

Loaded database info for: debit_card_specializing
Tables in debit_card_specializing: ['customers', 'gasstations', 'products', 'transactions_1k', 'yearmonth']

Columns in yearmonth:
  - CustomerID (Customer ID)
  - Date (Date)
  - Consumption (Consumption)

Primary keys in yearmonth: ['CustomerID', 'Date']

Foreign keys in yearmonth:
  - CustomerID -> customers.CustomerID


## 4. Generate Schema Description in XML Format

Now let's test generating the XML schema description for a database.

In [5]:
if db_ids:
    test_db_id = db_ids[0]
    
    # Create a simple schema selection where we keep all tables/columns
    selected_schema = {}  # Empty means keep all
    
    # Generate the schema description
    schema_xml, fk_infos, chosen_schema = schema_manager.generate_schema_description(
        db_id=test_db_id,
        selected_schema=selected_schema,
        use_gold_schema=False
    )
    
    print(f"Generated XML schema description for {test_db_id}")
    print(f"Number of foreign key relationships: {len(fk_infos)}")
    print(f"Number of tables in chosen schema: {len(chosen_schema)}")
    
    # Display a sample of the XML
    print("\nSample of the XML schema description:")
    schema_sample = '\n'.join(schema_xml.split('\n'))
    print(schema_sample)

Generated XML schema description for debit_card_specializing
Number of foreign key relationships: 1
Number of tables in chosen schema: 5

Sample of the XML schema description:
<database_schema>
  <table name="customers">
    <column name="CustomerID">
      <description>CustomerID</description>
    </column>
    <column name="Segment">
      <description>client segment</description>
      <values>['SME', 'LAM', 'KAM']</values>
    </column>
    <column name="Currency">
      <description>Currency</description>
      <values>['CZK', 'EUR']</values>
    </column>
  </table>
  <table name="gasstations">
    <column name="GasStationID">
      <description>Gas Station ID</description>
    </column>
    <column name="ChainID">
      <description>Chain ID</description>
    </column>
    <column name="Country">
      <description>Country</description>
      <values>['CZE', 'SVK']</values>
    </column>
    <column name="Segment">
      <description>chain segment</description>
      <values>['O

In [6]:
# Check complexity for all databases
complex_dbs = []
simple_dbs = []

for db_id in db_ids:
    is_complex = schema_manager._is_complex_schema(db_id)
    if is_complex:
        complex_dbs.append(db_id)
    else:
        simple_dbs.append(db_id)

print(f"Complex databases: {len(complex_dbs)} out of {len(db_ids)}")
if complex_dbs:
    print(f"Examples of complex DBs: {complex_dbs[:3]}")

print(f"\nSimple databases: {len(simple_dbs)} out of {len(db_ids)}")
if simple_dbs:
    print(f"Examples of simple DBs: {simple_dbs[:3]}")

# Test schema pruning using a complex DB if available
if complex_dbs:
    # Choose a complex database for testing pruning
    test_db_id = complex_dbs[0]
    
    # Load database info if not already loaded
    if test_db_id not in schema_manager.db2infos:
        schema_manager.db2infos[test_db_id] = schema_manager._load_single_db_info(test_db_id)
    
    # Get table names for this database
    tables = list(schema_manager.db2infos[test_db_id]["desc_dict"].keys())
    
    # Create a pruned schema selection
    pruned_schema = {}
    
    # For the first table, drop all columns (keep only 6)
    if len(tables) > 0:
        pruned_schema[tables[0]] = "drop_all"
    
    # For the second table, keep all columns
    if len(tables) > 1:
        pruned_schema[tables[1]] = "keep_all"
    
    # For the third table, select specific columns
    if len(tables) > 2:
        # Get columns for this table
        columns = [name for name, _, _ in schema_manager.db2infos[test_db_id]["desc_dict"][tables[2]]]
        
        # Select half of the columns
        selected_columns = columns[:len(columns)//2]
        pruned_schema[tables[2]] = selected_columns
    
    # Generate schema with pruning
    pruned_xml, pruned_fk, pruned_chosen = schema_manager.generate_schema_description(
        db_id=test_db_id,
        selected_schema=pruned_schema,
        use_gold_schema=False
    )
    
    # Generate full schema for comparison
    full_xml, full_fk, full_chosen = schema_manager.generate_schema_description(
        db_id=test_db_id,
        selected_schema={},
        use_gold_schema=False
    )
    
    # Compare sizes
    pruned_size = len(pruned_xml)
    full_size = len(full_xml)
    reduction = 100 - (pruned_size / full_size * 100)
    
    print(f"\nPruning results for {test_db_id}:")
    print(f"Full schema size: {full_size} characters")
    print(f"Pruned schema size: {pruned_size} characters")
    print(f"Size reduction: {reduction:.2f}%")
    
    # Compare column counts
    full_col_count = sum(len(cols) for cols in full_chosen.values())
    pruned_col_count = sum(len(cols) for cols in pruned_chosen.values())
    
    print(f"\nFull schema column count: {full_col_count}")
    print(f"Pruned schema column count: {pruned_col_count}")
    print(f"Column reduction: {100 - (pruned_col_count / full_col_count * 100):.2f}%")

Complex databases: 9 out of 11
Examples of complex DBs: ['financial', 'formula_1', 'california_schools']

Simple databases: 2 out of 11
Examples of simple DBs: ['debit_card_specializing', 'toxicology']

Pruning results for financial:
Full schema size: 8398 characters
Pruned schema size: 8271 characters
Size reduction: 1.51%

Full schema column count: 55
Pruned schema column count: 54
Column reduction: 1.82%


In [7]:
# Test the method to determine if a schema is complex and needs pruning.
    
if db_ids and complex_dbs:
    # Choose a complex database for testing pruning
    test_db_id = complex_dbs[0]
    
    # Load database info if not already loaded
    if test_db_id not in schema_manager.db2infos:
        schema_manager.db2infos[test_db_id] = schema_manager._load_single_db_info(test_db_id)
    
    # Get table names for this database
    tables = list(schema_manager.db2infos[test_db_id]["desc_dict"].keys())
    
    # Create a pruned schema selection
    pruned_schema = {}
    
    # For the first table, drop all columns (keep only 6)
    if len(tables) > 0:
        pruned_schema[tables[0]] = "drop_all"
    
    # For the second table, keep all columns
    if len(tables) > 1:
        pruned_schema[tables[1]] = "keep_all"
    
    # For the third table, select specific columns
    if len(tables) > 2:
        # Get columns for this table
        columns = [name for name, _, _ in schema_manager.db2infos[test_db_id]["desc_dict"][tables[2]]]
        
        # Select half of the columns
        selected_columns = columns[:len(columns)//2]
        pruned_schema[tables[2]] = selected_columns
    
    # Generate schema with pruning
    pruned_xml, pruned_fk, pruned_chosen = schema_manager.generate_schema_description(
        db_id=test_db_id,
        selected_schema=pruned_schema,
        use_gold_schema=False
    )
    
    # Generate full schema for comparison
    full_xml, full_fk, full_chosen = schema_manager.generate_schema_description(
        db_id=test_db_id,
        selected_schema={},
        use_gold_schema=False
    )
    
    # Compare sizes
    pruned_size = len(pruned_xml)
    full_size = len(full_xml)
    reduction = 100 - (pruned_size / full_size * 100)
    
    print(f"Pruning results for {test_db_id}:")
    print(f"Full schema size: {full_size} characters")
    print(f"Pruned schema size: {pruned_size} characters")
    print(f"Size reduction: {reduction:.2f}%")
    
    # Compare column counts
    full_col_count = sum(len(cols) for cols in full_chosen.values())
    pruned_col_count = sum(len(cols) for cols in pruned_chosen.values())
    
    print(f"\nFull schema column count: {full_col_count}")
    print(f"Pruned schema column count: {pruned_col_count}")
    print(f"Column reduction: {100 - (pruned_col_count / full_col_count * 100):.2f}%")

Pruning results for financial:
Full schema size: 8398 characters
Pruned schema size: 8271 characters
Size reduction: 1.51%

Full schema column count: 55
Pruned schema column count: 54
Column reduction: 1.82%


In [8]:
# Load the gold schema if available
gold_schema_path = os.path.join(data_path, "dev_gold_schema.json")

if os.path.exists(gold_schema_path):
    try:
        with open(gold_schema_path, 'r') as f:
            gold_schema_data = json.load(f)
            
        print(f"Loaded gold schema from {gold_schema_path}")
        print(f"Gold schema has information for {len(gold_schema_data)} databases")
        
        # Test with a database from the gold schema
        test_db_ids = list(gold_schema_data.keys())
        if test_db_ids:
            test_db_id = test_db_ids[0]
            if test_db_id in schema_manager.db2dbjsons:
                gold_tables = gold_schema_data[test_db_id]
                
                print(f"\nGold schema for {test_db_id} includes {len(gold_tables)} tables")
                
                # Generate schema using gold tables
                gold_xml, gold_fk, gold_chosen = schema_manager.generate_schema_description(
                    db_id=test_db_id,
                    selected_schema=gold_tables,
                    use_gold_schema=True
                )
                
                print(f"\nGenerated schema with {len(gold_chosen)} tables using gold schema")
                
                # Display a sample of the XML
                print("\nSample of the gold-based XML schema description:")
                gold_sample = '\n'.join(gold_xml.split('\n')[:20]) + "\n..."
                print(gold_sample)
    except Exception as e:
        print(f"Error loading gold schema: {e}")
else:
    print(f"Gold schema file not found at {gold_schema_path}")

Loaded gold schema from ../data/bird/dev_gold_schema.json
Gold schema has information for 1533 databases
