# Schema Extractor Test Notebook

This notebook focuses specifically on testing the SchemaExtractor component of the Text-to-SQL system, which is responsible for extracting relevant tables and columns from database schemas based on natural language queries.

## 1. Setup and Environment Configuration

In [1]:
import os
import sys
import logging
import json
import time
from pprint import pprint

# Add parent directory to path for imports
sys.path.append(os.path.abspath('..'))

# Load environment variables (for API keys)
from dotenv import load_dotenv
load_dotenv()

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

print("Python version:", sys.version)
print("Current working directory:", os.getcwd())

Python version: 3.9.21 (main, Dec 11 2024, 16:24:11) 
[GCC 11.2.0]
Current working directory: /home/norman/work/text-to-sql/MAC-SQL/dspy_sql


## 2. Import Dependencies

In [2]:
# Import necessary modules
import sqlite3
import dspy
from core.utils import load_json_file, parse_json

# Import only the schema-related components
from schema_manager import SchemaManager
from models import create_gemini_lm, GeminiProLM
from agents import SchemaExtractor

# Define paths for the database and schema files
DATA_PATH = "../data/bird"  # Path to the BIRD dataset
DEV_DB_DIRECTORY = os.path.join(DATA_PATH, "dev_databases")  # Database files are here
TABLES_JSON_PATH = os.path.join(DATA_PATH, "dev_tables.json")  # Path to tables.json

dspy.settings.configure(lm=create_gemini_lm())

# NOTE: We use DEV_DB_DIRECTORY as data_path to correctly locate the SQLite files
schema_manager = SchemaManager(DEV_DB_DIRECTORY, TABLES_JSON_PATH)
print("✅ SchemaManager initialized successfully")

# Get the list of database IDs
db_ids = list(schema_manager.db2dbjsons.keys())
print(f"Available database IDs: {db_ids[:5] if len(db_ids) >= 5 else db_ids}")

2025-05-11 10:42:34,738 - httpx - INFO - HTTP Request: GET https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json "HTTP/1.1 200 OK"


load json file from ../data/bird/dev_tables.json
✅ SchemaManager initialized successfully
Available database IDs: ['debit_card_specializing', 'financial', 'formula_1', 'california_schools', 'card_games']


In [3]:
# Schema Extractor Signature
schema_extractor_signature = dspy.Signature(
    "db_id, query, db_schema, foreign_keys, evidence -> extracted_schema"
)

class SchemaExtractor(dspy.Module):
    """DSPy module to extract relevant schema from database"""
    
    def __init__(self, lm=None):
        super().__init__()
        self.lm = lm
        # Define a chain of thought predictor for schema extraction
        self.predictor = dspy.ChainOfThought(
            schema_extractor_signature,
            lm=lm
        )
    
    def forward(self, db_id, query, db_schema, foreign_keys, evidence=""):
        """Extract the relevant schema based on the query"""
        result = self.predictor(
            db_id=db_id,
            query=query,
            db_schema=db_schema,
            foreign_keys=foreign_keys, 
            evidence=evidence
        )
        
        # Try to parse JSON from the result
        try:
            extracted_schema = parse_json(result.extracted_schema)
            return dspy.Prediction(extracted_schema=extracted_schema)
        except Exception as e:
            logger.error(f"Error parsing extracted schema: {e}")
            return dspy.Prediction(extracted_schema={})

Checking file paths and database structure...
✅ Found database directory: ../data/bird/dev_databases
✅ Found tables.json at ../data/bird/dev_tables.json
load json file from ../data/bird/dev_tables.json
✅ SchemaManager initialized successfully
Available database IDs: ['debit_card_specializing', 'financial', 'formula_1', 'california_schools', 'card_games']


## 4. Initialize SchemaExtractor

In [4]:
# Initialize the SchemaExtractor component
try:
    schema_extractor = SchemaExtractor(lm=lm)
    print("✅ SchemaExtractor initialized successfully")
except Exception as e:
    print(f"❌ Error initializing SchemaExtractor: {e}")

✅ SchemaExtractor initialized successfully


## 5. Test SchemaExtractor on Simple Database

In [5]:
# Test SchemaExtractor on a simple database
if 'schema_extractor' in locals() and 'schema_manager' in locals() and len(db_ids) > 0:
    # Find a simple database that doesn't need pruning
    simple_db = None
    for db_id in db_ids:
        if not schema_manager.is_need_prune(db_id):
            simple_db = db_id
            break
    
    if simple_db:
        print(f"Testing SchemaExtractor with simple database ID: {simple_db}")
        
        # Get the database schema
        schema_info = schema_manager.get_db_schema(simple_db)
        
        # Create a simple query based on the database schema
        table_names = list(schema_info["chosen_columns"].keys())
        first_table = table_names[0] if table_names else "unknown"
        query = f"List all data from the {first_table} table"
        
        print(f"Query: '{query}'")

        # Run schema extraction
        extraction_result = schema_extractor(
            db_id=simple_db,
            query=query,
            db_schema=schema_info["schema_str"],
            foreign_keys=schema_info["fk_str"],
            evidence=""
        )
        
        # Display extracted schema
        print("\nExtracted Schema:")
        print(extraction_result.extracted_schema)
        
        # Verify if the extraction includes the table mentioned in the query
        if first_table in extraction_result.extracted_schema:
            print(f"\n✅ Successfully extracted {first_table} table")
        else:
            print(f"\n❌ Failed to extract {first_table} table")

    else:
        print("No simple database found for testing")
else:
    print("SchemaExtractor or SchemaManager not properly initialized")

Testing SchemaExtractor with simple database ID: debit_card_specializing


2025-05-10 23:13:15,606 - models - ERROR - Error in GeminiProLM.__call__: generate_content() got an unexpected keyword argument 'generation_config'
2025-05-10 23:13:15,606 - models - ERROR - Error in GeminiProLM.__call__: generate_content() got an unexpected keyword argument 'generation_config'


Query: 'List all data from the customers table'


ValueError: Expected a JSON object but parsed a <class 'str'>

## 6. Test SchemaExtractor on Complex Database

In [None]:
# Test SchemaExtractor on a complex database
if 'schema_extractor' in locals() and 'schema_manager' in locals() and len(db_ids) > 0:
    # Find a complex database that needs pruning
    complex_db = None
    for db_id in db_ids:
        if schema_manager.is_need_prune(db_id):
            complex_db = db_id
            break
    
    if complex_db:
        print(f"Testing SchemaExtractor with complex database ID: {complex_db}")
        
        # Get the database schema
        schema_info = schema_manager.get_db_schema(complex_db)
        
        # Create a query that might require specific tables
        table_names = list(schema_info["chosen_columns"].keys())
        
        if len(table_names) >= 2:
            table1 = table_names[0]
            table2 = table_names[1]
            query = f"Show data from {table1} related to {table2}"
        else:
            first_table = table_names[0] if table_names else "unknown"
            query = f"Show me information from the {first_table} table"
        
        print(f"Query: '{query}'")
        print(f"Database has {len(table_names)} tables and is complex enough to need pruning")
        
        try:
            # Start timer
            start_time = time.time()
            
            # Run schema extraction
            extraction_result = schema_extractor(
                db_id=complex_db,
                query=query,
                db_schema=schema_info["schema_str"],
                foreign_keys=schema_info["fk_str"],
                evidence=""
            )
            
            # Calculate processing time
            extraction_time = time.time() - start_time
            
            # Display extraction results
            print(f"\nExtraction completed in {extraction_time:.2f} seconds")
            print("\nExtracted Schema:")
            print(extraction_result.extracted_schema)
            
            # Analyze extraction results
            extracted_tables = list(extraction_result.extracted_schema.keys())
            extraction_ratio = len(extracted_tables) / len(table_names) * 100
            
            print(f"\nExtracted {len(extracted_tables)} out of {len(table_names)} tables ({extraction_ratio:.1f}%)")
            
            # Check if tables mentioned in the query were extracted
            if len(table_names) >= 2:
                if table1 in extracted_tables and table2 in extracted_tables:
                    print(f"✅ Successfully extracted both {table1} and {table2} tables mentioned in the query")
                elif table1 in extracted_tables:
                    print(f"⚠️ Extracted {table1} but missed {table2}")
                elif table2 in extracted_tables:
                    print(f"⚠️ Extracted {table2} but missed {table1}")
                else:
                    print(f"❌ Failed to extract either {table1} or {table2}")
            else:
                first_table = table_names[0] if table_names else "unknown"
                if first_table in extracted_tables:
                    print(f"✅ Successfully extracted {first_table} mentioned in the query")
                else:
                    print(f"❌ Failed to extract {first_table}")
        except Exception as e:
            print(f"\n❌ Error running SchemaExtractor: {e}")
    else:
        print("No complex database found for testing")
else:
    print("SchemaExtractor or SchemaManager not properly initialized")

## 7. Test Different Query Types

In [None]:
# Test different query types on a sample database
if 'schema_extractor' in locals() and 'schema_manager' in locals() and len(db_ids) > 0:
    # Use the first database for testing different queries
    sample_db = db_ids[0]
    
    # Get the database schema
    schema_info = schema_manager.get_db_schema(sample_db)
    table_names = list(schema_info["chosen_columns"].keys())
    
    # Skip if there are no tables
    if not table_names:
        print(f"No tables found in database {sample_db}")
    else:
        print(f"Testing different query types on database: {sample_db}")
        print(f"Database has tables: {table_names}")
        
        # Create different types of queries
        first_table = table_names[0]
        
        queries = [
            f"How many records are in the {first_table} table?",  # Count query
            f"What is the average value in {first_table}?",  # Aggregation query
            f"Find the top 5 records in {first_table}",  # Limit query
            "Show all tables in the database",  # Schema exploration
            "Generate a complex report spanning multiple tables"  # Complex query
        ]
        
        for i, query in enumerate(queries, 1):
            print(f"\n===== Query {i}: '{query}' =====")
            
            try:
                # Run schema extraction
                extraction_result = schema_extractor(
                    db_id=sample_db,
                    query=query,
                    db_schema=schema_info["schema_str"],
                    foreign_keys=schema_info["fk_str"],
                    evidence=""
                )
                
                # Display extracted schema
                print("Extracted Schema:")
                print(extraction_result.extracted_schema)
                
                # Count extracted tables
                extracted_tables = list(extraction_result.extracted_schema.keys())
                print(f"Extracted {len(extracted_tables)} tables")
                
                # For the last query (complex), we expect multiple tables
                if i == 5 and len(extracted_tables) > 1:
                    print("✅ Successfully extracted multiple tables for complex query")
                elif i < 4 and first_table in extracted_tables:
                    print(f"✅ Successfully extracted {first_table} table mentioned in the query")
                elif i == 4 and len(extracted_tables) > 0:
                    print("✅ Successfully extracted some tables for schema exploration query")
                else:
                    print("⚠️ Extraction may not be optimal for this query")
            except Exception as e:
                print(f"❌ Error running SchemaExtractor: {e}")
else:
    print("SchemaExtractor or SchemaManager not properly initialized")

## 8. Conclusion

This notebook has tested the SchemaExtractor component, which is responsible for extracting relevant tables and columns from database schemas based on natural language queries. The tests covered:

1. **Simple Database Testing** - How the extractor performs on small, straightforward schemas
2. **Complex Database Testing** - How the extractor performs on large, complex schemas that need pruning
3. **Query Type Testing** - How different types of queries affect extraction results

The SchemaExtractor is a critical component of the Text-to-SQL system, as it reduces the complexity of large database schemas for the subsequent SQL generation steps, which improves both performance and accuracy.