#### Embedding (Same as vanilla)

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math


class SinusoidalPositionEmbedding(nn.Module):
    """Timestep embedding for diffusion process."""
    
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    
    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = t[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

In [24]:
class GaussianDiffusion:
    """
    Simplified DDPM for storm track forecasting.
    """
    
    def __init__(self, timesteps=1000, beta_start=0.0001, beta_end=0.02):
        self.timesteps = timesteps
        
        # Linear beta schedule
        self.betas = torch.linspace(beta_start, beta_end, timesteps)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        
        # Calculations for diffusion q(x_t | x_{t-1})
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        
        # Calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = (
            self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
    
    def q_sample(self, x_start, t, noise=None):
        """
        Forward diffusion: add noise to clean data.
        
        Args:
            x_start: (batch, 5, 2) - clean future positions
            t: (batch,) - diffusion timestep
            noise: optional noise to add
        
        Returns:
            x_t: noisy version of x_start at timestep t
        """
        if noise is None:
            noise = torch.randn_like(x_start)
        
        sqrt_alpha = self.sqrt_alphas_cumprod[t]
        sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t]
        
        # Reshape for broadcasting
        sqrt_alpha = sqrt_alpha[:, None, None]
        sqrt_one_minus_alpha = sqrt_one_minus_alpha[:, None, None]
        
        return sqrt_alpha * x_start + sqrt_one_minus_alpha * noise
    
    def p_sample(self, model, x_t, t, past_traj, era5_features):
        """
        Reverse diffusion: denoise one step.
        
        Args:
            model: DiffusionTransformer
            x_t: (batch, 5, 2) - noisy positions at timestep t
            t: (batch,) - current timestep
            past_traj: (batch, 8, 6) - conditioning
            era5_features: (batch, 8, 10) - conditioning
        
        Returns:
            x_{t-1}: less noisy positions
        """
        # Predict noise
        predicted_noise = model(past_traj, era5_features, x_t, t)
        device = x_t.device

        # Calculate x_0 prediction
        alpha = self.alphas_cumprod.to(device)[t][:, None, None]
        alpha_prev = self.alphas_cumprod_prev.to(device)[t][:, None, None]
        beta = self.betas.to(device)[t][:, None, None]
        
        # Predict x_0
        pred_x0 = (x_t - torch.sqrt(1 - alpha) * predicted_noise) / torch.sqrt(alpha)
        
        # Calculate x_{t-1}
        alphas_t = self.alphas.to(device)[t][:, None, None]
        mean = (
            torch.sqrt(alpha_prev) * beta * pred_x0 +
            torch.sqrt(alphas_t) * (1 - alpha_prev) * x_t
        ) / (1 - alpha)
        
        if t[0] > 0:
            noise = torch.randn_like(x_t)
            variance = self.posterior_variance.to(device)[t][:, None, None]
            return mean + torch.sqrt(variance) * noise
        else:
            return mean
    
    @torch.no_grad()
    def sample(self, model, past_traj, era5_features, device):
        """
        Generate storm track by denoising from pure noise.
        
        Args:
            model: trained DiffusionTransformer
            past_traj: (batch, 8, 6)
            era5_features: (batch, 8, 10)
        
        Returns:
            predicted_track: (batch, 5, 2) - forecasted positions
        """
        batch_size = past_traj.shape[0]
        
        # Start from pure noise
        x = torch.randn(batch_size, 5, 2, device=device)
        
        # Iteratively denoise
        for i in reversed(range(self.timesteps)):
            t = torch.full((batch_size,), i, device=device, dtype=torch.long)
            x = self.p_sample(model, x, t, past_traj, era5_features)
        
        return x

In [None]:
class ERA5Pooler(nn.Module):
    """Pool ERA5 spatial grids into fixed-size vectors."""
    
    def __init__(self):
        super().__init__()
        # We'll just use spatial averaging - simple but effective!
    
    def forward(self, env_data_batch):
        """
        Args:
            env_data_batch: List of environmental data dicts
        
        Returns:
            Tensor of shape (batch, n_timesteps, 10)
            where 10 = 8 wind fields + 1 SST + 1 geopotential
        """
        batch_features = []
        
        for env_timesteps in env_data_batch:  # Each sample
            timestep_features = []
            
            for env in env_timesteps:  # Each of 8 timesteps
                features = []
                
                # Pool wind fields (8 fields: u and v at 4 levels)
                for level in [300, 500, 700, 850]:
                    u = env['wind'][f'u_{level}']
                    v = env['wind'][f'v_{level}']
                    features.append(np.nanmean(u))
                    features.append(np.nanmean(v))
                
                # Pool SST
                features.append(np.nanmean(env['sst']))
                
                # Pool geopotential
                features.append(np.nanmean(env['geopotential']))
                
                timestep_features.append(features)
            
            batch_features.append(timestep_features)
        
        return torch.FloatTensor(batch_features)  # (batch, 8, 10)

## CNN Based ERA5 Encoder

In [25]:
"""
CNN Encoder for ERA5 Spatial Fields
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class ERA5CNNEncoder(nn.Module):
    """
    Encode ERA5 spatial grids using CNN.
    
    Takes multiple 2D fields and produces a fixed-size embedding vector.
    """
    
    def __init__(self, output_dim=256):
        super().__init__()
        
        # Separate encoders for different field types (different spatial sizes)
        
        # Wind + SST encoder (21x21 grids)
        # Input: 9 channels (8 wind + 1 SST)
        self.wind_sst_encoder = nn.Sequential(
            nn.Conv2d(9, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 21x21 -> 10x10
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 10x10 -> 5x5
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),  # 5x5 -> 1x1
        )
        
        # Geopotential encoder (21x41 grids)
        # Input: 1 channel
        self.geo_encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 21x41 -> 10x20
            
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 10x20 -> 5x10
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),  # 5x10 -> 1x1
        )
        
        # Combine features and project to output dimension
        self.fusion = nn.Sequential(
            nn.Linear(128 + 64, output_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
    
    def forward(self, env_data_batch):
        """
        Args:
            env_data_batch: List of environmental data (batch_size samples)
            Each sample contains 8 timesteps of ERA5 fields
        
        Returns:
            Tensor of shape (batch_size, n_timesteps, output_dim)
        """
        device = next(self.wind_sst_encoder.parameters()).device
        batch_features = []
        
        for env_timesteps in env_data_batch:  # Each sample in batch
            timestep_features = []
            
            for env in env_timesteps:  # Each of 8 timesteps
                # Prepare wind + SST (9 channels, 21x21)
                wind_sst_fields = []
                for level in [300, 500, 700, 850]:
                    wind_sst_fields.append(env['wind'][f'u_{level}'])
                    wind_sst_fields.append(env['wind'][f'v_{level}'])
                wind_sst_fields.append(env['sst'])
                
                wind_sst = np.stack(wind_sst_fields, axis=0)  # (9, 21, 21)
                
                # Handle NaN values (replace with mean)
                for i in range(wind_sst.shape[0]):
                    field = wind_sst[i]
                    if np.any(np.isnan(field)):
                        wind_sst[i] = np.nan_to_num(field, nan=np.nanmean(field))
                
                wind_sst = torch.FloatTensor(wind_sst).unsqueeze(0).to(device)  # (1, 9, 21, 21)
                
                # Prepare geopotential (1 channel, 21x41)
                geo = env['geopotential']
                if np.any(np.isnan(geo)):
                    geo = np.nan_to_num(geo, nan=np.nanmean(geo))
                geo = torch.FloatTensor(geo).unsqueeze(0).unsqueeze(0).to(device)  # (1, 1, 21, 41)
                
                # Encode
                wind_sst_feat = self.wind_sst_encoder(wind_sst).squeeze(-1).squeeze(-1)  # (1, 128)
                geo_feat = self.geo_encoder(geo).squeeze(-1).squeeze(-1)  # (1, 64)
                
                # Fuse
                combined = torch.cat([wind_sst_feat, geo_feat], dim=1)  # (1, 192)
                fused = self.fusion(combined)  # (1, output_dim)
                
                timestep_features.append(fused)
            
            # Stack timesteps
            timestep_features = torch.cat(timestep_features, dim=0)  # (8, output_dim)
            batch_features.append(timestep_features)
        
        return torch.stack(batch_features, dim=0)  # (batch, 8, output_dim)


class DiffusionTransformerCNN(nn.Module):
    """
    Updated Diffusion Transformer using CNN encoder for ERA5.
    """
    
    def __init__(
        self,
        d_model=256,
        n_heads=8,
        n_layers=6,
        dropout=0.1,
    ):
        super().__init__()
        
        self.d_model = d_model
        
        # CNN encoder for ERA5
        self.era5_encoder = ERA5CNNEncoder(output_dim=d_model)
        
        # Embeddings
        self.traj_embed = nn.Linear(6, d_model)
        self.pos_embed = nn.Linear(2, d_model)
        
        # Diffusion timestep embedding
        self.time_embed = nn.Sequential(
            SinusoidalPositionEmbedding(d_model),
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model)
        )
        
        # Transformer encoder for conditioning
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_model * 4,
            dropout=dropout,
            batch_first=True
        )
        self.condition_encoder = nn.TransformerEncoder(encoder_layer, n_layers)
        
        # Transformer decoder for denoising
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_model * 4,
            dropout=dropout,
            batch_first=True
        )
        self.denoiser = nn.TransformerDecoder(decoder_layer, n_layers)
        
        # Output head
        self.output_head = nn.Linear(d_model, 2)
        
        # Learnable positional encoding
        self.forecast_pos_embed = nn.Parameter(torch.randn(5, d_model))
    
    def forward(self, past_traj, env_data_batch, noisy_positions, diffusion_t):
        """
        Args:
            past_traj: (batch, 8, 6)
            env_data_batch: List of environmental data dicts
            noisy_positions: (batch, 5, 2)
            diffusion_t: (batch,)
        """
        batch_size = past_traj.shape[0]
        
        # Encode ERA5 with CNN
        era5_tokens = self.era5_encoder(env_data_batch)  # (batch, 8, d_model)
        
        # Embed trajectory
        traj_tokens = self.traj_embed(past_traj)  # (batch, 8, d_model)
        
        # Concatenate conditioning
        conditioning = torch.cat([traj_tokens, era5_tokens], dim=1)  # (batch, 16, d_model)
        conditioning = self.condition_encoder(conditioning)
        
        # Embed noisy positions
        pos_tokens = self.pos_embed(noisy_positions)  # (batch, 5, d_model)
        pos_tokens = pos_tokens + self.forecast_pos_embed.unsqueeze(0)
        
        # Add diffusion timestep
        t_embed = self.time_embed(diffusion_t)
        pos_tokens = pos_tokens + t_embed.unsqueeze(1)
        
        # Denoise
        denoised = self.denoiser(pos_tokens, conditioning)
        predicted_noise = self.output_head(denoised)
        
        return predicted_noise

### Dataset

In [18]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pickle
from pathlib import Path

In [26]:
class StormDatasetCNN(Dataset):
    """Dataset that keeps raw ERA5 grids for CNN processing."""
    
    def __init__(self, pkl_file):
        with open(pkl_file, 'rb') as f:
            self.samples = pickle.load(f)
        
        # Filter valid samples
        self.samples = [
            s for s in self.samples 
            if all(s['targets'][f't+{fh}h'] is not None for fh in [6, 12, 24, 48, 72])
        ]
        
        print(f"Loaded {len(self.samples)} valid samples")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Past trajectory
        traj = torch.FloatTensor(sample['input_trajectory'])
        
        # ERA5 raw data (keep as-is for CNN)
        env_data = sample['environmental_data']
        
        # Target positions
        targets = []
        for fh in [6, 12, 24, 48, 72]:
            t = sample['targets'][f't+{fh}h']
            targets.append([t['lat'], t['lon']])
        targets = torch.FloatTensor(targets)
        
        return traj, env_data, targets


def collate_fn_cnn(batch):
    """Custom collate to handle list of dicts in ERA5 data."""
    trajs, envs, targets = zip(*batch)
    
    trajs = torch.stack(trajs)
    targets = torch.stack(targets)
    # envs stays as list of lists of dicts
    
    return trajs, list(envs), targets

### COMBINED MODEL RUNS

In [None]:
"""
Unified Training Framework for Storm Track Forecasting
=====================================================

Three models, one data pipeline, fair comparison.
"""

import torch
import torch.nn as nn
import wandb
from torch.utils.data import DataLoader
import numpy as np


# ============================================================================
# SHARED COMPONENTS
# ============================================================================

# Use the same ERA5CNNEncoder we already built
# Use the same StormDatasetCNN we already built
# Use the same collate_fn_cnn we already built


# ============================================================================
# MODEL 1: GRU-CNN (Baseline from Paper)
# ============================================================================

class GRUCNN_Baseline(nn.Module):
    """
    GRU-CNN baseline: Direct prediction trajectory → future positions.
    Simpler task than diffusion.
    """
    
    def __init__(self, hidden_dim=128, num_layers=2, dropout=0.1):
        super().__init__()
        
        # CNN encoder for ERA5
        self.era5_encoder = ERA5CNNEncoder(output_dim=hidden_dim)
        
        # Trajectory embedding
        self.traj_embed = nn.Linear(6, hidden_dim)
        
        # GRU processes combined features
        self.gru = nn.GRU(
            input_size=hidden_dim * 2,  # traj + era5
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )
        
        # Decoder: predict 5 positions directly
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, 10)  # 5 positions × 2 coords
        )
    
    def forward(self, past_traj, env_data_batch):
        """
        Direct prediction: past → future
        
        Args:
            past_traj: (batch, 8, 6)
            env_data_batch: list of env dicts
        
        Returns:
            (batch, 5, 2) predicted positions
        """
        batch_size = past_traj.shape[0]
        
        # Encode
        era5_features = self.era5_encoder(env_data_batch)  # (batch, 8, hidden)
        traj_features = self.traj_embed(past_traj)  # (batch, 8, hidden)
        
        # Combine
        combined = torch.cat([traj_features, era5_features], dim=-1)  # (batch, 8, 2*hidden)
        
        # GRU
        gru_out, _ = self.gru(combined)  # (batch, 8, hidden)
        
        # Decode from last hidden state
        predictions = self.decoder(gru_out[:, -1, :])  # (batch, 10)
        return predictions.view(batch_size, 5, 2)


# ============================================================================
# MODEL 2: GRU-CNN with Pooled Features (Ablation)
# ============================================================================

class GRUCNN_Pooled(nn.Module):
    """
    Ablation: GRU-CNN but with pooled ERA5 instead of CNN.
    Tests if spatial structure matters.
    """
    
    def __init__(self, hidden_dim=128, num_layers=2, dropout=0.1):
        super().__init__()
        
        # Simple pooling instead of CNN
        self.era5_pooler = ERA5Pooler()
        self.era5_embed = nn.Linear(10, hidden_dim)
        
        # Rest is same as CNN version
        self.traj_embed = nn.Linear(6, hidden_dim)
        
        self.gru = nn.GRU(
            input_size=hidden_dim * 2,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, 10)
        )
    
    def forward(self, past_traj, env_data_batch):
        batch_size = past_traj.shape[0]
        
        # Pool ERA5 (loses spatial structure)
        era5_pooled = self.era5_pooler(env_data_batch)  # (batch, 8, 10)
        era5_features = self.era5_embed(era5_pooled)  # (batch, 8, hidden)
        
        traj_features = self.traj_embed(past_traj)
        combined = torch.cat([traj_features, era5_features], dim=-1)
        
        gru_out, _ = self.gru(combined)
        predictions = self.decoder(gru_out[:, -1, :])
        return predictions.view(batch_size, 5, 2)


# ============================================================================
# MODEL 3: Diffusion Transformer (Your Novel Approach)
# ============================================================================

# Keep your existing DiffusionTransformerCNN and GaussianDiffusion classes


# ============================================================================
# UNIFIED TRAINING FUNCTIONS
# ============================================================================

def train_direct_model(model, dataloader, optimizer, device, epoch, model_name):
    """
    Training for GRU-CNN models (direct prediction, no diffusion).
    """
    model.train()
    total_loss = 0
    
    for batch_idx, (traj, env_data, targets) in enumerate(dataloader):
        traj = traj.to(device)
        targets = targets.to(device)
        
        # Direct prediction
        predictions = model(traj, env_data)
        
        # MSE loss on positions
        loss = nn.MSELoss()(predictions, targets)
        
        # Backprop
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        
        if batch_idx % 10 == 0:
            wandb.log({
                f'{model_name}/batch_loss': loss.item(),
                f'{model_name}/step': epoch * len(dataloader) + batch_idx,
            })
    
    return total_loss / len(dataloader)


@torch.no_grad()
def evaluate_direct_model(model, dataset, device, n_samples=20):
    """Evaluate direct prediction models."""
    model.eval()
    errors = []
    
    for idx in range(min(n_samples, len(dataset))):
        traj, env_data, actual = dataset[idx]
        
        pred = model(traj.unsqueeze(0).to(device), [env_data])[0].cpu().numpy()
        actual = actual.numpy()
        
        # 72h error
        lat_err = (pred[4, 0] - actual[4, 0]) * 111
        lon_err = (pred[4, 1] - actual[4, 1]) * 111 * np.cos(np.radians(actual[4, 0]))
        dist_err = np.sqrt(lat_err**2 + lon_err**2)
        errors.append(dist_err)
    
    return np.mean(errors), np.std(errors)


# ============================================================================
# UNIFIED TRAINING SCRIPT
# ============================================================================

def train_all_models(data_path, n_epochs=100):
    """
    Train all three models on the same data for fair comparison.
    """
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load shared dataset
    dataset = StormDatasetCNN(data_path)
    dataloader = DataLoader(
        dataset, 
        batch_size=16, 
        shuffle=True, 
        num_workers=0,
        collate_fn=collate_fn_cnn
    )
    
    # ========================================================================
    # MODEL 1: GRU-CNN Baseline
    # ========================================================================
    
    print("\n" + "="*70)
    print("TRAINING MODEL 1: GRU-CNN Baseline")
    print("="*70)
    
    wandb.init(project="cyclone-comparison", name="grucnn-baseline", reinit=True)
    
    model_grucnn = GRUCNN_Baseline(hidden_dim=128, num_layers=2).to(device)
    optimizer_grucnn = torch.optim.AdamW(model_grucnn.parameters(), lr=1e-3, weight_decay=0.01)
    scheduler_grucnn = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_grucnn, T_max=n_epochs)
    
    best_loss_grucnn = float('inf')
    
    for epoch in range(n_epochs):
        print(f"\nEpoch {epoch+1}/{n_epochs}")
        
        loss = train_direct_model(model_grucnn, dataloader, optimizer_grucnn, device, epoch, 'grucnn')
        
        if (epoch + 1) % 5 == 0:
            eval_mean, eval_std = evaluate_direct_model(model_grucnn, dataset, device)
            wandb.log({'grucnn/eval_72h_km': eval_mean, 'grucnn/epoch': epoch})
            print(f"  Loss: {loss:.4f}, Eval: {eval_mean:.1f}±{eval_std:.1f} km")
        else:
            print(f"  Loss: {loss:.4f}")
        
        wandb.log({'grucnn/epoch_loss': loss, 'grucnn/lr': optimizer_grucnn.param_groups[0]['lr']})
        scheduler_grucnn.step()
        
        if loss < best_loss_grucnn:
            best_loss_grucnn = loss
            torch.save(model_grucnn.state_dict(), 'grucnn_baseline_best.pt')
    
    wandb.finish()
    
    # ========================================================================
    # MODEL 2: GRU-CNN Pooled (Ablation)
    # ========================================================================
    
    print("\n" + "="*70)
    print("TRAINING MODEL 2: GRU-CNN Pooled (Ablation)")
    print("="*70)
    
    wandb.init(project="cyclone-comparison", name="grucnn-pooled", reinit=True)
    
    model_pooled = GRUCNN_Pooled(hidden_dim=128, num_layers=2).to(device)
    optimizer_pooled = torch.optim.AdamW(model_pooled.parameters(), lr=1e-3, weight_decay=0.01)
    scheduler_pooled = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_pooled, T_max=n_epochs)
    
    best_loss_pooled = float('inf')
    
    for epoch in range(n_epochs):
        print(f"\nEpoch {epoch+1}/{n_epochs}")
        
        loss = train_direct_model(model_pooled, dataloader, optimizer_pooled, device, epoch, 'pooled')
        
        if (epoch + 1) % 5 == 0:
            eval_mean, eval_std = evaluate_direct_model(model_pooled, dataset, device)
            wandb.log({'pooled/eval_72h_km': eval_mean, 'pooled/epoch': epoch})
            print(f"  Loss: {loss:.4f}, Eval: {eval_mean:.1f}±{eval_std:.1f} km")
        else:
            print(f"  Loss: {loss:.4f}")
        
        wandb.log({'pooled/epoch_loss': loss, 'pooled/lr': optimizer_pooled.param_groups[0]['lr']})
        scheduler_pooled.step()
        
        if loss < best_loss_pooled:
            best_loss_pooled = loss
            torch.save(model_pooled.state_dict(), 'grucnn_pooled_best.pt')
    
    wandb.finish()
    
    # ========================================================================
    # MODEL 3: Diffusion Transformer (if you want to retry with fixes)
    # ========================================================================
    
    print("\n" + "="*70)
    print("SKIPPING Diffusion Transformer (can retry later if fixed)")
    print("="*70)
    
    # You can add this back once we debug the diffusion training
    
    print("\n" + "="*70)
    print("TRAINING COMPLETE!")
    print("="*70)
    print(f"GRU-CNN Baseline best loss: {best_loss_grucnn:.4f}")
    print(f"GRU-CNN Pooled best loss: {best_loss_pooled:.4f}")


In [None]:
train_all_models('processed_data/processed_samples_1980.pkl', n_epochs=50)