# SpatialNeuralAdapter Demo

This notebook demonstrates the optimized SpatialNeuralAdapter with comprehensive training and evaluation.

In [None]:
# Import required libraries
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, Dict, Any
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter
import time

from geospatial_neural_adapter import (
    SpatialNeuralAdapter,
    SpatialBasisLearner,
    TrendModel,
    compute_metrics,
)

from geospatial_neural_adapter.data.generators import generate_time_synthetic_data
from geospatial_neural_adapter.data.preprocessing import prepare_all_with_scaling, denormalize_predictions

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
plt.style.use('default')
sns.set_palette("husl")

print("✅ All imports successful!")

## 1. Data Generation with Meaningful Correlations

We'll use the improved data generator that creates features with strong correlations to targets.

In [None]:
# Generate synthetic temporal data with meaningful correlations
print("Generating correlated temporal synthetic data...")

n_locations = 50
n_time_steps = 200
locations = np.linspace(-5, 5, n_locations)
noise_std = 0.1
eigenvalue = 2.0

cat_features, cont_features, targets = generate_time_synthetic_data(
    locs=locations,
    n_time_steps=n_time_steps,
    noise_std=noise_std,
    eigenvalue=eigenvalue,
    eta_rho=0.8,
    f_rho=0.6,
    global_mean=50.0,
    feature_noise_std=0.1,
    non_linear_strength=0.2,
    seed=42
)

print(f"Data shapes: {cont_features.shape}, {targets.shape}")
print(f"Original targets - Mean: {targets.mean():.2f}, Std: {targets.std():.2f}")
print(f"Original targets - Range: {targets.min():.2f} to {targets.max():.2f}")

In [None]:
# Analyze feature-target correlations
print("Feature-Target Correlations:")
for i in range(cont_features.shape[-1]):
    corr = np.corrcoef(targets.flatten(), cont_features[:, :, i].flatten())[0, 1]
    print(f"  Feature {i}: {corr:.4f}")

# Visualize data characteristics
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Plot 1: Target distribution
axes[0, 0].hist(targets.flatten(), bins=30, alpha=0.7, edgecolor='black')
axes[0, 0].set_title('Target Distribution')
axes[0, 0].set_xlabel('Target Value')
axes[0, 0].set_ylabel('Frequency')
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: Spatial pattern at first time step
axes[0, 1].plot(locations, targets[0, :], 'o-', linewidth=2, markersize=4)
axes[0, 1].set_title('Spatial Pattern at t=0')
axes[0, 1].set_xlabel('Location')
axes[0, 1].set_ylabel('Target Value')
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Temporal pattern at middle location
time_steps = np.arange(len(targets))
axes[1, 0].plot(time_steps, targets[:, 25], linewidth=2)
axes[1, 0].set_title('Temporal Pattern at Location 25')
axes[1, 0].set_xlabel('Time Step')
axes[1, 0].set_ylabel('Target Value')
axes[1, 0].grid(True, alpha=0.3)

# Plot 4: Feature correlations
feature_corrs = []
for i in range(cont_features.shape[-1]):
    corr = np.corrcoef(targets.flatten(), cont_features[:, :, i].flatten())[0, 1]
    feature_corrs.append(corr)

axes[1, 1].bar(range(len(feature_corrs)), feature_corrs, alpha=0.7, edgecolor='black')
axes[1, 1].set_title('Feature-Target Correlations')
axes[1, 1].set_xlabel('Feature Index')
axes[1, 1].set_ylabel('Correlation')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 2. Enhanced Preprocessing with Automatic Scaling

We'll use the improved preprocessing pipeline that handles normalization and denormalization automatically.

In [None]:
# Prepare datasets with automatic scaling (enhanced preprocessing)
print("Preparing datasets with automatic scaling...")

# Use enhanced preprocessing with automatic scaling
train_dataset, val_dataset, test_dataset, preprocessor = prepare_all_with_scaling(
    cat_features=cat_features,
    cont_features=cont_features,
    targets=targets,
    train_ratio=0.7,
    val_ratio=0.15,
    feature_scaler_type="standard",
    target_scaler_type="standard",
    fit_on_train_only=True
)

train_cat, train_cont, train_targets = train_dataset.tensors
val_cat, val_cont, val_targets = val_dataset.tensors
test_cat, test_cont, test_targets = test_dataset.tensors

print(f"Dataset sizes: {len(train_dataset)}, {len(val_dataset)}, {len(test_dataset)}")

# Print scaler information
scaler_info = preprocessor.get_scaler_info()
print(f"Target scaler - Mean: {scaler_info['target_mean'][0]:.2f}, Std: {scaler_info['target_scale'][0]:.2f}")

# Create data loader for training
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Get data dimensions
T, N, F = cont_features.shape
print(f"Original data shape: {cont_features.shape}")
print(f"Training time steps: {len(train_dataset)}")
print(f"Validation time steps: {len(val_dataset)}")
print(f"Test time steps: {len(test_dataset)}")
print(f"Number of locations: {N}")
print(f"Number of features: {F}")

In [None]:
# Visualize the effect of standardization
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Original vs standardized distributions
axes[0].hist(targets.flatten(), bins=30, alpha=0.7, label='Original', density=True)
axes[0].hist(train_targets.numpy().flatten(), bins=30, alpha=0.7, label='Standardized', density=True)
axes[0].set_title('Target Distribution Comparison')
axes[0].set_xlabel('Value')
axes[0].set_ylabel('Density')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Original vs standardized spatial patterns
val_targets_orig = denormalize_predictions(val_targets.numpy(), preprocessor)
axes[1].plot(locations, val_targets_orig[0], 'o-', label='Original', alpha=0.7, linewidth=2)
axes[1].plot(locations, val_targets[0].numpy(), 's-', label='Standardized', alpha=0.7, linewidth=2)
axes[1].set_title('Spatial Pattern Comparison')
axes[1].set_xlabel('Location')
axes[1].set_ylabel('Value')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 3. Model Creation

In [None]:
# Create models
print("Creating models...")

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Trend model
trend = TrendModel(
    num_continuous_features=F,
    hidden_layer_sizes=[256, 128, 64],
    n_locations=N,
    init_weight=None,
    init_bias=None,
    freeze_init=False,
    dropout_rate=0.1,
)

# Spatial basis learner
basis = SpatialBasisLearner(
    num_locations=N,
    latent_dim=10,
    pca_init=None,
)

print(f"Trend model parameters: {sum(p.numel() for p in trend.parameters()):,}")
print(f"Basis model parameters: {sum(p.numel() for p in basis.parameters()):,}")

# Model summary
print("\nTrend Model Architecture:")
print(trend)

## 4. Training Configuration

In [None]:
# Training configuration
config = {
    "rho": 5.0,
    "dual_momentum": 0.2,
    "max_iters": 100,  # Moderate training
    "min_outer": 50,
    "lr_mu": 1e-3,
    "batch_size": 128,
    "phi_every": 5,
    "phi_freeze": 50,
    "tol": 1e-4,
    "adaptive_rho_mu": 10.0,
    "adaptive_rho_tau_inc": 2.0,
    "adaptive_rho_tau_dec": 2.0,
    "matrix_reg": 1e-6,
    "irl1_max_iters": 10,
    "irl1_eps": 1e-6,
    "irl1_tol": 5e-4,
    "coord_threshold": 1e-12,
    "avoid_zero_eps": 1e-12,
    "pretrain_epochs": 5,
    "use_mixed_precision": True,
}

print("Training Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")

## 5. SpatialNeuralAdapter Training

In [None]:
# Create trainer
print("Creating SpatialNeuralAdapter...")

# Create tensorboard writer
writer = SummaryWriter("./logs/spatial_neural_adapter_demo")

trainer = SpatialNeuralAdapter(
    trend=trend,
    basis=basis,
    train_loader=train_loader,
    val_cont=val_cont,
    val_y=val_targets,
    locs=locations,
    config=config,
    device=device,
    writer=writer,
    tau1=0.1,
    tau2=0.1,
)

# Print configuration
trainer.print_config()

In [None]:
# Pretrain trend model
print("Pretraining trend model...")
trainer.pretrain_trend(epochs=config.get("pretrain_epochs", 5))

# Initialize basis
print("Initializing spatial basis...")
trainer.init_basis_dense()

# Run ADMM training
print("Starting ADMM training...")
start_time = time.time()

best_val = trainer.run()

training_time = time.time() - start_time
print(f"Training completed in {training_time:.2f}s")
print(f"Best validation RMSE: {best_val:.6f}")

# Close tensorboard writer
writer.close()

## 6. Model Evaluation

In [None]:
# Evaluate model
print("Evaluating model...")

trainer.trend.eval()
trainer.basis.eval()

with torch.no_grad():
    # Get predictions (on standardized scale)
    y_pred_std = trainer.predict(val_cont.to(device), val_targets.to(device))
    
    # Compute metrics on standardized scale
    rmse_std, mae_std, r2_std = compute_metrics(val_targets.to(device), y_pred_std)
    
    # Compute additional metrics on standardized scale
    mse_std = torch.nn.functional.mse_loss(val_targets.to(device), y_pred_std).item()
    
    # Denormalize predictions for original scale evaluation
    y_pred_denorm = denormalize_predictions(y_pred_std.cpu().numpy(), preprocessor)
    val_targets_denorm = denormalize_predictions(val_targets.cpu().numpy(), preprocessor)
    
    # Compute metrics on original scale
    rmse_denorm = np.sqrt(np.mean((val_targets_denorm - y_pred_denorm) ** 2))
    mae_denorm = np.mean(np.abs(val_targets_denorm - y_pred_denorm))
    
    # R-squared on original scale
    ss_res_denorm = np.sum((val_targets_denorm - y_pred_denorm) ** 2)
    ss_tot_denorm = np.sum((val_targets_denorm - val_targets_denorm.mean()) ** 2)
    r2_denorm = 1 - (ss_res_denorm / ss_tot_denorm)
    
    mse_denorm = np.mean((val_targets_denorm - y_pred_denorm) ** 2)
    
    metrics = {
        "rmse_std": rmse_std,
        "mae_std": mae_std,
        "r2_std": r2_std,
        "mse_std": mse_std,
        "rmse_denorm": rmse_denorm,
        "mae_denorm": mae_denorm,
        "r2_denorm": r2_denorm,
        "mse_denorm": mse_denorm,
    }

print(f"Standardized metrics: RMSE={rmse_std:.6f}, MAE={mae_std:.6f}, R²={r2_std:.6f}")
print(f"Denormalized metrics: RMSE={rmse_denorm:.6f}, MAE={mae_denorm:.6f}, R²={r2_denorm:.6f}")

## 7. Results Visualization

In [None]:
# Create comprehensive visualizations
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Use denormalized data for visualization
val_y_np = val_targets_denorm
y_pred_np = y_pred_denorm

# Plot 1: Predictions vs Actual scatter plot (original scale)
axes[0, 0].scatter(val_y_np.flatten(), y_pred_np.flatten(), alpha=0.5, s=20)
axes[0, 0].plot([val_y_np.min(), val_y_np.max()], 
                [val_y_np.min(), val_y_np.max()], 'r--', linewidth=2, label='Perfect Prediction')
axes[0, 0].set_title('Predictions vs Actual Values')
axes[0, 0].set_xlabel('Actual Values')
axes[0, 0].set_ylabel('Predicted Values')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: Spatial pattern comparison
sample_idx = 0
axes[0, 1].plot(locations, val_y_np[sample_idx], 'o-', label='Actual', alpha=0.7, linewidth=2, markersize=4)
axes[0, 1].plot(locations, y_pred_np[sample_idx], 's-', label='Predicted', alpha=0.7, linewidth=2, markersize=4)
axes[0, 1].set_title('Spatial Pattern Comparison')
axes[0, 1].set_xlabel('Location')
axes[0, 1].set_ylabel('Target Value')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Residuals analysis
residuals = val_y_np.flatten() - y_pred_np.flatten()
axes[1, 0].scatter(y_pred_np.flatten(), residuals, alpha=0.5, s=20)
axes[1, 0].axhline(y=0, color='r', linestyle='--', alpha=0.7)
axes[1, 0].set_title('Residuals vs Predicted Values')
axes[1, 0].set_xlabel('Predicted Values')
axes[1, 0].set_ylabel('Residuals')
axes[1, 0].grid(True, alpha=0.3)

# Plot 4: Residuals distribution
axes[1, 1].hist(residuals, bins=30, alpha=0.7, edgecolor='black')
axes[1, 1].set_title('Residuals Distribution')
axes[1, 1].set_xlabel('Residual Value')
axes[1, 1].set_ylabel('Frequency')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 8. Spatial Basis Analysis

In [None]:
# Analyze the learned spatial basis
print("Analyzing learned spatial basis...")

# Get the basis matrix
basis_matrix = trainer.basis.basis.detach().cpu().numpy()
print(f"Basis matrix shape: {basis_matrix.shape}")
print(f"Basis norm: {np.linalg.norm(basis_matrix):.4f}")

# Visualize the basis
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i in range(min(6, basis_matrix.shape[1])):
    # Plot basis vector as spatial pattern
    axes[i].plot(locations, basis_matrix[:, i], 'o-', linewidth=2, markersize=4)
    axes[i].set_title(f'Spatial Basis {i+1}')
    axes[i].set_xlabel('Location')
    axes[i].set_ylabel('Basis Value')
    axes[i].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Basis statistics
print("\nBasis Statistics:")
print(f"Mean: {basis_matrix.mean():.4f}")
print(f"Std: {basis_matrix.std():.4f}")
print(f"Min: {basis_matrix.min():.4f}")
print(f"Max: {basis_matrix.max():.4f}")

## 9. Performance Summary

In [None]:
# Print final summary
print("=" * 50)
print("SPATIAL NEURAL ADAPTER DEMO SUMMARY")
print("=" * 50)
print(f"Device: {device}")
print(f"Training time: {training_time:.2f}s")
print(f"Final RMSE (std): {metrics['rmse_std']:.6f}")
print(f"Final MAE (std): {metrics['mae_std']:.6f}")
print(f"Final R² (std): {metrics['r2_std']:.6f}")
print(f"Final MSE (std): {metrics['mse_std']:.6f}")
print(f"Final RMSE (denorm): {metrics['rmse_denorm']:.6f}")
print(f"Final MAE (denorm): {metrics['mae_denorm']:.6f}")
print(f"Final R² (denorm): {metrics['r2_denorm']:.6f}")
print(f"Final MSE (denorm): {metrics['mse_denorm']:.6f}")
print(f"Best validation RMSE: {best_val:.6f}")
print(f"Tensorboard logs: ./logs/spatial_neural_adapter_demo")
print("=" * 50)

# Performance comparison
print("\nPerformance Analysis:")
print(f"Scale recovery: {y_pred_np.max() - y_pred_np.min():.2f} / {val_y_np.max() - val_y_np.min():.2f} = {(y_pred_np.max() - y_pred_np.min()) / (val_y_np.max() - val_y_np.min()):.2%}")
print(f"Std recovery: {y_pred_np.std():.2f} / {val_y_np.std():.2f} = {y_pred_np.std() / val_y_np.std():.2%}")

print("\n✅ SpatialNeuralAdapter demo completed successfully!")
print("💡 Run 'tensorboard --logdir ./logs/spatial_neural_adapter_demo' to view training progress")

## 10. Additional Analysis: Feature Importance

In [None]:
# Feature importance analysis
print("Feature Importance Analysis:")

# Calculate feature-target correlations
feature_correlations = []
for i in range(cont_features.shape[-1]):
    corr = np.corrcoef(targets.flatten(), cont_features[:, :, i].flatten())[0, 1]
    feature_correlations.append(corr)

# Plot feature importance
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Feature correlations
axes[0].bar(range(len(feature_correlations)), feature_correlations, alpha=0.7, edgecolor='black')
axes[0].set_title('Feature-Target Correlations')
axes[0].set_xlabel('Feature Index')
axes[0].set_ylabel('Correlation')
axes[0].grid(True, alpha=0.3)

# Feature importance ranking
feature_importance = np.abs(feature_correlations)
sorted_indices = np.argsort(feature_importance)[::-1]

axes[1].bar(range(len(sorted_indices)), feature_importance[sorted_indices], alpha=0.7, edgecolor='black')
axes[1].set_title('Feature Importance Ranking')
axes[1].set_xlabel('Feature Rank')
axes[1].set_ylabel('Absolute Correlation')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nTop 5 Most Important Features:")
for i, idx in enumerate(sorted_indices[:5]):
    print(f"  {i+1}. Feature {idx}: {feature_correlations[idx]:.4f}")