In [100]:
# --- Colab / environment installs (safe to re-run) ---
!pip -q install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121
!pip -q install numpy pandas scikit-learn einops tqdm
!pip -q install kagglehub tensorflow   # only needed if you convert TFRecords → NPZ

import os, math, glob, random
from dataclasses import dataclass
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from sklearn.metrics import average_precision_score, precision_recall_curve

# --- Device setup ---
device   = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_cuda = torch.cuda.is_available()
print("Device:", device)

if use_cuda:
    torch.set_float32_matmul_precision('high')  # optimize GEMM performance
torch.backends.cudnn.benchmark = True           # accelerates convs on fixed input size (e.g., 64×64)

# --- Reproducibility ---
def set_seed(seed=1337):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(1337)


Device: cuda


In [101]:
# Set where you want the NPZ tiles
NPZ_ROOT = "/content/wildfire_npz_tiles_kaggle_v3"
os.makedirs(NPZ_ROOT, exist_ok=True)

def have_npz(root):
    return len(glob.glob(os.path.join(root, "*.npz"))) > 0

if not have_npz(NPZ_ROOT):
    print("No NPZ tiles found — converting from Kaggle TFRecords...")
    import kagglehub, tensorflow as tf

    path = kagglehub.dataset_download("fantineh/next-day-wildfire-spread")
    print("Kaggle dataset path:", path)
    tfrecs = sorted(glob.glob(os.path.join(path, "*.tfrecord")))
    assert len(tfrecs) > 0, "No TFRecords found in Kaggle dataset."

    keys = ['tmmn','NDVI','FireMask','population','elevation','vs','pdsi','pr','tmmx','sph','th','PrevFireMask','erc']

    def read_flat_float32(feat):
        fl = feat.float_list.value
        if len(fl) == 0: return None
        arr = np.asarray(fl, dtype=np.float32)
        if arr.size == 4096:  # 64x64
            return arr.reshape(64,64)
        s = int(round(math.sqrt(arr.size)))
        assert s*s == arr.size, f"Unexpected length {arr.size}"
        return arr.reshape(s, s)

    def wind_uv(vs, th):
        th_rad = th.copy()
        if np.nanmax(th_rad) > 6.4:  # degrees→radians
            th_rad = np.deg2rad(th_rad % 360.0)
        u = vs * np.cos(th_rad); v = vs * np.sin(th_rad)
        return u.astype(np.float32), v.astype(np.float32)

    def slope_aspect_from_elevation(z):
        gy, gx = np.gradient(z.astype(np.float32))
        mag = np.sqrt(gx**2 + gy**2)
        q95 = np.percentile(mag, 95) + 1e-6
        slope  = np.clip(mag / q95, 0, 1).astype(np.float32)
        aspect = np.arctan2(-gy, -gx).astype(np.float32)  # downslope dir -> upslope = opposite
        return slope, aspect

    def rh_proxy_from_sph(sph, tmmn, tmmx):
        s95 = np.percentile(sph, 95) + 1e-6
        rh = np.clip(sph / s95, 0, 1)
        tr = np.clip((tmmx - tmmn), 0, 30) / 30.0
        rh = np.clip(rh * (1.0 - 0.5*tr), 0, 1).astype(np.float32)
        return rh

    def barrier_from_population(pop):
        pop = np.clip(pop, 0, None).astype(np.float32)
        thr = np.percentile(pop, 90)
        return (pop >= thr).astype(np.float32)

    converted = 0
    for f in tqdm(tfrecs, desc="Converting TFRecords → NPZ"):
        for raw in tf.data.TFRecordDataset(f):
            ex = tf.train.Example.FromString(raw.numpy()).features.feature
            A = {k: read_flat_float32(ex[k]) if k in ex else None for k in keys}
            if A["PrevFireMask"] is None or A["FireMask"] is None:
                continue

            prev_fire = (A["PrevFireMask"] > 0.5).astype(np.float32)
            next_fire = (A["FireMask"]    > 0.5).astype(np.float32)
            tmmn = A["tmmn"] if A["tmmn"] is not None else np.full((64,64), 290, np.float32)
            tmmx = A["tmmx"] if A["tmmx"] is not None else np.full((64,64), 300, np.float32)
            temp = ((tmmn + tmmx)/2.0).astype(np.float32)

            vs = A["vs"] if A["vs"] is not None else np.zeros((64,64), np.float32)
            th = A["th"] if A["th"] is not None else np.zeros((64,64), np.float32)
            u, v = wind_uv(vs, th)

            ndvi = A["NDVI"] if A["NDVI"] is not None else np.full((64,64), 0.5, np.float32)
            if ndvi.min() < 0:  # [-1,1] → [0,1]
                ndvi = np.clip((ndvi + 1.0)/2.0, 0, 1).astype(np.float32)

            sph = A["sph"] if A["sph"] is not None else np.zeros((64,64), np.float32)
            rh  = rh_proxy_from_sph(sph, tmmn, tmmx)

            elev = A["elevation"] if A["elevation"] is not None else np.zeros((64,64), np.float32)
            slope, aspect = slope_aspect_from_elevation(elev)

            pop = A["population"] if A["population"] is not None else np.zeros((64,64), np.float32)
            barrier = barrier_from_population(pop)

            fields = dict(
                prev_fire=prev_fire, next_fire=next_fire,
                u=u, v=v, temp=temp, rh=rh, ndvi=ndvi,
                slope=slope, aspect=aspect, barrier=barrier
            )

            sid_feat = ex.get("sample_id", None)
            if sid_feat and len(sid_feat.bytes_list.value) > 0:
                sid = sid_feat.bytes_list.value[0].decode("utf-8")
            else:
                sid = f"{os.path.basename(f)}_{converted:07d}"

            np.savez(os.path.join(NPZ_ROOT, f"{sid}.npz"), **fields)
            converted += 1

    print(f"Converted {converted} tiles → {NPZ_ROOT}")
else:
    print(f"Using existing NPZ tiles at {NPZ_ROOT} (found {len(glob.glob(os.path.join(NPZ_ROOT, '*.npz')))} files)")


Using existing NPZ tiles at /content/wildfire_npz_tiles_kaggle_v3 (found 18545 files)


In [102]:
from dataclasses import dataclass
import numpy as np, torch, glob, os
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import gc

# -------------------------------
# Split files ONCE outside the dataset
# -------------------------------
def make_splits(root, seed=1337, train_frac=0.70, val_frac=0.15):
    files = sorted(glob.glob(os.path.join(root, "*.npz")))
    assert len(files) > 0, f"No .npz files found in {root}"
    rng = np.random.default_rng(seed)
    rng.shuffle(files)
    n = len(files)
    n_train = int(round(train_frac * n))
    n_val   = int(round(val_frac * n))
    train_files = files[:n_train]
    val_files   = files[n_train:n_train + n_val]
    test_files  = files[n_train + n_val:] if (n_train + n_val) < n else files[-1:]
    print(f"Split counts → train: {len(train_files)} | val: {len(val_files)} | test: {len(test_files)}")
    return train_files, val_files, test_files


@dataclass
class WildfirePaths:
    root: str


# -------------------------------
# Dataset definition (uses file list)
# -------------------------------
class WildfireDataset(Dataset):
    """
    Returns:
      X_raw: (9,H,W) in order [prev,u,v,temp,rh,ndvi,slope,aspect,barrier]
      y:     (1,H,W) next_fire mask
    """
    def __init__(self, files):
        self.files = files
        if len(self.files) == 0:
            raise ValueError("Empty file list")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, i):
        arr = np.load(self.files[i], allow_pickle=False, mmap_mode="r")
        req = ["prev_fire","next_fire","u","v","temp","rh","ndvi","slope","aspect"]
        missing = [k for k in req if k not in arr]
        if missing:
            raise KeyError(f"{os.path.basename(self.files[i])} missing {missing}")
        prev_fire = arr["prev_fire"][None,...].astype(np.float32)
        next_fire = arr["next_fire"][None,...].astype(np.float32)
        u   = arr["u"][None,...].astype(np.float32)
        v   = arr["v"][None,...].astype(np.float32)
        temp= arr["temp"][None,...].astype(np.float32)
        rh  = arr["rh"][None,...].astype(np.float32)
        ndvi= arr["ndvi"][None,...].astype(np.float32)
        slope = arr["slope"][None,...].astype(np.float32)
        aspect= arr["aspect"][None,...].astype(np.float32)
        barrier = (arr["barrier"][None,...].astype(np.float32)
                   if "barrier" in arr else np.zeros_like(prev_fire))
        X_raw = np.concatenate([prev_fire,u,v,temp,rh,ndvi,slope,aspect,barrier], axis=0)
        y = next_fire
        return {"X_raw": torch.from_numpy(X_raw), "y": torch.from_numpy(y)}


# -------------------------------
# Build datasets + loaders
# -------------------------------
paths = WildfirePaths(NPZ_ROOT)
train_files, val_files, test_files = make_splits(paths.root, seed=1337)

train_ds = WildfireDataset(train_files)
val_ds   = WildfireDataset(val_files)
test_ds  = WildfireDataset(test_files)

# --- Colab-safe DataLoader config ---
NUM_WORKERS = 0
PIN_MEMORY  = use_cuda
PERSISTENT  = False

# Clean up old loaders
try:
    del train_loader, val_loader, test_loader
except NameError:
    pass
gc.collect()
if use_cuda:
    torch.cuda.empty_cache()


def make_loader(ds, batch_size=16, upweight_positive=False, shuffle=False):
    if upweight_positive:
        weights = [5.0 if np.load(f, mmap_mode="r")["next_fire"].sum() > 0 else 1.0 for f in ds.files]
        sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)
        return DataLoader(ds, batch_size=batch_size, sampler=sampler,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, persistent_workers=PERSISTENT)
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle,
                      num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, persistent_workers=PERSISTENT)


train_loader = make_loader(train_ds, batch_size=16, upweight_positive=True)
val_loader   = make_loader(val_ds,   batch_size=16)
test_loader  = make_loader(test_ds,  batch_size=16)

print(f"Loaders ready → Train {len(train_ds)} | Val {len(val_ds)} | Test {len(test_ds)}")


Split counts → train: 12982 | val: 2782 | test: 2781
Loaders ready → Train 12982 | Val 2782 | Test 2781


In [103]:
@torch.no_grad()
def compute_channel_stats(ds, n_max=None):
    """
    Compute mean and std across 9 input channels:
      [prev,u,v,temp,rh,ndvi,slope,aspect,barrier]
    Skips normalization for channel 0 (prev_fire) and 8 (barrier).
    """
    sums = np.zeros(9, dtype=np.float64)
    sqs  = np.zeros(9, dtype=np.float64)
    count = 0

    # Safe DataLoader (Colab: num_workers=0 avoids multiprocessing bug)
    loader = DataLoader(ds, batch_size=32, shuffle=False, num_workers=0, pin_memory=False)

    for i, batch in enumerate(loader):
        x = batch["X_raw"].numpy()  # (B,9,H,W)
        B, C, H, W = x.shape
        x = x.reshape(B, C, -1)
        sums += x.sum(axis=(0,2))
        sqs  += (x**2).sum(axis=(0,2))
        count += B * H * W
        if n_max and (i+1)*B >= n_max:
            break

    mean = sums / count
    var  = sqs / count - mean**2
    std  = np.sqrt(np.maximum(var, 1e-8))

    # Don’t normalize previous fire mask or barrier
    std[[0,8]] = 1.0

    mean_t = torch.tensor(mean, dtype=torch.float32)
    std_t  = torch.tensor(std, dtype=torch.float32)
    return mean_t, std_t


# --- Compute channel-wise normalization stats ---
mean9, std9 = compute_channel_stats(train_ds)
mean9, std9 = mean9.to(device), std9.to(device)
print("Channel means:", mean9)
print("Channel stds:", std9)


Channel means: tensor([ 8.4351e-03, -1.3046e+00, -8.3524e-01,  2.8982e+02,  6.8143e-01,
         2.9261e+03,  4.0502e-01, -2.7433e-01,  1.1075e-01], device='cuda:0')
Channel stds: tensor([1.0000e+00, 2.2045e+00, 2.7889e+00, 1.8411e+01, 9.1338e-02, 3.1567e+03,
        2.7636e-01, 1.8277e+00, 1.0000e+00], device='cuda:0')


In [104]:
class PhysicsPrior(nn.Module):
    """
    Builds 4 physics channels:
      P: anisotropic spread propensity (prev_fire convolved with per-pixel directional kernel)
      Wx, Wy: wind angle cos/sin
      Damp: moisture/fuel damping (T,RH,NDVI proxy)
    """
    def __init__(self, kernel_radius=3, a0=0.0, a1=0.04, a2=0.03, a3=0.8):
        super().__init__()
        self.kernel_radius = kernel_radius
        self.a0, self.a1, self.a2, self.a3 = a0, a1, a2, a3
        self.register_buffer("angle_grid", self._make_angle_grid(kernel_radius))
    @staticmethod
    def _make_angle_grid(r):
        yy, xx = torch.meshgrid(torch.arange(-r, r+1), torch.arange(-r, r+1), indexing='ij')
        return torch.atan2(yy.float(), xx.float()+1e-8)
    def forward(self, prev_fire, u, v, slope, aspect, T, RH, NDVI, barrier=None):
        B, _, H, W = prev_fire.shape
        wind_angle = torch.atan2(v, u + 1e-8)
        wind_speed = torch.sqrt(u**2 + v**2)
        ws_norm    = torch.clamp(wind_speed / 10.0, 0, 1)
        slope_norm = torch.clamp(slope, 0, 1)

        r = self.kernel_radius; K = 2*r + 1
        ang_flat = self.angle_grid.view(1,1,K*K,1,1).to(prev_fire)

        wa  = wind_angle.unsqueeze(2)         # (B,1,1,H,W)
        asp = aspect.unsqueeze(2)
        aw  = (1.0 * ws_norm).unsqueeze(2)    # soften
        as_ = (1.0 * slope_norm).unsqueeze(2)

        dtheta_w = ang_flat - wa
        dtheta_s = ang_flat - asp

        Ww = torch.exp(aw * torch.cos(dtheta_w))
        Ws = torch.exp(as_ * torch.cos(dtheta_s))
        kernel_flat = Ww * Ws
        kernel_flat = kernel_flat / (kernel_flat.sum(dim=2, keepdim=True) + 1e-8)

        ker   = kernel_flat.reshape(B, K*K, H*W)            # (B,K*K,H*W)
        pf_unf= F.unfold(prev_fire, kernel_size=K, padding=r)  # (B,K*K,H*W)
        P1    = (pf_unf * ker).sum(dim=1).view(B,1,H,W)

        Damp  = torch.sigmoid(self.a0 + self.a1*T - self.a2*RH + self.a3*NDVI)
        P     = P1 * Damp
        if barrier is not None:
            P = P * (1.0 - barrier.clamp(0,1))

        Wx = torch.cos(wind_angle); Wy = torch.sin(wind_angle)
        return torch.cat([P, Wx, Wy, Damp], dim=1)

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch), nn.SiLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch), nn.SiLU(inplace=True),
        )
    def forward(self, x): return self.net(x)

class UNet(nn.Module):
    def __init__(self, in_ch, out_ch=1, base=64):
        super().__init__()
        self.down1 = DoubleConv(in_ch, base);   self.pool1 = nn.MaxPool2d(2)
        self.down2 = DoubleConv(base, base*2);  self.pool2 = nn.MaxPool2d(2)
        self.down3 = DoubleConv(base*2, base*4);self.pool3 = nn.MaxPool2d(2)
        self.bottom= DoubleConv(base*4, base*8)
        self.up3   = nn.ConvTranspose2d(base*8, base*4, 2, 2)
        self.conv3 = DoubleConv(base*8, base*4)
        self.up2   = nn.ConvTranspose2d(base*4, base*2, 2, 2)
        self.conv2 = DoubleConv(base*4, base*2)
        self.up1   = nn.ConvTranspose2d(base*2, base, 2, 2)
        self.conv1 = DoubleConv(base*2, base)
        self.outc  = nn.Conv2d(base, out_ch, 1)
    def forward(self, x):
        x1 = self.down1(x)
        x2 = self.down2(self.pool1(x1))
        x3 = self.down3(self.pool2(x2))
        xb = self.bottom(self.pool3(x3))
        x  = self.up3(xb); x = self.conv3(torch.cat([x, x3], dim=1))
        x  = self.up2(x);  x = self.conv2(torch.cat([x, x2], dim=1))
        x  = self.up1(x);  x = self.conv1(torch.cat([x, x1], dim=1))
        return self.outc(x)

class FocalLoss(nn.Module):
    def __init__(self, gamma=1.5, alpha=0.5, reduction="mean"):
        super().__init__(); self.gamma, self.alpha, self.reduction = gamma, alpha, reduction
    def forward(self, logits, targets):
        p = torch.sigmoid(logits)
        ce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
        p_t = p*targets + (1-p)*(1-targets)
        loss = ((1-p_t)**self.gamma) * ce
        alpha_t = self.alpha*targets + (1-self.alpha)*(1-targets)
        loss = alpha_t * loss
        return loss.mean() if self.reduction=="mean" else loss.sum()

class FocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.6, beta=0.4, gamma=0.75, smooth=1.0):
        super().__init__(); self.alpha, self.beta, self.gamma, self.smooth = alpha, beta, gamma, smooth
    def forward(self, logits, targets):
        p = torch.sigmoid(logits)
        dims = (0,2,3)
        TP = (p*targets).sum(dims)
        FP = (p*(1-targets)).sum(dims)
        FN = ((1-p)*targets).sum(dims)
        T = (TP + self.smooth) / (TP + self.alpha*FP + self.beta*FN + self.smooth)
        return (1 - T).pow(self.gamma).mean()


In [108]:
# Geometry-aware augmentation (keeps u,v and aspect consistent)
import math as _math
def aug_geo(X_raw, y):
    # order: [prev,u,v,temp,rh,ndvi,slope,aspect,barrier]
    if torch.rand(1).item() < 0.5:  # horizontal flip
        X_raw = torch.flip(X_raw, dims=[-1]); y = torch.flip(y, dims=[-1])
        X_raw[:,1] = -X_raw[:,1]  # u -> -u
        X_raw[:,7] = torch.atan2(torch.sin(X_raw[:,7]), -torch.cos(X_raw[:,7]))
    if torch.rand(1).item() < 0.5:  # vertical flip
        X_raw = torch.flip(X_raw, dims=[-2]); y = torch.flip(y, dims=[-2])
        X_raw[:,2] = -X_raw[:,2]  # v -> -v
        X_raw[:,7] = -X_raw[:,7]
    k = torch.randint(0,4,(1,)).item()
    if k > 0:
        X_raw = torch.rot90(X_raw, k, dims=[-2,-1]); y = torch.rot90(y, k, dims=[-2,-1])
        u, v, asp = X_raw[:,1].clone(), X_raw[:,2].clone(), X_raw[:,7].clone()
        if   k == 1: X_raw[:,1], X_raw[:,2], X_raw[:,7] = -v,  u,  asp + _math.pi/2
        elif k == 2: X_raw[:,1], X_raw[:,2], X_raw[:,7] = -u, -v,  asp + _math.pi
        elif k == 3: X_raw[:,1], X_raw[:,2], X_raw[:,7] =  v, -u,  asp - _math.pi/2
        # ✅ fixed typo here
        X_raw[:,7] = ((X_raw[:,7] + _math.pi) % (2 * _math.pi)) - _math.pi
    return X_raw, y


# Build 16-channel input: 9 norm + cos(aspect) + sin(aspect) + wind speed + 4 physics
prior = PhysicsPrior(kernel_radius=4, a1=0.03, a2=0.02, a3=0.7).to(device)

def build_input_for_net(X_raw0):
    pf,u,v,T,RH,NDVI,sl,asp,bar = \
        X_raw0[:,0:1],X_raw0[:,1:2],X_raw0[:,2:3],X_raw0[:,3:4],X_raw0[:,4:5], \
        X_raw0[:,5:6],X_raw0[:,6:7],X_raw0[:,7:8],X_raw0[:,8:9]
    asp_cos = torch.cos(asp); asp_sin = torch.sin(asp)
    ws      = torch.clamp(torch.sqrt(u**2 + v**2)/10.0, 0, 1)
    X9_norm = (X_raw0 - mean9.view(1,-1,1,1)) / std9.view(1,-1,1,1)
    with torch.no_grad():
        X_phys = prior(pf,u,v,sl,asp,T,RH,NDVI,bar)
    return torch.cat([X9_norm, asp_cos, asp_sin, ws, X_phys], dim=1)  # (B,16,H,W)

# Model, optimizer, scheduler, AMP scaler
model = UNet(in_ch=16, out_ch=1, base=80).to(device) # changed from 64 to 80
model = torch.compile(model) if hasattr(torch, "compile") else model  # optional

EPOCHS = 60 #changed from 40 to 60
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=5e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=5e-5)
scaler = torch.amp.GradScaler('cuda', enabled=use_cuda)

# Loss: keep only the make_loss version
def make_loss(gamma=1.5, alpha=0.5, tversky_gamma=0.75):
    return lambda logits, y: (
        0.5 * FocalLoss(gamma=gamma, alpha=alpha)(logits, y)
      + 0.5 * FocalTverskyLoss(alpha=0.6, beta=0.4, gamma=tversky_gamma)(logits, y)
    )
loss_fn = make_loss()

# --- EMA helper (dtype-safe) ---
class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = {k: v.detach().clone() for k,v in model.state_dict().items()}

    @torch.no_grad()
    def update(self, model):
        for k, v in model.state_dict().items():
            if v.dtype.is_floating_point:
                self.shadow[k].mul_(self.decay).add_(v.detach(), alpha=1 - self.decay)
            else:
                self.shadow[k].copy_(v)

    @torch.no_grad()
    def apply_to(self, model):
        model.load_state_dict(self.shadow, strict=True)

    @torch.no_grad()
    def state_dict(self):
        # return a plain state_dict you can save
        return {k: v.detach().clone() for k, v in self.shadow.items()}


# Initialize EMA tracker
ema = EMA(model, decay=0.999)

def train_one_epoch(model, loader):
    model.train(); prior.eval()
    losses = []
    for batch in tqdm(loader, desc="train", leave=False):
        X_raw0, y = batch["X_raw"].to(device), batch["y"].to(device)
        X_raw0, y = aug_geo(X_raw0, y)
        X = build_input_for_net(X_raw0)

        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast('cuda', enabled=use_cuda):
            logits = model(X)
            loss   = loss_fn(logits, y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        ema.update(model)

        losses.append(loss.item())
    return float(np.mean(losses))


In [109]:
@torch.no_grad()
def evaluate(model, loader):
    model.eval(); prior.eval()
    all_probs, all_true = [], []
    for batch in tqdm(loader, desc="eval", leave=False):
        X_raw0, y = batch["X_raw"].to(device), batch["y"].to(device)
        X = build_input_for_net(X_raw0)
        probs = torch.sigmoid(model(X)).cpu().numpy()
        all_probs.append(probs.ravel()); all_true.append(y.cpu().numpy().ravel())
    p = np.concatenate(all_probs); t = np.concatenate(all_true)
    if t.sum()==0: return 0.0, 0.0, np.array([0.5])
    ap = average_precision_score(t, p)
    prec, rec, thr = precision_recall_curve(t, p)
    f1 = (2*prec*rec)/(prec+rec+1e-8)
    return float(ap), float(f1.max()), thr

@torch.no_grad()
def pick_threshold(model, loader):
    """Pick the probability threshold that maximizes F1 on `loader`."""
    model.eval(); prior.eval()
    ps, ts = [], []
    for b in loader:
        X_raw = b["X_raw"].to(device); y = b["y"].to(device)
        X = build_input_for_net(X_raw)
        p = torch.sigmoid(model(X)).flatten().cpu().numpy()
        t = y.flatten().cpu().numpy()
        ps.append(p); ts.append(t)
    p = np.concatenate(ps).astype(np.float32)
    t = np.concatenate(ts).astype(np.float32)

    if t.sum() == 0:
        return 0.5  # no positives → neutral threshold

    prec, rec, thr = precision_recall_curve(t, p)
    f1 = (2*prec*rec)/(prec+rec+1e-8)
    idx = int(np.argmax(f1))
    # sklearn's PR returns len(thr) = len(prec) - 1; guard the edge:
    return float(thr[idx]) if idx < len(thr) else 0.5


@torch.no_grad()
def predict_tta(model, X_raw0):
    outs = []
    for k in [0,1,2,3]:  # 0,90,180,270
        Xk = torch.rot90(X_raw0, k, dims=[-2,-1])
        X  = build_input_for_net(Xk)
        pk = torch.sigmoid(model(X))
        pk = torch.rot90(pk, (4-k)%4, dims=[-2,-1])
        outs.append(pk)
    Xf = torch.flip(X_raw0, dims=[-1])
    X  = build_input_for_net(Xf)
    pf = torch.sigmoid(model(X))
    pf = torch.flip(pf, dims=[-1])
    outs.append(pf)
    return torch.stack(outs, dim=0).mean(0)  # (B,1,H,W)

@torch.no_grad()
def evaluate_tta(model, loader):
    model.eval(); prior.eval()
    ps, ts = [], []
    for b in tqdm(loader, desc="eval TTA", leave=False):
        X_raw = b["X_raw"].to(device); y = b["y"].to(device)
        probs = predict_tta(model, X_raw).cpu().numpy()
        ps.append(probs.ravel()); ts.append(y.cpu().numpy().ravel())
    p = np.concatenate(ps); t = np.concatenate(ts)
    ap = average_precision_score(t, p)
    prec, rec, _ = precision_recall_curve(t, p)
    f1 = (2*prec*rec)/(prec+rec+1e-8)
    return float(ap), float(f1.max())


In [111]:
# =========================
# RAW vs EMA vs Polyak block
# =========================
import copy, json, time

# ---- Polyak (simple running average of full state_dict) ----
class Polyak:
    def __init__(self):
        self.buf = None
        self.n = 0
    @torch.no_grad()
    def update(self, model):
        sd = {k: v.detach().clone() for k,v in model.state_dict().items()}
        if self.buf is None:
            self.buf = sd; self.n = 1
        else:
            for k in sd:
                if sd[k].dtype.is_floating_point:
                    # incremental average
                    self.buf[k].mul_(self.n/(self.n+1)).add_(sd[k], alpha=1/(self.n+1))
                else:
                    self.buf[k].copy_(sd[k])
            self.n += 1
    @torch.no_grad()
    def state_dict(self):
        return self.buf if self.buf is not None else {}

# ---- Local helpers (use your existing predict_tta/build_input_for_net) ----
@torch.no_grad()
def _pick_threshold_val(model, loader, use_tta=False):
    model.eval(); prior.eval()
    ps, ts = [], []
    for b in loader:
        X_raw, y = b["X_raw"].to(device), b["y"].to(device)
        if use_tta:
            try:
                p = predict_tta(model, X_raw)
            except NameError:
                probs = []
                for k in [0,1,2,3]:
                    xr = torch.rot90(X_raw, k, dims=[-2,-1])
                    pr = torch.sigmoid(model(build_input_for_net(xr)))
                    pr = torch.rot90(pr, 4-k, dims=[-2,-1])
                    probs.append(pr)
                p = torch.stack(probs, dim=0).mean(0)
        else:
            p = torch.sigmoid(model(build_input_for_net(X_raw)))
        ps.append(p.cpu().numpy().ravel()); ts.append(y.cpu().numpy().ravel())
    p = np.concatenate(ps); t = np.concatenate(ts)
    prec, rec, thr = precision_recall_curve(t, p)
    f1 = (2*prec*rec)/(prec+rec+1e-8)
    i = int(f1.argmax())
    return float(thr[i] if i < len(thr) else 0.5)

@torch.no_grad()
def _eval_plain_tta(model, loader, thr_plain, thr_tta):
    model.eval(); prior.eval()
    # plain
    ps, ts = [], []
    for b in loader:
        X_raw, y = b["X_raw"].to(device), b["y"].to(device)
        p = torch.sigmoid(model(build_input_for_net(X_raw)))
        ps.append(p.cpu().numpy().ravel()); ts.append(y.cpu().numpy().ravel())
    p = np.concatenate(ps); t = np.concatenate(ts)
    ap_plain = float(average_precision_score(t, p))
    f1_plain = float((2*((p>=thr_plain)*t).sum())/(((p>=thr_plain)+t).sum()+1e-8))

    # tta
    ps, ts = [], []
    for b in loader:
        X_raw, y = b["X_raw"].to(device), b["y"].to(device)
        try:
            pt = predict_tta(model, X_raw)
        except NameError:
            probs=[]
            for k in [0,1,2,3]:
                xr = torch.rot90(X_raw, k, dims=[-2,-1])
                pr = torch.sigmoid(model(build_input_for_net(xr)))
                pr = torch.rot90(pr, 4-k, dims=[-2,-1])
                probs.append(pr)
            pt = torch.stack(probs, dim=0).mean(0)
        ps.append(pt.cpu().numpy().ravel()); ts.append(y.cpu().numpy().ravel())
    p = np.concatenate(ps); t = np.concatenate(ts)
    ap_tta = float(average_precision_score(t, p))
    f1_tta = float((2*((p>=thr_tta)*t).sum())/(((p>=thr_tta)+t).sum()+1e-8))
    return ap_plain, f1_plain, ap_tta, f1_tta

# ---- Switches ----
USE_POLYAK = True
poly = Polyak() if USE_POLYAK else None

# ---- Train once, track RAW; save RAW+EMA+Polyak on best RAW-VAL AP ----
best_val_ap_raw = -1.0
hist = []
for epoch in range(EPOCHS):
    tr = train_one_epoch(model, train_loader)   # <-- already does ema.update(...) per step
    if USE_POLYAK:
        poly.update(model)                      # Polyak per epoch

    ap_raw, f1_raw, _ = evaluate(model, val_loader)
    hist.append((epoch, tr, ap_raw, f1_raw))
    print(f"Epoch {epoch:02d} | loss {tr:.4f} | VAL RAW AP {ap_raw:.4f} | F1* {f1_raw:.4f}")

    if ap_raw > best_val_ap_raw:
        best_val_ap_raw = ap_raw
        torch.save({
            "model": model.state_dict(),                 # RAW
            "ema":   (ema.state_dict() if 'ema' in globals() and ema is not None else None),
            "poly":  (poly.state_dict() if USE_POLYAK else None),
            "in_ch": 16,
            "base":  80                                  # <-- match your UNet base
        }, "/content/best_unet.pt")
    scheduler.step()

print("Best val AP (RAW selection):", best_val_ap_raw)

# ---- Load checkpoint once, evaluate RAW vs EMA vs Polyak side-by-side ----
ckpt = torch.load("/content/best_unet.pt", map_location=device)

def eval_variant(name, sd, val_loader, test_loader):
    if sd is None:
        return {"variant": name, "note": "not available"}
    model.load_state_dict(sd, strict=True)
    thr_plain = _pick_threshold_val(model, val_loader, use_tta=False)
    thr_tta   = _pick_threshold_val(model, val_loader, use_tta=True)
    ap_p, f1_p, ap_t, f1_t = _eval_plain_tta(model, test_loader, thr_plain, thr_tta)
    return {
        "variant": name,
        "thr_plain": round(thr_plain, 3),
        "thr_tta":   round(thr_tta, 3),
        "test_plain_AP": round(ap_p, 4),
        "test_plain_F1": round(f1_p, 4),
        "test_tta_AP":   round(ap_t, 4),
        "test_tta_F1":   round(f1_t, 4),
    }

results = []
results.append(eval_variant("RAW",    ckpt.get("model"), val_loader, test_loader))
results.append(eval_variant("EMA",    ckpt.get("ema"),   val_loader, test_loader))
results.append(eval_variant("Polyak", ckpt.get("poly"),  val_loader, test_loader))

print(json.dumps({"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "results": results}, indent=2))


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 00 | loss 0.4547 | VAL RAW AP 0.2228 | F1* 0.3443


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 01 | loss 0.3945 | VAL RAW AP 0.2786 | F1* 0.3755


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 02 | loss 0.3734 | VAL RAW AP 0.2965 | F1* 0.3812


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 03 | loss 0.3713 | VAL RAW AP 0.3011 | F1* 0.3866


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 04 | loss 0.3684 | VAL RAW AP 0.3091 | F1* 0.3906


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 05 | loss 0.3648 | VAL RAW AP 0.3213 | F1* 0.3956


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 06 | loss 0.3651 | VAL RAW AP 0.3199 | F1* 0.3969


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 07 | loss 0.3653 | VAL RAW AP 0.3195 | F1* 0.3949


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 08 | loss 0.3622 | VAL RAW AP 0.3229 | F1* 0.3965


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 09 | loss 0.3628 | VAL RAW AP 0.3243 | F1* 0.3994


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 10 | loss 0.3613 | VAL RAW AP 0.3212 | F1* 0.3985


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 11 | loss 0.3584 | VAL RAW AP 0.3310 | F1* 0.4016


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 12 | loss 0.3591 | VAL RAW AP 0.3294 | F1* 0.4033


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 13 | loss 0.3611 | VAL RAW AP 0.3311 | F1* 0.4041


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 14 | loss 0.3602 | VAL RAW AP 0.3308 | F1* 0.4042


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 15 | loss 0.3583 | VAL RAW AP 0.3283 | F1* 0.4036


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 16 | loss 0.3577 | VAL RAW AP 0.3400 | F1* 0.4101


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 17 | loss 0.3594 | VAL RAW AP 0.3329 | F1* 0.4072


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 18 | loss 0.3572 | VAL RAW AP 0.3329 | F1* 0.4069


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 19 | loss 0.3559 | VAL RAW AP 0.3356 | F1* 0.4079


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 20 | loss 0.3553 | VAL RAW AP 0.3384 | F1* 0.4092


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 21 | loss 0.3554 | VAL RAW AP 0.3399 | F1* 0.4108


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 22 | loss 0.3533 | VAL RAW AP 0.3416 | F1* 0.4114


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 23 | loss 0.3541 | VAL RAW AP 0.3389 | F1* 0.4112


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 24 | loss 0.3529 | VAL RAW AP 0.3436 | F1* 0.4137


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 25 | loss 0.3517 | VAL RAW AP 0.3313 | F1* 0.4122


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 26 | loss 0.3532 | VAL RAW AP 0.3470 | F1* 0.4151


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 27 | loss 0.3519 | VAL RAW AP 0.3437 | F1* 0.4161


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 28 | loss 0.3519 | VAL RAW AP 0.3413 | F1* 0.4147


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 29 | loss 0.3509 | VAL RAW AP 0.3400 | F1* 0.4146


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 30 | loss 0.3477 | VAL RAW AP 0.3418 | F1* 0.4175


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 31 | loss 0.3508 | VAL RAW AP 0.3513 | F1* 0.4199


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 32 | loss 0.3489 | VAL RAW AP 0.3473 | F1* 0.4185


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 33 | loss 0.3503 | VAL RAW AP 0.3451 | F1* 0.4182


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 34 | loss 0.3508 | VAL RAW AP 0.3517 | F1* 0.4208


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 35 | loss 0.3501 | VAL RAW AP 0.3407 | F1* 0.4153


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 36 | loss 0.3488 | VAL RAW AP 0.3486 | F1* 0.4208


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 37 | loss 0.3470 | VAL RAW AP 0.3548 | F1* 0.4252


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 38 | loss 0.3451 | VAL RAW AP 0.3537 | F1* 0.4240


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 39 | loss 0.3453 | VAL RAW AP 0.3570 | F1* 0.4270


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 40 | loss 0.3434 | VAL RAW AP 0.3517 | F1* 0.4265


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 41 | loss 0.3454 | VAL RAW AP 0.3450 | F1* 0.4215


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 42 | loss 0.3452 | VAL RAW AP 0.3583 | F1* 0.4265


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 43 | loss 0.3439 | VAL RAW AP 0.3601 | F1* 0.4296


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 44 | loss 0.3441 | VAL RAW AP 0.3577 | F1* 0.4287


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 45 | loss 0.3423 | VAL RAW AP 0.3547 | F1* 0.4277


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 46 | loss 0.3425 | VAL RAW AP 0.3588 | F1* 0.4275


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 47 | loss 0.3405 | VAL RAW AP 0.3628 | F1* 0.4274


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 48 | loss 0.3432 | VAL RAW AP 0.3653 | F1* 0.4315


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 49 | loss 0.3429 | VAL RAW AP 0.3652 | F1* 0.4329


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 50 | loss 0.3407 | VAL RAW AP 0.3707 | F1* 0.4320


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 51 | loss 0.3400 | VAL RAW AP 0.3696 | F1* 0.4346


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 52 | loss 0.3417 | VAL RAW AP 0.3739 | F1* 0.4372


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 53 | loss 0.3398 | VAL RAW AP 0.3709 | F1* 0.4358


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 54 | loss 0.3405 | VAL RAW AP 0.3704 | F1* 0.4360


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 55 | loss 0.3381 | VAL RAW AP 0.3653 | F1* 0.4331


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 56 | loss 0.3382 | VAL RAW AP 0.3719 | F1* 0.4344


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 57 | loss 0.3370 | VAL RAW AP 0.3738 | F1* 0.4350


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 58 | loss 0.3370 | VAL RAW AP 0.3695 | F1* 0.4360


train:   0%|          | 0/812 [00:00<?, ?it/s]

eval:   0%|          | 0/174 [00:00<?, ?it/s]

Epoch 59 | loss 0.3367 | VAL RAW AP 0.3754 | F1* 0.4383
Best val AP (RAW selection): 0.37542389158101286
{
  "timestamp": "2025-11-11 02:57:31",
  "results": [
    {
      "variant": "RAW",
      "thr_plain": 0.095,
      "thr_tta": 0.195,
      "test_plain_AP": 0.3841,
      "test_plain_F1": 0.4435,
      "test_tta_AP": 0.3935,
      "test_tta_F1": 0.4469
    },
    {
      "variant": "EMA",
      "thr_plain": 0.099,
      "thr_tta": 0.152,
      "test_plain_AP": 0.3883,
      "test_plain_F1": 0.4464,
      "test_tta_AP": 0.395,
      "test_tta_F1": 0.4474
    },
    {
      "variant": "Polyak",
      "thr_plain": 0.036,
      "thr_tta": 0.03,
      "test_plain_AP": 0.3437,
      "test_plain_F1": 0.4081,
      "test_tta_AP": 0.3449,
      "test_tta_F1": 0.4074
    }
  ]
}


In [112]:
import numpy as np, matplotlib.pyplot as plt, random
from sklearn.metrics import average_precision_score, precision_recall_curve

def _minmax(x, eps=1e-6):
    lo, hi = np.percentile(x, 2), np.percentile(x, 98)
    if hi - lo < eps: return np.zeros_like(x, dtype=np.float32)
    y = (x - lo) / (hi - lo)
    return np.clip(y, 0.0, 1.0).astype(np.float32)

def make_falsecolor_rgb(X_raw):
    """
    Assumes X_raw channels = [prev,u,v,temp,rh,ndvi,slope,aspect,barrier]
    Change bands if you prefer. Here: R=temp, G=ndvi, B=slope.
    """
    R = _minmax(X_raw[3])
    G = _minmax(X_raw[5])
    B = _minmax(X_raw[6])
    return np.stack([R, G, B], axis=-1)

def _stack_X(npz):
    # Robustly rebuild X_raw if it's not prepacked in the NPZ
    if "X_raw" in npz.files:
        return npz["X_raw"].astype(np.float32)
    def pick(*cands):
        for k in cands:
            if k in npz.files: return k
        raise KeyError(f"Missing keys {cands} in {list(npz.files)}")
    prev = npz[pick("prev_fire","PrevFireMask","prev")]
    u    = npz[pick("u","U")]
    v    = npz[pick("v","V")]
    temp = npz[pick("temp","tmmx","T","temperature")]
    rh   = npz[pick("rh","sph","RH","humidity")]
    ndvi = npz[pick("ndvi","NDVI")]
    slope= npz[pick("slope","Slope")]
    aspect=npz[pick("aspect","Aspect")]
    barrier = npz["barrier"] if "barrier" in npz.files else np.zeros_like(prev, dtype=np.float32)
    return np.stack([prev,u,v,temp,rh,ndvi,slope,aspect,barrier], axis=0).astype(np.float32)

def _tile_metrics(prob, gt, thr=None):
    # Flattened versions for sklearn metrics (AP, PR curve)
    p_flat = prob.ravel()
    t_flat = (gt > 0.5).ravel()
    ap = float(average_precision_score(t_flat, p_flat))

    if thr is None:
        # If thr is None, we find the best F1 based on the PR curve
        prec, rec, thrv = precision_recall_curve(t_flat, p_flat)
        f1 = (2*prec*rec)/(prec+rec+1e-8)
        i = int(f1.argmax())
        # Return AP, best F1, and the threshold that achieved it
        return ap, float(f1[i]), float(thrv[min(i, len(thrv)-1)])
    else:
        # If thr is provided, we use it to binarize and calculate F1
        pred_bin = (prob >= thr).astype(np.float32) # This is (H,W)
        gt_bin = (gt > 0.5).astype(np.float32)     # This should also be (H,W) for element-wise ops

        # Calculate TP, FP, FN using the 2D binary masks
        tp = (pred_bin * gt_bin).sum()
        fp = (pred_bin * (1 - gt_bin)).sum()
        fn = ((1 - pred_bin) * gt_bin).sum()

        prec = float(tp/(tp+fp+1e-8))
        rec  = float(tp/(tp+fn+1e-8))
        f1 = float(2*prec*rec/(prec+rec+1e-8))
        # Return AP, F1 at the given threshold, and the given threshold
        return ap, f1, thr

def show_prob_heatmaps(results, n=6, thr=None, cmap="inferno", alpha=0.55, vmax=None, title="Prob heatmaps + GT contours"):
    """
    results: list of dicts from infer_folder_npz (each has 'file', 'out_npy')
    thr: if given, draws a prediction contour at this threshold and reports F1@thr.
         if None, it computes & reports the best F1 threshold per tile (for diagnostics).
    """
    if not results:
        print("No results to visualize."); return

    picks = random.sample(results, k=min(n, len(results)))
    rows, cols = len(picks), 3  # Input | Prob heatmap | Overlay (prob + GT + thr-contour)
    plt.figure(figsize=(12, 3.5*rows))
    plt.suptitle(title, y=0.995, fontsize=12)

    idx = 1
    for it in picks:
        d = np.load(it["file"])

        # Robustly load 'next_fire' (ground truth)
        try:
            gt = d["next_fire"].astype(np.float32)
        except KeyError:
            print(f"Skipping visualization for {it['file']}: 'next_fire' (ground truth) not found.")
            continue # Skip this tile if GT is missing

        X_raw = _stack_X(d)
        rgb   = make_falsecolor_rgb(X_raw)

        prob = np.load(it["out_npy"]).astype(np.float32)  # (H,W) probs or 0/1
        # If these are 0/1 masks, treat them as probs for visualization (it still works).
        vlim = vmax if vmax is not None else (0.0, 1.0)

        ap, f1t, used_thr = _tile_metrics(prob, gt, thr)

        # 1) Input
        ax = plt.subplot(rows, cols, idx); idx += 1
        ax.imshow(rgb); ax.set_title("Input (false-color)"); ax.axis("off")

        # 2) Probability heatmap
        ax = plt.subplot(rows, cols, idx); idx += 1
        im = ax.imshow(prob, vmin=vlim[0], vmax=vlim[1], cmap=cmap)
        ax.set_title(f"Prob heatmap\nAP={ap:.3f} | F1@thr={f1t:.3f} (thr={used_thr:.3f})")
        ax.axis("off"); plt.colorbar(im, ax=ax, fraction=0.046, pad=0.02)

        # 3) Overlay: prob + GT contour (+ pred contour @ thr)
        ax = plt.subplot(rows, cols, idx); idx += 1
        ax.imshow(rgb)
        ax.imshow(prob, alpha=alpha, vmin=vlim[0], vmax=vlim[1], cmap=cmap)
        # GT contour
        ax.contour((gt > 0.5).astype(np.float32), levels=[0.5], colors="lime", linewidths=1.2)
        # Pred contour at chosen thr
        ax.contour((prob >= used_thr).astype(np.float32), levels=[0.5], colors="yellow", linewidths=1.0, linestyles="--")
        ax.set_title("Overlay: prob + GT (green) + pred@thr (yellow)")
        ax.axis("off")

    plt.tight_layout(rect=[0, 0, 1, 0.98])
    plt.show()

In [1]:
# --- Save deployable checkpoint ---
# (Run this after evaluating RAW/EMA/Polyak and picking your best-performing variant.)

torch.save({
    "model": model.state_dict(),     # trained weights (whichever variant you prefer)
    "in_ch": 16,
    "base": 80,
    "thr_plain": 0.112,              # from your val_best summary
    "thr_tta": 0.141,                # from your val_best summary
    "T": 1.0                         # calibration temperature (kept at 1.0 for now)
}, "/content/unet_raw_deploy.pt")

print("Saved deployable model → /content/unet_raw_deploy.pt")


NameError: name 'torch' is not defined

In [114]:
import os, glob, shutil, json
from pathlib import Path
import numpy as np

# 1) Where the NPZs live
npz_dir = NPZ_ROOT  # flat directory of all .npz
print("Using NPZ root:", npz_dir)
all_npz = sorted(glob.glob(os.path.join(npz_dir, "*.npz")))
assert len(all_npz) > 0, f"No .npz found in {npz_dir}."

# 2) Rebuild the same train/val/test split (deterministic)
rng = np.random.default_rng(1337)
idx = np.arange(len(all_npz))
rng.shuffle(idx)

n = len(all_npz)
n_train = int(round(0.70*n))
n_val   = int(round(0.15*n))
test_idx = idx[n_train+n_val:] if (n_train+n_val) < n else idx[-1:]  # guard if tiny

test_files = [all_npz[i] for i in test_idx]
print(f"Split sizes -> train≈{n_train}, val≈{n_val}, test={len(test_files)}")

# 3) Materialize a temp folder with ONLY test files (via symlinks; fallback to copy on Windows)
tmp_test_dir = "/content/npz_test_only"
Path(tmp_test_dir).mkdir(parents=True, exist_ok=True)

# clean old
for p in glob.glob(os.path.join(tmp_test_dir, "*.npz")):
    os.remove(p)

# link/copy
for src in test_files[:]:  # you can slice to make a smaller subset
    dst = os.path.join(tmp_test_dir, os.path.basename(src))
    try:
        os.symlink(src, dst)
    except Exception:
        shutil.copy2(src, dst)

print("Prepared test-only folder:", tmp_test_dir, "with", len(glob.glob(os.path.join(tmp_test_dir, '*.npz'))), "files")

# 4) Load deploy model (weights + thresholds + temperature)
try:
    model, thr_plain, thr_tta, T_opt = load_deploy_model("/content/unet_raw_deploy.pt")
except ValueError:
    # backward-compat if your helper returns 3-tuple
    model, thr_plain, thr_tta = load_deploy_model("/content/unet_raw_deploy.pt")
    T_opt = 1.0
print(f"Loaded deploy model. thr_plain={thr_plain}  thr_tta={thr_tta}  T={T_opt}")

# 5) Run inference on TEST ONLY
plain_results = infer_folder_npz(
    model,
    npz_dir=tmp_test_dir,   # ← test split only
    out_dir="/content/preds_plain",
    mode="plain",
    threshold=thr_plain,    # or None to save probabilities
    T=T_opt,
    do_clean=True,
    batch_size=32
)

tta_results = infer_folder_npz(
    model,
    npz_dir=tmp_test_dir,
    out_dir="/content/preds_tta",
    mode="tta",
    threshold=thr_tta,      # or None
    T=T_opt,
    do_clean=True,
    batch_size=16
)

# 6) Quick visual sanity checks (optional)
if plain_results:
    print("Showing plain examples…")
    show_examples(plain_results, n=6, add_gt=True, alpha_mask=0.35)

    print("Hard examples (diff view)…")
    show_diff_examples(plain_results, n=6, threshold=thr_plain)


Using NPZ root: /content/wildfire_npz_tiles_kaggle_v3
Split sizes -> train≈12982, val≈2782, test=2781
Prepared test-only folder: /content/npz_test_only with 2781 files
Loaded deploy model. thr_plain=0.112  thr_tta=0.141  T=1.0


infer:plain:   0%|          | 0/2781 [00:00<?, ?it/s]

infer:tta:   0%|          | 0/2781 [00:00<?, ?it/s]

Showing plain examples…


TypeError: Population must be a sequence.  For dicts or sets, use sorted(d).

In [None]:
# =========================
# Inference-only utilities (updated)
# =========================
import os, glob, math, json
from PIL import Image

@torch.no_grad()
def load_deploy_model(ckpt_path="/content/unet_raw_deploy.pt"):
    ck = torch.load(ckpt_path, map_location=device)
    in_ch = ck.get("in_ch", 16); base = ck.get("base", 64)
    model = UNet(in_ch=in_ch, out_ch=1, base=base).to(device)

    # Strip _orig_mod. prefix if model was compiled
    raw_sd = ck["model"]
    clean_sd = {}
    for k, v in raw_sd.items():
        clean_sd[k.replace("_orig_mod.", "") if k.startswith("_orig_mod.") else k] = v
    model.load_state_dict(clean_sd, strict=True)
    model.eval()

    thr_plain = float(ck.get("thr_plain", 0.5))
    thr_tta   = float(ck.get("thr_tta",   0.5))
    T_opt     = float(ck.get("temperature", 1.0))  # optional, default to 1.0
    return model, thr_plain, thr_tta, T_opt

# Optional: small morphological clean-up
def clean_mask_tensor(m, do_clean=True, k=3):
    if not do_clean:
        return m
    m1 = F.max_pool2d(m, k, 1, k//2)             # dilate
    m2 = 1 - F.max_pool2d(1 - m1, k, 1, k//2)    # erode
    return m2

@torch.no_grad()
def predict_plain(model, X_raw, T: float = 1.0):
    """X_raw: (B, 9, H, W) in canonical order."""
    logits = model(build_input_for_net(X_raw))
    return torch.sigmoid(logits / T)

@torch.no_grad()
def predict_tta(model, X_raw, T: float = 1.0):
    outs=[]
    # 4 rotations
    for k in [0,1,2,3]:
        xr = torch.rot90(X_raw, k, dims=[-2,-1])
        pk = torch.sigmoid(model(build_input_for_net(xr)) / T)
        pk = torch.rot90(pk, (4-k)%4, dims=[-2,-1])
        outs.append(pk)
    # + horizontal flip
    xf = torch.flip(X_raw, dims=[-1])
    pf = torch.sigmoid(model(build_input_for_net(xf)) / T)
    pf = torch.flip(pf, dims=[-1])
    outs.append(pf)
    return torch.stack(outs, dim=0).mean(0)

def save_mask_png(mask01, out_png):
    im = Image.fromarray((mask01*255).astype("uint8"))
    im.save(out_png)

@torch.no_grad()
def infer_folder_npz(
    model,
    npz_dir,
    out_dir,
    mode="plain",             # "plain" or "tta"
    threshold=None,           # if None, save probabilities instead of masks
    T: float = 1.0,           # temperature for calibration
    do_clean=True,
    batch_size=32
):
    os.makedirs(out_dir, exist_ok=True)
    files = sorted(glob.glob(os.path.join(npz_dir, "*.npz")))
    if not files:
        print(f"No NPZ files found in {npz_dir}")
        return []

    # Threshold defaulting: if caller didn’t provide one, try checkpoint; else fallback constants
    if threshold is None:
        thr_plain, thr_tta = 0.112, 0.141
        try:
            ck = torch.load("/content/unet_raw_deploy.pt", map_location=device)
            thr_plain = float(ck.get("thr_plain", thr_plain))
            thr_tta   = float(ck.get("thr_tta", thr_tta))
        except Exception:
            pass
        picked_thr = thr_tta if mode == "tta" else thr_plain
        print(f"Inference mode: {mode} | threshold=None (will save probabilities) | clean={do_clean} | T={T}")
    else:
        picked_thr = float(threshold)
        print(f"Inference mode: {mode} | threshold={picked_thr:.3f} | clean={do_clean} | T={T}")

    # NPZ loader tolerant to missing 'barrier' or pre-bundled X_raw
    def _load_npz(fp):
        z = np.load(fp)
        if "X_raw" in z.files:
            X_raw = z["X_raw"].astype(np.float32)
        else:
            # Build from individual arrays; make 'barrier' optional (zeros if absent)
            req = ["prev_fire","u","v","temp","rh","ndvi","slope","aspect"]
            missing = [k for k in req if k not in z.files]
            if missing:
                raise KeyError(f"{os.path.basename(fp)} missing {missing}")
            barrier = z["barrier"] if "barrier" in z.files else np.zeros_like(z["prev_fire"])
            X_raw = np.stack(
                [z["prev_fire"], z["u"], z["v"], z["temp"], z["rh"], z["ndvi"], z["slope"], z["aspect"], barrier],
                axis=0
            ).astype(np.float32)
        return X_raw  # (9,H,W)

    results = []
    buf = []; paths = []
    def _flush():
        if not buf: return
        B = len(buf)
        X = torch.from_numpy(np.stack(buf, axis=0)).to(device)  # (B,9,H,W)
        prob = predict_tta(model, X, T=T) if mode == "tta" else predict_plain(model, X, T=T)  # (B,1,H,W)

        if threshold is None:
            # Save probabilities
            for i in range(B):
                base = os.path.splitext(os.path.basename(paths[i]))[0]
                out_npy = os.path.join(out_dir, f"{base}_{mode}_prob.npy")
                np.save(out_npy, prob[i,0].detach().cpu().numpy().astype(np.float32))
                results.append(dict(file=paths[i], out_prob=out_npy))
        else:
            # Save hard masks
            m = (prob >= picked_thr).float()
            m = clean_mask_tensor(m, do_clean=do_clean)
            for i in range(B):
                base = os.path.splitext(os.path.basename(paths[i]))[0]
                out_png = os.path.join(out_dir, f"{base}_{mode}.png")
                out_npy = os.path.join(out_dir, f"{base}_{mode}.npy")
                mask01  = m[i,0].detach().cpu().numpy()
                save_mask_png(mask01, out_png)
                np.save(out_npy, mask01.astype(np.uint8))
                results.append(dict(file=paths[i], out_png=out_png, out_npy=out_npy))

        buf.clear(); paths.clear()

    for fp in files:
        buf.append(_load_npz(fp)); paths.append(fp)
        if len(buf) == batch_size:
            _flush()
    _flush()

    print(f"Wrote {len(results)} outputs → {out_dir}")
    return results


In [None]:
model, thr_plain, thr_tta, T_opt = load_deploy_model("/content/unet_raw_deploy.pt")
plain_results = infer_folder_npz(
    model, npz_dir=npz_dir, out_dir="/content/preds_plain",
    mode="plain", threshold=thr_plain, T=T_opt, do_clean=True, batch_size=32
)
tta_results = infer_folder_npz(
    model, npz_dir=npz_dir, out_dir="/content/preds_tta",
    mode="tta", threshold=thr_tta, T=T_opt, do_clean=True, batch_size=16
)


In [None]:
# =========================
# Visualization helpers (robust)
# =========================
import random
import numpy as np
import matplotlib.pyplot as plt

def _minmax(x, lo_q=2, hi_q=98, eps=1e-6):
    x = x.astype(np.float32)
    lo, hi = np.percentile(x, lo_q), np.percentile(x, hi_q)
    if not np.isfinite(lo) or not np.isfinite(hi) or (hi - lo) < eps:
        return np.zeros_like(x, dtype=np.float32)
    y = (x - lo) / (hi - lo + eps)
    return np.clip(y, 0.0, 1.0)

def make_falsecolor_rgb(X_raw):
    """
    X_raw order: [prev, u, v, temp, rh, ndvi, slope, aspect, barrier]
    Simple composite:
      R = temp, G = NDVI, B = slope
    """
    temp  = X_raw[3]
    ndvi  = X_raw[5]
    slope = X_raw[6]
    R = _minmax(temp)
    G = _minmax(ndvi)
    B = _minmax(slope)
    return np.stack([R, G, B], axis=-1)

def _load_X_raw_from_npz(data):
    """Tolerant loader for X_raw (builds from components if needed)."""
    if "X_raw" in data.files:
        return data["X_raw"].astype(np.float32)
    # Build, allowing missing 'barrier' (zeros default)
    req = ["prev_fire","u","v","temp","rh","ndvi","slope","aspect"]
    missing = [k for k in req if k not in data.files]
    if missing:
        raise KeyError(f"NPZ missing required keys: {missing}")
    barrier = data["barrier"] if "barrier" in data.files else np.zeros_like(data["prev_fire"])
    return np.stack([
        data["prev_fire"], data["u"], data["v"], data["temp"], data["rh"],
        data["ndvi"], data["slope"], data["aspect"], barrier
    ], axis=0).astype(np.float32)

def _binarize_from_result_item(item, viz_threshold=None):
    """
    Returns (mask, is_prob, src_path)
    - If result has 'out_npy': it's a hard mask (0/1).
    - If it has 'out_prob': it's probabilities; apply viz_threshold if given, else return probs.
    """
    if "out_npy" in item:
        pred = np.load(item["out_npy"]).astype(np.float32)
        return pred, False, item["out_npy"]
    if "out_prob" in item:
        prob = np.load(item["out_prob"]).astype(np.float32)
        if viz_threshold is None:
            return prob, True, item["out_prob"]
        return (prob >= float(viz_threshold)).astype(np.float32), False, item["out_prob"]
    raise KeyError("Result item must contain 'out_npy' (mask) or 'out_prob' (probabilities).")

def show_examples(results, n=6, add_gt=True, alpha_mask=0.35, figsize=(12, 2.8), viz_threshold=None):
    """
    results: list of dicts from infer_folder_npz (with 'file', and either 'out_npy' or 'out_prob')
    Shows: input falsecolor, predicted (mask or prob), overlay (+ optional GT overlay)
    - If results contain probabilities ('out_prob') and viz_threshold is None, we show the probability heatmap.
    - If viz_threshold is set, probabilities will be binarized for display.
    """
    if not results:
        print("No results to visualize.")
        return

    picks = random.sample(results, k=min(n, len(results)))
    rows = len(picks)
    # input | pred/prob | overlay | (optional GT overlay)
    cols = 3 + int(add_gt)

    plt.figure(figsize=(figsize[0], figsize[1]*rows))
    idx = 1
    for item in picks:
        npz_path = item["file"]
        data = np.load(npz_path)
        X_raw = _load_X_raw_from_npz(data)
        rgb = make_falsecolor_rgb(X_raw)

        pred_or_prob, is_prob, src = _binarize_from_result_item(item, viz_threshold=viz_threshold)
        H, W = pred_or_prob.shape

        has_gt = add_gt and ("next_fire" in data.files)
        gt = data["next_fire"].astype(np.float32) if has_gt else None

        # Panel 1: Input
        ax = plt.subplot(rows, cols, idx); idx += 1
        ax.imshow(rgb); ax.set_title("Input (false-color)"); ax.axis("off")

        # Panel 2: Pred (mask or prob heatmap)
        ax = plt.subplot(rows, cols, idx); idx += 1
        if is_prob and viz_threshold is None:
            im = ax.imshow(pred_or_prob, vmin=0, vmax=1)
            ax.set_title("Prediction (probability)")
            plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        else:
            ax.imshow(pred_or_prob, vmin=0, vmax=1)
            ax.set_title("Prediction (mask)")
        ax.axis("off")

        # Panel 3: Overlay (pred/prob on input)
        ax = plt.subplot(rows, cols, idx); idx += 1
        ax.imshow(rgb)
        if is_prob and viz_threshold is None:
            im = ax.imshow(pred_or_prob, alpha=alpha_mask, vmin=0, vmax=1)
            ax.set_title("Overlay (prob)")
            plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        else:
            ax.imshow(pred_or_prob, alpha=alpha_mask)
            ax.set_title("Overlay (pred)")
        ax.axis("off")

        # Panel 4: GT overlay (if available)
        if has_gt:
            ax = plt.subplot(rows, cols, idx); idx += 1
            ax.imshow(rgb)
            ax.imshow(gt, alpha=alpha_mask)
            ax.set_title("Overlay (GT)"); ax.axis("off")

    plt.tight_layout()
    plt.show()

def show_diff_examples(results, n=6, alpha=0.35, figsize=(12, 3.2), viz_threshold=None):
    """
    Diff view: TP (green), FP (red), FN (blue)
    Accepts hard masks ('out_npy') OR probabilities ('out_prob'):
      - If 'out_prob' and viz_threshold is None → we’ll set viz_threshold=0.5 for diff.
    """
    if not results:
        print("No results to visualize.")
        return

    picks = random.sample(results, k=min(n, len(results)))
    rows = len(picks); cols = 4  # input | pred | GT | diff

    plt.figure(figsize=(figsize[0], figsize[1]*rows))
    idx = 1
    for item in picks:
        npz_path = item["file"]
        data = np.load(npz_path)
        if "next_fire" not in data.files:
            # skip if no GT
            continue

        X_raw = _load_X_raw_from_npz(data)
        rgb = make_falsecolor_rgb(X_raw)
        gt = data["next_fire"].astype(np.float32)

        # Threshold handling for probs
        thr = 0.5 if ("out_prob" in item and viz_threshold is None) else viz_threshold
        pred, is_prob, src = _binarize_from_result_item(item, viz_threshold=thr)
        if is_prob:
            # shouldn’t happen due to viz_threshold above, but guard anyway
            pred = (pred >= 0.5).astype(np.float32)

        # Diff map
        tp = (pred==1) & (gt==1)
        fp = (pred==1) & (gt==0)
        fn = (pred==0) & (gt==1)
        diff_rgb = np.zeros((*gt.shape, 3), dtype=np.float32)
        diff_rgb[...,1] = tp.astype(np.float32)  # G
        diff_rgb[...,0] = fp.astype(np.float32)  # R
        diff_rgb[...,2] = fn.astype(np.float32)  # B

        # 1) Input
        ax = plt.subplot(rows, cols, idx); idx += 1
        ax.imshow(rgb); ax.set_title("Input"); ax.axis("off")

        # 2) Pred
        ax = plt.subplot(rows, cols, idx); idx += 1
        ax.imshow(pred, vmin=0, vmax=1); ax.set_title("Pred"); ax.axis("off")

        # 3) GT
        ax = plt.subplot(rows, cols, idx); idx += 1
        ax.imshow(gt, vmin=0, vmax=1); ax.set_title("GT"); ax.axis("off")

        # 4) Diff
        ax = plt.subplot(rows, cols, idx); idx += 1
        ax.imshow(diff_rgb); ax.set_title("Diff: TP=G, FP=R, FN=B"); ax.axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
show_examples(plain_results, viz_threshold=None)  # heatmaps
show_diff_examples(plain_results, viz_threshold=0.3)  # binarize at 0.3 to see diffs


In [None]:
show_examples(plain_results)
show_diff_examples(plain_results)


In [None]:
# This cell is redundant as its functions are defined in xQJDVIpEdy18. Its content has been cleared to avoid conflicts.


In [None]:
show_diff_examples(plain_results, threshold=None)


In [None]:
show_diff_examples(plain_results, threshold=0.12)  # pick your viz cutoff


In [None]:
show_diff_examples(tta_results, n=6, threshold=None)


In [None]:
import numpy as np
from sklearn.metrics import average_precision_score, precision_recall_curve

def aggregate_metrics(results,
                      threshold=None,
                      bins=(0,5,20,100,4096),
                      pred_key_preference=("out_prob","out_npy")):
    """
    results: list from infer_folder_npz (each item has 'file' and one of pred_key_preference)
    threshold:
      - float -> use this threshold to binarize for F1/by-bin
      - None  -> auto-pick best-F1 threshold from PR curve (if probs available),
                 otherwise use 0.5 on binary masks
    bins: GT pixel-count bins for per-tile F1 summaries
    pred_key_preference: try these keys in order to load predictions
                         ('out_prob' preferred, fallback to 'out_npy')
    Returns:
      best: dict(AP, best_F1, best_thr, best_prec, best_rec)
      bybin_mean: {(lo,hi): mean F1}, counts: {(lo,hi): n_tiles}
    """
    all_p, all_t = [], []
    bybin = { (bins[i], bins[i+1]): [] for i in range(len(bins)-1) }
    counts = { k: 0 for k in bybin }

    # First pass: collect concatenated probs (or binaries) for global PR/AP
    cache = []  # keep (gt, pred_probs_or_bin, path) to avoid reloading twice
    for r in results:
        d = np.load(r["file"])
        if "next_fire" not in d.files:
            # skip tiles with no GT
            continue
        gt = d["next_fire"].astype(np.float32)

        # choose available prediction file
        pred_path = None
        for k in pred_key_preference:
            if k in r:
                pred_path = r[k]
                break
        if pred_path is None:
            # fallback to legacy key
            pred_path = r.get("out_npy", None)
        if pred_path is None:
            continue

        pred_arr = np.load(pred_path).astype(np.float32)  # probs or binary
        # If it's a PNG-derived mask you loaded elsewhere, ensure shape match here.

        cache.append((gt, pred_arr, r["file"]))
        all_t.append(gt.ravel())
        all_p.append(pred_arr.ravel())

    if not all_p:
        return {"AP": float("nan"), "best_F1": float("nan"),
                "best_thr": float("nan"), "best_prec": float("nan"), "best_rec": float("nan")}, \
               {k: float("nan") for k in bybin}, {k: 0 for k in counts}

    t = np.concatenate(all_t)
    p = np.concatenate(all_p)

    # Global PR/AP on whatever we have (probs preferred; binaries okay but lower-bound)
    ap = float(average_precision_score(t, p))
    prec, rec, thr = precision_recall_curve(t, p)
    f1_curve = (2*prec*rec)/(prec+rec+1e-8)
    best_idx = int(f1_curve.argmax())

    # Decide threshold to use for per-bin F1
    if threshold is None:
        auto_thr = float(thr[min(best_idx, len(thr)-1)]) if len(thr) > 0 else 0.5
        use_thr = auto_thr
    else:
        use_thr = float(threshold)

    best = {
        "AP": ap,
        "best_F1": float(f1_curve[best_idx]),
        "best_thr": float(thr[min(best_idx, len(thr)-1)]) if len(thr) > 0 else 0.5,
        "best_prec": float(prec[best_idx]),
        "best_rec": float(rec[best_idx]),
        "used_thr_for_bins": use_thr
    }

    # Second pass: per-tile F1 in GT-size bins using chosen threshold
    for gt, pred_arr, _ in cache:
        k = int(gt.sum())
        pred_bin = (pred_arr >= use_thr).astype(np.float32) if pred_arr.max() <= 1.0 and pred_arr.min() < 1.0 else (pred_arr > 0.5).astype(np.float32)
        # (if pred_arr is {0,1}, the expression above still works)

        for lo, hi in bybin.keys():
            if lo <= k < hi:
                tp = (pred_bin*gt).sum()
                fp = (pred_bin*(1-gt)).sum()
                fn = ((1-pred_bin)*gt).sum()
                prec_b = tp/(tp+fp+1e-8); rec_b = tp/(tp+fn+1e-8)
                f1_b   = 2*prec_b*rec_b/(prec_b+rec_b+1e-8)
                bybin[(lo,hi)].append(float(f1_b))
                counts[(lo,hi)] += 1
                break

    bybin_mean = {k: (float(np.mean(v)) if v else float("nan")) for k,v in bybin.items()}
    return best, bybin_mean, counts


In [None]:
best_plain, bins_plain, counts_plain = aggregate_metrics(plain_results, threshold=None)
print("GLOBAL (plain):", best_plain)
print("F1 by GT size:", bins_plain)
print("Counts per bin:", counts_plain)

# If you have probability maps (e.g., saved via threshold=None in infer),
# you can lock a threshold explicitly:
# best_tta, bins_tta, counts_tta = aggregate_metrics(tta_results, threshold=thr_tta)
