# SQL Executor Test Notebook

This notebook tests the functionality of the `SQLExecutor` class in `dispatcher/sql_executor.py`.

## 1. Setup and Imports

In [1]:
import os
import sys
import time
import sqlite3
from typing import Dict, Any

# Add the project root to the path to import the code
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

# Import the SQLExecutor class
from dispatcher.sql_executor import SQLExecutor

# Import other required modules
from core.utils import load_json_file

## 2. Configuration and Environment Check

In [2]:
# Define paths for testing
# We'll use BIRD dataset for testing as it seems to be a common choice in the codebase
DATA_PATH = "../data/bird"
DATASET_NAME = "bird"
DEV_DB_DIRECTORY = os.path.join(DATA_PATH, "dev_databases")
TABLES_JSON_PATH = os.path.join(DATA_PATH, "dev_tables.json")

# Check if the required files and directories exist
print("Checking file paths and database structure...")
if os.path.exists(DATA_PATH):
    print(f"✅ Found data path: {DATA_PATH}")
else:
    print(f"❌ Data path not found: {DATA_PATH}")
    
if os.path.exists(DEV_DB_DIRECTORY):
    print(f"✅ Found database directory: {DEV_DB_DIRECTORY}")
else:
    print(f"❌ Database directory not found: {DEV_DB_DIRECTORY}")
    
if os.path.exists(TABLES_JSON_PATH):
    print(f"✅ Found tables.json at {TABLES_JSON_PATH}")
else:
    print(f"❌ tables.json not found at {TABLES_JSON_PATH}")
    
# Load tables data to get db_ids
tables_data = load_json_file(TABLES_JSON_PATH)
db_ids = [item["db_id"] for item in tables_data]
print(f"Number of database entries in tables.json: {len(db_ids)}")

# Display some sample database IDs
sample_db_ids = db_ids[:5]
print(f"Sample database IDs: {sample_db_ids}")

Checking file paths and database structure...
✅ Found data path: ../data/bird
✅ Found database directory: ../data/bird/dev_databases
✅ Found tables.json at ../data/bird/dev_tables.json
load json file from ../data/bird/dev_tables.json
Number of database entries in tables.json: 11
Sample database IDs: ['debit_card_specializing', 'financial', 'formula_1', 'california_schools', 'card_games']


## 3. Initialize the SQLExecutor

In [3]:
try:
    # Initialize the SQL Executor
    sql_executor = SQLExecutor(DEV_DB_DIRECTORY, DATASET_NAME)
    print("✅ SQL Executor initialized successfully")
    
    # Print the configuration
    print(f"Data path: {sql_executor.data_path}")
    print(f"Dataset name: {sql_executor.dataset_name}")
except Exception as e:
    print(f"❌ Error initializing SQL Executor: {e}")

✅ SQL Executor initialized successfully
Data path: ../data/bird/dev_databases
Dataset name: bird


## 4. Test Basic SQL Execution

Test running simple SQL queries against the databases.

In [4]:
def test_basic_query(db_id: str, sql_query: str):
    """Test a basic SQL query and display the results."""
    print(f"Testing query on {db_id}:\n{sql_query}")
    
    try:
        result = sql_executor.safe_execute(sql_query, db_id)
        
        # Print result summary
        print(f"\nExecution success: {result.get('success', False)}")
        
        if result.get('success', False):
            print(f"Result row count: {result.get('row_count', 0)}")
            print(f"Columns: {result.get('column_names', [])}")
            print("\nSample data (up to 5 rows):")
            for row in result.get('data', []):
                print(f"  {row}")
        else:
            print(f"Error: {result.get('sqlite_error', 'Unknown error')}")
            print(f"Exception class: {result.get('exception_class', 'Unknown')}")
            
        # Check validity
        is_valid, reason = sql_executor.is_valid_result(result)
        print(f"\nResult validity: {is_valid}")
        if not is_valid:
            print(f"Invalid reason: {reason}")
        
        return result
    except Exception as e:
        print(f"❌ Error executing test: {e}")
        return None

In [5]:
# Select a sample database to test
if sample_db_ids:
    test_db_id = sample_db_ids[0]
    print(f"Testing with database: {test_db_id}")
    
    # Test a simple SELECT query
    # First, we need to know the table names in this database
    db_entry = next((item for item in tables_data if item["db_id"] == test_db_id), None)
    
    if db_entry and len(db_entry.get("table_names", [])) > 0:
        tables = db_entry["table_names"]
        first_table = tables[0]
        
        # Run a simple query to get all rows from the first table (limit 5)
        simple_query = f"SELECT * FROM {first_table} LIMIT 5"
        test_basic_query(test_db_id, simple_query)
    else:
        print(f"❌ No tables found for database {test_db_id}")
else:
    print("❌ No database IDs available for testing")

Testing with database: debit_card_specializing
Testing query on debit_card_specializing:
SELECT * FROM customers LIMIT 5

Execution success: True
Result row count: 5
Columns: ['CustomerID', 'Segment', 'Currency']

Sample data (up to 5 rows):
  (3, 'SME', 'EUR')
  (5, 'LAM', 'EUR')
  (6, 'SME', 'EUR')
  (7, 'LAM', 'EUR')
  (9, 'SME', 'EUR')

Result validity: True


## 5. Test Complex Queries

Test more complex SQL queries including joins and aggregations.

In [6]:
# Continue with the same database
if sample_db_ids:
    test_db_id = sample_db_ids[0]
    db_entry = next((item for item in tables_data if item["db_id"] == test_db_id), None)
    
    if db_entry and len(db_entry.get("table_names", [])) >= 2:
        tables = db_entry["table_names"]
        
        # Let's try a more complex query with a JOIN between the first two tables
        # We need to find columns to join on
        table1 = tables[0]
        table2 = tables[1]
        
        # For a real JOIN, we would need to know the foreign key relationships
        # For testing, we'll use a CROSS JOIN with LIMIT to avoid large result sets
        complex_query = f"SELECT {table1}.*, {table2}.* FROM {table1} CROSS JOIN {table2} LIMIT 5"
        print(f"\nTesting complex query (CROSS JOIN):")
        test_basic_query(test_db_id, complex_query)
        
        # Also test an aggregation query
        agg_query = f"SELECT COUNT(*) as count FROM {table1}"
        print(f"\nTesting aggregation query (COUNT):")
        test_basic_query(test_db_id, agg_query)
    else:
        print(f"❌ Not enough tables found for complex query testing in database {test_db_id}")


Testing complex query (CROSS JOIN):
Testing query on debit_card_specializing:
SELECT customers.*, gasstations.* FROM customers CROSS JOIN gasstations LIMIT 5

Execution success: True
Result row count: 5
Columns: ['CustomerID', 'Segment', 'Currency', 'GasStationID', 'ChainID', 'Country', 'Segment']

Sample data (up to 5 rows):
  (3, 'SME', 'EUR', 44, 13, 'CZE', 'Value for money')
  (3, 'SME', 'EUR', 45, 6, 'CZE', 'Premium')
  (3, 'SME', 'EUR', 46, 23, 'CZE', 'Other')
  (3, 'SME', 'EUR', 47, 33, 'CZE', 'Premium')
  (3, 'SME', 'EUR', 48, 4, 'CZE', 'Premium')

Result validity: True

Testing aggregation query (COUNT):
Testing query on debit_card_specializing:
SELECT COUNT(*) as count FROM customers

Execution success: True
Result row count: 1
Columns: ['count']

Sample data (up to 5 rows):
  (32461,)

Result validity: True


## 6. Test Error Handling

Test how the executor handles various error conditions.

In [7]:
# Test with invalid SQL syntax
if sample_db_ids:
    test_db_id = sample_db_ids[0]
    
    # Syntax error
    invalid_syntax_query = "SELECT * FORM invalid_table"
    print("\nTesting invalid SQL syntax:")
    syntax_result = test_basic_query(test_db_id, invalid_syntax_query)
    
    # Non-existent table
    non_existent_table_query = "SELECT * FROM non_existent_table"
    print("\nTesting query with non-existent table:")
    table_result = test_basic_query(test_db_id, non_existent_table_query)
    
    # Non-existent column
    if db_entry and len(db_entry.get("table_names", [])) > 0:
        table = db_entry["table_names"][0]
        non_existent_column_query = f"SELECT non_existent_column FROM {table}"
        print("\nTesting query with non-existent column:")
        column_result = test_basic_query(test_db_id, non_existent_column_query)


Testing invalid SQL syntax:
Testing query on debit_card_specializing:
SELECT * FORM invalid_table

Execution success: False
Error: near "FORM": syntax error
Exception class: <class 'sqlite3.OperationalError'>

Result validity: False
Invalid reason: near "FORM": syntax error

Testing query with non-existent table:
Testing query on debit_card_specializing:
SELECT * FROM non_existent_table

Execution success: False
Error: no such table: non_existent_table
Exception class: <class 'sqlite3.OperationalError'>

Result validity: False
Invalid reason: no such table: non_existent_table

Testing query with non-existent column:
Testing query on debit_card_specializing:
SELECT non_existent_column FROM customers

Execution success: False
Error: no such column: non_existent_column
Exception class: <class 'sqlite3.OperationalError'>

Result validity: False
Invalid reason: no such column: non_existent_column


## 7. Test Timeout Handling

Test how the executor handles queries that would take too long to execute.

In [8]:
# Simulating a very slow/complex query that would trigger a timeout
# Note: This is challenging to reliably test without a truly slow query
if sample_db_ids:
    test_db_id = sample_db_ids[0]
    
    # A recursive query that should time out
    # Note: This might not work on all SQLite versions, depending on recursive query support
    recursive_query = """
    WITH RECURSIVE slow_query(n) AS (
        SELECT 1
        UNION ALL
        SELECT n+1 FROM slow_query WHERE n < 5000000
    )
    SELECT COUNT(*) FROM slow_query;
    """
    
    print("\nTesting timeout handling with a recursive query:")
    print("Note: This test is expected to time out after 120 seconds")
    start_time = time.time()
    timeout_result = sql_executor.safe_execute(recursive_query, test_db_id)
    elapsed_time = time.time() - start_time
    
    print(f"\nQuery completed in {elapsed_time:.2f} seconds")
    print(f"Timeout occurred: {timeout_result.get('timeout', False)}")
    print(f"Success: {timeout_result.get('success', False)}")
    
    if not timeout_result.get('success', False):
        print(f"Error: {timeout_result.get('sqlite_error', 'Unknown error')}")
        print(f"Exception class: {timeout_result.get('exception_class', 'Unknown')}")


Testing timeout handling with a recursive query:
Note: This test is expected to time out after 120 seconds

Query completed in 0.38 seconds
Timeout occurred: False
Success: True


## 8. Test Result Validation

Test the validation of SQL execution results.

In [9]:
# Test validation with various result scenarios
if sample_db_ids:
    test_db_id = sample_db_ids[0]
    db_entry = next((item for item in tables_data if item["db_id"] == test_db_id), None)
    
    if db_entry and len(db_entry.get("table_names", [])) > 0:
        table = db_entry["table_names"][0]
        
        # 1. Test with a query that returns data
        print("\n1. Testing validation with a query that returns data:")
        normal_query = f"SELECT * FROM {table} LIMIT 5"
        normal_result = sql_executor.safe_execute(normal_query, test_db_id)
        is_valid, reason = sql_executor.is_valid_result(normal_result)
        print(f"Valid result: {is_valid}, Reason: {reason}")
        
        # 2. Test with a query that returns empty result
        print("\n2. Testing validation with a query that returns empty result:")
        empty_query = f"SELECT * FROM {table} WHERE 1=0"
        empty_result = sql_executor.safe_execute(empty_query, test_db_id)
        is_valid, reason = sql_executor.is_valid_result(empty_result)
        print(f"Valid result: {is_valid}, Reason: {reason}")
        
        # 3. Test with a query that might return NULL values
        print("\n3. Testing validation with a query that might return NULL values:")
        null_query = f"SELECT * FROM {table} WHERE 1 ORDER BY 1 DESC LIMIT 5"
        null_result = sql_executor.safe_execute(null_query, test_db_id)
        is_valid, reason = sql_executor.is_valid_result(null_result)
        print(f"Valid result: {is_valid}, Reason: {reason}")
        
        # 4. Test with an error result
        print("\n4. Testing validation with an error result:")
        error_query = f"SELECT * FROM non_existent_table"
        error_result = sql_executor.safe_execute(error_query, test_db_id)
        is_valid, reason = sql_executor.is_valid_result(error_result)
        print(f"Valid result: {is_valid}, Reason: {reason}")


1. Testing validation with a query that returns data:
Valid result: True, Reason: 

2. Testing validation with a query that returns empty result:
Valid result: False, Reason: No data returned from query

3. Testing validation with a query that might return NULL values:
Valid result: True, Reason: 

4. Testing validation with an error result:
Valid result: False, Reason: no such table: non_existent_table


## 9. Test with Different Datasets

Test handling of different datasets (if available).

In [10]:
# Test with Spider dataset if available
SPIDER_DATA_PATH = "../data/spider"
SPIDER_TABLES_JSON_PATH = os.path.join(SPIDER_DATA_PATH, "tables.json")

if os.path.exists(SPIDER_DATA_PATH) and os.path.exists(SPIDER_TABLES_JSON_PATH):
    print("Testing with Spider dataset")
    
    # Initialize another SQL executor for Spider
    spider_executor = SQLExecutor(os.path.join(SPIDER_DATA_PATH, "database"), "spider")
    
    # Load Spider database IDs
    spider_tables_data = load_json_file(SPIDER_TABLES_JSON_PATH)
    spider_db_ids = list(set(item["db_id"] for item in spider_tables_data))
    
    if spider_db_ids:
        spider_test_db_id = spider_db_ids[0]
        print(f"Testing with Spider database: {spider_test_db_id}")
        
        # Get a table from this database
        spider_db_entry = next((item for item in spider_tables_data if item["db_id"] == spider_test_db_id), None)
        
        if spider_db_entry and len(spider_db_entry.get("table_names", [])) > 0:
            spider_table = spider_db_entry["table_names"][0]
            
            # Run a simple query
            spider_query = f"SELECT * FROM {spider_table} LIMIT 5"
            print(f"\nExecuting query: {spider_query}")
            
            try:
                spider_result = spider_executor.safe_execute(spider_query, spider_test_db_id)
                print(f"\nExecution success: {spider_result.get('success', False)}")
                
                if spider_result.get('success', False):
                    print(f"Result row count: {spider_result.get('row_count', 0)}")
                    print(f"Columns: {spider_result.get('column_names', [])}")
                    print("\nSample data (up to 5 rows):")
                    for row in spider_result.get('data', []):
                        print(f"  {row}")
                else:
                    print(f"Error: {spider_result.get('sqlite_error', 'Unknown error')}")
                
                # Test validation differences with Spider
                is_valid, reason = spider_executor.is_valid_result(spider_result)
                print(f"\nSpider result validity: {is_valid}, Reason: {reason}")
                print("Note: Spider validation should always return True for non-error results")
                
            except Exception as e:
                print(f"❌ Error executing Spider test: {e}")
        else:
            print(f"❌ No tables found for Spider database {spider_test_db_id}")
    else:
        print("❌ No Spider database IDs found")
else:
    print("❌ Spider dataset not available")

Testing with Spider dataset
load json file from ../data/spider/tables.json
Testing with Spider database: candidate_poll

Executing query: SELECT * FROM candidate LIMIT 5

Execution success: True
Result row count: 5
Columns: ['Candidate_ID', 'People_ID', 'Poll_Source', 'Date', 'Support_rate', 'Consider_rate', 'Oppose_rate', 'Unsure_rate']

Sample data (up to 5 rows):
  (1, 1, 'WNBC/Marist Poll', 'Feb 12–15, 2007', 0.25, 0.3, 0.43, 0.2)
  (2, 3, 'WNBC/Marist Poll', 'Feb 12–15, 2007', 0.17, 0.42, 0.32, 0.9)
  (3, 4, 'FOX News/Opinion Dynamics Poll', 'Feb 13–14, 2007', 0.18, 0.34, 0.44, 0.3)
  (4, 6, 'Newsweek Poll', 'Nov 9–10, 2006', 0.33, 0.2, 0.45, 0.2)
  (5, 7, 'Newsweek Poll', 'Nov 9–10, 2006', 0.24, 0.3, 0.32, 0.4)

Spider result validity: True, Reason: 
Note: Spider validation should always return True for non-error results


## 10. Performance Testing

Test the performance of SQL execution for various query types.

In [11]:
def measure_query_performance(db_id: str, query: str, iterations: int = 5):
    """Measure the performance of a query by running it multiple times."""
    print(f"\nMeasuring performance for query:\n{query}")
    
    times = []
    for i in range(iterations):
        start_time = time.time()
        result = sql_executor.safe_execute(query, db_id)
        elapsed_time = time.time() - start_time
        times.append(elapsed_time)
        
        success = result.get('success', False)
        if not success:
            print(f"❌ Iteration {i+1} failed: {result.get('sqlite_error', 'Unknown error')}")
        else:
            print(f"✅ Iteration {i+1} completed in {elapsed_time:.4f} seconds")
    
    avg_time = sum(times) / len(times)
    min_time = min(times)
    max_time = max(times)
    
    print(f"\nPerformance summary:")
    print(f"- Average time: {avg_time:.4f} seconds")
    print(f"- Minimum time: {min_time:.4f} seconds")
    print(f"- Maximum time: {max_time:.4f} seconds")
    
    return avg_time

In [12]:
# Run performance tests on different query types
if sample_db_ids:
    test_db_id = sample_db_ids[0]
    db_entry = next((item for item in tables_data if item["db_id"] == test_db_id), None)
    
    if db_entry and len(db_entry.get("table_names", [])) > 0:
        tables = db_entry["table_names"]
        
        print("\nPerformance Testing:")
        print(f"Database: {test_db_id}")
        print(f"Tables: {tables}")
        
        if len(tables) > 0:
            table = tables[0]
            
            # Test 1: Simple SELECT
            simple_select = f"SELECT * FROM {table} LIMIT 100"
            print("\nTest 1: Simple SELECT query")
            simple_time = measure_query_performance(test_db_id, simple_select)
            
            # Test 2: Aggregation query
            agg_query = f"SELECT COUNT(*) FROM {table}"
            print("\nTest 2: Aggregation query")
            agg_time = measure_query_performance(test_db_id, agg_query)
            
            # Test 3: Complex query with JOINs (if multiple tables)
            if len(tables) >= 2:
                join_query = f"SELECT t1.*, t2.* FROM {tables[0]} t1 CROSS JOIN {tables[1]} t2 LIMIT 100"
                print("\nTest 3: JOIN query")
                join_time = measure_query_performance(test_db_id, join_query)
            else:
                print("\nTest 3: JOIN query - Skipped (not enough tables)")
                join_time = None
            
            # Test 4: Query with WHERE clause
            where_query = f"SELECT * FROM {table} WHERE 1=1 LIMIT 100"
            print("\nTest 4: Query with WHERE clause")
            where_time = measure_query_performance(test_db_id, where_query)
            
            # Compare timings
            print("\nPerformance Comparison:")
            print(f"- Simple SELECT: {simple_time:.4f} seconds")
            print(f"- Aggregation query: {agg_time:.4f} seconds")
            if join_time:
                print(f"- JOIN query: {join_time:.4f} seconds")
            print(f"- WHERE query: {where_time:.4f} seconds")


Performance Testing:
Database: debit_card_specializing
Tables: ['customers', 'gasstations', 'products', 'transactions_1k', 'yearmonth']

Test 1: Simple SELECT query

Measuring performance for query:
SELECT * FROM customers LIMIT 100
✅ Iteration 1 completed in 0.0011 seconds
✅ Iteration 2 completed in 0.0002 seconds
✅ Iteration 3 completed in 0.0002 seconds
✅ Iteration 4 completed in 0.0002 seconds
✅ Iteration 5 completed in 0.0002 seconds

Performance summary:
- Average time: 0.0004 seconds
- Minimum time: 0.0002 seconds
- Maximum time: 0.0011 seconds

Test 2: Aggregation query

Measuring performance for query:
SELECT COUNT(*) FROM customers
✅ Iteration 1 completed in 0.0002 seconds
✅ Iteration 2 completed in 0.0002 seconds
✅ Iteration 3 completed in 0.0001 seconds
✅ Iteration 4 completed in 0.0001 seconds
✅ Iteration 5 completed in 0.0001 seconds

Performance summary:
- Average time: 0.0001 seconds
- Minimum time: 0.0001 seconds
- Maximum time: 0.0002 seconds

Test 3: JOIN query

Mea