In [None]:
# ============================================================
# Cell 1: Imports & Data (same as your existing cells 1-6)
# ============================================================
import sys
sys.path.append(".")

import os, json, time
import numpy as np
import torch
from torch.utils.data import DataLoader
from src.data.cmems_dataset import load_cmems_uv
from src.models.unet_convlstm_unc import UNetConvLSTMUncertainty, gaussian_nll
from src.data.cmems_dataset_ar import SlidingWindowMultiStep

# --- Data loading (same as before) ---
data_path = "/home/svillhauer/Desktop/Thesis/Currents/deep_spatiotemporal_currents_uncertainty/cmems_mod_glo_phy_anfc_merged-uv_PT1H-i_1771446846329.nc"

uv, time_arr, lat, lon = load_cmems_uv(
    data_path, u_var="utotal", v_var="vtotal",
    depth_index=0, regrid_hw=(64, 64)
)
print(uv.shape)

# --- Z-score normalisation ---
class ZScoreStats:
    def __init__(self, u_mean, u_std, v_mean, v_std):
        self.u_mean = u_mean
        self.u_std = u_std
        self.v_mean = v_mean
        self.v_std = v_std

def compute_zscore(uv):
    u, v = uv[:,0], uv[:,1]
    return ZScoreStats(float(np.mean(u)), float(np.std(u)+1e-8),
                       float(np.mean(v)), float(np.std(v)+1e-8))

def apply_zscore(uv, stats):
    uv_n = uv.copy()
    uv_n[:,0] = (uv_n[:,0] - stats.u_mean) / stats.u_std
    uv_n[:,1] = (uv_n[:,1] - stats.v_mean) / stats.v_std
    return uv_n

seq_len = 12
max_horizon = 12
split_t = int(len(uv) * 0.7)
uv_train, uv_val = uv[:split_t], uv[split_t:]

stats = compute_zscore(uv_train)
uv_train_n = apply_zscore(uv_train, stats)
uv_val_n   = apply_zscore(uv_val, stats)

train_ds = SlidingWindowMultiStep(uv_train_n, seq_len=seq_len, max_horizon=max_horizon)
val_ds   = SlidingWindowMultiStep(uv_val_n,   seq_len=seq_len, max_horizon=max_horizon)

print(f"Train: {len(train_ds)} samples,  Val: {len(val_ds)} samples")

In [None]:
# ============================================================
# Cell 2: Training functions
# ============================================================

def ar_forward_loss(model, X, Y_seq, ar_steps):
    """Autoregressive forward pass with accumulated NLL loss (differentiable)."""
    current_input = X
    total_nll = torch.tensor(0.0, device=X.device)
    total_mse = torch.tensor(0.0, device=X.device)

    for t in range(ar_steps):
        mu, logvar = model(current_input)
        target = Y_seq[:, t]
        total_nll = total_nll + gaussian_nll(mu, logvar, target)
        total_mse = total_mse + torch.mean((mu - target) ** 2)

        next_frame = mu.unsqueeze(1)
        current_input = torch.cat([current_input[:, 1:], next_frame], dim=1)

    return total_nll / ar_steps, torch.sqrt(total_mse / ar_steps)


@torch.no_grad()
def validate(model, loader, ar_steps, device):
    model.eval()
    total_nll, total_rmse, n = 0.0, 0.0, 0
    for X, Y_seq in loader:
        X, Y_seq = X.to(device), Y_seq.to(device)
        nll, r = ar_forward_loss(model, X, Y_seq, ar_steps)
        bs = X.shape[0]
        total_nll += nll.item() * bs
        total_rmse += r.item() * bs
        n += bs
    return total_nll / max(n,1), total_rmse / max(n,1)

In [None]:
# ============================================================
# Cell 3: Curriculum training loop
# ============================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=4, shuffle=False)

model = UNetConvLSTMUncertainty(base_ch=32, lstm_ch=256).to(device)
opt   = torch.optim.Adam(model.parameters(), lr=1e-3)

# --- Config ---
run_dir = "runs/curriculum_unc"
os.makedirs(run_dir, exist_ok=True)
history_path = os.path.join(run_dir, "history.jsonl")
if os.path.exists(history_path):
    os.remove(history_path)

curriculum       = [1, 2, 4, 8, 12]   # AR steps per stage
epochs_per_stage = 100
val_every        = 5
grad_clip        = 1.0
lr_decay         = 0.5                 # multiply LR by this at each new stage

best_val_rmse = float("inf")
global_epoch  = 0

for stage_idx, ar_steps in enumerate(curriculum):
    print(f"\n{'='*60}")
    print(f"STAGE {stage_idx}: ar_steps = {ar_steps}")
    print(f"{'='*60}")

    # Decay LR after first stage
    if stage_idx > 0:
        for pg in opt.param_groups:
            pg["lr"] *= lr_decay
        print(f"  LR -> {opt.param_groups[0]['lr']:.2e}")

    for local_epoch in range(epochs_per_stage):
        t0 = time.time()
        model.train()
        train_nll_sum, n_train = 0.0, 0

        for X, Y_seq in train_loader:
            X, Y_seq = X.to(device), Y_seq.to(device)

            nll, _ = ar_forward_loss(model, X, Y_seq, ar_steps)

            opt.zero_grad(set_to_none=True)
            nll.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            opt.step()

            bs = X.shape[0]
            train_nll_sum += nll.item() * bs
            n_train += bs

        train_nll = train_nll_sum / max(n_train, 1)
        elapsed = time.time() - t0

        # --- Validate ---
        do_val = (local_epoch % val_every == 0) or (local_epoch == epochs_per_stage - 1)
        if do_val:
            val_nll, val_rmse = validate(model, val_loader, ar_steps, device)
        else:
            val_nll, val_rmse = float("nan"), float("nan")

        # --- Print ---
        if do_val:
            print(f"  [{global_epoch:04d}] ar={ar_steps} train_nll={train_nll:.6f} "
                  f"val_nll={val_nll:.6f} val_rmse={val_rmse:.6f} ({elapsed:.1f}s)", flush=True)
        else:
            print(f"  [{global_epoch:04d}] ar={ar_steps} train_nll={train_nll:.6f} ({elapsed:.1f}s)", flush=True)

        # --- Save history (appended every epoch â€” safe if kernel dies) ---
        row = {
            "global_epoch": global_epoch, "stage": stage_idx, "ar_steps": ar_steps,
            "train_nll": float(train_nll),
            "val_nll": float(val_nll) if do_val else None,
            "val_rmse": float(val_rmse) if do_val else None,
            "lr": float(opt.param_groups[0]["lr"]),
        }
        with open(history_path, "a") as f:
            f.write(json.dumps(row) + "\n")

        # --- Checkpoints ---
        ckpt = {
            "global_epoch": global_epoch, "stage": stage_idx, "ar_steps": ar_steps,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": opt.state_dict(),
            "best_val_rmse": best_val_rmse if best_val_rmse < float("inf") else None,
            "stats": stats.__dict__, "seq_len": seq_len, "max_horizon": max_horizon,
            "curriculum": curriculum,
        }
        torch.save(ckpt, os.path.join(run_dir, "last.pt"))

        if do_val and val_rmse < best_val_rmse:
            best_val_rmse = val_rmse
            torch.save(ckpt, os.path.join(run_dir, "best.pt"))
            print(f"    -> New best val_rmse={val_rmse:.6f}")

        global_epoch += 1

    # End-of-stage checkpoint
    torch.save(ckpt, os.path.join(run_dir, f"stage_{stage_idx}_ar{ar_steps}.pt"))
    print(f"  Saved stage checkpoint: stage_{stage_idx}_ar{ar_steps}.pt")

print(f"\nDone! Best val_rmse = {best_val_rmse:.6f}")