In [None]:
import os
import time
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

# --- Project Imports ---
from src.data.cmems_dataset import load_cmems_uv, SlidingWindowUVDataset
from src.utils import seed_all

# --- Configuration ---
CONFIG = {
    "data_path": "/home/svillhauer/Desktop/Thesis/Currents/deep_spatiotemporal_currents/src/data/cmems_mod_glo_phy_anfc_merged-uv_PT1H-i_1770985217793.nc", 
    "out_dir": "thesis_experiments",
    "epochs": 30,
    "batch_size": 16,
    "lr": 1e-3,
    "seq_len": 3,
    "base_ch": 32,
    "lstm_ch": 128,
    "seed": 42
}

seed_all(CONFIG['seed'])
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on: {DEVICE}")
os.makedirs(CONFIG['out_dir'], exist_ok=True)

# --- Z-Score Helper Classes (Defined Here to Ensure Consistency) ---
class ZScoreStats:
    def __init__(self, u_mean, u_std, v_mean, v_std):
        self.u_mean = u_mean
        self.u_std = u_std
        self.v_mean = v_mean
        self.v_std = v_std

def compute_zscore(uv):
    # uv: (T,2,H,W)
    u = uv[:,0]
    v = uv[:,1]
    
    u_mean = np.mean(u)
    u_std  = np.std(u) + 1e-8
    v_mean = np.mean(v)
    v_std  = np.std(v) + 1e-8
    
    return ZScoreStats(u_mean, u_std, v_mean, v_std)

def apply_zscore(uv, stats):
    uv_n = uv.copy()
    uv_n[:,0] = (uv_n[:,0] - stats.u_mean) / stats.u_std
    uv_n[:,1] = (uv_n[:,1] - stats.v_mean) / stats.v_std
    return uv_n

def invert_zscore(uv_n, stats):
    """Converts normalized data back to physical units (m/s)"""
    # Handle both tensor and numpy input
    if isinstance(uv_n, torch.Tensor):
        uv_n = uv_n.detach().cpu().numpy()
        
    uv = uv_n.copy()
    uv[:,0] = uv[:,0] * stats.u_std + stats.u_mean
    uv[:,1] = uv[:,1] * stats.v_std + stats.v_mean
    return uv

Running on: cuda


In [None]:
# --- 1. ConvLSTM Cell ---
class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        super(ConvLSTMCell, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias
        
        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state
        combined = torch.cat([input_tensor, h_cur], dim=1)
        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)
        
        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)
        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=DEVICE),
                torch.zeros(batch_size, self.hidden_dim, height, width, device=DEVICE))

# --- 2. Thesis Variant Model ---
class ThesisVariant(nn.Module):
    def __init__(self, in_ch, out_ch, base_ch=32, lstm_ch=128, seq_len=3, mode='unet_convlstm'):
        super().__init__()
        self.mode = mode
        self.seq_len = seq_len
        
        def double_conv(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, 1, 1), nn.BatchNorm2d(out_c), nn.ReLU(True),
                nn.Conv2d(out_c, out_c, 3, 1, 1), nn.BatchNorm2d(out_c), nn.ReLU(True)
            )
        
        self.inc = double_conv(in_ch, base_ch)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), double_conv(base_ch, base_ch * 2))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), double_conv(base_ch * 2, base_ch * 4))
        self.down3 = nn.Sequential(nn.MaxPool2d(2), double_conv(base_ch * 4, base_ch * 8))

        if 'lstm' in mode:
            self.bottleneck = ConvLSTMCell(base_ch * 8, lstm_ch, (3,3), True)
            bot_out_ch = lstm_ch
        else:
            self.bottleneck = double_conv(base_ch * 8, lstm_ch)
            bot_out_ch = lstm_ch

        use_skips = ('unet' in mode)
        self.up1 = UpBlock(bot_out_ch, base_ch * 4, use_skips)
        self.up2 = UpBlock(base_ch * 4, base_ch * 2, use_skips)
        self.up3 = UpBlock(base_ch * 2, base_ch, use_skips)
        self.outc = nn.Conv2d(base_ch, out_ch, 1)

    def forward(self, x):
        if 'lstm' in self.mode:
            b, t, c, h, w = x.shape
            x_flat = x.view(b * t, c, h, w)
            x1 = self.inc(x_flat)
            x2 = self.down1(x1)
            x3 = self.down2(x2)
            x4 = self.down3(x3) 
            
            lstm_in = x4.view(b, t, -1, x4.shape[-2], x4.shape[-1])
            h_state, c_state = self.bottleneck.init_hidden(b, (x4.shape[-2], x4.shape[-1]))
            for t_step in range(t):
                h_state, c_state = self.bottleneck(lstm_in[:, t_step], (h_state, c_state))
            bot_out = h_state
            
            s1 = x1.view(b, t, -1, h, w)[:, -1]
            s2 = x2.view(b, t, -1, h//2, w//2)[:, -1]
            s3 = x3.view(b, t, -1, h//4, w//4)[:, -1]
            
        else:
            x1 = self.inc(x)
            x2 = self.down1(x1)
            x3 = self.down2(x2)
            x4 = self.down3(x3)
            bot_out = self.bottleneck(x4)
            s1, s2, s3 = x1, x2, x3

        x = self.up1(bot_out, s3)
        x = self.up2(x, s2)
        x = self.up3(x, s1)
        return self.outc(x)

class UpBlock(nn.Module):
    def __init__(self, in_c, out_c, use_skips):
        super().__init__()
        self.use_skips = use_skips
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = nn.Sequential(
            nn.Conv2d(in_c + (out_c if use_skips else 0), out_c, 3, 1, 1),
            nn.BatchNorm2d(out_c), nn.ReLU(True),
            nn.Conv2d(out_c, out_c, 3, 1, 1),
            nn.BatchNorm2d(out_c), nn.ReLU(True)
        )
    def forward(self, x1, x2):
        x1 = self.up(x1)
        if self.use_skips:
            diffY = x2.size()[2] - x1.size()[2]
            diffX = x2.size()[3] - x1.size()[3]
            x1 = nn.functional.pad(x1, [diffX//2, diffX-diffX//2, diffY//2, diffY-diffY//2])
            x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

In [None]:
def uncertainty_loss(pred_mean, pred_logvar, target):
    sq_diff = (pred_mean - target)**2
    loss = 0.5 * torch.exp(-pred_logvar) * sq_diff + 0.5 * pred_logvar
    return loss.mean()

def train_and_evaluate(mode, train_loader, val_loader, stats, config):
    print(f"\n[{mode.upper()}] Starting Training...")
    
    in_ch = 2 * config['seq_len'] if mode == "standard_unet" else 2
    
    model = ThesisVariant(
        in_ch=in_ch, out_ch=4, 
        base_ch=config['base_ch'], lstm_ch=config['lstm_ch'], 
        seq_len=config['seq_len'], mode=mode
    ).to(DEVICE)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
    history = []
    
    for epoch in range(1, config['epochs'] + 1):
        model.train()
        total_loss = 0
        
        # Pbar (optional, can remove if too noisy)
        pbar = tqdm(train_loader, desc=f"Ep {epoch}", leave=False)
        for X, Y in pbar:
            X, Y = X.to(DEVICE), Y.to(DEVICE)
            
            if mode == "standard_unet":
                b, t, c, h, w = X.shape
                X = X.view(b, t * c, h, w)
                
            optimizer.zero_grad()
            out = model(X)
            pred_mean, pred_logvar = out[:, :2], out[:, 2:]
            
            loss = uncertainty_loss(pred_mean, pred_logvar, Y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            pbar.set_postfix(loss=loss.item())
            
        # Validation
        val_rmse, val_nll = evaluate(model, val_loader, stats, mode)
        
        history.append({
            "epoch": epoch,
            "val_rmse": val_rmse,
            "val_nll": val_nll,
            "train_loss": total_loss / len(train_loader)
        })
        
        if epoch % 5 == 0:
            print(f"Ep {epoch}: Train Loss={total_loss/len(train_loader):.3f} | Val RMSE={val_rmse:.3f} | Val NLL={val_nll:.3f}")
            
    return history, model

def evaluate(model, loader, stats, mode):
    model.eval()
    rmse_accum = 0
    nll_accum = 0
    count = 0
    
    with torch.no_grad():
        for X, Y in loader:
            X, Y = X.to(DEVICE), Y.to(DEVICE)
            if mode == "standard_unet":
                b, t, c, h, w = X.shape
                X = X.view(b, t * c, h, w)
                
            out = model(X)
            pred_mean, pred_logvar = out[:, :2], out[:, 2:]
            
            # NLL (Loss in normalized space)
            loss = uncertainty_loss(pred_mean, pred_logvar, Y)
            nll_accum += loss.item() * X.size(0)
            
            # RMSE (Physical Units - Z-Score Inversion)
            # 1. Invert Prediction
            pred_phys = invert_zscore(pred_mean, stats)
            # 2. Invert Ground Truth
            true_phys = invert_zscore(Y, stats)
            
            # 3. Calculate Speed RMSE
            sq_err = (pred_phys - true_phys)**2
            rmse_accum += np.sqrt(np.mean(sq_err)) * X.size(0)
            count += X.size(0)
            
    return rmse_accum / count, nll_accum / count

In [None]:
# 1. Learning Curves
plt.figure(figsize=(10, 5))
labels = {
    "unet_convlstm": "Proposed (U-Net + LSTM)",
    "cnn_convlstm": "No Skips (CNN + LSTM)",
    "standard_unet": "No Memory (Standard U-Net)"
}

for mode in MODES:
    if mode in results and len(results[mode]) > 0:
        df = pd.DataFrame(results[mode])
        plt.plot(df['epoch'], df['val_rmse'], label=labels[mode], marker='.')

plt.title("Thesis Model Comparison: RMSE")
plt.xlabel("Epoch")
plt.ylabel("RMSE (m/s)")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

# 2. Difference Field (Qualitative)
sample_idx = 10 
X_sample, Y_sample = val_ds[sample_idx]

# Ensure inputs are Tensors
if isinstance(X_sample, np.ndarray):
    X_in = torch.from_numpy(X_sample).float().unsqueeze(0).to(DEVICE)
else:
    X_in = X_sample.float().unsqueeze(0).to(DEVICE)

# Get Ground Truth Speed (Use invert_zscore)
if isinstance(Y_sample, torch.Tensor):
    Y_sample_np = Y_sample.cpu().numpy()
else:
    Y_sample_np = Y_sample

# FIX: Using invert_zscore here
Y_phys = invert_zscore(Y_sample_np, stats)
true_speed = np.sqrt(Y_phys[0]**2 + Y_phys[1]**2)

# Setup Figure
fig, axes = plt.subplots(2, 4, figsize=(18, 8), constrained_layout=True)
max_speed_disp = np.max(true_speed)
max_error_disp = max_speed_disp * 0.5 

# Plot Ground Truth
im_gt = axes[0,0].imshow(true_speed, cmap='viridis', vmin=0, vmax=max_speed_disp)
axes[0,0].set_title("Ground Truth\n(Current Speed)")
axes[0,0].axis('off')
fig.colorbar(im_gt, ax=axes[0,0], fraction=0.046, pad=0.04, label="m/s")
axes[1,0].axis('off')

for i, mode in enumerate(MODES):
    if mode not in models: continue
    model = models[mode]
    model.eval()
    
    if mode == "standard_unet":
        inp = X_in.view(1, -1, X_in.shape[-2], X_in.shape[-1])
    else:
        inp = X_in
        
    with torch.no_grad():
        out = model(inp)
    
    # FIX: Using invert_zscore here
    pred_phys = invert_zscore(out[0, :2].cpu().numpy(), stats)
    pred_speed = np.sqrt(pred_phys[0]**2 + pred_phys[1]**2)
    diff = np.abs(true_speed - pred_speed)
    
    # Plot Prediction
    im_pred = axes[0, i+1].imshow(pred_speed, cmap='viridis', vmin=0, vmax=max_speed_disp)
    axes[0, i+1].set_title(f"{labels[mode]}\nPrediction")
    axes[0, i+1].axis('off')
    
    # Plot Difference
    im_err = axes[1, i+1].imshow(diff, cmap='inferno', vmin=0, vmax=max_error_disp)
    axes[1, i+1].set_title(f"Error Difference\n{labels[mode]}")
    axes[1, i+1].axis('off')
    cb = fig.colorbar(im_err, ax=axes[1, i+1], fraction=0.046, pad=0.04)
    cb.set_label("Error (m/s)")

plt.suptitle(f"Comparison at Sample {sample_idx}", fontsize=16)
plt.show()

Loading dataset...
Data Shape: (1045, 2, 64, 64) (Time, Channels, H, W)
Computing Z-Score Stats on Train set...
  U mean=0.1161, std=0.1591
  V mean=-0.0443, std=0.1502

[UNET_CONVLSTM] Starting Training...


KeyboardInterrupt: 

In [None]:
# 1. Learning Curves
plt.figure(figsize=(10, 5))
labels = {
    "unet_convlstm": "Proposed (U-Net + LSTM)",
    "cnn_convlstm": "No Skips (CNN + LSTM)",
    "standard_unet": "No Memory (Standard U-Net)"
}

for mode in MODES:
    df = pd.DataFrame(results[mode])
    plt.plot(df['epoch'], df['val_rmse'], label=labels[mode], marker='.')

plt.title("Thesis Model Comparison: RMSE")
plt.xlabel("Epoch")
plt.ylabel("RMSE (m/s)")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

# 2. Difference Field (Qualitative)
# Pick a specific sample to visualize
sample_idx = 10 
X_sample, Y_sample = val_ds[sample_idx]
X_in = torch.tensor(X_sample).unsqueeze(0).to(DEVICE) # (1, T, 2, H, W)

# Get Ground Truth Speed
Y_phys = invert_minmax(Y_sample, stats)
true_speed = np.sqrt(Y_phys[0]**2 + Y_phys[1]**2)

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

# Plot GT
axes[0,0].imshow(true_speed, cmap='viridis')
axes[0,0].set_title("Ground Truth")
axes[1,0].axis('off')

for i, mode in enumerate(MODES):
    model = models[mode]
    model.eval()
    
    # Prepare input
    if mode == "standard_unet":
        inp = X_in.view(1, -1, X_in.shape[-2], X_in.shape[-1])
    else:
        inp = X_in
        
    with torch.no_grad():
        out = model(inp)
    
    # Phys conversion
    pred_phys = invert_minmax(out[0, :2].cpu().numpy(), stats)
    pred_speed = np.sqrt(pred_phys[0]**2 + pred_phys[1]**2)
    diff = np.abs(true_speed - pred_speed)
    
    # Plot Pred
    axes[0, i+1].imshow(pred_speed, cmap='viridis')
    axes[0, i+1].set_title(labels[mode])
    
    # Plot Diff
    axes[1, i+1].imshow(diff, cmap='inferno')
    axes[1, i+1].set_title(f"Error ({mode})")

plt.tight_layout()
plt.show()