# DSPy Text-to-SQL Test Notebook

This notebook demonstrates how to use the DSPy implementation of the Text-to-SQL system. It provides a comprehensive test of the SchemaManager, the individual components, and the complete Text-to-SQL pipeline.

## 1. Setup and Environment Configuration

In [1]:
import os
import sys
import logging
import json
import time
from pprint import pprint

# Add parent directory to path for imports
sys.path.append(os.path.abspath('..'))

# Load environment variables (for API keys)
from dotenv import load_dotenv
load_dotenv()

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

print("Python version:", sys.version)
print("Current working directory:", os.getcwd())

Python version: 3.9.21 (main, Dec 11 2024, 16:24:11) 
[GCC 11.2.0]
Current working directory: /home/norman/work/text-to-sql/MAC-SQL/dspy_sql


## 2. Import Dependencies and Configure DSPy

In [2]:
# Import necessary modules
import sqlite3
import dspy
from core.utils import load_json_file

# Import our system components
from schema_manager import SchemaManager
from models import create_gemini_lm, GeminiProLM
from agents import SchemaExtractor, SqlDecomposer, SqlValidator
from text_to_sql import DSPyTextToSQL

# Verify DSPy installation
print("DSPy version:", dspy.__version__)

# Check if API key is set for Gemini
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
if GOOGLE_API_KEY:
    print("✅ GOOGLE_API_KEY is set")
else:
    print("❌ GOOGLE_API_KEY is not set, LLM functionality will be limited")

# Create a language model instance to test configuration
try:
    lm = create_gemini_lm()
    print("✅ Successfully created Gemini LM instance")
except Exception as e:
    print(f"❌ Error creating Gemini LM instance: {e}")

2025-05-10 22:32:39,916 - httpx - INFO - HTTP Request: GET https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json "HTTP/1.1 200 OK"


DSPy version: 2.6.23
✅ GOOGLE_API_KEY is set
✅ Successfully created Gemini LM instance


## 3. SchemaManager Testing

In [3]:
# Define paths for the database and schema files
DATA_PATH = "../data/bird"  # Path to the BIRD dataset
DEV_DB_DIRECTORY = os.path.join(DATA_PATH, "dev_databases")  # Database files are here
TABLES_JSON_PATH = os.path.join(DATA_PATH, "dev_tables.json")  # Path to tables.json

# Verify that the required files and directories exist
print("Checking file paths and database structure...")
if os.path.exists(DATA_PATH):
    print(f"✅ Found data path: {DATA_PATH}")
else:
    print(f"❌ Data path not found: {DATA_PATH}")
    
if os.path.exists(DEV_DB_DIRECTORY):
    print(f"✅ Found database directory: {DEV_DB_DIRECTORY}")
else:
    print(f"❌ Database directory not found: {DEV_DB_DIRECTORY}")
    
if os.path.exists(TABLES_JSON_PATH):
    print(f"✅ Found tables.json at {TABLES_JSON_PATH}")
else:
    print(f"❌ tables.json not found at {TABLES_JSON_PATH}")
    
# Check the number of database entries in tables.json
try:
    tables_data = load_json_file(TABLES_JSON_PATH)
    print(f"Number of database entries in tables.json: {len(tables_data)}")
    # Display some sample database IDs
    sample_db_ids = [entry["db_id"] for entry in tables_data[:5]]
    print(f"Sample database IDs: {sample_db_ids}")
except Exception as e:
    print(f"❌ Error loading tables.json: {e}")

Checking file paths and database structure...
✅ Found data path: ../data/bird
✅ Found database directory: ../data/bird/dev_databases
✅ Found tables.json at ../data/bird/dev_tables.json
load json file from ../data/bird/dev_tables.json
Number of database entries in tables.json: 11
Sample database IDs: ['debit_card_specializing', 'financial', 'formula_1', 'california_schools', 'card_games']


In [4]:
# Initialize the SchemaManager with the correct path to the dev_databases directory
try:
    # NOTE: We use DEV_DB_DIRECTORY as data_path to correctly locate the SQLite files
    schema_manager = SchemaManager(DEV_DB_DIRECTORY, TABLES_JSON_PATH)
    print("✅ SchemaManager initialized successfully")
    
    # Print the number of database schemas loaded
    print(f"Number of database schemas loaded: {len(schema_manager.db2dbjsons)}")
    
    # Get the list of database IDs
    db_ids = list(schema_manager.db2dbjsons.keys())
    print(f"Sample database IDs: {db_ids[:5] if len(db_ids) >= 5 else db_ids}")
    
    # Verify that the cache is initially empty
    print(f"Initial cache size (db2infos): {len(schema_manager.db2infos)}")
except Exception as e:
    print(f"❌ Error initializing SchemaManager: {e}")

load json file from ../data/bird/dev_tables.json
✅ SchemaManager initialized successfully
Number of database schemas loaded: 11
Sample database IDs: ['debit_card_specializing', 'financial', 'formula_1', 'california_schools', 'card_games']
Initial cache size (db2infos): 0


In [5]:
# Test database complexity analysis
if len(db_ids) > 0:
    print("Testing database complexity analysis...")
    
    # Test on first 3 databases or all if less than 3
    test_dbs = db_ids[:3] if len(db_ids) >= 3 else db_ids
    
    for db_id in test_dbs:
        try:
            db_dict = schema_manager.db2dbjsons[db_id]
            avg_column_count = db_dict.get('avg_column_count', 0)
            total_column_count = db_dict.get('total_column_count', 0)
            table_count = db_dict.get('table_count', 0)
            
            need_prune = schema_manager.is_need_prune(db_id)
            
            print(f"\nDatabase: {db_id}")
            print(f"  Tables: {table_count}")
            print(f"  Total columns: {total_column_count}")
            print(f"  Average columns per table: {avg_column_count}")
            print(f"  Complexity: {'Complex, pruning needed' if need_prune else 'Simple, no pruning needed'}")
        except Exception as e:
            print(f"❌ Error analyzing {db_id}: {e}")
else:
    print("No database IDs available for testing")

Testing database complexity analysis...

Database: debit_card_specializing
  Tables: 5
  Total columns: 21
  Average columns per table: 4
  Complexity: Simple, no pruning needed

Database: financial
  Tables: 8
  Total columns: 55
  Average columns per table: 6
  Complexity: Complex, pruning needed

Database: formula_1
  Tables: 13
  Total columns: 94
  Average columns per table: 7
  Complexity: Complex, pruning needed


In [6]:
# Test retrieving schema for a simple database
if len(db_ids) > 0:
    # Find a simple database that doesn't need pruning
    simple_db = None
    for db_id in db_ids:
        if not schema_manager.is_need_prune(db_id):
            simple_db = db_id
            break
    
    if simple_db:
        print(f"Testing schema retrieval with database ID: {simple_db}")
        
        try:
            # Get the database schema
            schema_info = schema_manager.get_db_schema(simple_db)
            
            # Print some statistics about the schema
            chosen_columns = schema_info["chosen_columns"]
            print(f"\nTables in schema: {len(chosen_columns)}")
            
            # Print table names and column counts
            print("\nTable Details:")
            for table, columns in chosen_columns.items():
                print(f"  - {table}: {len(columns)} columns")
                # Print first few columns
                if len(columns) > 0:
                    print(f"    Sample columns: {columns[:3]}...")
            
            # Check for foreign keys
            fk_str = schema_info["fk_str"]
            if fk_str:
                print(f"\nForeign Key Count: {len(fk_str.split(chr(10)))}")
                print("First foreign key:" if '\n' in fk_str else "Foreign key:")
                print(f"  {fk_str.split(chr(10))[0] if chr(10) in fk_str else fk_str}")
            else:
                print("\nNo foreign keys found")
                
            # Print a small portion of the schema string
            schema_str = schema_info["schema_str"]
            preview_length = min(300, len(schema_str))
            print(f"\nSchema String Preview ({preview_length} chars of {len(schema_str)} total):")
            print(schema_str[:preview_length] + "...")
            
            # Verify that cache now contains this database
            print(f"\nCache status - db2infos contains {simple_db}: {simple_db in schema_manager.db2infos}")
            
        except Exception as e:
            print(f"❌ Error retrieving schema: {e}")
    else:
        print("No simple databases found for testing")
else:
    print("No database IDs available for testing")

Testing schema retrieval with database ID: debit_card_specializing

Tables in schema: 5

Table Details:
  - customers: 3 columns
    Sample columns: ['CustomerID', 'Segment', 'Currency']...
  - gasstations: 4 columns
    Sample columns: ['GasStationID', 'ChainID', 'Country']...
  - products: 2 columns
    Sample columns: ['ProductID', 'Description']...
  - transactions_1k: 9 columns
    Sample columns: ['TransactionID', 'Date', 'Time']...
  - yearmonth: 3 columns
    Sample columns: ['CustomerID', 'Date', 'Consumption']...

Foreign Key Count: 1
Foreign key:
  yearmonth.`CustomerID` = customers.`CustomerID`

Schema String Preview (300 chars of 1104 total):
# Table: customers
[
  (CustomerID, CustomerID.),
  (Segment, client segment. Value examples: ['SME', 'LAM', 'KAM'].),
  (Currency, Currency. Value examples: ['CZK', 'EUR'].)
]
# Table: gasstations
[
  (GasStationID, Gas Station ID.),
  (ChainID, Chain ID.),
  (Country, Country. Value examples: ['CZ...

Cache status - db2infos contain

## 4. Initialize the Text-to-SQL System

In [7]:
# Initialize the Text-to-SQL system
try:
    # Make sure to use the dev_databases path for proper database access
    text_to_sql = DSPyTextToSQL(
        data_path=DEV_DB_DIRECTORY,  # Use the correct path to databases
        tables_json_path=TABLES_JSON_PATH,
        dataset_name="bird"
    )
    print("✅ Text-to-SQL system initialized successfully")
    
    # Print information about the system
    print(f"Schema Manager has {len(text_to_sql.schema_manager.db2dbjsons)} database schemas loaded")
    print(f"Using '{text_to_sql.lm.model}' as the language model")
    
    # List available components
    print("\nSystem components:")
    print("- Schema Extractor")
    print("- SQL Decomposer")
    print("- SQL Validator")
    
except Exception as e:
    print(f"❌ Error initializing Text-to-SQL system: {e}")
    print("Check that paths and API keys are correctly configured")

2025-05-10 22:32:40,720 - text_to_sql - INFO - Using tables JSON: ../data/bird/dev_tables.json


load json file from ../data/bird/dev_tables.json
❌ Error initializing Text-to-SQL system: module 'dspy' has no attribute 'Chain'
Check that paths and API keys are correctly configured


## 5. Test with Simple Query

In [8]:
# Test with a simple query on a small database
if 'text_to_sql' in locals():
    # Use a simple database that we identified earlier
    if 'simple_db' in locals() and simple_db:
        db_id = simple_db
    else:
        # Default to the first database in our list
        db_id = db_ids[0] if len(db_ids) > 0 else None
    
    if db_id:
        # Create a simple query based on the database schema
        schema_info = schema_manager.get_db_schema(db_id)
        table_names = list(schema_info["chosen_columns"].keys())
        first_table = table_names[0] if table_names else "unknown"
        
        query = f"How many records are in the {first_table} table?"
        
        print(f"Testing simple query on database: {db_id}")
        print(f"Query: '{query}'")
        
        try:
            # Start timer
            start_time = time.time()
            
            # Process the query
            result = text_to_sql.process_query(db_id, query)
            
            # Calculate processing time
            processing_time = time.time() - start_time
            
            # Display results
            print(f"\nProcessing completed in {processing_time:.2f} seconds")
            print("\nInitial SQL:")
            print(result["initial_sql"])
            
            print("\nFinal SQL:")
            print(result["final_sql"])
            
            if result["explanation"]:
                print("\nRefinement Explanation:")
                print(result["explanation"])
        except Exception as e:
            print(f"❌ Error processing query: {e}")
    else:
        print("No database available for testing")
else:
    print("Text-to-SQL system not initialized")

Text-to-SQL system not initialized


## 6. Test with Complex Query and Schema Pruning

In [9]:
# Test with a complex query on a database that needs pruning
if 'text_to_sql' in locals():
    # Find a complex database that needs pruning
    complex_db = None
    for db_id in db_ids:
        if schema_manager.is_need_prune(db_id):
            complex_db = db_id
            break
    
    if complex_db:
        # Get schema info to create a reasonable query
        schema_info = schema_manager.get_db_schema(complex_db)
        table_names = list(schema_info["chosen_columns"].keys())
        
        # Create a query that might require joins if there are foreign keys
        if len(table_names) >= 2 and schema_info["fk_str"]:
            table1 = table_names[0]
            table2 = table_names[1]
            query = f"List information from {table1} joined with {table2}"
        else:
            # Simple query on a complex database
            first_table = table_names[0] if table_names else "unknown"
            query = f"Show me a summary of data from the {first_table} table"
        
        print(f"Testing complex query on database: {complex_db}")
        print(f"Query: '{query}'")
        print(f"Database complexity: Complex, pruning needed")
        
        try:
            # Start timer
            start_time = time.time()
            
            # Process the query
            result = text_to_sql.process_query(complex_db, query)
            
            # Calculate processing time
            processing_time = time.time() - start_time
            
            # Display results
            print(f"\nProcessing completed in {processing_time:.2f} seconds")
            
            # Show extracted schema if available
            if result["extracted_schema"]:
                print("\nExtracted Schema (pruned tables/columns):")
                for table, selection in result["extracted_schema"].items():
                    if isinstance(selection, list):
                        print(f"  - {table}: {selection}")
                    else:
                        print(f"  - {table}: {selection}")
            
            # Show decomposed sub-questions
            if result["sub_questions"]:
                print("\nSub-questions:")
                for i, q in enumerate(result["sub_questions"]):
                    print(f"  {i+1}. {q}")
            
            print("\nInitial SQL:")
            print(result["initial_sql"])
            
            print("\nFinal SQL:")
            print(result["final_sql"])
            
            if result["explanation"]:
                print("\nRefinement Explanation:")
                print(result["explanation"])
        except Exception as e:
            print(f"❌ Error processing complex query: {e}")
    else:
        print("No complex database found for testing")
else:
    print("Text-to-SQL system not initialized")

Text-to-SQL system not initialized


## 7. Test Individual Components

In [10]:
# Test each component individually
if 'text_to_sql' in locals():
    # Use a sample database
    sample_db_id = db_ids[0] if db_ids else None
    
    if sample_db_id:
        # Create a simple query
        query = "Show me all customer information"
        
        print(f"Testing individual components with database: {sample_db_id}")
        print(f"Query: '{query}'")
        
        try:
            # 1. Test Schema Extractor
            print("\n=== Schema Extractor Component Test ===")
            
            # Get basic schema for extraction
            basic_schema = schema_manager.get_db_schema(sample_db_id)
            
            # Start timer
            start_time = time.time()
            
            # Run schema extraction
            extraction_result = text_to_sql.schema_extractor(
                db_id=sample_db_id,
                query=query,
                db_schema=basic_schema["schema_str"],
                foreign_keys=basic_schema["fk_str"],
                evidence=""
            )
            
            # Calculate processing time
            extraction_time = time.time() - start_time
            
            print(f"Extraction completed in {extraction_time:.2f} seconds")
            print("Extracted schema:")
            print(extraction_result.extracted_schema)
            
            # 2. Test SQL Decomposer
            print("\n=== SQL Decomposer Component Test ===")
            
            # Get schema with extracted tables (or full schema if extraction was empty)
            extracted_schema = extraction_result.extracted_schema
            schema_info = schema_manager.get_db_schema(
                sample_db_id, 
                extracted_schema=extracted_schema
            )
            
            # Start timer
            start_time = time.time()
            
            # Run SQL decomposition
            decomposition_result = text_to_sql.sql_decomposer(
                query=query,
                schema_info=schema_info["schema_str"],
                foreign_keys=schema_info["fk_str"],
                evidence=""
            )
            
            # Calculate processing time
            decomposition_time = time.time() - start_time
            
            print(f"Decomposition completed in {decomposition_time:.2f} seconds")
            
            print("Sub-questions:")
            for i, q in enumerate(decomposition_result.sub_questions):
                print(f"  {i+1}. {q}")
                
            print("\nGenerated SQL:")
            print(decomposition_result.sql)
            
            # 3. Test SQL Validator
            print("\n=== SQL Validator Component Test ===")
            
            # Extract SQL from decomposition result
            sql = decomposition_result.sql
            
            # Start timer
            start_time = time.time()
            
            # Run SQL validation
            validation_result = text_to_sql.sql_validator(
                query=query,
                sql=sql,
                schema_info=schema_info["schema_str"],
                foreign_keys=schema_info["fk_str"],
                db_id=sample_db_id,
                error_info="",
                evidence=""
            )
            
            # Calculate processing time
            validation_time = time.time() - start_time
            
            print(f"Validation completed in {validation_time:.2f} seconds")
            
            print("Refined SQL:")
            print(validation_result.refined_sql)
            
            if hasattr(validation_result, 'explanation') and validation_result.explanation:
                print("\nRefinement Explanation:")
                print(validation_result.explanation)
            
            # Compare times
            print("\n=== Performance Comparison ===")
            print(f"Schema Extraction:   {extraction_time:.2f} seconds")
            print(f"SQL Decomposition:   {decomposition_time:.2f} seconds")
            print(f"SQL Validation:      {validation_time:.2f} seconds")
            print(f"Total:               {extraction_time + decomposition_time + validation_time:.2f} seconds")
        except Exception as e:
            print(f"❌ Error testing components: {e}")
    else:
        print("No database available for testing components")
else:
    print("Text-to-SQL system not initialized")

Text-to-SQL system not initialized


## 8. Test Multiple Databases with Same Query

In [11]:
# Test the system with multiple databases using the same query
if 'text_to_sql' in locals() and len(db_ids) >= 3:
    # Select 3 databases to test
    test_dbs = db_ids[:3]
    
    # Use a generic query that should work on any database
    query = "List the top 5 records from the main table"
    
    print(f"Testing the same query across multiple databases:")
    print(f"Query: '{query}'")
    
    results_summary = []
    
    for db_id in test_dbs:
        print(f"\n===== Testing with database: {db_id} =====")
        
        try:
            # Process the query
            result = text_to_sql.process_query(db_id, query)
            
            # Store results summary
            results_summary.append({
                "db_id": db_id,
                "status": "success",
                "sql": result["final_sql"]
            })
            
            # Show final SQL
            print(f"Final SQL: {result['final_sql']}")
            
        except Exception as e:
            print(f"❌ Error processing query for {db_id}: {e}")
            results_summary.append({
                "db_id": db_id,
                "status": "error",
                "error": str(e)
            })
    
    # Print summary of results
    print("\n===== Results Summary =====")
    for result in results_summary:
        status_icon = "✅" if result["status"] == "success" else "❌"
        print(f"{status_icon} {result['db_id']}: {result['status']}")
        
else:
    print("Text-to-SQL system not initialized or insufficient databases for testing")

Text-to-SQL system not initialized or insufficient databases for testing


## 9. Conclusion and Next Steps

"""
This notebook has demonstrated the DSPy implementation of the Text-to-SQL system. We've tested:

1. **SchemaManager** - Loads database schemas from tables.json and provides efficient access to them.
   - Correctly identifies complex schemas that need pruning
   - Provides schema information with proper formatting for the language model
   - Caches database information for improved performance

2. **Complete Pipeline** - The system successfully processes queries through all stages:
   - Schema extraction for complex databases
   - Query decomposition into sub-questions
   - SQL generation and validation/refinement

3. **Individual Components** - Each component performs its specific task:
   - SchemaExtractor: Selects relevant tables and columns for complex schemas
   - SqlDecomposer: Breaks down complex questions and generates SQL
   - SqlValidator: Refines the generated SQL for correctness

4. **Adaptability** - The system adapts to different database schemas and query complexities:
   - Works with both simple and complex databases
   - Handles different query types and complexity levels
   - Performs schema pruning when needed

Next steps for improvement:

1. **Optimization** - Fine-tune the DSPy modules with more examples
2. **Error Handling** - Add more robust error handling for edge cases
3. **Performance** - Consider optimizations for large schemas and complex queries
4. **Evaluation** - Add more comprehensive evaluation metrics
5. **Frontend Integration** - Connect with a user interface for easier interaction
"""