In [14]:
import pickle
import numpy as np

# Load one of your processed samples
with open('Processed_Data_Subset/processed_samples_1980.pkl', 'rb') as f:
    samples = pickle.load(f)

sample = samples[0]  # Look at first sample

print("Input trajectory shape:", sample['input_trajectory'].shape)
print("Number of environmental timesteps:", len(sample['environmental_data']))
print("\nOne environmental data point contains:")
for key in sample['environmental_data'][0].keys():
    if key == 'wind':
        print(f"  wind fields: {list(sample['environmental_data'][0]['wind'].keys())}")
        print(f"    u_300 shape: {sample['environmental_data'][0]['wind']['u_300'].shape}")
    elif key == 'sst':
        print(f"  sst shape: {sample['environmental_data'][0]['sst'].shape}")
    elif key == 'geopotential':
        print(f"  geopotential shape: {sample['environmental_data'][0]['geopotential'].shape}")

Input trajectory shape: (8, 6)
Number of environmental timesteps: 8

One environmental data point contains:
  wind fields: ['u_300', 'v_300', 'u_500', 'v_500', 'u_700', 'v_700', 'u_850', 'v_850']
    u_300 shape: (21, 21)
  sst shape: (21, 21)
  geopotential shape: (21, 41)


#### Gaussian Diffusion and Embedding (Same as vanilla)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math


class SinusoidalPositionEmbedding(nn.Module):
    """Timestep embedding for diffusion process."""
    
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    
    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = t[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

In [9]:
class GaussianDiffusion:
    """
    Simplified DDPM for storm track forecasting.
    """
    
    def __init__(self, timesteps=1000, beta_start=0.0001, beta_end=0.02):
        self.timesteps = timesteps
        
        # Linear beta schedule
        self.betas = torch.linspace(beta_start, beta_end, timesteps)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        
        # Calculations for diffusion q(x_t | x_{t-1})
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        
        # Calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = (
            self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
    
    def q_sample(self, x_start, t, noise=None):
        """
        Forward diffusion: add noise to clean data.
        
        Args:
            x_start: (batch, 5, 2) - clean future positions
            t: (batch,) - diffusion timestep
            noise: optional noise to add
        
        Returns:
            x_t: noisy version of x_start at timestep t
        """
        if noise is None:
            noise = torch.randn_like(x_start)
        
        sqrt_alpha = self.sqrt_alphas_cumprod[t]
        sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t]
        
        # Reshape for broadcasting
        sqrt_alpha = sqrt_alpha[:, None, None]
        sqrt_one_minus_alpha = sqrt_one_minus_alpha[:, None, None]
        
        return sqrt_alpha * x_start + sqrt_one_minus_alpha * noise
    
    def p_sample(self, model, x_t, t, past_traj, era5_features):
        """
        Reverse diffusion: denoise one step.
        
        Args:
            model: DiffusionTransformer
            x_t: (batch, 5, 2) - noisy positions at timestep t
            t: (batch,) - current timestep
            past_traj: (batch, 8, 6) - conditioning
            era5_features: (batch, 8, 10) - conditioning
        
        Returns:
            x_{t-1}: less noisy positions
        """
        # Predict noise
        predicted_noise = model(past_traj, era5_features, x_t, t)
        
        # Calculate x_0 prediction
        alpha = self.alphas_cumprod[t][:, None, None]
        alpha_prev = self.alphas_cumprod_prev[t][:, None, None]
        beta = self.betas[t][:, None, None]
        
        # Predict x_0
        pred_x0 = (x_t - torch.sqrt(1 - alpha) * predicted_noise) / torch.sqrt(alpha)
        
        # Calculate x_{t-1}
        mean = (
            torch.sqrt(alpha_prev) * beta * pred_x0 +
            torch.sqrt(self.alphas[t][:, None, None]) * (1 - alpha_prev) * x_t
        ) / (1 - alpha)
        
        if t[0] > 0:
            noise = torch.randn_like(x_t)
            variance = self.posterior_variance[t][:, None, None]
            return mean + torch.sqrt(variance) * noise
        else:
            return mean
    
    @torch.no_grad()
    def sample(self, model, past_traj, era5_features, device):
        """
        Generate storm track by denoising from pure noise.
        
        Args:
            model: trained DiffusionTransformer
            past_traj: (batch, 8, 6)
            era5_features: (batch, 8, 10)
        
        Returns:
            predicted_track: (batch, 5, 2) - forecasted positions
        """
        batch_size = past_traj.shape[0]
        
        # Start from pure noise
        x = torch.randn(batch_size, 5, 2, device=device)
        
        # Iteratively denoise
        for i in reversed(range(self.timesteps)):
            t = torch.full((batch_size,), i, device=device, dtype=torch.long)
            x = self.p_sample(model, x, t, past_traj, era5_features)
        
        return x

## CNN Based ERA5 Encoder

In [15]:
"""
CNN Encoder for ERA5 Spatial Fields
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class ERA5CNNEncoder(nn.Module):
    """
    Encode ERA5 spatial grids using CNN.
    
    Takes multiple 2D fields and produces a fixed-size embedding vector.
    """
    
    def __init__(self, output_dim=256):
        super().__init__()
        
        # Separate encoders for different field types (different spatial sizes)
        
        # Wind + SST encoder (21x21 grids)
        # Input: 9 channels (8 wind + 1 SST)
        self.wind_sst_encoder = nn.Sequential(
            nn.Conv2d(9, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 21x21 -> 10x10
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 10x10 -> 5x5
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),  # 5x5 -> 1x1
        )
        
        # Geopotential encoder (21x41 grids)
        # Input: 1 channel
        self.geo_encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 21x41 -> 10x20
            
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 10x20 -> 5x10
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),  # 5x10 -> 1x1
        )
        
        # Combine features and project to output dimension
        self.fusion = nn.Sequential(
            nn.Linear(128 + 64, output_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
    
    def forward(self, env_data_batch):
        """
        Args:
            env_data_batch: List of environmental data (batch_size samples)
            Each sample contains 8 timesteps of ERA5 fields
        
        Returns:
            Tensor of shape (batch_size, n_timesteps, output_dim)
        """
        batch_features = []
        
        for env_timesteps in env_data_batch:  # Each sample in batch
            timestep_features = []
            
            for env in env_timesteps:  # Each of 8 timesteps
                # Prepare wind + SST (9 channels, 21x21)
                wind_sst_fields = []
                for level in [300, 500, 700, 850]:
                    wind_sst_fields.append(env['wind'][f'u_{level}'])
                    wind_sst_fields.append(env['wind'][f'v_{level}'])
                wind_sst_fields.append(env['sst'])
                
                wind_sst = np.stack(wind_sst_fields, axis=0)  # (9, 21, 21)
                
                # Handle NaN values (replace with mean)
                for i in range(wind_sst.shape[0]):
                    field = wind_sst[i]
                    if np.any(np.isnan(field)):
                        wind_sst[i] = np.nan_to_num(field, nan=np.nanmean(field))
                
                wind_sst = torch.FloatTensor(wind_sst).unsqueeze(0)  # (1, 9, 21, 21)
                
                # Prepare geopotential (1 channel, 21x41)
                geo = env['geopotential']
                if np.any(np.isnan(geo)):
                    geo = np.nan_to_num(geo, nan=np.nanmean(geo))
                geo = torch.FloatTensor(geo).unsqueeze(0).unsqueeze(0)  # (1, 1, 21, 41)
                
                # Encode
                wind_sst_feat = self.wind_sst_encoder(wind_sst).squeeze(-1).squeeze(-1)  # (1, 128)
                geo_feat = self.geo_encoder(geo).squeeze(-1).squeeze(-1)  # (1, 64)
                
                # Fuse
                combined = torch.cat([wind_sst_feat, geo_feat], dim=1)  # (1, 192)
                fused = self.fusion(combined)  # (1, output_dim)
                
                timestep_features.append(fused)
            
            # Stack timesteps
            timestep_features = torch.cat(timestep_features, dim=0)  # (8, output_dim)
            batch_features.append(timestep_features)
        
        return torch.stack(batch_features, dim=0)  # (batch, 8, output_dim)


class DiffusionTransformerCNN(nn.Module):
    """
    Updated Diffusion Transformer using CNN encoder for ERA5.
    """
    
    def __init__(
        self,
        d_model=256,
        n_heads=8,
        n_layers=6,
        dropout=0.1,
    ):
        super().__init__()
        
        self.d_model = d_model
        
        # CNN encoder for ERA5
        self.era5_encoder = ERA5CNNEncoder(output_dim=d_model)
        
        # Embeddings
        self.traj_embed = nn.Linear(6, d_model)
        self.pos_embed = nn.Linear(2, d_model)
        
        # Diffusion timestep embedding
        self.time_embed = nn.Sequential(
            SinusoidalPositionEmbedding(d_model),
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model)
        )
        
        # Transformer encoder for conditioning
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_model * 4,
            dropout=dropout,
            batch_first=True
        )
        self.condition_encoder = nn.TransformerEncoder(encoder_layer, n_layers)
        
        # Transformer decoder for denoising
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_model * 4,
            dropout=dropout,
            batch_first=True
        )
        self.denoiser = nn.TransformerDecoder(decoder_layer, n_layers)
        
        # Output head
        self.output_head = nn.Linear(d_model, 2)
        
        # Learnable positional encoding
        self.forecast_pos_embed = nn.Parameter(torch.randn(5, d_model))
    
    def forward(self, past_traj, env_data_batch, noisy_positions, diffusion_t):
        """
        Args:
            past_traj: (batch, 8, 6)
            env_data_batch: List of environmental data dicts
            noisy_positions: (batch, 5, 2)
            diffusion_t: (batch,)
        """
        batch_size = past_traj.shape[0]
        
        # Encode ERA5 with CNN
        era5_tokens = self.era5_encoder(env_data_batch)  # (batch, 8, d_model)
        
        # Embed trajectory
        traj_tokens = self.traj_embed(past_traj)  # (batch, 8, d_model)
        
        # Concatenate conditioning
        conditioning = torch.cat([traj_tokens, era5_tokens], dim=1)  # (batch, 16, d_model)
        conditioning = self.condition_encoder(conditioning)
        
        # Embed noisy positions
        pos_tokens = self.pos_embed(noisy_positions)  # (batch, 5, d_model)
        pos_tokens = pos_tokens + self.forecast_pos_embed.unsqueeze(0)
        
        # Add diffusion timestep
        t_embed = self.time_embed(diffusion_t)
        pos_tokens = pos_tokens + t_embed.unsqueeze(1)
        
        # Denoise
        denoised = self.denoiser(pos_tokens, conditioning)
        predicted_noise = self.output_head(denoised)
        
        return predicted_noise

### Dataset

In [10]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pickle
from pathlib import Path

In [11]:
class StormDatasetCNN(Dataset):
    """Dataset that keeps raw ERA5 grids for CNN processing."""
    
    def __init__(self, pkl_file):
        with open(pkl_file, 'rb') as f:
            self.samples = pickle.load(f)
        
        # Filter valid samples
        self.samples = [
            s for s in self.samples 
            if all(s['targets'][f't+{fh}h'] is not None for fh in [6, 12, 24, 48, 72])
        ]
        
        print(f"Loaded {len(self.samples)} valid samples")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Past trajectory
        traj = torch.FloatTensor(sample['input_trajectory'])
        
        # ERA5 raw data (keep as-is for CNN)
        env_data = sample['environmental_data']
        
        # Target positions
        targets = []
        for fh in [6, 12, 24, 48, 72]:
            t = sample['targets'][f't+{fh}h']
            targets.append([t['lat'], t['lon']])
        targets = torch.FloatTensor(targets)
        
        return traj, env_data, targets


def collate_fn_cnn(batch):
    """Custom collate to handle list of dicts in ERA5 data."""
    trajs, envs, targets = zip(*batch)
    
    trajs = torch.stack(trajs)
    targets = torch.stack(targets)
    # envs stays as list of lists of dicts
    
    return trajs, list(envs), targets

### Training Function

In [12]:
def train_epoch_cnn(model, diffusion, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    
    for batch_idx, (traj, env_data, targets) in enumerate(dataloader):
        traj = traj.to(device)
        targets = targets.to(device)
        
        batch_size = traj.shape[0]
        
        # Sample diffusion timesteps
        t = torch.randint(0, diffusion.timesteps, (batch_size,), device=device)
        
        # Add noise
        noise = torch.randn_like(targets)
        noisy_targets = diffusion.q_sample(targets, t, noise=noise)
        
        # Predict noise (env_data processed inside model)
        predicted_noise = model(traj, env_data, noisy_targets, t)
        
        # Loss
        loss = nn.MSELoss()(predicted_noise, noise)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
        optimizer.step()
        
        total_loss += loss.item()
        
        if batch_idx % 10 == 0:
            print(f"  Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")
    
    return total_loss / len(dataloader)

### Training RUN

In [13]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load data with CNN dataset
dataset = StormDatasetCNN('Processed_Data_Subset/processed_samples_1980.pkl')
dataloader = DataLoader(
    dataset, 
    batch_size=16,  # Smaller batch due to CNN memory
    shuffle=True, 
    num_workers=0,
    collate_fn=collate_fn_cnn
)

# Create model with CNN encoder
model = DiffusionTransformerCNN(
    d_model=256,
    n_heads=8,
    n_layers=6,
    dropout=0.1
).to(device)

print(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")

# Diffusion
diffusion = GaussianDiffusion(timesteps=1000)
diffusion.betas = diffusion.betas.to(device)
diffusion.alphas_cumprod = diffusion.alphas_cumprod.to(device)
diffusion.alphas_cumprod_prev = diffusion.alphas_cumprod_prev.to(device)
diffusion.sqrt_alphas_cumprod = diffusion.sqrt_alphas_cumprod.to(device)
diffusion.sqrt_one_minus_alphas_cumprod = diffusion.sqrt_one_minus_alphas_cumprod.to(device)

# Optimizer with weight decay
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

# Training
n_epochs = 100
best_loss = float('inf')

for epoch in range(n_epochs):
    print(f"\nEpoch {epoch+1}/{n_epochs}")
    avg_loss = train_epoch_cnn(model, diffusion, dataloader, optimizer, device)
    print(f"Average loss: {avg_loss:.4f}")
    
    scheduler.step()
    
    # Save best model
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, 'best_model_cnn.pt')
        print(f"âœ“ Saved best model (loss: {avg_loss:.4f})")
    
    # Regular checkpoints
    if (epoch + 1) % 20 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, f'checkpoint_cnn_epoch_{epoch+1}.pt')


Using device: cpu
Loaded 727 valid samples


NameError: name 'SinusoidalPositionEmbedding' is not defined