In [1]:
import pickle
import numpy as np

# Load one of your processed samples
with open('Processed_Data_Subset/processed_samples_1980.pkl', 'rb') as f:
    samples = pickle.load(f)

sample = samples[0]  # Look at first sample

print("Input trajectory shape:", sample['input_trajectory'].shape)
print("Number of environmental timesteps:", len(sample['environmental_data']))
print("\nOne environmental data point contains:")
for key in sample['environmental_data'][0].keys():
    if key == 'wind':
        print(f"  wind fields: {list(sample['environmental_data'][0]['wind'].keys())}")
        print(f"    u_300 shape: {sample['environmental_data'][0]['wind']['u_300'].shape}")
    elif key == 'sst':
        print(f"  sst shape: {sample['environmental_data'][0]['sst'].shape}")
    elif key == 'geopotential':
        print(f"  geopotential shape: {sample['environmental_data'][0]['geopotential'].shape}")

Input trajectory shape: (8, 6)
Number of environmental timesteps: 8

One environmental data point contains:
  wind fields: ['u_300', 'v_300', 'u_500', 'v_500', 'u_700', 'v_700', 'u_850', 'v_850']
    u_300 shape: (21, 21)
  sst shape: (21, 21)
  geopotential shape: (21, 41)


#### Embedding (Same as vanilla)

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math


class SinusoidalPositionEmbedding(nn.Module):
    """Timestep embedding for diffusion process."""
    
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    
    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = t[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

## Gaussian Diffusion with DDIM (Implicit Modeling) for faster convergence

In [24]:
class GaussianDiffusion:
    """
    Simplified DDPM for storm track forecasting.
    """
    
    def __init__(self, timesteps=1000, beta_start=0.0001, beta_end=0.02):
        self.timesteps = timesteps
        
        # Linear beta schedule
        self.betas = torch.linspace(beta_start, beta_end, timesteps)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        
        # Calculations for diffusion q(x_t | x_{t-1})
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        
        # Calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = (
            self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
    
    def q_sample(self, x_start, t, noise=None):
        """
        Forward diffusion: add noise to clean data.
        
        Args:
            x_start: (batch, 5, 2) - clean future positions
            t: (batch,) - diffusion timestep
            noise: optional noise to add
        
        Returns:
            x_t: noisy version of x_start at timestep t
        """
        if noise is None:
            noise = torch.randn_like(x_start)
        
        sqrt_alpha = self.sqrt_alphas_cumprod[t]
        sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t]
        
        # Reshape for broadcasting
        sqrt_alpha = sqrt_alpha[:, None, None]
        sqrt_one_minus_alpha = sqrt_one_minus_alpha[:, None, None]
        
        return sqrt_alpha * x_start + sqrt_one_minus_alpha * noise
    
    def p_sample(self, model, x_t, t, past_traj, era5_features):
        """
        Reverse diffusion: denoise one step.
        
        Args:
            model: DiffusionTransformer
            x_t: (batch, 5, 2) - noisy positions at timestep t
            t: (batch,) - current timestep
            past_traj: (batch, 8, 6) - conditioning
            era5_features: (batch, 8, 10) - conditioning
        
        Returns:
            x_{t-1}: less noisy positions
        """
        # Predict noise
        predicted_noise = model(past_traj, era5_features, x_t, t)
        device = x_t.device

        # Calculate x_0 prediction
        alpha = self.alphas_cumprod.to(device)[t][:, None, None]
        alpha_prev = self.alphas_cumprod_prev.to(device)[t][:, None, None]
        beta = self.betas.to(device)[t][:, None, None]
        
        # Predict x_0
        pred_x0 = (x_t - torch.sqrt(1 - alpha) * predicted_noise) / torch.sqrt(alpha)
        
        # Calculate x_{t-1}
        alphas_t = self.alphas.to(device)[t][:, None, None]
        mean = (
            torch.sqrt(alpha_prev) * beta * pred_x0 +
            torch.sqrt(alphas_t) * (1 - alpha_prev) * x_t
        ) / (1 - alpha)
        
        if t[0] > 0:
            noise = torch.randn_like(x_t)
            variance = self.posterior_variance.to(device)[t][:, None, None]
            return mean + torch.sqrt(variance) * noise
        else:
            return mean
    
    # @torch.no_grad()
    # def sample(self, model, past_traj, era5_features, device):
    #     """
    #     Generate storm track by denoising from pure noise.
        
    #     Args:
    #         model: trained DiffusionTransformer
    #         past_traj: (batch, 8, 6)
    #         era5_features: (batch, 8, 10)
        
    #     Returns:
    #         predicted_track: (batch, 5, 2) - forecasted positions
    #     """
    #     batch_size = past_traj.shape[0]
        
    #     # Start from pure noise
    #     x = torch.randn(batch_size, 5, 2, device=device)
        
    #     # Iteratively denoise
    #     for i in reversed(range(self.timesteps)):
    #         t = torch.full((batch_size,), i, device=device, dtype=torch.long)
    #         x = self.p_sample(model, x, t, past_traj, era5_features)
        
    #     return x

    # @torch.no_grad()
    # def ddim_sample(self, model, past_traj, era5_features, device, steps=50):
    #     """
    #     Fast sampling using DDIM with fewer steps.
        
    #     Args:
    #         steps: Number of sampling steps (default 50, vs 1000 for DDPM)
    #     """
    #     batch_size = past_traj.shape[0]
        
    #     # Select subset of timesteps to use
    #     step_size = self.timesteps // steps
    #     timesteps = list(range(0, self.timesteps, step_size))
    #     timesteps = list(reversed(timesteps))
        
    #     # Start from pure noise
    #     x = torch.randn(batch_size, 5, 2, device=device)
        
    #     # Iteratively denoise (only 'steps' iterations instead of 1000!)
    #     for i, t_curr in enumerate(timesteps):
    #         t = torch.full((batch_size,), t_curr, device=device, dtype=torch.long)
            
    #         # Predict noise
    #         predicted_noise = model(past_traj, era5_features, x, t)
            
    #         # DDIM update (deterministic, no random noise added)
    #         alpha = self.alphas_cumprod.to(device)[t][:, None, None]
            
    #         # Predict x_0
    #         pred_x0 = (x - torch.sqrt(1 - alpha) * predicted_noise) / torch.sqrt(alpha)
            
    #         # Get next timestep
    #         if i < len(timesteps) - 1:
    #             t_next = timesteps[i + 1]
    #             alpha_next = self.alphas_cumprod.to(device)[t_next][:, None, None]
                
    #             # DDIM step
    #             x = torch.sqrt(alpha_next) * pred_x0 + torch.sqrt(1 - alpha_next) * predicted_noise
    #         else:
    #             x = pred_x0
        
    #     return x

    @torch.no_grad()
    def ddim_sample(self, model, past_traj, era5_features, device, steps=50):
        """
        Fast sampling using DDIM with fewer steps.
        
        Args:
            steps: Number of sampling steps (default 50, vs 1000 for DDPM)
        """
        batch_size = past_traj.shape[0]
        
        # Select subset of timesteps to use
        step_size = self.timesteps // steps
        timesteps = list(range(0, self.timesteps, step_size))
        timesteps = list(reversed(timesteps))
        
        # Start from pure noise
        x = torch.randn(batch_size, 5, 2, device=device)
        
        # Iteratively denoise (only 'steps' iterations instead of 1000!)
        for i, t_curr in enumerate(timesteps):
            t = torch.full((batch_size,), t_curr, device=device, dtype=torch.long)
            
            # Predict noise
            predicted_noise = model(past_traj, era5_features, x, t)
            
            # DDIM update (deterministic, no random noise added)
            alpha = self.alphas_cumprod.to(device)[t_curr]  # Scalar indexing
            
            # Predict x_0
            pred_x0 = (x - torch.sqrt(1 - alpha) * predicted_noise) / torch.sqrt(alpha)
            
            # Get next timestep
            if i < len(timesteps) - 1:
                t_next = timesteps[i + 1]
                alpha_next = self.alphas_cumprod.to(device)[t_next]  # Scalar indexing
                
                # DDIM step
                x = torch.sqrt(alpha_next) * pred_x0 + torch.sqrt(1 - alpha_next) * predicted_noise
            else:
                x = pred_x0
        
        return x

## CNN Based ERA5 Encoder

In [25]:
"""
CNN Encoder for ERA5 Spatial Fields
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class ERA5CNNEncoder(nn.Module):
    """
    Encode ERA5 spatial grids using CNN.
    
    Takes multiple 2D fields and produces a fixed-size embedding vector.
    """
    
    def __init__(self, output_dim=256):
        super().__init__()
        
        # Separate encoders for different field types (different spatial sizes)
        
        # Wind + SST encoder (21x21 grids)
        # Input: 9 channels (8 wind + 1 SST)
        self.wind_sst_encoder = nn.Sequential(
            nn.Conv2d(9, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 21x21 -> 10x10
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 10x10 -> 5x5
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),  # 5x5 -> 1x1
        )
        
        # Geopotential encoder (21x41 grids)
        # Input: 1 channel
        self.geo_encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 21x41 -> 10x20
            
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 10x20 -> 5x10
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),  # 5x10 -> 1x1
        )
        
        # Combine features and project to output dimension
        self.fusion = nn.Sequential(
            nn.Linear(128 + 64, output_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
    
    def forward(self, env_data_batch):
        """
        Args:
            env_data_batch: List of environmental data (batch_size samples)
            Each sample contains 8 timesteps of ERA5 fields
        
        Returns:
            Tensor of shape (batch_size, n_timesteps, output_dim)
        """
        device = next(self.wind_sst_encoder.parameters()).device
        batch_features = []
        
        for env_timesteps in env_data_batch:  # Each sample in batch
            timestep_features = []
            
            for env in env_timesteps:  # Each of 8 timesteps
                # Prepare wind + SST (9 channels, 21x21)
                wind_sst_fields = []
                for level in [300, 500, 700, 850]:
                    wind_sst_fields.append(env['wind'][f'u_{level}'])
                    wind_sst_fields.append(env['wind'][f'v_{level}'])
                wind_sst_fields.append(env['sst'])
                
                wind_sst = np.stack(wind_sst_fields, axis=0)  # (9, 21, 21)
                
                # Handle NaN values (replace with mean)
                for i in range(wind_sst.shape[0]):
                    field = wind_sst[i]
                    if np.any(np.isnan(field)):
                        wind_sst[i] = np.nan_to_num(field, nan=np.nanmean(field))
                
                wind_sst = torch.FloatTensor(wind_sst).unsqueeze(0).to(device)  # (1, 9, 21, 21)
                
                # Prepare geopotential (1 channel, 21x41)
                geo = env['geopotential']
                if np.any(np.isnan(geo)):
                    geo = np.nan_to_num(geo, nan=np.nanmean(geo))
                geo = torch.FloatTensor(geo).unsqueeze(0).unsqueeze(0).to(device)  # (1, 1, 21, 41)
                
                # Encode
                wind_sst_feat = self.wind_sst_encoder(wind_sst).squeeze(-1).squeeze(-1)  # (1, 128)
                geo_feat = self.geo_encoder(geo).squeeze(-1).squeeze(-1)  # (1, 64)
                
                # Fuse
                combined = torch.cat([wind_sst_feat, geo_feat], dim=1)  # (1, 192)
                fused = self.fusion(combined)  # (1, output_dim)
                
                timestep_features.append(fused)
            
            # Stack timesteps
            timestep_features = torch.cat(timestep_features, dim=0)  # (8, output_dim)
            batch_features.append(timestep_features)
        
        return torch.stack(batch_features, dim=0)  # (batch, 8, output_dim)


class DiffusionTransformerCNN(nn.Module):
    """
    Updated Diffusion Transformer using CNN encoder for ERA5.
    """
    
    def __init__(
        self,
        d_model=256,
        n_heads=8,
        n_layers=6,
        dropout=0.1,
    ):
        super().__init__()
        
        self.d_model = d_model
        
        # CNN encoder for ERA5
        self.era5_encoder = ERA5CNNEncoder(output_dim=d_model)
        
        # Embeddings
        self.traj_embed = nn.Linear(6, d_model)
        self.pos_embed = nn.Linear(2, d_model)
        
        # Diffusion timestep embedding
        self.time_embed = nn.Sequential(
            SinusoidalPositionEmbedding(d_model),
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model)
        )
        
        # Transformer encoder for conditioning
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_model * 4,
            dropout=dropout,
            batch_first=True
        )
        self.condition_encoder = nn.TransformerEncoder(encoder_layer, n_layers)
        
        # Transformer decoder for denoising
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_model * 4,
            dropout=dropout,
            batch_first=True
        )
        self.denoiser = nn.TransformerDecoder(decoder_layer, n_layers)
        
        # Output head
        self.output_head = nn.Linear(d_model, 2)
        
        # Learnable positional encoding
        self.forecast_pos_embed = nn.Parameter(torch.randn(5, d_model))
    
    def forward(self, past_traj, env_data_batch, noisy_positions, diffusion_t):
        """
        Args:
            past_traj: (batch, 8, 6)
            env_data_batch: List of environmental data dicts
            noisy_positions: (batch, 5, 2)
            diffusion_t: (batch,)
        """
        batch_size = past_traj.shape[0]
        
        # Encode ERA5 with CNN
        era5_tokens = self.era5_encoder(env_data_batch)  # (batch, 8, d_model)
        
        # Embed trajectory
        traj_tokens = self.traj_embed(past_traj)  # (batch, 8, d_model)
        
        # Concatenate conditioning
        conditioning = torch.cat([traj_tokens, era5_tokens], dim=1)  # (batch, 16, d_model)
        conditioning = self.condition_encoder(conditioning)
        
        # Embed noisy positions
        pos_tokens = self.pos_embed(noisy_positions)  # (batch, 5, d_model)
        pos_tokens = pos_tokens + self.forecast_pos_embed.unsqueeze(0)
        
        # Add diffusion timestep
        t_embed = self.time_embed(diffusion_t)
        pos_tokens = pos_tokens + t_embed.unsqueeze(1)
        
        # Denoise
        denoised = self.denoiser(pos_tokens, conditioning)
        predicted_noise = self.output_head(denoised)
        
        return predicted_noise

### Dataset

In [18]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pickle
from pathlib import Path

In [26]:
class StormDatasetCNN(Dataset):
    """Dataset that keeps raw ERA5 grids for CNN processing."""
    
    def __init__(self, pkl_file):
        with open(pkl_file, 'rb') as f:
            self.samples = pickle.load(f)
        
        # Filter valid samples
        self.samples = [
            s for s in self.samples 
            if all(s['targets'][f't+{fh}h'] is not None for fh in [6, 12, 24, 48, 72])
        ]
        
        print(f"Loaded {len(self.samples)} valid samples")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Past trajectory
        traj = torch.FloatTensor(sample['input_trajectory'])
        
        # ERA5 raw data (keep as-is for CNN)
        env_data = sample['environmental_data']
        
        # Target positions
        targets = []
        for fh in [6, 12, 24, 48, 72]:
            t = sample['targets'][f't+{fh}h']
            targets.append([t['lat'], t['lon']])
        targets = torch.FloatTensor(targets)
        
        return traj, env_data, targets


def collate_fn_cnn(batch):
    """Custom collate to handle list of dicts in ERA5 data."""
    trajs, envs, targets = zip(*batch)
    
    trajs = torch.stack(trajs)
    targets = torch.stack(targets)
    # envs stays as list of lists of dicts
    
    return trajs, list(envs), targets

### Training Function

In [None]:
%pip install wandb

In [27]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import wandb
import os
from pathlib import Path
import numpy as np

# ============================================================================
# CONFIGURATION
# ============================================================================

CONFIG = {
    # Model
    'd_model': 256,
    'n_heads': 8,
    'n_layers': 6,
    'dropout': 0.1,
    
    # Training
    'batch_size': 16,
    'learning_rate': 1e-4,
    'weight_decay': 0.01,
    'n_epochs': 100,
    'grad_clip': 1.0,
    
    # Diffusion
    'diffusion_timesteps': 100,
    
    # Data
    'data_path': 'Processed_Data_Subset/processed_samples_1980.pkl',
    
    # Logging
    'log_interval': 10,
    'save_interval': 20,
    'checkpoint_dir': 'checkpoints',

    # Evaluation settings
    'eval_interval': 5,      # Evaluate every 5 epochs
    'eval_samples': 20,      # Evaluate on 20 samples
    'eval_ddim_steps': 50,   # Use 50 DDIM steps (vs 1000 DDPM)
}


# ============================================================================
# TRAINING WITH WANDB
# ============================================================================

def train_epoch_logged(model, diffusion, dataloader, optimizer, device, epoch):
    model.train()
    total_loss = 0
    step = 0
    
    for batch_idx, (traj, env_data, targets) in enumerate(dataloader):
        traj = traj.to(device)
        targets = targets.to(device)
        
        batch_size = traj.shape[0]
        
        # Sample diffusion timesteps
        t = torch.randint(0, diffusion.timesteps, (batch_size,), device=device)
        
        # Add noise
        noise = torch.randn_like(targets)
        noisy_targets = diffusion.q_sample(targets, t, noise=noise)
        
        # Forward pass
        predicted_noise = model(traj, env_data, noisy_targets, t)
        
        # Loss
        loss = nn.MSELoss()(predicted_noise, noise)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['grad_clip'])
        optimizer.step()
        
        total_loss += loss.item()
        step += 1
        
        # Log to wandb
        if batch_idx % CONFIG['log_interval'] == 0:
            wandb.log({
                'train/batch_loss': loss.item(),
                'train/epoch': epoch,
                'train/step': epoch * len(dataloader) + batch_idx,
            })
            print(f"  Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")
    
    avg_loss = total_loss / len(dataloader)
    return avg_loss


# @torch.no_grad()
# def evaluate(model, diffusion, dataset, device, n_samples=50):
#     """Quick evaluation on a subset."""
#     model.eval()
    
#     errors = []
    
#     for idx in range(min(n_samples, len(dataset))):
#         traj, env_data, actual = dataset[idx]
        
#         # Generate prediction
#         pred = diffusion.sample(
#             model,
#             traj.unsqueeze(0).to(device),
#             [env_data],  # Wrap in list for batch
#             device
#         )
        
#         pred = pred[0].cpu().numpy()
#         actual = actual.numpy()
        
#         # Calculate 72h error (last position)
#         lat_err = (pred[4, 0] - actual[4, 0]) * 111
#         lon_err = (pred[4, 1] - actual[4, 1]) * 111 * np.cos(np.radians(actual[4, 0]))
#         dist_err = np.sqrt(lat_err**2 + lon_err**2)
#         errors.append(dist_err)
    
#     return np.mean(errors), np.std(errors)

@torch.no_grad()
def evaluate(model, diffusion, dataset, device, n_samples=20):
    """Fast evaluation using DDIM sampling."""
    model.eval()
    
    errors = []
    
    for idx in range(min(n_samples, len(dataset))):
        traj, env_data, actual = dataset[idx]
        
        # Use DDIM for fast sampling (50 steps instead of 1000)
        pred = diffusion.ddim_sample(
            model,
            traj.unsqueeze(0).to(device),
            [env_data],
            device,
            steps=50  # 20x faster than full sampling!
        )
        
        pred = pred[0].cpu().numpy()
        actual = actual.numpy()
        
        # Calculate 72h error
        lat_err = (pred[4, 0] - actual[4, 0]) * 111
        lon_err = (pred[4, 1] - actual[4, 1]) * 111 * np.cos(np.radians(actual[4, 0]))
        dist_err = np.sqrt(lat_err**2 + lon_err**2)
        errors.append(dist_err)
    
    return np.mean(errors), np.std(errors)


def save_checkpoint(model, optimizer, scheduler, epoch, loss, filename):
    """Save checkpoint with all training state."""
    checkpoint_path = Path(CONFIG['checkpoint_dir']) / filename
    checkpoint_path.parent.mkdir(exist_ok=True)
    
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': loss,
        'config': CONFIG,
    }, checkpoint_path)
    
    print(f"Saved checkpoint: {checkpoint_path}")
    
    # Also save to wandb
    wandb.save(str(checkpoint_path))

### Training RUN

In [28]:

# Initialize wandb
wandb.init(
    project="cyclone-diffusion-transformer",
    config=CONFIG,
    name=f"cnn-encoder-{CONFIG['d_model']}d-{CONFIG['n_layers']}L",
)

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"  Using device: {device}")

if device.type == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Load data
print(f"\n Loading data...")
dataset = StormDatasetCNN(CONFIG['data_path'])
dataloader = DataLoader(
    dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn_cnn
)

# Create model
print(f"\n Building model...")
model = DiffusionTransformerCNN(
    d_model=CONFIG['d_model'],
    n_heads=CONFIG['n_heads'],
    n_layers=CONFIG['n_layers'],
    dropout=CONFIG['dropout']
).to(device)

n_params = sum(p.numel() for p in model.parameters())
print(f"   Parameters: {n_params/1e6:.2f}M")
wandb.config.update({'n_parameters': n_params})

# Watch model gradients in wandb
wandb.watch(model, log='all', log_freq=100)

# Diffusion
diffusion = GaussianDiffusion(timesteps=CONFIG['diffusion_timesteps'])
diffusion.betas = diffusion.betas.to(device)
diffusion.alphas_cumprod = diffusion.alphas_cumprod.to(device)
diffusion.alphas_cumprod_prev = diffusion.alphas_cumprod_prev.to(device)
diffusion.sqrt_alphas_cumprod = diffusion.sqrt_alphas_cumprod.to(device)
diffusion.sqrt_one_minus_alphas_cumprod = diffusion.sqrt_one_minus_alphas_cumprod.to(device)

# Optimizer and scheduler
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay']
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=CONFIG['n_epochs']
)

# Training loop
print(f"\n Starting training for {CONFIG['n_epochs']} epochs...")
print("="*70)

best_loss = float('inf')

for epoch in range(CONFIG['n_epochs']):
    print(f"\n Epoch {epoch+1}/{CONFIG['n_epochs']}")
    
    # Train
    avg_loss = train_epoch_logged(model, diffusion, dataloader, optimizer, device, epoch)
    
    if (epoch + 1) % CONFIG['eval_interval'] == 0:
        eval_mean, eval_std = evaluate(model, diffusion, dataset, device, 
                                      n_samples=CONFIG['eval_samples'])
        
        wandb.log({
            'eval/72h_error_mean_km': eval_mean,
            'eval/72h_error_std_km': eval_std,
        })
        print(f"   Eval 72h Error: {eval_mean:.1f} ± {eval_std:.1f} km")
    
    # Log training metrics
    wandb.log({
        'train/epoch_loss': avg_loss,
        'train/learning_rate': optimizer.param_groups[0]['lr'],
        'epoch': epoch,
    })
    print(f"   Train Loss: {avg_loss:.4f}")
    
    # Step scheduler
    scheduler.step()
    
    # Save best model
    if avg_loss < best_loss:
        best_loss = avg_loss
        save_checkpoint(model, optimizer, scheduler, epoch, avg_loss, 'best_model.pt')
        wandb.run.summary["best_loss"] = best_loss
        wandb.run.summary["best_epoch"] = epoch
    
    # Regular checkpoints
    if (epoch + 1) % CONFIG['save_interval'] == 0:
        save_checkpoint(
            model, optimizer, scheduler, epoch, avg_loss,
            f'checkpoint_epoch_{epoch+1}.pt'
        )

print("\n" + "="*70)
print("✅ Training complete!")
print(f"   Best loss: {best_loss:.4f}")

# Save final model
save_checkpoint(model, optimizer, scheduler, CONFIG['n_epochs']-1, avg_loss, 'final_model.pt')

wandb.finish()


0,1
epoch,▁▃▆█
train/batch_loss,▇█▂▆▄▃▅▅▄▅▂▃▆██▄▆▃▂▂▁▂▂▂▂
train/epoch,▁▁▁▁▁▃▃▃▃▃▅▅▅▅▅▆▆▆▆▆█████
train/epoch_loss,█▂▁▂
train/learning_rate,█▇▅▁
train/step,▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇▇██

0,1
best_epoch,2.0
best_loss,1.03109
epoch,3.0
train/batch_loss,0.9343
train/epoch,4.0
train/epoch_loss,1.04635
train/learning_rate,0.0001
train/step,224.0


  Using device: cuda
   GPU: Tesla T4
   Memory: 15.6 GB

 Loading data...
Loaded 727 valid samples

 Building model...
   Parameters: 11.76M

 Starting training for 100 epochs...

 Epoch 1/100
  Batch 0/46, Loss: 1.4081
  Batch 10/46, Loss: 1.1226
  Batch 20/46, Loss: 1.1075
  Batch 30/46, Loss: 0.9606
  Batch 40/46, Loss: 1.1020
   Train Loss: 1.1296




Saved checkpoint: checkpoints/best_model.pt

 Epoch 2/100
  Batch 0/46, Loss: 1.0963
  Batch 10/46, Loss: 1.0645
  Batch 20/46, Loss: 0.9546
  Batch 30/46, Loss: 0.8338
  Batch 40/46, Loss: 1.2151
   Train Loss: 1.0147
Saved checkpoint: checkpoints/best_model.pt

 Epoch 3/100
  Batch 0/46, Loss: 0.9193
  Batch 10/46, Loss: 1.0805
  Batch 20/46, Loss: 1.1388
  Batch 30/46, Loss: 0.8843
  Batch 40/46, Loss: 0.9086
   Train Loss: 1.0316

 Epoch 4/100
  Batch 0/46, Loss: 1.0074
  Batch 10/46, Loss: 0.9310
  Batch 20/46, Loss: 1.0559
  Batch 30/46, Loss: 1.0718
  Batch 40/46, Loss: 0.7969
   Train Loss: 1.0245

 Epoch 5/100
  Batch 0/46, Loss: 1.0596
  Batch 10/46, Loss: 0.9815
  Batch 20/46, Loss: 0.8802
  Batch 30/46, Loss: 1.0010
  Batch 40/46, Loss: 1.0398
   Eval 72h Error: 12691.3 ± 251.7 km
   Train Loss: 1.0211

 Epoch 6/100
  Batch 0/46, Loss: 0.8128
  Batch 10/46, Loss: 1.0061
  Batch 20/46, Loss: 0.9743
  Batch 30/46, Loss: 1.0931
  Batch 40/46, Loss: 0.9501
   Train Loss: 1.0253



   Eval 72h Error: 12566.5 ± 246.0 km
   Train Loss: 0.9724
Saved checkpoint: checkpoints/checkpoint_epoch_20.pt

 Epoch 21/100
  Batch 0/46, Loss: 1.0494
  Batch 10/46, Loss: 1.0265
  Batch 20/46, Loss: 0.9200
  Batch 30/46, Loss: 1.0584
  Batch 40/46, Loss: 1.0472
   Train Loss: 1.0161

 Epoch 22/100
  Batch 0/46, Loss: 1.0970
  Batch 10/46, Loss: 0.8061
  Batch 20/46, Loss: 1.2029
  Batch 30/46, Loss: 1.2116
  Batch 40/46, Loss: 0.8526
   Train Loss: 1.0069

 Epoch 23/100
  Batch 0/46, Loss: 1.0661
  Batch 10/46, Loss: 0.7999
  Batch 20/46, Loss: 1.1445
  Batch 30/46, Loss: 0.9130
  Batch 40/46, Loss: 0.8711
   Train Loss: 0.9932

 Epoch 24/100
  Batch 0/46, Loss: 1.1194
  Batch 10/46, Loss: 1.1215
  Batch 20/46, Loss: 1.0059
  Batch 30/46, Loss: 1.0857
  Batch 40/46, Loss: 1.0438
   Train Loss: 1.0233

 Epoch 25/100
  Batch 0/46, Loss: 1.1407
  Batch 10/46, Loss: 1.0373
  Batch 20/46, Loss: 1.0280
  Batch 30/46, Loss: 0.9946
  Batch 40/46, Loss: 0.9544
   Eval 72h Error: 12673.5 ± 




 Epoch 41/100
  Batch 0/46, Loss: 1.0695
  Batch 10/46, Loss: 0.8712
  Batch 20/46, Loss: 1.1414
  Batch 30/46, Loss: 1.0774
  Batch 40/46, Loss: 0.9142
   Train Loss: 0.9973

 Epoch 42/100
  Batch 0/46, Loss: 1.2402
  Batch 10/46, Loss: 1.0705
  Batch 20/46, Loss: 0.9209
  Batch 30/46, Loss: 1.1678
  Batch 40/46, Loss: 0.9803
   Train Loss: 1.0291

 Epoch 43/100
  Batch 0/46, Loss: 0.8584
  Batch 10/46, Loss: 1.0534
  Batch 20/46, Loss: 0.8491
  Batch 30/46, Loss: 0.9446
  Batch 40/46, Loss: 1.0508
   Train Loss: 1.0074

 Epoch 44/100
  Batch 0/46, Loss: 1.0186
  Batch 10/46, Loss: 1.1116
  Batch 20/46, Loss: 1.0932
  Batch 30/46, Loss: 1.0754
  Batch 40/46, Loss: 1.0611
   Train Loss: 1.0013

 Epoch 45/100
  Batch 0/46, Loss: 0.9930
  Batch 10/46, Loss: 1.0185
  Batch 20/46, Loss: 0.8837
  Batch 30/46, Loss: 0.8550
  Batch 40/46, Loss: 0.8513
   Eval 72h Error: 12685.2 ± 243.9 km
   Train Loss: 0.9820

 Epoch 46/100
  Batch 0/46, Loss: 1.0768
  Batch 10/46, Loss: 0.8352
  Batch 20/4



Saved checkpoint: checkpoints/checkpoint_epoch_60.pt

 Epoch 61/100
  Batch 0/46, Loss: 1.0311
  Batch 10/46, Loss: 0.8507
  Batch 20/46, Loss: 1.0880
  Batch 30/46, Loss: 1.0801
  Batch 40/46, Loss: 0.8794
   Train Loss: 0.9963

 Epoch 62/100
  Batch 0/46, Loss: 0.7921
  Batch 10/46, Loss: 0.9192
  Batch 20/46, Loss: 0.9847
  Batch 30/46, Loss: 0.8829
  Batch 40/46, Loss: 0.9399
   Train Loss: 0.9990

 Epoch 63/100
  Batch 0/46, Loss: 0.9160
  Batch 10/46, Loss: 1.0070
  Batch 20/46, Loss: 0.9905
  Batch 30/46, Loss: 1.0637
  Batch 40/46, Loss: 1.0310
   Train Loss: 0.9845

 Epoch 64/100
  Batch 0/46, Loss: 1.1839
  Batch 10/46, Loss: 0.9365
  Batch 20/46, Loss: 0.9128
  Batch 30/46, Loss: 0.9420
  Batch 40/46, Loss: 0.9611
   Train Loss: 0.9783

 Epoch 65/100
  Batch 0/46, Loss: 0.9728
  Batch 10/46, Loss: 1.0348
  Batch 20/46, Loss: 0.9546
  Batch 30/46, Loss: 0.9931
  Batch 40/46, Loss: 0.8564
   Eval 72h Error: 12613.4 ± 295.6 km
   Train Loss: 1.0053

 Epoch 66/100
  Batch 0/46, 




 Epoch 81/100
  Batch 0/46, Loss: 1.0137
  Batch 10/46, Loss: 0.9562
  Batch 20/46, Loss: 0.9981
  Batch 30/46, Loss: 1.3419
  Batch 40/46, Loss: 0.9127
   Train Loss: 0.9893

 Epoch 82/100
  Batch 0/46, Loss: 1.0505
  Batch 10/46, Loss: 0.7879
  Batch 20/46, Loss: 1.1129
  Batch 30/46, Loss: 1.0908
  Batch 40/46, Loss: 1.0641
   Train Loss: 1.0171

 Epoch 83/100
  Batch 0/46, Loss: 0.9512
  Batch 10/46, Loss: 1.0176
  Batch 20/46, Loss: 1.0834
  Batch 30/46, Loss: 1.1693
  Batch 40/46, Loss: 1.0333
   Train Loss: 0.9799

 Epoch 84/100
  Batch 0/46, Loss: 1.0027
  Batch 10/46, Loss: 1.0031
  Batch 20/46, Loss: 0.9873
  Batch 30/46, Loss: 0.9598
  Batch 40/46, Loss: 1.0679
   Train Loss: 0.9567
Saved checkpoint: checkpoints/best_model.pt

 Epoch 85/100
  Batch 0/46, Loss: 0.9126
  Batch 10/46, Loss: 0.9828
  Batch 20/46, Loss: 0.8955
  Batch 30/46, Loss: 0.8981
  Batch 40/46, Loss: 0.9245
   Eval 72h Error: 12573.4 ± 287.7 km
   Train Loss: 0.9529
Saved checkpoint: checkpoints/best_mod

KeyboardInterrupt: 

## GRUCNN 

In [22]:
"""
GRU-CNN Baseline from Original Paper
Adapted to work with our .pkl data format
"""

import torch
import torch.nn as nn
import numpy as np


class GRUCNN(nn.Module):
    """
    GRU-CNN model from the paper.
    
    Architecture:
    1. CNN encodes ERA5 spatial fields at each timestep
    2. GRU processes sequence of (trajectory + ERA5 embeddings)
    3. Decoder predicts future positions
    """
    
    def __init__(
        self,
        hidden_dim=128,
        num_layers=2,
        dropout=0.1
    ):
        super().__init__()
        
        # CNN encoder for ERA5 (same as our diffusion model)
        self.era5_encoder = ERA5CNNEncoder(output_dim=hidden_dim)
        
        # Trajectory encoder
        self.traj_encoder = nn.Linear(6, hidden_dim)
        
        # GRU for sequence processing
        self.gru = nn.GRU(
            input_size=hidden_dim * 2,  # traj + era5
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )
        
        # Decoder: predict 5 future positions (10 values: 5 × [lat, lon])
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 10)  # 5 positions × 2 coords
        )
    
    def forward(self, past_traj, env_data_batch):
        """
        Args:
            past_traj: (batch, 8, 6) - past trajectory
            env_data_batch: list of env data for CNN
        
        Returns:
            predictions: (batch, 5, 2) - predicted future positions
        """
        batch_size = past_traj.shape[0]
        
        # Encode ERA5
        era5_features = self.era5_encoder(env_data_batch)  # (batch, 8, hidden_dim)
        
        # Encode trajectory
        traj_features = self.traj_encoder(past_traj)  # (batch, 8, hidden_dim)
        
        # Combine features
        combined = torch.cat([traj_features, era5_features], dim=-1)  # (batch, 8, hidden_dim*2)
        
        # GRU processes sequence
        gru_out, _ = self.gru(combined)  # (batch, 8, hidden_dim)
        
        # Use last timestep to predict future
        last_hidden = gru_out[:, -1, :]  # (batch, hidden_dim)
        
        # Decode to future positions
        predictions = self.decoder(last_hidden)  # (batch, 10)
        predictions = predictions.view(batch_size, 5, 2)  # (batch, 5, 2)
        
        return predictions


# Training function for GRU-CNN
def train_epoch_grucnn(model, dataloader, optimizer, device, epoch):
    model.train()
    total_loss = 0
    
    for batch_idx, (traj, env_data, targets) in enumerate(dataloader):
        traj = traj.to(device)
        targets = targets.to(device)
        
        # Forward pass
        predictions = model(traj, env_data)
        
        # Loss: MSE on position predictions
        loss = nn.MSELoss()(predictions, targets)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        
        if batch_idx % 10 == 0:
            wandb.log({
                'grucnn/batch_loss': loss.item(),
                'grucnn/epoch': epoch,
                'grucnn/step': epoch * len(dataloader) + batch_idx,
            })
            print(f"  Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")
    
    return total_loss / len(dataloader)


@torch.no_grad()
def evaluate_grucnn(model, dataset, device, n_samples=20):
    """Evaluate GRU-CNN model."""
    model.eval()
    
    errors = []
    
    for idx in range(min(n_samples, len(dataset))):
        traj, env_data, actual = dataset[idx]
        
        # Predict
        pred = model(traj.unsqueeze(0).to(device), [env_data])
        pred = pred[0].cpu().numpy()
        actual = actual.numpy()
        
        # Calculate 72h error
        lat_err = (pred[4, 0] - actual[4, 0]) * 111
        lon_err = (pred[4, 1] - actual[4, 1]) * 111 * np.cos(np.radians(actual[4, 0]))
        dist_err = np.sqrt(lat_err**2 + lon_err**2)
        errors.append(dist_err)
    
    return np.mean(errors), np.std(errors)


def train_grucnn_baseline():
    """Main training loop for GRU-CNN baseline."""
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Training GRU-CNN Baseline on {device}")
    
    # Initialize wandb
    wandb.init(
        project="cyclone-diffusion-transformer",
        name="grucnn-baseline",
        config={
            'model': 'GRU-CNN',
            'hidden_dim': 128,
            'num_layers': 2,
            'learning_rate': 1e-3,
            'batch_size': 16,
            'n_epochs': 100,
        }
    )
    
    # Load data (same as diffusion model)
    dataset = StormDatasetCNN('Processed_Data_Subset/processed_samples_1980.pkl')
    dataloader = DataLoader(
        dataset,
        batch_size=16,
        shuffle=True,
        num_workers=0,
        collate_fn=collate_fn_cnn
    )
    
    # Create model
    model = GRUCNN(hidden_dim=128, num_layers=2).to(device)
    print(f"Parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
    
    # Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
    
    # Training
    best_loss = float('inf')
    
    for epoch in range(100):
        print(f"\nEpoch {epoch+1}/100")
        
        # Train
        avg_loss = train_epoch_grucnn(model, dataloader, optimizer, device, epoch)
        
        # Evaluate every 5 epochs
        if (epoch + 1) % 5 == 0:
            eval_mean, eval_std = evaluate_grucnn(model, dataset, device, n_samples=20)
            wandb.log({
                'grucnn/eval_72h_error_km': eval_mean,
                'grucnn/eval_72h_error_std_km': eval_std,
            })
            print(f"   Eval 72h Error: {eval_mean:.1f} ± {eval_std:.1f} km")
        
        wandb.log({
            'grucnn/epoch_loss': avg_loss,
            'grucnn/learning_rate': optimizer.param_groups[0]['lr'],
        })
        print(f"   Train Loss: {avg_loss:.4f}")
        
        scheduler.step()
        
        # Save best
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), 'grucnn_best.pt')
            print(f"✓ Saved best model")
    
    wandb.finish()

## Comparison Framework

In [23]:
"""
Comprehensive Model Comparison
"""

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns


def comprehensive_evaluation(diffusion_model, grucnn_model, diffusion_obj, dataset, device):
    """
    Evaluate both models on all metrics for comparison.
    """
    diffusion_model.eval()
    grucnn_model.eval()
    
    results = {
        'sample_id': [],
        'horizon': [],
        'diffusion_error_km': [],
        'grucnn_error_km': [],
        'cyclone_name': [],
        'actual_lat': [],
        'actual_lon': [],
        'diffusion_lat': [],
        'diffusion_lon': [],
        'grucnn_lat': [],
        'grucnn_lon': [],
    }
    
    print("Running comprehensive evaluation...")
    
    for idx in range(len(dataset)):
        traj, env_data, actual = dataset[idx]
        sample = dataset.samples[idx]
        
        # Diffusion prediction (DDIM for speed)
        with torch.no_grad():
            diff_pred = diffusion_obj.ddim_sample(
                diffusion_model,
                traj.unsqueeze(0).to(device),
                [env_data],
                device,
                steps=50
            )[0].cpu().numpy()
        
        # GRU-CNN prediction
        with torch.no_grad():
            gru_pred = grucnn_model(
                traj.unsqueeze(0).to(device),
                [env_data]
            )[0].cpu().numpy()
        
        actual_np = actual.numpy()
        
        # Calculate errors for each horizon
        for i, horizon in enumerate([6, 12, 24, 48, 72]):
            # Diffusion error
            diff_lat_err = (diff_pred[i, 0] - actual_np[i, 0]) * 111
            diff_lon_err = (diff_pred[i, 1] - actual_np[i, 1]) * 111 * np.cos(np.radians(actual_np[i, 0]))
            diff_err = np.sqrt(diff_lat_err**2 + diff_lon_err**2)
            
            # GRU-CNN error
            gru_lat_err = (gru_pred[i, 0] - actual_np[i, 0]) * 111
            gru_lon_err = (gru_pred[i, 1] - actual_np[i, 1]) * 111 * np.cos(np.radians(actual_np[i, 0]))
            gru_err = np.sqrt(gru_lat_err**2 + gru_lon_err**2)
            
            results['sample_id'].append(idx)
            results['horizon'].append(horizon)
            results['diffusion_error_km'].append(diff_err)
            results['grucnn_error_km'].append(gru_err)
            results['cyclone_name'].append(sample['cyclone_name'])
            results['actual_lat'].append(actual_np[i, 0])
            results['actual_lon'].append(actual_np[i, 1])
            results['diffusion_lat'].append(diff_pred[i, 0])
            results['diffusion_lon'].append(diff_pred[i, 1])
            results['grucnn_lat'].append(gru_pred[i, 0])
            results['grucnn_lon'].append(gru_pred[i, 1])
        
        if (idx + 1) % 50 == 0:
            print(f"  Evaluated {idx+1}/{len(dataset)} samples")
    
    return pd.DataFrame(results)


def create_comparison_table(df):
    """Create publication-quality comparison table."""
    
    summary = df.groupby('horizon').agg({
        'diffusion_error_km': ['mean', 'std'],
        'grucnn_error_km': ['mean', 'std']
    }).round(1)
    
    print("\n" + "="*80)
    print("MODEL COMPARISON: Track Error (km)")
    print("="*80)
    print(f"{'Horizon':<10} {'Diffusion Transformer':<30} {'GRU-CNN Baseline':<30}")
    print(f"{'(hours)':<10} {'Mean ± Std':<30} {'Mean ± Std':<30}")
    print("-"*80)
    
    for horizon in [6, 12, 24, 48, 72]:
        diff_mean = summary.loc[horizon, ('diffusion_error_km', 'mean')]
        diff_std = summary.loc[horizon, ('diffusion_error_km', 'std')]
        gru_mean = summary.loc[horizon, ('grucnn_error_km', 'mean')]
        gru_std = summary.loc[horizon, ('grucnn_error_km', 'std')]
        
        diff_str = f"{diff_mean:.1f} ± {diff_std:.1f}"
        gru_str = f"{gru_mean:.1f} ± {gru_std:.1f}"
        
        # Highlight winner
        winner = "✓" if diff_mean < gru_mean else ""
        
        print(f"{horizon:<10} {diff_str:<30} {gru_str:<30} {winner}")
    
    print("="*80)
    
    return summary


def plot_model_comparison(df):
    """Create comparison visualizations."""
    
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # 1. Error by horizon
    ax = axes[0, 0]
    horizons = [6, 12, 24, 48, 72]
    diff_means = [df[df['horizon']==h]['diffusion_error_km'].mean() for h in horizons]
    gru_means = [df[df['horizon']==h]['grucnn_error_km'].mean() for h in horizons]
    
    ax.plot(horizons, diff_means, 'o-', linewidth=2, markersize=8, label='Diffusion Transformer')
    ax.plot(horizons, gru_means, 's-', linewidth=2, markersize=8, label='GRU-CNN')
    ax.set_xlabel('Forecast Horizon (hours)', fontsize=12)
    ax.set_ylabel('Mean Track Error (km)', fontsize=12)
    ax.set_title('Track Error by Forecast Horizon', fontsize=14, fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 2. Error distribution comparison
    ax = axes[0, 1]
    df_72h = df[df['horizon'] == 72]
    ax.boxplot([df_72h['diffusion_error_km'], df_72h['grucnn_error_km']], 
               labels=['Diffusion', 'GRU-CNN'])
    ax.set_ylabel('72h Track Error (km)', fontsize=12)
    ax.set_title('72-hour Forecast Error Distribution', fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3, axis='y')
    
    # 3. Scatter plot: Diffusion vs GRU-CNN
    ax = axes[1, 0]
    df_72h = df[df['horizon'] == 72]
    ax.scatter(df_72h['grucnn_error_km'], df_72h['diffusion_error_km'], alpha=0.5)
    max_err = max(df_72h['grucnn_error_km'].max(), df_72h['diffusion_error_km'].max())
    ax.plot([0, max_err], [0, max_err], 'r--', alpha=0.5, label='Equal performance')
    ax.set_xlabel('GRU-CNN Error (km)', fontsize=12)
    ax.set_ylabel('Diffusion Transformer Error (km)', fontsize=12)
    ax.set_title('72h: Per-Sample Error Comparison', fontsize=14, fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 4. Improvement histogram
    ax = axes[1, 1]
    df_72h['improvement'] = df_72h['grucnn_error_km'] - df_72h['diffusion_error_km']
    ax.hist(df_72h['improvement'], bins=30, edgecolor='black')
    ax.axvline(0, color='red', linestyle='--', linewidth=2, label='No improvement')
    ax.set_xlabel('Error Reduction (km)', fontsize=12)
    ax.set_ylabel('Number of Samples', fontsize=12)
    ax.set_title('72h: Error Reduction Distribution\n(Positive = Diffusion Better)', 
                 fontsize=14, fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig('model_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()


def plot_sample_tracks_comparison(df, dataset, sample_indices=[0, 10, 20]):
    """Plot side-by-side track comparisons for specific samples."""
    
    n_samples = len(sample_indices)
    fig = plt.figure(figsize=(18, 6*n_samples))
    
    for plot_idx, sample_idx in enumerate(sample_indices):
        traj, env_data, actual = dataset[sample_idx]
        sample_data = df[df['sample_id'] == sample_idx]
        
        ax = plt.subplot(n_samples, 1, plot_idx+1, projection=ccrs.PlateCarree())
        ax.coastlines()
        ax.add_feature(cfeature.LAND, alpha=0.3)
        ax.add_feature(cfeature.OCEAN, alpha=0.3)
        ax.gridlines(draw_labels=True)
        
        # Past trajectory
        past_lats = traj[:, 0].numpy()
        past_lons = traj[:, 1].numpy()
        ax.plot(past_lons, past_lats, 'o-', color='black', linewidth=2, 
                markersize=6, label='Past 24h', transform=ccrs.PlateCarree())
        
        # Actual
        actual_data = sample_data[sample_data['horizon'].isin([6, 12, 24, 48, 72])]
        actual_lats = actual_data['actual_lat'].values
        actual_lons = actual_data['actual_lon'].values
        full_actual_lats = np.concatenate([past_lats[-1:], actual_lats])
        full_actual_lons = np.concatenate([past_lons[-1:], actual_lons])
        ax.plot(full_actual_lons, full_actual_lats, 's-', color='green', 
                linewidth=2, markersize=8, label='Actual', transform=ccrs.PlateCarree())
        
        # Diffusion
        diff_lats = actual_data['diffusion_lat'].values
        diff_lons = actual_data['diffusion_lon'].values
        full_diff_lats = np.concatenate([past_lats[-1:], diff_lats])
        full_diff_lons = np.concatenate([past_lons[-1:], diff_lons])
        ax.plot(full_diff_lons, full_diff_lats, '^-', color='red', 
                linewidth=2, markersize=8, label='Diffusion', transform=ccrs.PlateCarree())
        
        # GRU-CNN
        gru_lats = actual_data['grucnn_lat'].values
        gru_lons = actual_data['grucnn_lon'].values
        full_gru_lats = np.concatenate([past_lats[-1:], gru_lats])
        full_gru_lons = np.concatenate([past_lons[-1:], gru_lons])
        ax.plot(full_gru_lons, full_gru_lats, 'D-', color='blue', 
                linewidth=2, markersize=8, label='GRU-CNN', transform=ccrs.PlateCarree())
        
        # Calculate 72h errors
        diff_72h_err = sample_data[sample_data['horizon']==72]['diffusion_error_km'].values[0]
        gru_72h_err = sample_data[sample_data['horizon']==72]['grucnn_error_km'].values[0]
        
        ax.legend(loc='upper right')
        ax.set_title(f"Sample {sample_idx} | 72h Error: Diffusion={diff_72h_err:.0f}km, GRU-CNN={gru_72h_err:.0f}km",
                    fontsize=12, fontweight='bold')
        
        # Set extent
        all_lats = np.concatenate([past_lats, actual_lats, diff_lats, gru_lats])
        all_lons = np.concatenate([past_lons, actual_lons, diff_lons, gru_lons])
        margin = 3
        ax.set_extent([all_lons.min()-margin, all_lons.max()+margin,
                      all_lats.min()-margin, all_lats.max()+margin])
    
    plt.tight_layout()
    plt.savefig('track_comparison_samples.png', dpi=300, bbox_inches='tight')
    plt.show()