# Appa: Autoencoder Reconstruction Error Analysis

This notebook analyzes reconstruction errors from a trained Appa autoencoder by processing ERA5 data through the model and computing various error metrics.

Based on the reconstruction.py script, this notebook provides an interactive way to:
- Load ERA5 data and process it through the autoencoder
- Compute reconstruction errors (MSE, signed errors)
- Generate error histograms and statistics
- Visualize error patterns across variables and spatial locations
- Analyze error characteristics for different atmospheric conditions

## Setup and Imports


In [None]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from pathlib import Path
from omegaconf import OmegaConf
from einops import rearrange
import warnings
warnings.filterwarnings('ignore')

# Add the appa module to the path
sys.path.append('/home/azureuser/cloudfiles/code/Users/randy.chase/appa_tio/')

import appa
from appa.save import load_auto_encoder
from appa.data.datasets import ERA5Dataset
from appa.data.dataloaders import get_dataloader
from appa.data.transforms import StandardizeTransform
from appa.data.const import (
    CONTEXT_VARIABLES,
    ERA5_ATMOSPHERIC_VARIABLES,
    ERA5_PRESSURE_LEVELS,
    ERA5_RESOLUTION,
    ERA5_SURFACE_VARIABLES,
    ERA5_VARIABLES,
    SUB_PRESSURE_LEVELS,
)
from appa.config import PATH_ERA5, PATH_MASK, PATH_STAT
from appa.config.hydra import compose

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set up plotting
plt.style.use('default')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10


## Data Setup

**IMPORTANT**: Before running this analysis, you need to download the ERA5 data. The Appa project uses WeatherBench2 data format.

### Option 1: Download ERA5 Data Subset (Space-Efficient)

For reconstruction analysis, you only need a small subset of the data. Here are space-efficient options:

#### **Option 1A: Small Date Range (Recommended for Testing)**
```bash
# Download just 1 month of data (much smaller!)
python scripts/data/download_era5.py \
    output_path=/home/azureuser/cloudfiles/code/Users/randy.chase/appa_data/data/era5_2019-01-01_2019-01-31.zarr \
    start_date=2019-01-01 \
    end_date=2019-01-31

# Compute statistics for this subset
python scripts/data/data_stats.py \
    data_path=/path/to/your/data/era5_2019-01-01_2019-01-31.zarr \
    output_path=/path/to/your/data/stats_era5_2019-01-01_2019-01-31.zarr \
    start_date=2019-01-01 \
    end_date=2019-01-31
```

#### **Option 1B: Even Smaller - Just Surface Variables**
```bash
# Download only surface variables (6 variables instead of 71)
python scripts/data/download_era5.py \
    output_path=/path/to/your/data/era5_surface_only.zarr \
    start_date=2019-01-01 \
    end_date=2019-01-31 \
    variables="2m_temperature,10m_u_component_of_wind,10m_v_component_of_wind,mean_sea_level_pressure,total_precipitation,sea_surface_temperature"
```

#### **Option 1C: Use Lower Resolution (Much Smaller)**
```bash
# Download at lower resolution (360x181 instead of 1440x721)
python scripts/data/download_era5.py \
    output_path=/path/to/your/data/era5_low_res.zarr \
    start_date=2019-01-01 \
    end_date=2019-01-31 \
    resolution=360x181
```

#### **Option 1D: Full Dataset (If You Have Space)**
```bash
# Download full dataset (several GB)
python scripts/data/download_era5.py output_path=/path/to/your/data/era5_1993-2021-1h-1440x721.zarr

# Compute statistics
python scripts/data/data_stats.py data_path=/path/to/your/data/era5_1993-2021-1h-1440x721.zarr output_path=/path/to/your/data/stats_era5_1993-2021-1h-1440x721.zarr
```

### Option 2: Use Existing Data

If you already have ERA5 data in the correct format, just update the paths in the configuration below.

## Configuration

Configure the reconstruction analysis parameters. Update the paths to match your model and data directories.


### 📊 **File Size Estimates**

Here's how much space each option will use:

| Option | Date Range | Variables | Resolution | Estimated Size | Autoencoder Compatible |
|--------|------------|-----------|------------|----------------|----------------------|
| **Quick Start** | 1 week | All 71 | 1440×721 | ~100-200 MB | ✅ Yes |
| **1A: Small Date Range** | 1 month | All 71 | 1440×721 | ~500 MB | ✅ Yes |
| **1B: Surface Only** | 1 month | 6 surface | 1440×721 | ~50 MB | ❌ No - Missing atmospheric |
| **1C: Low Resolution** | 1 month | All 71 | 360×181 | ~30 MB | ✅ Yes (but lower quality) |
| **1D: Full Dataset** | 1993-2021 | All 71 | 1440×721 | ~50 GB | ✅ Yes |

**Recommendation**: Start with **Quick Start** (1 week, all variables) for testing, then scale up if needed.


### 🚀 **Quick Start: Download Small Subset (FIXED)**

The previous approach was downloading too much data. Here's the corrected approach:

```bash
# Navigate to the project directory first
cd /Users/randychase/Documents/PythonWorkspace/cbottle/appa_tio

# Create a data directory
mkdir -p ~/appa_data

# Use the 13-level dataset instead of 37-level (much smaller!)
python scripts/data/download_era5.py \
    scripts/data/configs/download_era5.yaml \
    output_path=~/appa_data/era5_1week_13level.zarr \
    start_date=2019-01-01 \
    end_date=2019-01-07 \
    use_coarsened_levels=true

# Compute statistics for this subset
python scripts/data/data_stats.py \
    scripts/data/configs/data_stats.yaml \
    data_path=~/appa_data/era5_1week_13level.zarr \
    output_path=~/appa_data/stats_1week_13level.zarr \
    start_date=2019-01-01 \
    end_date=2019-01-07
```

**This should be ~50-100 MB** - much more reasonable!

### 🔧 **Alternative: Use Lower Resolution (Even Smaller)**

If you want the absolute smallest download:

```bash
# Use lower resolution (360x181 instead of 1440x721)
python scripts/data/download_era5.py \
    scripts/data/configs/download_era5.yaml \
    output_path=~/appa_data/era5_1week_lowres.zarr \
    start_date=2019-01-01 \
    end_date=2019-01-07 \
    resolution=360x181 \
    use_coarsened_levels=true
```

**This should be ~10-20 MB** - perfect for testing!


In [None]:
# Configuration for reconstruction analysis
# NOTE: You need to download ERA5 data first! See the data setup section below.

# Update these paths to match your setup:
config = {
    'ae_model_path': '/home/azureuser/cloudfiles/code/Users/randy.chase/appa_models/autoencoders/workshop/0/',  # Update this path
    'data_path': '/path/to/your/era5_data.zarr',  # Path to downloaded ERA5 data
    'stats_path': '/home/azureuser/cloudfiles/code/Users/randy.chase/appa_data/data/stats_era5_1993-2021-1h-1440x721.zarr',  # Path to ERA5 statistics
    'mask_path': '/home/azureuser/cloudfiles/code/Users/randy.chase/appa_data/data/masks_era5_1993-2021-1h-1440x721.zarr',   # Path to land/sea mask
    'checkpoint': 'best',  # Options: 'best', 'last'
    'start_date': '2019-01-01',
    'end_date': '2019-01-31',  # Start with a small date range for testing
    'batch_size': 4,
    'num_bins': 100,  # Number of bins for error histograms
    'sub_pressure_levels': True,  # Use sub pressure levels if available
}

# Error analysis parameters
error_config = {
    'bin_ranges': {
        '2m_temperature': [-10, 10],
        '10m_u_component_of_wind': [-10, 10],
        '10m_v_component_of_wind': [-10, 10],
        'mean_sea_level_pressure': [-7, 7],
        'total_precipitation': [-4, 4],
        'sea_surface_temperature': [-20, 20],
        'temperature': [-10, 10],
        'u_component_of_wind': [-10, 10],
        'v_component_of_wind': [-10, 10],
        'geopotential': [-600, 600],
        'specific_humidity': [-5, 5],
    },
    'multipliers': {
        'specific_humidity': 1000,  # kg/kg to g/kg
        'total_precipitation': 1000,  # m to mm
        'mean_sea_level_pressure': 0.01,  # Pa to hPa
    }
}

print("Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")

# Check if model path exists
ae_path = Path(config['ae_model_path'])
if ae_path.exists():
    print(f"\n✓ Autoencoder path exists: {ae_path}")
    print("Contents:")
    for item in ae_path.iterdir():
        print(f"  - {item.name}")
else:
    print(f"\n✗ Autoencoder path does not exist: {ae_path}")
    print("Please update the 'ae_model_path' in the config above")


In [None]:
# Check if data paths exist
print("Checking data availability...")
data_checks = {
    'ERA5 Data': config['data_path'],
    'Statistics': config['stats_path'], 
    'Masks': config['mask_path'],
    'Autoencoder': config['ae_model_path']
}

all_data_available = True
for name, path in data_checks.items():
    if Path(path).exists():
        print(f"✓ {name}: {path}")
    else:
        print(f"✗ {name}: {path} (NOT FOUND)")
        all_data_available = False

if not all_data_available:
    print("\n⚠️  Some required data is missing!")
    print("Please download the ERA5 data using the instructions above, or update the paths in the config.")
    print("You can also use the default paths from the project config:")
    print(f"  data_path: {PATH_ERA5}")
    print(f"  stats_path: {PATH_STAT}")
    print(f"  mask_path: {PATH_MASK}")
else:
    print("\n✓ All required data is available!")


## Load Autoencoder Model

Load the trained autoencoder model for reconstruction analysis.


In [None]:
# Load autoencoder model
print("Loading autoencoder...")
ae_model = load_auto_encoder(
    path=Path(config['ae_model_path']),
    model_name="model_best" if config['checkpoint'] == 'best' else "model_last",
    device=device,
    eval_mode=True
)
print(f"Autoencoder loaded successfully")

# Get model information
print(f"Model latent shape: {ae_model.latent_shape}")
print(f"Model device: {next(ae_model.parameters()).device}")

# Load autoencoder config to get pressure levels
ae_config_path = Path(config['ae_model_path']) / "config.yaml"
if ae_config_path.exists():
    ae_config = compose(ae_config_path)
    print(f"Autoencoder config loaded")
    
    # Determine pressure levels
    if hasattr(ae_config.train, 'sub_pressure_levels') and ae_config.train.sub_pressure_levels:
        atm_levels = SUB_PRESSURE_LEVELS
        print(f"Using sub pressure levels: {len(atm_levels)} levels")
    else:
        atm_levels = ERA5_PRESSURE_LEVELS
        print(f"Using full pressure levels: {len(atm_levels)} levels")
else:
    print("Warning: Could not load autoencoder config, using default pressure levels")
    atm_levels = SUB_PRESSURE_LEVELS if config['sub_pressure_levels'] else ERA5_PRESSURE_LEVELS

print(f"Pressure levels: {atm_levels}")


## Setup Data Loading

Set up the ERA5 dataset and data loader for reconstruction analysis.


In [None]:
# Create standardization transform
print("Setting up data standardization...")
st = StandardizeTransform(
    config['stats_path'],
    state_variables=ERA5_VARIABLES,
    context_variables=CONTEXT_VARIABLES,
    levels=atm_levels,
)
print(f"Standardization transform created")
print(f"State mean shape: {st.state_mean.shape}")
print(f"State std shape: {st.state_std.shape}")

# Create dataset
print(f"\nCreating ERA5 dataset...")
print(f"Date range: {config['start_date']} to {config['end_date']}")
print(f"Data path: {config['data_path']}")

dataset = ERA5Dataset(
    path=config['data_path'],
    start_date=config['start_date'],
    end_date=config['end_date'],
    num_samples=None,  # Use all samples in date range
    transform=st,
    trajectory_size=1,  # Single timestep for reconstruction
    state_variables=ERA5_VARIABLES,
    context_variables=CONTEXT_VARIABLES,
    levels=atm_levels,
)

print(f"Dataset created with {len(dataset)} samples")

# Create data loader
dataloader = get_dataloader(
    dataset,
    batch_size=config['batch_size'],
    num_workers=4,
    prefetch_factor=2,
    shuffle=False,  # Don't shuffle for consistent analysis
    drop_last=False,
)

print(f"Data loader created with {len(dataloader)} batches")
print(f"Batch size: {config['batch_size']}")


## Compute Reconstruction Errors

Process the ERA5 data through the autoencoder and compute various error metrics.


In [None]:
# Initialize error tracking
all_std_mse = []
all_signed_errors = []
all_ground_truth = []
all_predictions = []
all_dates = []

# Setup for sea surface temperature masking if needed
sst_idx = None
sea_mask = None
if "sea_surface_temperature" in ERA5_VARIABLES:
    import xarray as xr
    sst_idx = ERA5_VARIABLES.index("sea_surface_temperature")
    sea_mask_cpu = xr.open_zarr(config['mask_path'])["sea_surface_temperature_mask"].values
    sea_mask = torch.from_numpy(sea_mask_cpu).to(device)[None, None]
    print(f"Sea surface temperature masking enabled (index: {sst_idx})")

# Process data through autoencoder
print(f"\nProcessing {len(dataloader)} batches through autoencoder...")
ae_model.eval()

with torch.no_grad():
    for batch_idx, (state, context, date) in enumerate(dataloader):
        if batch_idx % 10 == 0:
            print(f"Processing batch {batch_idx}/{len(dataloader)}")
        
        # Move data to device
        state = state.to(device, non_blocking=True)
        context = context.to(device, non_blocking=True)
        date = date.to(device, non_blocking=True)
        
        # Reshape for autoencoder
        state_flat = rearrange(state, "B T Z Lat Lon -> (B T) (Lat Lon) Z")
        context_flat = rearrange(context, "B T K Lat Lon -> (B T) (Lat Lon) K")
        date_flat = rearrange(date, "B T D -> (B T) D")
        
        # Forward pass through autoencoder
        _, state_pred_flat = ae_model(state_flat, date_flat, context_flat)
        
        # Reshape back to original format
        state_pred = rearrange(
            state_pred_flat, 
            "(B T) (Lat Lon) Z -> B T Z Lat Lon", 
            B=state.shape[0], 
            Lon=ERA5_RESOLUTION[0]
        )
        
        # Compute standardized MSE (before unstandardization)
        std_error = (state - state_pred) ** 2  # [B T Z Lat Lon]
        std_mse = std_error.mean(dim=(-1, -2))  # [B T Z] - mean over spatial dimensions
        
        # Handle sea surface temperature masking
        if sst_idx is not None:
            # Only compute error over sea surface for SST
            sst_error = (state[:, :, sst_idx] - state_pred[:, :, sst_idx]) ** 2
            sst_error_masked = sst_error * sea_mask
            std_mse[:, :, sst_idx] = sst_error_masked.mean(dim=(-1, -2))
        
        # Store standardized errors
        all_std_mse.append(std_mse.cpu())
        
        # Unstandardize for signed error computation
        state_unstd, _ = st.unstandardize(state.cpu())
        state_pred_unstd, _ = st.unstandardize(state_pred.cpu())
        
        # Compute signed errors (in physical units)
        signed_error = state_unstd - state_pred_unstd  # [B T Z Lat Lon]
        
        # Store data
        all_signed_errors.append(signed_error)
        all_ground_truth.append(state_unstd)
        all_predictions.append(state_pred_unstd)
        all_dates.append(date.cpu())
        
        # Clean up GPU memory
        del state, context, date, state_pred, state_flat, context_flat, date_flat, state_pred_flat
        torch.cuda.empty_cache()

print("\nReconstruction completed!")
print(f"Processed {len(all_std_mse)} batches")


## Aggregate and Analyze Results

Combine all the error data and compute summary statistics.


In [None]:
# Aggregate all results
print("Aggregating results...")

# Concatenate all data
std_mse_all = torch.cat(all_std_mse, dim=0)  # [N, T, Z]
signed_errors_all = torch.cat(all_signed_errors, dim=0)  # [N, T, Z, Lat, Lon]
ground_truth_all = torch.cat(all_ground_truth, dim=0)  # [N, T, Z, Lat, Lon]
predictions_all = torch.cat(all_predictions, dim=0)  # [N, T, Z, Lat, Lon]
dates_all = torch.cat(all_dates, dim=0)  # [N, T, D]

print(f"Aggregated data shapes:")
print(f"  Standardized MSE: {std_mse_all.shape}")
print(f"  Signed errors: {signed_errors_all.shape}")
print(f"  Ground truth: {ground_truth_all.shape}")
print(f"  Predictions: {predictions_all.shape}")
print(f"  Dates: {dates_all.shape}")

# Compute summary statistics
print("\nComputing summary statistics...")

# Overall MSE statistics
overall_mse = std_mse_all.mean(dim=(0, 1))  # Mean over samples and time
overall_rmse = torch.sqrt(overall_mse)  # Root mean square error

# Variable-wise statistics
surface_vars = ERA5_SURFACE_VARIABLES
atmospheric_vars = ERA5_ATMOSPHERIC_VARIABLES

print(f"\n=== RECONSTRUCTION ERROR SUMMARY ===")
print(f"Total samples: {std_mse_all.shape[0]}")
print(f"Total variables: {std_mse_all.shape[2]}")
print(f"Surface variables: {len(surface_vars)}")
print(f"Atmospheric variables: {len(atmospheric_vars)} × {len(atm_levels)} = {len(atmospheric_vars) * len(atm_levels)}")

# Surface variable errors
print(f"\n=== SURFACE VARIABLE ERRORS (RMSE) ===")
for i, var in enumerate(surface_vars):
    rmse_val = overall_rmse[i].item()
    print(f"{var:25s}: {rmse_val:.4f}")

# Atmospheric variable errors (averaged across pressure levels)
print(f"\n=== ATMOSPHERIC VARIABLE ERRORS (RMSE, averaged across levels) ===")
atm_start_idx = len(surface_vars)
for i, var in enumerate(atmospheric_vars):
    var_start = atm_start_idx + i * len(atm_levels)
    var_end = var_start + len(atm_levels)
    var_rmse = overall_rmse[var_start:var_end].mean().item()
    print(f"{var:25s}: {var_rmse:.4f}")

# Overall statistics
print(f"\n=== OVERALL STATISTICS ===")
print(f"Mean RMSE: {overall_rmse.mean().item():.4f}")
print(f"Median RMSE: {overall_rmse.median().item():.4f}")
print(f"Min RMSE: {overall_rmse.min().item():.4f}")
print(f"Max RMSE: {overall_rmse.max().item():.4f}")


## Visualize Error Patterns

Create comprehensive visualizations of the reconstruction errors.


In [None]:
# Create comprehensive error visualizations
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
fig.suptitle('Autoencoder Reconstruction Error Analysis', fontsize=16, fontweight='bold')

# 1. RMSE by variable
ax1 = axes[0, 0]
variable_names = surface_vars + [f"{var}_{level}hPa" for var in atmospheric_vars for level in atm_levels]
rmse_values = overall_rmse.numpy()

# Plot surface variables
surface_rmse = rmse_values[:len(surface_vars)]
x_pos = np.arange(len(surface_vars))
bars1 = ax1.bar(x_pos, surface_rmse, alpha=0.7, label='Surface', color='skyblue')

# Plot atmospheric variables (averaged across levels)
atm_start = len(surface_vars)
atm_rmse = []
atm_labels = []
for i, var in enumerate(atmospheric_vars):
    var_start = atm_start + i * len(atm_levels)
    var_end = var_start + len(atm_levels)
    var_avg_rmse = rmse_values[var_start:var_end].mean()
    atm_rmse.append(var_avg_rmse)
    atm_labels.append(var)

x_pos_atm = np.arange(len(atm_labels)) + len(surface_vars) + 1
bars2 = ax1.bar(x_pos_atm, atm_rmse, alpha=0.7, label='Atmospheric (avg)', color='lightcoral')

ax1.set_xlabel('Variables')
ax1.set_ylabel('RMSE')
ax1.set_title('RMSE by Variable')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Set x-axis labels
all_labels = surface_vars + atm_labels
all_x_pos = list(x_pos) + list(x_pos_atm)
ax1.set_xticks(all_x_pos)
ax1.set_xticklabels(all_labels, rotation=45, ha='right')

# 2. Error distribution histogram
ax2 = axes[0, 1]
all_errors_flat = signed_errors_all.flatten().numpy()
# Remove extreme outliers for better visualization
q1, q99 = np.percentile(all_errors_flat, [1, 99])
errors_filtered = all_errors_flat[(all_errors_flat >= q1) & (all_errors_flat <= q99)]

ax2.hist(errors_filtered, bins=100, alpha=0.7, density=True, color='lightgreen')
ax2.axvline(0, color='red', linestyle='--', alpha=0.7, label='Zero error')
ax2.set_xlabel('Signed Error')
ax2.set_ylabel('Density')
ax2.set_title('Error Distribution (1st-99th percentile)')
ax2.legend()
ax2.grid(True, alpha=0.3)

# 3. RMSE by pressure level (for atmospheric variables)
ax3 = axes[1, 0]
pressure_rmse = {}
for i, level in enumerate(atm_levels):
    level_rmse = []
    for j, var in enumerate(atmospheric_vars):
        var_idx = atm_start + j * len(atm_levels) + i
        level_rmse.append(rmse_values[var_idx])
    pressure_rmse[level] = np.mean(level_rmse)

levels = list(pressure_rmse.keys())
level_rmse_vals = list(pressure_rmse.values())
ax3.plot(levels, level_rmse_vals, 'o-', linewidth=2, markersize=6)
ax3.set_xlabel('Pressure Level (hPa)')
ax3.set_ylabel('Average RMSE')
ax3.set_title('RMSE vs Pressure Level')
ax3.grid(True, alpha=0.3)
ax3.invert_xaxis()  # Higher pressure at bottom

# 4. Spatial error pattern (example for 2m temperature)
ax4 = axes[1, 1]
temp_2m_idx = surface_vars.index('2m_temperature')
temp_2m_errors = signed_errors_all[:, 0, temp_2m_idx, :, :].mean(dim=0)  # Average over samples

# Create a simple spatial plot
im = ax4.imshow(temp_2m_errors.numpy(), cmap='RdBu_r', aspect='auto')
ax4.set_title('2m Temperature Error Pattern')
ax4.set_xlabel('Longitude')
ax4.set_ylabel('Latitude')
plt.colorbar(im, ax=ax4, label='Error (K)')

plt.tight_layout()
plt.show()

print("Error visualization completed!")


## Sample Visualization

Visualize some example reconstructions to see the quality of the autoencoder.


In [None]:
# Visualize sample reconstructions
print("=== SAMPLE RECONSTRUCTION VISUALIZATION ===")

# Select a few samples for visualization
n_samples = min(4, ground_truth_all.shape[0])
sample_indices = torch.randperm(ground_truth_all.shape[0])[:n_samples]

# Create visualization for 2m temperature
temp_2m_idx = surface_vars.index('2m_temperature')

fig, axes = plt.subplots(2, n_samples, figsize=(4*n_samples, 8))
if n_samples == 1:
    axes = axes.reshape(-1, 1)

for i, sample_idx in enumerate(sample_indices):
    # Ground truth
    gt_temp = ground_truth_all[sample_idx, 0, temp_2m_idx, :, :]
    pred_temp = predictions_all[sample_idx, 0, temp_2m_idx, :, :]
    error_temp = signed_errors_all[sample_idx, 0, temp_2m_idx, :, :]
    
    # Ground truth plot
    im1 = axes[0, i].imshow(gt_temp.numpy(), cmap='viridis', aspect='auto')
    axes[0, i].set_title(f'Ground Truth (Sample {sample_idx})')
    axes[0, i].set_xlabel('Longitude')
    axes[0, i].set_ylabel('Latitude')
    plt.colorbar(im1, ax=axes[0, i], label='Temperature (K)')
    
    # Prediction plot
    im2 = axes[1, i].imshow(pred_temp.numpy(), cmap='viridis', aspect='auto')
    axes[1, i].set_title(f'Prediction (Sample {sample_idx})')
    axes[1, i].set_xlabel('Longitude')
    axes[1, i].set_ylabel('Latitude')
    plt.colorbar(im2, ax=axes[1, i], label='Temperature (K)')

plt.suptitle('2m Temperature: Ground Truth vs Predictions', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# Error visualization
fig, axes = plt.subplots(1, n_samples, figsize=(4*n_samples, 4))
if n_samples == 1:
    axes = [axes]

for i, sample_idx in enumerate(sample_indices):
    error_temp = signed_errors_all[sample_idx, 0, temp_2m_idx, :, :]
    
    im = axes[i].imshow(error_temp.numpy(), cmap='RdBu_r', aspect='auto')
    axes[i].set_title(f'Error (Sample {sample_idx})')
    axes[i].set_xlabel('Longitude')
    axes[i].set_ylabel('Latitude')
    plt.colorbar(im, ax=axes[i], label='Error (K)')

plt.suptitle('2m Temperature Reconstruction Errors', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# Print sample statistics
print("\nSample Reconstruction Statistics:")
for i, sample_idx in enumerate(sample_indices):
    sample_rmse = std_mse_all[sample_idx, 0, :].mean().item()
    temp_rmse = std_mse_all[sample_idx, 0, temp_2m_idx].item()
    print(f"Sample {sample_idx}: Overall RMSE = {sample_rmse:.4f}, 2m Temp RMSE = {temp_rmse:.4f}")


## Save Results

Save the analysis results for later use.


In [None]:
# Save analysis results
output_dir = Path("reconstruction_analysis")
output_dir.mkdir(exist_ok=True)

print("Saving analysis results...")

# Create variable names for easy reference
variable_names = surface_vars + [f"{var}_{level}hPa" for var in atmospheric_vars for level in atm_levels]

# Save aggregated data (smaller subset for memory efficiency)
save_data = {
    'std_mse': std_mse_all,
    'overall_rmse': overall_rmse,
    'config': config,
    'variable_names': variable_names,
    'surface_vars': surface_vars,
    'atmospheric_vars': atmospheric_vars,
    'pressure_levels': atm_levels,
}

torch.save(save_data, output_dir / "reconstruction_analysis.pt")
print(f"Analysis data saved to: {output_dir / 'reconstruction_analysis.pt'}")

# Save a few sample reconstructions for visualization
sample_data = {
    'ground_truth': ground_truth_all[:10],  # First 10 samples
    'predictions': predictions_all[:10],
    'errors': signed_errors_all[:10],
    'dates': dates_all[:10],
}

torch.save(sample_data, output_dir / "sample_reconstructions.pt")
print(f"Sample reconstructions saved to: {output_dir / 'sample_reconstructions.pt'}")

print(f"\nAnalysis completed! Results saved to: {output_dir}")
print(f"\nSummary:")
print(f"  - Total samples analyzed: {std_mse_all.shape[0]}")
print(f"  - Date range: {config['start_date']} to {config['end_date']}")
print(f"  - Overall mean RMSE: {overall_rmse.mean().item():.4f}")
print(f"  - Best performing variable: {variable_names[overall_rmse.argmin().item()]} (RMSE: {overall_rmse.min().item():.4f})")
print(f"  - Worst performing variable: {variable_names[overall_rmse.argmax().item()]} (RMSE: {overall_rmse.max().item():.4f})")


## Next Steps

This analysis provides a comprehensive view of the autoencoder's reconstruction performance. You can now:

1. **Compare different models**: Run this analysis on different autoencoder checkpoints to compare performance
2. **Temporal analysis**: Extend the date range to analyze seasonal patterns in reconstruction errors
3. **Spatial analysis**: Investigate which geographical regions have higher reconstruction errors
4. **Variable-specific analysis**: Deep dive into specific variables that show high errors
5. **Error correlation**: Analyze which variables tend to have correlated reconstruction errors
6. **Model improvement**: Use these insights to guide model architecture improvements

The saved results can be loaded later for further analysis:
```python
results = torch.load('reconstruction_analysis/reconstruction_analysis.pt')
```
