# KAN Performance Investigation Workbook

**Objective**: This workbook provides a structured environment to investigate the performance of the KAN-based D-IV-LATE estimator. It addresses two key questions:

1.  **Hyperparameter Sensitivity**: Can the KAN estimator's performance (as measured by RMSE) be improved by tuning its hyperparameters?
2.  **Uncertainty Quantification**: Can we obtain reliable confidence intervals and coverage rates for the KAN estimator using a computationally feasible bootstrap procedure?

The results from this investigation will inform the final narrative of the research paper.

## 1. Setup and Imports

First, we import the necessary libraries and the custom functions from our project's source code. This includes utilities for training KANs, generating data, and estimating the D-IV-LATE.

In [None]:
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import os
import sys
from tqdm.notebook import tqdm

# --- Colab Environment Setup ---
print("Setting up Colab environment...")

# 1. Install the efficient-kan dependency
print("Installing efficient-kan library...")
!pip install git+https://github.com/shawcharles/efficient-kan.git

# 2. Clone the main project repository
repo_name = 'kan-d-iv-late' # The name of the repo directory
if not os.path.exists(repo_name):
    print(f"\nCloning repository '{repo_name}'...")
    !git clone https://github.com/shawcharles/kan-d-iv-late.git
else:
    print(f"\nRepository '{repo_name}' already exists.")

# 3. Add the project's code directory to the Python path
# The path structure is kan-d-iv-late/kan-d-iv-late/code
code_path = os.path.abspath(os.path.join(repo_name, 'kan-d-iv-late', 'code'))
if code_path not in sys.path:
    sys.path.append(code_path)
    print(f"Added '{code_path}' to sys.path")
else:
    print(f"'{code_path}' is already in sys.path")

# --- End Colab Setup ---

# 4. Now, import our standardized utilities and simulation functions
print("\nImports starting...")
import kan_utils
from kan_d_iv_late_simulation_enhanced import (
    generate_dlate_data,
    estimate_nuisance_functions_enhanced,
    dlate_estimator,
    run_enhanced_simulation, # We will adapt its logic
    create_enhanced_plots
)

# Set plot style
sns.set_theme(style="whitegrid")

print(f"\nUsing device: {kan_utils.DEVICE}")
print("Setup complete.")

## 2. Part 1: Hyperparameter Tuning Analysis

In this section, we conduct a grid search over a predefined set of KAN hyperparameters. For each combination, we run a fast Monte Carlo simulation (few replications, no bootstrapping) to estimate the Root Mean Squared Error (RMSE). The goal is to identify the set of hyperparameters that minimizes the RMSE.

The results will be saved to `hyperparameter_tuning_results.csv`.

In [None]:
# --- Configuration for Hyperparameter Search ---
N_REPLICATIONS_FAST = 20  # Number of MC replications for each setting (increased for more stability)
N_SAMPLES = 1000
Y_GRID_SIZE = 30

# Define the hyperparameter grid to search
HYPERPARAM_GRID = {
    'KAN_STEPS': [50, 150, 250],
    'KAN_HIDDEN_DIM': [16, 32, 48],
    'KAN_REG_STRENGTH': [1e-3, 1e-4, 1e-5]
}

# Directory to save results
output_dir = '../results'
os.makedirs(output_dir, exist_ok=True)

all_hyperparam_results = []

print("Starting Hyperparameter Sensitivity Analysis...")

y_grid = np.linspace(-8, 15, Y_GRID_SIZE)

# --- Main Grid Search Loop ---
for steps in tqdm(HYPERPARAM_GRID['KAN_STEPS'], desc="KAN Steps"):
    for hidden_dim in tqdm(HYPERPARAM_GRID['KAN_HIDDEN_DIM'], desc="Hidden Dims", leave=False):
        for reg_strength in tqdm(HYPERPARAM_GRID['KAN_REG_STRENGTH'], desc="Reg Strength", leave=False):
            
            # Temporarily override KAN settings in the imported module
            kan_utils.KAN_STEPS = steps
            kan_utils.KAN_HIDDEN_DIM = hidden_dim
            kan_utils.KAN_REG_STRENGTH = reg_strength

            kan_estimates = []
            true_values = []

            # Run fast simulation for this hyperparameter set
            for rep in range(N_REPLICATIONS_FAST):
                data, true_dlate_func = generate_dlate_data(n_samples=N_SAMPLES, seed=rep)
                true_dlate = np.array([true_dlate_func(y) for y in y_grid])
                true_values.append(true_dlate)
                
                try:
                    nuisance_df = estimate_nuisance_functions_enhanced(data, y_grid, model_type='kan')
                    dlate_est = dlate_estimator(data, nuisance_df, y_grid)
                    kan_estimates.append(dlate_est)
                except Exception as e:
                    print(f"Error during estimation: {e}")
                    kan_estimates.append(np.full(len(y_grid), np.nan))

            # Calculate performance metric (mean RMSE across the y_grid)
            kan_estimates = np.array(kan_estimates)
            true_values = np.array(true_values)
            mean_rmse = np.sqrt(np.nanmean((kan_estimates - true_values)**2))
            
            all_hyperparam_results.append({
                'steps': steps, 
                'hidden_dim': hidden_dim, 
                'reg_strength': reg_strength, 
                'mean_rmse': mean_rmse
            })

# --- Process and Save Hyperparameter Results ---
hyperparam_results_df = pd.DataFrame(all_hyperparam_results)
hyperparam_results_path = os.path.join(output_dir, 'hyperparameter_tuning_results.csv')
hyperparam_results_df.to_csv(hyperparam_results_path, index=False)

print("\nHyperparameter Sensitivity Analysis Complete.")
print(f"Results saved to {hyperparam_results_path}")

# Find and display the best parameters
best_params = hyperparam_results_df.loc[hyperparam_results_df['mean_rmse'].idxmin()]
print("\nBest performing hyperparameters based on Mean RMSE:")
print(best_params)

### Visualize Tuning Results

A heatmap can help visualize the relationship between hyperparameters and performance.

In [None]:
if not hyperparam_results_df.empty:
    pivot_table = hyperparam_results_df.pivot_table(
        index=['steps', 'reg_strength'], 
        columns='hidden_dim', 
        values='mean_rmse'
    )

    plt.figure(figsize=(12, 8))
    sns.heatmap(pivot_table, annot=True, fmt=".4f", cmap="viridis_r", linewidths=.5)
    plt.title('Hyperparameter Grid Search Results (Mean RMSE)')
    plt.ylabel('Training Steps & Regularization Strength')
    plt.xlabel('Hidden Dimension')
    plt.show()
else:
    print("No results to plot.")

## 3. Part 2: Limited Bootstrap Analysis

Using the best hyperparameters identified in Part 1, we now run a more intensive simulation. This run uses a moderate number of Monte Carlo replications and enables bootstrapping to calculate 95% confidence intervals for the D-IV-LATE estimates. 

The primary goal is to assess the **empirical coverage** of the bootstrap CIs. Good coverage (close to 95%) would indicate that the bootstrap provides a reliable method for uncertainty quantification for the KAN estimator.

The detailed results of this analysis will be saved to `bootstrap_analysis_results.csv`.

In [None]:
# --- Configuration for Bootstrap Analysis ---
N_REPLICATIONS_BOOTSTRAP = 30  # A reasonable number for a limited analysis
N_BOOTSTRAP_SAMPLES = 199      # Odd number to avoid ties, standard for bootstrap CIs

# --- Set Best Hyperparameters from Part 1 ---
if 'best_params' in locals():
    print(f"Using best parameters from tuning: \n{best_params}")
    kan_utils.KAN_STEPS = int(best_params['steps'])
    kan_utils.KAN_HIDDEN_DIM = int(best_params['hidden_dim'])
    kan_utils.KAN_REG_STRENGTH = float(best_params['reg_strength'])
else:
    print("Warning: Best parameters not found. Using default KAN settings.")

print("\nStarting Limited Bootstrap Analysis...")

# --- Run Simulation with Bootstrap Enabled ---
# This logic is adapted from the run_enhanced_simulation script

y_grid = np.linspace(-8, 15, Y_GRID_SIZE)
results_storage = {'kan': [], 'rf': []}
true_values_list = []

for rep in tqdm(range(N_REPLICATIONS_BOOTSTRAP), desc="Bootstrap MC Replications"):
    data, true_dlate_func = generate_dlate_data(n_samples=N_SAMPLES, seed=rep)
    true_dlate = np.array([true_dlate_func(y) for y in y_grid])
    true_values_list.append(true_dlate)
    
    for model_type in ['kan', 'rf']:
        try:
            # The bootstrap_dlate_ci function handles the full estimation loop
            dlate_results = kan_utils.bootstrap_dlate_ci(
                data, y_grid,
                nuisance_estimator=lambda d, y: estimate_nuisance_functions_enhanced(d, y, model_type=model_type),
                dlate_estimator=dlate_estimator,
                n_bootstrap=N_BOOTSTRAP_SAMPLES
            )
            results_storage[model_type].append(dlate_results)
        except Exception as e:
            print(f"{model_type.upper()} estimation failed in replication {rep}: {e}")
            # Append NaNs if estimation fails
            nan_results = {
                'point_estimates': np.full(len(y_grid), np.nan),
                'ci_lower': np.full(len(y_grid), np.nan),
                'ci_upper': np.full(len(y_grid), np.nan)
            }
            results_storage[model_type].append(nan_results)

print("Bootstrap analysis complete.")

In [None]:
# --- Process and Save Bootstrap Results ---
true_values = np.array(true_values_list)

results_df = pd.DataFrame({'y_value': y_grid, 'true_dlate': np.mean(true_values, axis=0)})

for model_type, storage in results_storage.items():
    estimates = np.array([r['point_estimates'] for r in storage])
    ci_lower = np.array([r['ci_lower'] for r in storage])
    ci_upper = np.array([r['ci_upper'] for r in storage])
    
    # Performance metrics
    results_df[f'{model_type}_estimate'] = np.nanmean(estimates, axis=0)
    results_df[f'{model_type}_bias'] = np.nanmean(estimates - true_values, axis=0)
    results_df[f'{model_type}_rmse'] = np.sqrt(np.nanmean((estimates - true_values)**2, axis=0))
    
    # Confidence interval coverage
    coverage = np.mean((true_values >= ci_lower) & (true_values <= ci_upper), axis=0)
    results_df[f'{model_type}_coverage'] = coverage
    results_df[f'{model_type}_ci_width'] = np.nanmean(ci_upper - ci_lower, axis=0)

bootstrap_results_path = os.path.join(output_dir, 'bootstrap_analysis_results.csv')
results_df.to_csv(bootstrap_results_path, index=False)

print(f"Bootstrap analysis results saved to {bootstrap_results_path}")
print("\n--- Results Summary ---")
print(results_df[['y_value', 'kan_rmse', 'rf_rmse', 'kan_coverage', 'rf_coverage']].head())

### Visualize Bootstrap Results

We can now plot the key metrics from the bootstrap analysis: RMSE and, most importantly, empirical coverage.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
fig.suptitle('Bootstrap Analysis Results (Best KAN vs. RF)', fontsize=16)

# Plot 1: RMSE Comparison
ax = axes[0]
ax.plot(results_df['y_value'], results_df['kan_rmse'], 'r-', linewidth=2, label='KAN RMSE')
ax.plot(results_df['y_value'], results_df['rf_rmse'], 'b-', linewidth=2, label='RF RMSE')
ax.set_xlabel('y')
ax.set_ylabel('RMSE')
ax.set_title('Root Mean Squared Error (RMSE)')
ax.legend()
ax.grid(True, alpha=0.4)

# Plot 2: Coverage Comparison
ax = axes[1]
ax.plot(results_df['y_value'], results_df['kan_coverage'], 'r-', linewidth=2, label='KAN Coverage')
ax.plot(results_df['y_value'], results_df['rf_coverage'], 'b-', linewidth=2, label='RF Coverage')
ax.axhline(y=0.95, color='k', linestyle='--', alpha=0.7, label='Nominal 95% Level')
ax.set_xlabel('y')
ax.set_ylabel('Empirical Coverage')
ax.set_title('95% Bootstrap Confidence Interval Coverage')
ax.set_ylim(0, 1.05)
ax.legend()
ax.grid(True, alpha=0.4)

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

# Print average coverage
avg_kan_coverage = results_df['kan_coverage'].mean()
avg_rf_coverage = results_df['rf_coverage'].mean()
print(f"Average KAN Coverage: {avg_kan_coverage:.3f}")
print(f"Average RF Coverage:  {avg_rf_coverage:.3f}")

## 4. Conclusion and Next Steps

Based on the results from the hyperparameter tuning and the bootstrap analysis, we can now make an informed decision on how to proceed with the paper.

**Scenario A: KAN is Salvageable**
- If the tuned KAN estimator shows RMSE that is competitive with or better than the Random Forest.
- AND if the bootstrap confidence intervals for the KAN estimator show good empirical coverage (close to 0.95).
- **Recommendation**: Proceed with one final, large-scale computation using the optimal KAN hyperparameters and bootstrap CIs. The paper's narrative will be about the power of KANs when carefully tuned and paired with robust inference methods.

**Scenario B: KAN is Not Superior in this Context**
- If the KAN estimator's RMSE remains significantly worse than the Random Forest, even after tuning.
- OR if the bootstrap CIs for KAN have poor coverage, indicating unreliable uncertainty estimates.
- **Recommendation**: Pivot the paper's narrative to a "caveat emptor" story. The key finding becomes a cautionary tale about the challenges of applying new, complex models. This requires no further large-scale computation.