# Geo-Experiment Playground

This notebook provides an interactive environment for experimenting with the geo-experiment evaluation framework.

## Setup

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Import our modules
from data_simulation.generators import SimpleNullGenerator, DataConfig
from assignment.methods import RandomAssignment, KMeansEmbeddingAssignment, PrognosticScoreAssignment, EmbeddingBasedAssignment, HybridEmbeddingAssignment
from assignment.spatial_utils import add_spectral_spatial_embedding
from assignment.stratified_utils import stratified_assignment_within_clusters, evaluate_cluster_balance, print_balance_summary
from reporting.models import MeanMatchingModel, GBRModel, TBRModel, SyntheticControlModel
from evaluation.metrics import EvaluationRunner, EvaluationConfig
from diagnostics.plots import DiagnosticPlotter
from pipeline.runner import ExperimentRunner
from pipeline.config import ExperimentConfig

# Set style
plt.style.use('default')
sns.set_palette('husl')

print("✅ All modules imported successfully!")
print("📊 Available assignment methods:")
print("  • RandomAssignment: Simple random assignment")
print("  • KMeansEmbeddingAssignment: K-means clustering on features")
print("  • PrognosticScoreAssignment: OLS-based prognostic scoring")
print("  • EmbeddingBasedAssignment: General embedding approach (neural + spatial)")
print("  • HybridEmbeddingAssignment: Semi-supervised prediction-aware assignment")

## Quick Start: Single Experiment

Let's start with a simple single experiment to understand the framework.

In [None]:
# Create a simple configuration
config = ExperimentConfig(
    n_geos=20,
    n_days=60,
    pre_period_days=40,
    eval_period_days=20,
    seed=42
)

# Run a single experiment
runner = ExperimentRunner(config)
results = runner.run_single_experiment(show_plots=True)

print(f"\n📊 Single Experiment Results:")
print(f"iROAS Estimate: {results['iroas_estimate']:.4f}")
print(f"95% CI: [{results['iroas_ci'][0]:.4f}, {results['iroas_ci'][1]:.4f}]")
print(f"CI Width: {results['ci_width']:.4f}")
print(f"Significant: {results['significant']}")

## Full Evaluation Example

Now let's run a complete evaluation across multiple simulations to see how the method performs statistically.

In [None]:
# Run a full evaluation with the same configuration
# Using smaller numbers for faster execution in the playground
full_eval_config = ExperimentConfig(
    n_geos=30,           # Moderate number of geos
    n_days=60,           # 60 days total
    pre_period_days=40,  # 40 days for training
    eval_period_days=20, # 20 days for evaluation
    n_simulations=50,    # 50 simulations (increase for more robust results)
    n_bootstrap=200,     # 200 bootstrap samples per simulation
    seed=42
)

print("🔄 Running full evaluation (this may take a minute)...")
print(f"Configuration: {full_eval_config.n_simulations} simulations, {full_eval_config.n_geos} geos each")

# Create runner and run evaluation
eval_runner = ExperimentRunner(full_eval_config)

# Add all reporting models for comparison
eval_runner.add_reporting_method('GBR', GBRModel())
eval_runner.add_reporting_method('TBR', TBRModel())
eval_runner.add_reporting_method('SCM', SyntheticControlModel())

detailed_results, summary_results = eval_runner.run_full_evaluation(verbose=True)

print("\n📈 Summary Results:")
print(summary_results)

# Create visualization
print("\n📊 Creating results visualization...")
fig = eval_runner.plot_results(detailed_results)
plt.show()

# Additional insights
print(f"\n🔍 Key Insights:")
print(f"• Average iROAS estimate: {detailed_results['iroas_estimate'].mean():.4f}")
print(f"• Standard deviation of estimates: {detailed_results['iroas_estimate'].std():.4f}")
print(f"• False positive rate: {summary_results['false_positive_rate'].iloc[0]:.3f} (should be ~0.05)")
print(f"• Coverage rate: {summary_results['coverage_rate'].iloc[0]:.3f} (should be ~0.95)")
print(f"• Mean CI width: {summary_results['mean_ci_width'].iloc[0]:.4f}")

# Show distribution of estimates
plt.figure(figsize=(10, 6))
plt.subplot(1, 2, 1)
plt.hist(detailed_results['iroas_estimate'], bins=20, alpha=0.7, edgecolor='black')
plt.axvline(0, color='red', linestyle='--', label='True iROAS (0)')
plt.xlabel('iROAS Estimate')
plt.ylabel('Frequency')
plt.title('Distribution of iROAS Estimates')
plt.legend()

plt.subplot(1, 2, 2)
plt.hist(detailed_results['ci_width'], bins=20, alpha=0.7, edgecolor='black', color='orange')
plt.xlabel('Confidence Interval Width')
plt.ylabel('Frequency')
plt.title('Distribution of CI Widths')

plt.tight_layout()
plt.show()

## Data Generation Experiments

Let's experiment with different data generation parameters.

In [None]:
# Generate data with different noise levels
low_noise_config = DataConfig(
    n_geos=30,
    n_days=90,
    daily_sales_noise=100,  # Low noise
    seed=123
)

high_noise_config = DataConfig(
    n_geos=30,
    n_days=90,
    daily_sales_noise=1000,  # High noise
    seed=123
)

# Generate both datasets
low_noise_gen = SimpleNullGenerator(low_noise_config)
high_noise_gen = SimpleNullGenerator(high_noise_config)

panel_low, features_low = low_noise_gen.generate()
panel_high, features_high = high_noise_gen.generate()

# Compare variability
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Plot time series for first geo
geo_low = panel_low[panel_low['geo'] == 'geo_000']
geo_high = panel_high[panel_high['geo'] == 'geo_000']

axes[0].plot(geo_low['date'], geo_low['sales'], label='Low Noise', alpha=0.8)
axes[0].plot(geo_high['date'], geo_high['sales'], label='High Noise', alpha=0.8)
axes[0].set_title('Sales Time Series Comparison (geo_000)')
axes[0].set_ylabel('Sales')
axes[0].legend()
axes[0].tick_params(axis='x', rotation=45)

# Plot distributions
axes[1].hist(panel_low['sales'], alpha=0.6, label='Low Noise', bins=30)
axes[1].hist(panel_high['sales'], alpha=0.6, label='High Noise', bins=30)
axes[1].set_title('Sales Distribution Comparison')
axes[1].set_xlabel('Sales')
axes[1].set_ylabel('Frequency')
axes[1].legend()

plt.tight_layout()
plt.show()

print(f"Low noise std: {panel_low['sales'].std():.2f}")
print(f"High noise std: {panel_high['sales'].std():.2f}")

## Assignment Method Testing

Test different assignment strategies and their balance.

In [None]:
# Comprehensive Assignment Method Comparison
print("📊 COMPREHENSIVE ASSIGNMENT METHOD COMPARISON")
print("="*80)

# Generate test data with all necessary features
np.random.seed(42)
test_geo_features = pd.DataFrame({
    'geo': [f'geo_{i:03d}' for i in range(40)],
    'base_sales': np.random.normal(12000, 4000, 40),
    'base_spend': np.random.normal(6000, 2000, 40),
    'covariate': np.random.normal(0, 1.5, 40),
    'xy1': np.random.uniform(0, 100, 40),  # Spatial coordinates
    'xy2': np.random.uniform(0, 100, 40)
})

# Create panel data for time-series methods
dates = pd.date_range('2024-01-01', periods=60)
test_panel_data = []
for _, geo_row in test_geo_features.iterrows():
    base_sales = geo_row['base_sales']
    trend = np.random.normal(0, 50)
    for day_idx, date in enumerate(dates):
        sales = (base_sales + trend * day_idx + np.random.normal(0, 500) + 
                300 * np.sin(day_idx * 2 * np.pi / 7))  # Weekly seasonality
        test_panel_data.append({
            'geo': geo_row['geo'],
            'date': date,
            'sales': max(sales, 1000),
            'spend_dollars': np.random.normal(5000, 1000)
        })
test_panel_df = pd.DataFrame(test_panel_data)

print(f"Dataset: {len(test_geo_features)} geos with {len(test_panel_df)} panel observations")

# Define all available assignment methods
assignment_methods = {
    'Random': {
        'method': RandomAssignment(),
        'description': 'Simple random assignment (baseline)',
        'requires_panel': False,
        'requires_spatial': False
    },
    'K-Means': {
        'method': KMeansEmbeddingAssignment(n_clusters=5),
        'description': 'K-means clustering on standardized features',
        'requires_panel': False,
        'requires_spatial': False
    },
    'Prognostic Score': {
        'method': PrognosticScoreAssignment(n_strata=5),
        'description': 'OLS-based prognostic scoring with stratification',
        'requires_panel': True,
        'requires_spatial': False
    },
    'Embedding-Based': {
        'method': EmbeddingBasedAssignment(n_clusters=5, neural_epochs=15),
        'description': 'Neural + spectral spatial embeddings',
        'requires_panel': False,
        'requires_spatial': True
    },
    'Hybrid Embedding': {
        'method': HybridEmbeddingAssignment(n_clusters=5, neural_epochs=15),
        'description': 'Semi-supervised: reconstruction + prediction loss',
        'requires_panel': True,
        'requires_spatial': True
    }
}

# Create comprehensive comparison visualization
fig, axes = plt.subplots(3, len(assignment_methods), figsize=(25, 15))

print("\n" + "="*80)
print("🎯 ASSIGNMENT METHOD DETAILED COMPARISON")
print("="*80)

balance_summary = {}
for i, (method_name, method_info) in enumerate(assignment_methods.items()):
    print(f"\n🔹 {method_name.upper()}:")
    print(f"   {method_info['description']}")
    print("-" * 60)
    
    try:
        # Create assignment based on method requirements
        if method_info['requires_panel'] and method_info['requires_spatial']:
            # Hybrid embedding needs both
            assignment_df = method_info['method'].assign(
                test_geo_features, panel_data=test_panel_df, seed=42
            )
        elif method_info['requires_panel']:
            # Prognostic score needs panel data
            assignment_df = method_info['method'].assign(
                test_geo_features[['geo', 'base_sales', 'base_spend', 'covariate']], 
                panel_data=test_panel_df, seed=42
            )
        elif method_info['requires_spatial']:
            # Embedding-based needs spatial coordinates
            assignment_df = method_info['method'].assign(test_geo_features, seed=42)
        else:
            # Random and K-means work with basic features
            assignment_df = method_info['method'].assign(
                test_geo_features[['geo', 'base_sales', 'base_spend', 'covariate']], seed=42
            )
        
        # Merge for analysis
        if method_info['requires_spatial']:
            analysis_features = test_geo_features
        else:
            analysis_features = test_geo_features[['geo', 'base_sales', 'base_spend', 'covariate']]
        
        merged = analysis_features.merge(assignment_df, on='geo')
        
        # Print assignment summary
        treatment_count = (assignment_df['assignment'] == 'treatment').sum()
        control_count = (assignment_df['assignment'] == 'control').sum()
        
        print(f"   Assignment: {treatment_count} treatment, {control_count} control")
        
        # Handle cluster information
        if 'cluster' not in assignment_df.columns:
            assignment_df = assignment_df.copy()
            assignment_df['cluster'] = 0  # Single cluster for methods without clustering
            print(f"   Structure: No clustering (all geos treated as single group)")
        else:
            n_clusters = len(assignment_df['cluster'].unique())
            cluster_dist = assignment_df['cluster'].value_counts().sort_index()
            print(f"   Structure: {n_clusters} clusters/strata")
            print(f"   Distribution: {dict(cluster_dist)}")
        
        # Evaluate balance
        feature_cols = ['base_sales', 'base_spend', 'covariate']
        balance_df = evaluate_cluster_balance(analysis_features, assignment_df, feature_cols)
        
        # Calculate balance metrics
        overall_balance = balance_df[balance_df['cluster'] == 'Overall']
        avg_smd = overall_balance['standardized_diff'].mean()
        max_smd = overall_balance['standardized_diff'].max()
        
        balance_summary[method_name] = {
            'avg_smd': avg_smd,
            'max_smd': max_smd,
            'treatment_count': treatment_count,
            'control_count': control_count,
            'n_clusters': len(assignment_df['cluster'].unique())
        }
        
        # Print balance summary
        print(f"   Balance: Avg SMD = {avg_smd:.3f}, Max SMD = {max_smd:.3f}")
        balance_quality = ("Excellent" if avg_smd < 0.05 else 
                          "Good" if avg_smd < 0.1 else 
                          "Moderate" if avg_smd < 0.2 else "Poor")
        print(f"   Quality: {balance_quality} ({'✅' if avg_smd < 0.1 else '⚠️' if avg_smd < 0.2 else '❌'})")
        
        # Visualization 1: Sales balance
        sns.boxplot(data=merged, x='assignment', y='base_sales', ax=axes[0, i])
        axes[0, i].set_title(f'{method_name}\nSales Balance')
        axes[0, i].set_ylabel('Base Sales' if i == 0 else '')
        
        # Visualization 2: Covariate balance with cluster information
        if 'cluster' in assignment_df.columns and len(assignment_df['cluster'].unique()) > 1:
            # Show clusters with different colors
            palette = sns.color_palette("Set2", n_colors=len(merged['cluster'].unique()))
            sns.scatterplot(data=merged, x='assignment', y='covariate', 
                           hue='cluster', palette=palette, ax=axes[1, i], s=60, alpha=0.8)
            axes[1, i].set_title(f'{method_name}\nCovariate by Cluster')
            if i < len(assignment_methods) - 1:  # Remove legend except for last plot
                axes[1, i].get_legend().remove()
        else:
            sns.boxplot(data=merged, x='assignment', y='covariate', ax=axes[1, i])
            axes[1, i].set_title(f'{method_name}\nCovariate Balance')
        axes[1, i].set_ylabel('Covariate' if i == 0 else '')
        
        # Visualization 3: Balance quality heatmap
        balance_pivot = overall_balance.set_index('feature')['standardized_diff']
        balance_matrix = balance_pivot.values.reshape(-1, 1)
        
        im = axes[2, i].imshow(balance_matrix, cmap='RdYlGn_r', aspect='auto', vmin=0, vmax=0.3)
        axes[2, i].set_title(f'{method_name}\nBalance Quality')
        axes[2, i].set_yticks(range(len(feature_cols)))
        axes[2, i].set_yticklabels(['Sales', 'Spend', 'Covariate'] if i == 0 else ['', '', ''])
        axes[2, i].set_xticks([0])
        axes[2, i].set_xticklabels(['SMD'])
        
        # Add text annotations
        for j, val in enumerate(balance_matrix.flatten()):
            color = 'white' if val > 0.15 else 'black'
            axes[2, i].text(0, j, f'{val:.2f}', ha='center', va='center', color=color, fontweight='bold')
    
    except Exception as e:
        print(f"   ❌ Error: {str(e)}")
        balance_summary[method_name] = {'avg_smd': np.nan, 'error': str(e)}
        
        # Fill plots with error message
        for row in range(3):
            axes[row, i].text(0.5, 0.5, f'Error:\n{str(e)[:30]}...', 
                            transform=axes[row, i].transAxes, ha='center', va='center')
            axes[row, i].set_title(f'{method_name}\n(Failed)')

plt.tight_layout()
plt.show()

# Summary comparison table
print("\n" + "="*80)
print("📋 SUMMARY COMPARISON TABLE")
print("="*80)

summary_df = pd.DataFrame(balance_summary).T
summary_df = summary_df.dropna(subset=['avg_smd'])  # Remove failed methods

if len(summary_df) > 0:
    summary_df = summary_df.sort_values('avg_smd')
    print(f"{'Method':<18} {'Avg SMD':<8} {'Max SMD':<8} {'T/C Split':<12} {'Clusters':<8} {'Quality':<12}")
    print("-" * 75)
    
    for method, row in summary_df.iterrows():
        quality = ("Excellent" if row['avg_smd'] < 0.05 else 
                  "Good" if row['avg_smd'] < 0.1 else 
                  "Moderate" if row['avg_smd'] < 0.2 else "Poor")
        split = f"{int(row['treatment_count'])}/{int(row['control_count'])}"
        
        print(f"{method:<18} {row['avg_smd']:<8.3f} {row['max_smd']:<8.3f} {split:<12} {int(row['n_clusters']):<8} {quality:<12}")

# Final recommendations
print(f"\n🎯 RECOMMENDATIONS:")
print("="*50)
print("• Random: Use as baseline for comparison")
print("• K-Means: Good general-purpose method for feature-based balance")
print("• Prognostic Score: Best when historical outcomes predict future performance")
print("• Embedding-Based: Ideal when spatial structure matters")
print("• Hybrid Embedding: Most sophisticated, use when prediction accuracy is key")
print("\n💡 SMD Interpretation:")
print("• < 0.05: Excellent balance")
print("• 0.05-0.1: Good balance") 
print("• 0.1-0.2: Moderate balance (acceptable)")
print("• > 0.2: Poor balance (concerning)")

## Advanced Assignment Methods

Now let's test the more sophisticated assignment methods that use feature engineering and clustering.

In [None]:
# Detailed Balance Analysis: Understanding Stratified Assignment
print("🔬 DETAILED BALANCE ANALYSIS")
print("="*60)

# Generate focused test data
np.random.seed(123)
test_features = pd.DataFrame({
    'geo': [f'geo_{i:03d}' for i in range(24)],
    'base_sales': np.random.normal(10000, 3000, 24),
    'base_spend': np.random.normal(5000, 1500, 24), 
    'covariate': np.random.normal(0, 2, 24)
})

print(f"Test data: {len(test_features)} geos")
print(f"Feature ranges:")
print(f"  • base_sales: {test_features['base_sales'].min():.0f} - {test_features['base_sales'].max():.0f}")
print(f"  • base_spend: {test_features['base_spend'].min():.0f} - {test_features['base_spend'].max():.0f}")
print(f"  • covariate: {test_features['covariate'].min():.2f} - {test_features['covariate'].max():.2f}")

# Test K-Means method to show stratified assignment concept
print(f"\n🎯 K-MEANS STRATIFIED ASSIGNMENT DEMO:")
print("-" * 50)

kmeans_method = KMeansEmbeddingAssignment(n_clusters=3)
kmeans_assignment = kmeans_method.assign(test_features, treatment_ratio=0.5, seed=123)

# Show cluster formation and assignment within clusters
merged_detailed = test_features.merge(kmeans_assignment, on='geo')

print("\nCluster formation and assignment:")
for cluster_id in sorted(merged_detailed['cluster'].unique()):
    cluster_data = merged_detailed[merged_detailed['cluster'] == cluster_id]
    treatment_in_cluster = (cluster_data['assignment'] == 'treatment').sum()
    control_in_cluster = (cluster_data['assignment'] == 'control').sum()
    
    avg_sales = cluster_data['base_sales'].mean()
    avg_spend = cluster_data['base_spend'].mean()
    avg_cov = cluster_data['covariate'].mean()
    
    print(f"\n  Cluster {cluster_id}: {len(cluster_data)} geos total")
    print(f"    Assignment: {treatment_in_cluster} treatment, {control_in_cluster} control")
    print(f"    Avg features: sales={avg_sales:.0f}, spend={avg_spend:.0f}, cov={avg_cov:.2f}")

# Comprehensive balance evaluation
print(f"\n📊 BALANCE EVALUATION:")
print("-" * 30)
balance_results = evaluate_cluster_balance(
    test_features, 
    kmeans_assignment, 
    ['base_sales', 'base_spend', 'covariate']
)

print_balance_summary(balance_results, threshold=0.1)

# Visualization of clusters and assignments
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# 1. Cluster visualization in feature space
scatter = axes[0, 0].scatter(merged_detailed['base_sales'], merged_detailed['base_spend'], 
                           c=merged_detailed['cluster'], cmap='Set1', s=80, alpha=0.7)
axes[0, 0].set_xlabel('Base Sales')
axes[0, 0].set_ylabel('Base Spend')
axes[0, 0].set_title('K-Means Clusters in Feature Space')
plt.colorbar(scatter, ax=axes[0, 0], label='Cluster')

# 2. Assignment within clusters
colors = ['red' if x == 'treatment' else 'blue' for x in merged_detailed['assignment']]
axes[0, 1].scatter(merged_detailed['base_sales'], merged_detailed['base_spend'], 
                  c=colors, s=80, alpha=0.7)
axes[0, 1].set_xlabel('Base Sales')
axes[0, 1].set_ylabel('Base Spend')
axes[0, 1].set_title('Treatment/Control Assignment')
axes[0, 1].legend(['Control', 'Treatment'])

# 3. Balance comparison by cluster
cluster_balance = balance_results[balance_results['cluster'] != 'Overall']
if len(cluster_balance) > 0:
    sns.barplot(data=cluster_balance, x='cluster', y='standardized_diff', 
               hue='feature', ax=axes[1, 0])
    axes[1, 0].set_title('Within-Cluster Balance (SMD)')
    axes[1, 0].axhline(y=0.1, color='red', linestyle='--', alpha=0.7, label='Balance Threshold')
    axes[1, 0].legend(bbox_to_anchor=(1.05, 1), loc='upper left')

# 4. Overall vs within-cluster balance
overall_balance = balance_results[balance_results['cluster'] == 'Overall']
avg_within_cluster = cluster_balance.groupby('feature')['standardized_diff'].mean().reset_index()
avg_within_cluster['balance_type'] = 'Within-Cluster Avg'
overall_balance_plot = overall_balance[['feature', 'standardized_diff']].copy()
overall_balance_plot['balance_type'] = 'Overall'

balance_comparison = pd.concat([
    overall_balance_plot[['feature', 'standardized_diff', 'balance_type']], 
    avg_within_cluster[['feature', 'standardized_diff', 'balance_type']]
])

sns.barplot(data=balance_comparison, x='feature', y='standardized_diff', 
           hue='balance_type', ax=axes[1, 1])
axes[1, 1].set_title('Overall vs Within-Cluster Balance')
axes[1, 1].axhline(y=0.1, color='red', linestyle='--', alpha=0.7)
axes[1, 1].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

print(f"\n💡 KEY CONCEPTS:")
print("• Stratified assignment separates clustering from treatment assignment")
print("• Each cluster contributes proportionally to treatment and control groups")
print("• This prevents imbalance from assigning entire clusters to one group")
print("• Balance is evaluated both overall and within each cluster")

## Assignment Method Deep Dive

Individual analysis of each assignment method to understand their clustering/stratification strategies and balance characteristics.

In [None]:
# Embedding-Based Assignment Demonstration
print("🎯 EMBEDDING-BASED ASSIGNMENT DEMO")
print("="*60)
print("Combines neural embeddings (learned features) + spectral spatial embeddings")

# Import the EmbeddingBasedAssignment method
from assignment.methods import EmbeddingBasedAssignment

# Generate test data with spatial coordinates
np.random.seed(2024)
embedding_features = pd.DataFrame({
    'geo': [f'geo_{i:03d}' for i in range(30)],
    'base_sales': np.random.normal(15000, 5000, 30),
    'base_spend': np.random.normal(7500, 2500, 30),
    'covariate': np.random.normal(0, 2, 30),
    # Spatial coordinates (e.g., latitude/longitude scaled)
    'xy1': np.random.uniform(0, 100, 30),
    'xy2': np.random.uniform(0, 100, 30)
})

print(f"\nDataset: {len(embedding_features)} geos with features + spatial coordinates")
print(f"Features: {[col for col in embedding_features.columns if col not in ['geo', 'xy1', 'xy2']]}")
print(f"Spatial coords: ['xy1', 'xy2']")

# Test different configurations of the embedding method
embedding_configs = {
    'Default': EmbeddingBasedAssignment(neural_epochs=20),
    'High Neural Dim': EmbeddingBasedAssignment(
        neural_embedding_dim=12, 
        spatial_embedding_dim=3,
        neural_epochs=20
    ),
    'More Clusters': EmbeddingBasedAssignment(
        n_clusters=6,
        neural_epochs=20
    ),
    'Custom Features': EmbeddingBasedAssignment(
        feature_cols=['base_sales', 'covariate'],  # Skip base_spend
        neural_epochs=20
    )
}

# Create comprehensive visualization
fig, axes = plt.subplots(2, len(embedding_configs), figsize=(20, 10))

print(f"\n🔍 Testing {len(embedding_configs)} embedding-based configurations:")
print("-" * 70)

for i, (config_name, method) in enumerate(embedding_configs.items()):
    print(f"\n{config_name.upper()}:")
    print(f"  Neural dim: {method.neural_embedding_dim}, Spatial dim: {method.spatial_embedding_dim}")
    print(f"  Clusters: {method.n_clusters}, Features: {method.feature_cols or 'default'}")
    
    # Create assignment
    assignment_df = method.assign(embedding_features, seed=2024)
    merged_embedding = embedding_features.merge(assignment_df, on='geo')
    
    # Print summary
    treatment_count = (assignment_df['assignment'] == 'treatment').sum()
    control_count = (assignment_df['assignment'] == 'control').sum()
    n_clusters_actual = len(assignment_df['cluster'].unique())
    
    print(f"  Result: {treatment_count}T/{control_count}C across {n_clusters_actual} clusters")
    
    # Evaluate balance
    balance_results = evaluate_cluster_balance(
        embedding_features, assignment_df, ['base_sales', 'base_spend', 'covariate']
    )
    
    # Calculate average balance quality
    overall_balance = balance_results[balance_results['cluster'] == 'Overall']
    avg_smd = overall_balance['standardized_diff'].mean()
    print(f"  Avg SMD: {avg_smd:.3f} ({'Good' if avg_smd < 0.1 else 'Moderate' if avg_smd < 0.2 else 'Poor'} balance)")
    
    # Top plot: Spatial distribution with clusters
    scatter = axes[0, i].scatter(
        merged_embedding['xy1'], merged_embedding['xy2'],
        c=merged_embedding['cluster'], cmap='tab10', s=80, alpha=0.8
    )
    axes[0, i].set_xlabel('Spatial X1')
    axes[0, i].set_ylabel('Spatial X2')
    axes[0, i].set_title(f'{config_name}\nSpatial Clusters')
    
    # Bottom plot: Feature space with treatment assignment
    treatment_mask = merged_embedding['assignment'] == 'treatment'
    axes[1, i].scatter(
        merged_embedding.loc[treatment_mask, 'base_sales'],
        merged_embedding.loc[treatment_mask, 'covariate'],
        c='red', label='Treatment', s=80, alpha=0.8
    )
    axes[1, i].scatter(
        merged_embedding.loc[~treatment_mask, 'base_sales'],
        merged_embedding.loc[~treatment_mask, 'covariate'],
        c='blue', label='Control', s=80, alpha=0.8
    )
    axes[1, i].set_xlabel('Base Sales')
    axes[1, i].set_ylabel('Covariate')
    axes[1, i].set_title(f'{config_name}\nTreatment Assignment')
    if i == 0:
        axes[1, i].legend()

plt.tight_layout()
plt.show()

print(f"\n🎯 EMBEDDING-BASED KEY FEATURES:")
print("="*50)
print("• Neural embeddings: Learn representations from static geo features")
print("• Spatial embeddings: Capture geographic proximity via spectral methods")
print("• Combined embeddings: Merge learned + spatial representations")
print("• Stratified assignment: Balanced treatment/control within each cluster")
print("• General purpose: Works with limited data, fast training")

# Quick performance comparison with simpler methods
print(f"\n⚡ QUICK PERFORMANCE COMPARISON:")
print("-" * 40)

comparison_methods = {
    'Random': RandomAssignment(),
    'K-Means': KMeansEmbeddingAssignment(n_clusters=4),
    'Embedding-Based': EmbeddingBasedAssignment(n_clusters=4, neural_epochs=15)
}

balance_scores = {}
for method_name, method in comparison_methods.items():
    if method_name == 'Embedding-Based':
        assignment = method.assign(embedding_features, seed=2024)
    elif method_name == 'Random':
        assignment = method.assign(embedding_features[['geo', 'base_sales', 'base_spend', 'covariate']], seed=2024)
    else:  # K-Means
        assignment = method.assign(embedding_features[['geo', 'base_sales', 'base_spend', 'covariate']], seed=2024)
    
    # Handle methods that don't create cluster column (like RandomAssignment)
    if 'cluster' not in assignment.columns:
        # Add dummy cluster for random assignment
        assignment = assignment.copy()
        assignment['cluster'] = 0  # All in one cluster for random assignment
    
    balance = evaluate_cluster_balance(
        embedding_features[['geo', 'base_sales', 'base_spend', 'covariate']], 
        assignment, 
        ['base_sales', 'base_spend', 'covariate']
    )
    overall_smd = balance[balance['cluster'] == 'Overall']['standardized_diff'].mean()
    balance_scores[method_name] = overall_smd
    
    treatment_count = (assignment['assignment'] == 'treatment').sum()
    print(f"{method_name:15}: {treatment_count}T/{30-treatment_count}C, Avg SMD = {overall_smd:.3f}")

# Visualize balance comparison
plt.figure(figsize=(10, 6))
methods = list(balance_scores.keys())
scores = list(balance_scores.values())
colors = ['skyblue', 'lightcoral', 'lightgreen']

bars = plt.bar(methods, scores, color=colors, alpha=0.8, edgecolor='black')
plt.axhline(y=0.1, color='red', linestyle='--', alpha=0.7, label='Good Balance Threshold')
plt.ylabel('Average Standardized Mean Difference (SMD)')
plt.title('Balance Quality Comparison: Embedding-Based vs Other Methods')
plt.legend()

# Add value labels on bars
for bar, score in zip(bars, scores):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.005,
             f'{score:.3f}', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.show()

print(f"\n💡 INTERPRETATION:")
print("• Lower SMD = better balance between treatment and control groups")
print("• Embedding-based method combines feature learning + spatial awareness")
print("• Good for cases with static features and spatial structure")
print("• Neural component learns complex feature relationships")

## Hybrid Embedding Assignment (Semi-Supervised)

Test the advanced HybridEmbeddingAssignment method that uses a hybrid loss function combining reconstruction (unsupervised) and prediction (supervised) objectives. This method is prediction-aware and requires panel data.

In [None]:
# Hybrid Embedding Assignment (Semi-Supervised) Demonstration
print("🧠 HYBRID EMBEDDING ASSIGNMENT DEMO")
print("="*60)
print("Semi-supervised approach: Reconstruction (unsupervised) + Prediction (supervised)")

# Import the new HybridEmbeddingAssignment method
from assignment.methods import HybridEmbeddingAssignment

# Generate test data with time series (required for hybrid approach)
np.random.seed(2025)
n_geos = 20
n_days = 50

# Create geo features with spatial coordinates
hybrid_geo_features = pd.DataFrame({
    'geo': [f'geo_{i:03d}' for i in range(n_geos)],
    'xy1': np.random.uniform(0, 100, n_geos),
    'xy2': np.random.uniform(0, 100, n_geos)
})

# Create panel data with time series patterns
dates = pd.date_range('2024-01-01', periods=n_days)
panel_data = []

for _, geo_row in hybrid_geo_features.iterrows():
    # Each geo has a different base sales level and trend
    base_sales = np.random.normal(12000, 3000)
    trend = np.random.normal(0, 50)  # Some geos grow, others decline
    
    for day_idx, date in enumerate(dates):
        sales = (base_sales + 
                trend * day_idx + 
                np.random.normal(0, 800) +  # Daily noise
                500 * np.sin(day_idx * 2 * np.pi / 7))  # Weekly seasonality
        
        panel_data.append({
            'geo': geo_row['geo'],
            'date': date,
            'sales': max(sales, 1000),  # Ensure positive sales
            'spend_dollars': np.random.normal(5000, 1000)
        })

hybrid_panel_data = pd.DataFrame(panel_data)

print(f"\nDataset: {n_geos} geos with {n_days} days of time series data")
print(f"Panel data shape: {hybrid_panel_data.shape}")
print(f"Spatial coords: ['xy1', 'xy2']")
print(f"Time series: Sales data with trends and seasonality")

# Test different hybrid configurations
print(f"\n🔬 HYBRID LOSS CONFIGURATIONS:")
print("-" * 50)

hybrid_configs = {
    'Balanced': HybridEmbeddingAssignment(
        reconstruction_weight=0.5,
        prediction_weight=0.25,
        regularization_weight=0.25,
        neural_epochs=20,
        n_clusters=3
    ),
    'Prediction-Focused': HybridEmbeddingAssignment(
        reconstruction_weight=0.3,
        prediction_weight=0.5,    # Higher emphasis on prediction
        regularization_weight=0.2,
        neural_epochs=20,
        n_clusters=3
    ),
    'Reconstruction-Focused': HybridEmbeddingAssignment(
        reconstruction_weight=0.6,   # Higher emphasis on reconstruction
        prediction_weight=0.2,
        regularization_weight=0.2,
        neural_epochs=20,
        n_clusters=3
    )
}

# Run experiments with different loss configurations
results_comparison = {}
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

for i, (config_name, method) in enumerate(hybrid_configs.items()):
    print(f"\n{config_name.upper()}:")
    print(f"  Loss weights - Recon: {method.reconstruction_weight:.1f}, " +
          f"Pred: {method.prediction_weight:.1f}, Reg: {method.regularization_weight:.1f}")
    
    # Create assignment
    assignment_df = method.assign(
        hybrid_geo_features, 
        panel_data=hybrid_panel_data, 
        seed=2025
    )
    
    # Merge for analysis
    merged_hybrid = hybrid_geo_features.merge(assignment_df, on='geo')
    
    # Print summary
    treatment_count = (assignment_df['assignment'] == 'treatment').sum()
    control_count = (assignment_df['assignment'] == 'control').sum()
    n_clusters_actual = len(assignment_df['cluster'].unique())
    
    print(f"  Result: {treatment_count}T/{control_count}C across {n_clusters_actual} clusters")
    
    # Store results for comparison
    results_comparison[config_name] = {
        'assignment_df': assignment_df,
        'merged_data': merged_hybrid,
        'treatment_count': treatment_count,
        'control_count': control_count
    }
    
    # Visualize spatial clustering
    scatter = axes[0, i].scatter(
        merged_hybrid['xy1'], merged_hybrid['xy2'],
        c=merged_hybrid['cluster'], cmap='Set1', s=100, alpha=0.8, edgecolor='black'
    )
    axes[0, i].set_xlabel('Spatial X1')
    axes[0, i].set_ylabel('Spatial X2')
    axes[0, i].set_title(f'{config_name}\nSpatial Clusters')
    
    # Visualize treatment assignment
    treatment_mask = merged_hybrid['assignment'] == 'treatment'
    axes[1, i].scatter(
        merged_hybrid.loc[treatment_mask, 'xy1'],
        merged_hybrid.loc[treatment_mask, 'xy2'],
        c='red', label='Treatment', s=100, alpha=0.8, edgecolor='black'
    )
    axes[1, i].scatter(
        merged_hybrid.loc[~treatment_mask, 'xy1'],
        merged_hybrid.loc[~treatment_mask, 'xy2'],
        c='blue', label='Control', s=100, alpha=0.8, edgecolor='black'
    )
    axes[1, i].set_xlabel('Spatial X1')
    axes[1, i].set_ylabel('Spatial X2')
    axes[1, i].set_title(f'{config_name}\nTreatment Assignment')
    if i == 0:
        axes[1, i].legend()

plt.tight_layout()
plt.show()

# Compare the different approaches
print(f"\n📊 PERFORMANCE ANALYSIS:")
print("="*50)

print("Analyzing how different loss weightings affect cluster formation...")

# Analyze sales patterns by cluster for each method
for config_name, results in results_comparison.items():
    print(f"\n{config_name.upper()} ANALYSIS:")
    
    # Get sales data for this assignment
    assignment_with_sales = results['assignment_df'].merge(
        hybrid_panel_data.groupby('geo')['sales'].mean().reset_index(), 
        on='geo'
    )
    
    # Analyze clusters
    cluster_stats = assignment_with_sales.groupby('cluster').agg({
        'sales': ['mean', 'std', 'count'],
        'assignment': lambda x: f"{(x=='treatment').sum()}T/{(x=='control').sum()}C"
    }).round(0)
    
    cluster_stats.columns = ['Avg_Sales', 'Sales_Std', 'Count', 'T/C_Split']
    print(cluster_stats.to_string())

print(f"\n🎯 HYBRID EMBEDDING KEY ADVANTAGES:")
print("="*50)
print("• Semi-supervised learning: Combines unsupervised (reconstruction) + supervised (prediction)")
print("• Prediction-aware: Neural embeddings learn to predict future sales outcomes")
print("• Time series intelligence: Uses historical patterns, not just static features")
print("• Tunable objectives: Adjust loss weights based on prediction vs balance priorities")
print("• Spatial regularization: Geographic structure preserved through spectral embeddings")

# Demonstrate the prediction capability
print(f"\n🔮 PREDICTION AWARENESS DEMO:")
print("-" * 40)

# Use the balanced configuration for this demo
method = hybrid_configs['Balanced']

# Show how the method splits time series data
print("Time series data splitting:")
pre_period_data, prediction_targets, common_geos = method._prepare_time_series_data(hybrid_panel_data)

print(f"• Pre-period data shape: {pre_period_data.shape} (geos × time_steps)")
print(f"• Prediction targets shape: {prediction_targets.shape} (future sales per geo)")
print(f"• Pre-period uses: {pre_period_data.shape[1]} days ({method.pre_period_fraction:.1%} of timeline)")
print(f"• Prediction period: {int(n_days * method.prediction_fraction)} days ({method.prediction_fraction:.1%} of timeline)")

# Show sales distribution
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.hist(pre_period_data.mean(axis=1), bins=15, alpha=0.7, edgecolor='black')
plt.xlabel('Average Pre-Period Sales')
plt.ylabel('Number of Geos')
plt.title('Pre-Period Sales Distribution\n(Used for Learning Embeddings)')

plt.subplot(1, 2, 2)
plt.hist(prediction_targets, bins=15, alpha=0.7, color='orange', edgecolor='black')
plt.xlabel('Future Sales Target')
plt.ylabel('Number of Geos')
plt.title('Prediction Targets\n(Used for Supervised Learning)')

plt.tight_layout()
plt.show()

print(f"\n💡 WHEN TO USE HYBRID EMBEDDING:")
print("• Rich panel data available (time series for each geo)")
print("• Want assignments that consider likely future outcomes")
print("• Care about both balance AND prediction accuracy")
print("• Have computational resources for neural network training")
print("• Geographic spillovers are important (spatial component)")

print(f"\n🔄 COMPARISON WITH OTHER METHODS:")
print("• EmbeddingBasedAssignment: Static features, faster, less data required")
print("• HybridEmbeddingAssignment: Time series, prediction-aware, more sophisticated")
print("• Both use spatial embeddings and stratified assignment for balance")

## Model Performance Comparison

Run a small evaluation to see how methods perform.

In [None]:
# Demonstration of CSV and Plot Output Functionality
print("💾 CSV AND PLOT OUTPUT DEMO")
print("="*50)

# Create a configuration for output demonstration
output_config = ExperimentConfig(
    n_geos=20,
    n_days=50,
    pre_period_days=35,
    eval_period_days=15,
    n_simulations=10,  # Small number for demo
    n_bootstrap=100,
    seed=42
)

# Create runner with multiple assignment methods for comparison
output_runner = ExperimentRunner(output_config)
output_runner.add_assignment_method('K-Means', KMeansEmbeddingAssignment(n_clusters=4))
output_runner.add_assignment_method('Prognostic', PrognosticScoreAssignment(n_strata=4))

print(f"Configuration: {output_config.n_simulations} simulations, {output_config.n_geos} geos")
print(f"Assignment methods: {list(output_runner.assignment_methods.keys())}")

# Run evaluation with output files
print(f"\n📊 Running evaluation with CSV and plot output...")
detailed_results, summary_results = output_runner.run_full_evaluation(
    verbose=True,
    save_csv=True,           # Save results as CSV files
    save_plots=True,         # Save plots as PNG files  
    output_dir="demo_results" # Output directory
)

print(f"\n✅ Files saved to 'demo_results' directory!")
print(f"📄 Generated files:")
print(f"  • evaluation_summary_[timestamp].csv - Key performance metrics")
print(f"  • detailed_results_[timestamp].csv - All simulation results")  
print(f"  • evaluation_results_[timestamp].png - Main results visualization")
print(f"  • summary_metrics_[timestamp].png - Performance metrics charts")
print(f"  • method_comparison_[timestamp].png - Method comparison heatmap")

# Also demonstrate single experiment output
print(f"\n🔬 Running single experiment with plot saving...")
single_results = output_runner.run_single_experiment(
    show_plots=False,        # Don't display plots
    save_plots=True,         # Save plots to files
    output_dir="demo_results/single_experiment"
)

print(f"\n💡 Use Cases:")
print(f"• save_csv=True: Generate reports, share results, further analysis")
print(f"• save_plots=True: Include in presentations, documentation, papers")
print(f"• Timestamps prevent file overwrites")
print(f"• Organized output directories keep results structured")

In [None]:
# Quick evaluation with small number of simulations
quick_config = ExperimentConfig(
    n_geos=25,
    n_days=50,
    pre_period_days=35,
    eval_period_days=15,
    n_simulations=20,  # Small for quick testing
    n_bootstrap=100,
    seed=42
)

runner = ExperimentRunner(quick_config)
detailed_results, summary_results = runner.run_full_evaluation(verbose=True)

# Plot results
fig = runner.plot_results(detailed_results)
plt.show()

## Custom Experiments

This section is for your own experiments and testing new ideas.

In [None]:
# 🧪 Experiment with different parameter combinations

# Example: How does the number of geos affect CI width?
geo_counts = [10, 25, 50, 100]
ci_widths = []

for n_geos in geo_counts:
    config = ExperimentConfig(
        n_geos=n_geos,
        n_days=60,
        pre_period_days=40,
        eval_period_days=20,
        n_simulations=10,  # Small for speed
        seed=42
    )
    
    runner = ExperimentRunner(config)
    detailed_results, _ = runner.run_full_evaluation(verbose=False)
    
    avg_ci_width = detailed_results['ci_width'].mean()
    ci_widths.append(avg_ci_width)
    
    print(f"n_geos={n_geos}: avg CI width = {avg_ci_width:.3f}")

# Plot relationship
plt.figure(figsize=(8, 5))
plt.plot(geo_counts, ci_widths, 'o-', linewidth=2, markersize=8)
plt.xlabel('Number of Geos')
plt.ylabel('Average CI Width')
plt.title('CI Width vs Number of Geos')
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# 🎯 Your experiments here!

# Ideas to try:
# 1. Effect of noise level on false positive rates
# 2. Optimal pre-period vs evaluation period lengths
# 3. Impact of treatment ratio on power
# 4. Bootstrap sample size vs CI stability

# Example template:
# config = ExperimentConfig(
#     n_geos=...,
#     n_days=...,
#     # ... other parameters
# )
# runner = ExperimentRunner(config)
# results = runner.run_single_experiment(show_plots=True)

print("🚀 Ready for your experiments!")

## Development Notes

Use this section for notes, debugging, and development work.

In [None]:
# Development and debugging space

# Quick data validation
config = DataConfig(n_geos=5, n_days=10, seed=999)
gen = SimpleNullGenerator(config)
panel, features = gen.generate()

print("Panel data sample:")
print(panel.head())
print(f"\nPanel shape: {panel.shape}")
print(f"Features shape: {features.shape}")
print(f"Validation: {gen.validate_data(panel, features)}")