# Text-to-SQL: Exploratory Analysis

This notebook explores the sample e-commerce database and demonstrates basic Text-to-SQL functionality.

## Objectives
1. Explore the database schema and relationships
2. Analyze sample data distribution
3. Test basic natural language queries
4. Understand query complexity patterns

In [None]:
import sys
import sqlite3
import pandas as pd
import json
from pathlib import Path

# Add src to path
sys.path.append(str(Path('../src').resolve()))

from schema_manager import SchemaManager
from query_generator import TextToSQLGenerator, QueryResult
from query_validator import QueryValidator

## 1. Database Schema Exploration

In [None]:
# Initialize database connection
db_path = "../data/sample_database.db"
conn = sqlite3.connect(db_path)

# Initialize schema manager
schema_manager = SchemaManager(db_path)
schema_info = schema_manager.get_schema()

In [None]:
# Display all tables
print("Database Tables:")
print("=" * 50)
for table in schema_info.tables:
    print(f"\n{table.name}")
    print("-" * 50)
    print(f"Columns: {', '.join([f'{col.name} ({col.type})' for col in table.columns])}")
    if table.primary_key:
        print(f"Primary Key: {table.primary_key}")
    if table.foreign_keys:
        print(f"Foreign Keys: {table.foreign_keys}")

In [None]:
# Visualize schema in CREATE TABLE format
schema_str = schema_manager.format_schema(format_type="create_table")
print("\nDatabase Schema (CREATE TABLE format):")
print("=" * 80)
print(schema_str)

## 2. Data Distribution Analysis

In [None]:
# Get row counts for each table
print("Table Row Counts:")
print("=" * 50)
for table in schema_info.tables:
    query = f"SELECT COUNT(*) as count FROM {table.name}"
    count = pd.read_sql_query(query, conn).iloc[0]['count']
    print(f"{table.name}: {count} rows")

In [None]:
# Explore customers table
print("\nCustomers Sample:")
customers_df = pd.read_sql_query("SELECT * FROM customers LIMIT 5", conn)
display(customers_df)

print("\nCustomer Registration Over Time:")
registration_stats = pd.read_sql_query(
    """
    SELECT 
        DATE(registration_date) as date,
        COUNT(*) as new_customers
    FROM customers
    GROUP BY DATE(registration_date)
    ORDER BY date
    """,
    conn
)
display(registration_stats)

In [None]:
# Explore products table
print("Products by Category:")
products_by_category = pd.read_sql_query(
    """
    SELECT 
        category,
        COUNT(*) as product_count,
        ROUND(AVG(price), 2) as avg_price,
        ROUND(AVG(stock_quantity), 2) as avg_stock
    FROM products
    GROUP BY category
    ORDER BY product_count DESC
    """,
    conn
)
display(products_by_category)

In [None]:
# Explore orders and revenue
print("Order Statistics:")
order_stats = pd.read_sql_query(
    """
    SELECT 
        status,
        COUNT(*) as order_count,
        ROUND(AVG(total_amount), 2) as avg_order_value,
        ROUND(SUM(total_amount), 2) as total_revenue
    FROM orders
    GROUP BY status
    """,
    conn
)
display(order_stats)

In [None]:
# Analyze reviews
print("Review Distribution:")
review_dist = pd.read_sql_query(
    """
    SELECT 
        rating,
        COUNT(*) as review_count,
        ROUND(COUNT(*) * 100.0 / (SELECT COUNT(*) FROM reviews), 2) as percentage
    FROM reviews
    GROUP BY rating
    ORDER BY rating DESC
    """,
    conn
)
display(review_dist)

## 3. Test Natural Language Queries

In [None]:
# Initialize Text-to-SQL generator (requires API key)
# Uncomment and add your API key to test
# import os
# os.environ['OPENAI_API_KEY'] = 'your-key-here'
# generator = TextToSQLGenerator(db_path)

In [None]:
# Load test queries
with open('../data/test_queries.json', 'r') as f:
    test_queries = json.load(f)

print(f"Loaded {len(test_queries)} test queries\n")

# Show query categories
categories = {}
for query in test_queries:
    cat = query['category']
    categories[cat] = categories.get(cat, 0) + 1

print("Query Categories:")
for cat, count in sorted(categories.items()):
    print(f"  {cat}: {count} queries")

In [None]:
# Display sample queries from each complexity level
print("\nSample Queries by Complexity:\n")
print("=" * 80)

for complexity in ['simple', 'medium', 'complex']:
    matching = [q for q in test_queries if q['complexity'] == complexity]
    if matching:
        sample = matching[0]
        print(f"\n{complexity.upper()} Query:")
        print(f"Question: {sample['question']}")
        print(f"Expected SQL: {sample['expected_sql']}")
        print("-" * 80)

In [None]:
# Example: Manually execute a test query to verify results
example_query = test_queries[0]
print(f"Executing: {example_query['question']}")
print(f"SQL: {example_query['expected_sql']}\n")

result_df = pd.read_sql_query(example_query['expected_sql'], conn)
display(result_df)

## 4. Query Complexity Analysis

In [None]:
# Analyze query patterns in test set
import re

def analyze_query_complexity(sql: str) -> dict:
    """Analyze SQL query complexity."""
    sql_upper = sql.upper()
    return {
        'has_join': 'JOIN' in sql_upper,
        'has_subquery': '(' in sql and 'SELECT' in sql_upper,
        'has_aggregate': any(func in sql_upper for func in ['COUNT', 'SUM', 'AVG', 'MAX', 'MIN']),
        'has_group_by': 'GROUP BY' in sql_upper,
        'has_order_by': 'ORDER BY' in sql_upper,
        'has_where': 'WHERE' in sql_upper,
        'num_tables': len(re.findall(r'FROM\s+(\w+)', sql_upper)) + len(re.findall(r'JOIN\s+(\w+)', sql_upper)),
    }

# Analyze all test queries
complexity_stats = []
for query in test_queries:
    stats = analyze_query_complexity(query['expected_sql'])
    stats['complexity'] = query['complexity']
    complexity_stats.append(stats)

complexity_df = pd.DataFrame(complexity_stats)
print("\nQuery Pattern Distribution:")
print("=" * 80)
print(complexity_df.groupby('complexity').mean().round(2))

## 5. Schema Relevance Testing

In [None]:
# Test table relevance identification
test_questions = [
    "Show me all customers from California",
    "What are the top selling products?",
    "Which customers have never placed an order?",
    "What's the average order value by customer city?"
]

print("Table Relevance Detection:")
print("=" * 80)
for question in test_questions:
    relevant_tables = schema_manager.identify_relevant_tables(question)
    print(f"\nQuestion: {question}")
    print(f"Relevant tables: {relevant_tables}")

## 6. Summary and Insights

### Key Findings:
1. **Database Structure**: 5 tables with clear relationships (customers, products, orders, order_items, reviews)
2. **Data Volume**: Small sample dataset suitable for testing and demo
3. **Query Complexity**: Test queries range from simple filters to complex multi-table joins
4. **Common Patterns**: 
   - Customer analysis queries
   - Product performance queries
   - Order and revenue analytics
   - Review sentiment analysis

### Next Steps:
- Optimize prompts for different query complexities (see notebook 02)
- Evaluate model performance on test queries (see notebook 03)
- Test edge cases and error handling

In [None]:
# Cleanup
conn.close()