In [1]:
!pip -q install torch torchvision pandas scikit-learn tqdm

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
import os, json, math, random, glob
from pathlib import Path
from typing import List, Tuple, Dict

import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from sklearn.metrics import f1_score


DEFAULT_IMAGES_DIR = "/content/drive/MyDrive/blue beacon dataset/images1"
DEFAULT_EFFECTS_CSV = "/content/drive/MyDrive/blue beacon dataset/annotations/Copy of effects.csv - Sheet1 (1).csv"
DEFAULT_CALAM_CSV   = "/content/drive/MyDrive/blue beacon dataset/annotations/Copy of calamity.csv - Sheet1.csv"
OUT_DIR = "out_multitask"
IMG_SIZE = 256
BATCH = 16
EPOCHS = 20
LR = 3e-4
VAL_SPLIT = 0.10
TEST_SPLIT = 0.00
SEED = 42
ALPHA = 1.0
BETA  = 1.0

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

def set_seed(s=42):
    random.seed(s); np.random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s)


def read_csv(p: str) -> pd.DataFrame:
    df = pd.read_csv(p)
    if "image" not in df.columns:
        raise ValueError(f"{p} must have an 'image' column")
    df["image"] = df["image"].astype(str).str.strip()

    for c in df.columns:
        if c == "image": continue
        df[c] = df[c].map({True:1, False:0}).fillna(df[c]).astype(str).str.strip()
        df[c] = df[c].replace({"":"0","nan":"0"}).astype(float)
        df[c] = (df[c] >= 0.5).astype(int)
    return df

def index_all_images(images_dir: str) -> Dict[str, str]:
    """Return a dict mapping lowercased base filename -> full path.
       If you keep subfolders, we also index them; duplicates will pick the first found."""
    paths = glob.glob(str(Path(images_dir) / "**/*"), recursive=True)
    mapping = {}
    for p in paths:
        if os.path.isfile(p):
            base = os.path.basename(p).lower()
            if base not in mapping:
                mapping[base] = p
    return mapping

def make_merged_df(images_dir: str, effects_csv: str, calam_csv: str) -> Tuple[pd.DataFrame, List[str], List[str]]:
    df_e = read_csv(effects_csv)
    df_c = read_csv(calam_csv)

    effect_classes  = [c for c in df_e.columns if c != "image"]
    calam_classes   = [c for c in df_c.columns if c != "image"]
    if not effect_classes:  raise ValueError("effects.csv needs at least one effect column besides 'image'")
    if not calam_classes:   raise ValueError("calamities.csv needs at least one calamity column besides 'image'")

    df = pd.merge(df_e, df_c, on="image", how="outer", suffixes=("_e", "_c"))

    idx = index_all_images(images_dir)
    resolved_paths = []
    missing = []
    for name in df["image"].tolist():

        cand = Path(images_dir) / name
        if cand.exists():
            resolved_paths.append(str(cand))
            continue

        key = os.path.basename(name).lower()
        if key in idx:
            resolved_paths.append(idx[key])
        else:
            resolved_paths.append(None)
            missing.append(name)

    if missing:
        print(f"Warning: {len(missing)} images listed in CSVs not found under {images_dir}. They will be skipped.")
        print("First few missing:", missing[:10])

    df["image_path"] = resolved_paths
    df = df[~df["image_path"].isna()].reset_index(drop=True)
    return df, effect_classes, calam_classes

class MultiTaskDataset(Dataset):
    def __init__(self, df: pd.DataFrame, effect_cols: List[str], calam_cols: List[str], img_size=256, augment=True):
        self.df = df
        self.effect_cols = effect_cols
        self.calam_cols = calam_cols
        if augment:
            self.tf = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.ColorJitter(0.2,0.2,0.2,0.1),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(10),
                transforms.ToTensor(),
                transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
            ])
        else:
            self.tf = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.ToTensor(),
                transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
            ])

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row["image_path"]).convert("RGB")
        x = self.tf(img)


        yE = []
        for c in self.effect_cols:
            v = row.get(c, np.nan)
            yE.append(0 if pd.isna(v) else int(v))
        yE = torch.tensor(yE, dtype=torch.float32)
        mE = torch.ones_like(yE)


        yC = []
        for c in self.calam_cols:
            v = row.get(c, np.nan)
            yC.append(0 if pd.isna(v) else int(v))
        yC = torch.tensor(yC, dtype=torch.float32)
        mC = torch.ones_like(yC)

        return x, yE, mE, yC, mC

# Model: one backbone, two heads
class MultiTaskEffB0(nn.Module):
    def __init__(self, n_effects: int, n_calam: int):
        super().__init__()
        try:
            weights = models.EfficientNet_B0_Weights.IMAGENET1K_V1
        except:
            weights = None
        self.backbone = models.efficientnet_b0(weights=weights)
        in_feat = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Identity()
        self.pool = nn.AdaptiveAvgPool2d((1,1))
        self.dropout = nn.Dropout(0.2)
        self.head_effects = nn.Linear(in_feat, n_effects)
        self.head_calam   = nn.Linear(in_feat, n_calam)

    def forward(self, x):
        feats = self.backbone.features(x)
        feats = self.pool(feats).flatten(1)
        feats = self.dropout(feats)
        outE = self.head_effects(feats)
        outC = self.head_calam(feats)
        return outE, outC


def compute_pos_weight(loader, head_idx: int, n_classes: int):
    pos = torch.zeros(n_classes)
    tot = torch.zeros(n_classes)
    with torch.no_grad():
        for batch in loader:
            y = batch[head_idx]
            m = batch[head_idx+1]
            pos += (y*m).sum(dim=0).cpu()
            tot += m.sum(dim=0).cpu()
    pos = torch.clamp(pos, min=1.0)
    neg = torch.clamp(tot - pos, min=1.0)
    return neg / pos


def train_multitask(
    images_dir=DEFAULT_IMAGES_DIR,
    effects_csv=DEFAULT_EFFECTS_CSV,
    calam_csv=DEFAULT_CALAM_CSV,
    out_dir=OUT_DIR,
    img_size=IMG_SIZE,
    batch_size=BATCH,
    epochs=EPOCHS,
    lr=LR,
    val_split=VAL_SPLIT,
    test_split=TEST_SPLIT,
    alpha=ALPHA,
    beta=BETA,
    seed=SEED
):
    set_seed(seed)


    df_all, effect_cols, calam_cols = make_merged_df(images_dir, effects_csv, calam_csv)
    n = len(df_all)
    if n == 0:
        raise ValueError("No images found after matching CSVs to disk. Check names/case/paths.")


    n_test = int(round(n * test_split))
    n_val  = int(round(n * val_split))
    n_train = n - n_val - n_test
    if n_train < 1:
        take = 1 - n_train
        take_from_val = min(take, n_val); n_val -= take_from_val; take -= take_from_val
        take_from_test = min(take, n_test); n_test -= take_from_test
        n_train = n - n_val - n_test
    print(f"Split -> train:{n_train} val:{n_val} test:{n_test}")

    dstrain, dsval, dstest = random_split(df_all, [n_train, n_val, n_test], generator=torch.Generator().manual_seed(seed))
    ds_tr = MultiTaskDataset(dstrain.dataset.iloc[dstrain.indices], effect_cols, calam_cols, img_size, augment=True)
    ds_va = MultiTaskDataset(dsval.dataset.iloc[dsval.indices], effect_cols, calam_cols, img_size, augment=False)
    ds_te = MultiTaskDataset(dstest.dataset.iloc[dstest.indices], effect_cols, calam_cols, img_size, augment=False) if n_test>0 else None

    tr_loader = DataLoader(ds_tr, batch_size=batch_size, shuffle=True,  num_workers=2, pin_memory=True)
    va_loader = DataLoader(ds_va, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    te_loader = DataLoader(ds_te, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True) if ds_te else None

    model = MultiTaskEffB0(n_effects=len(effect_cols), n_calam=len(calam_cols)).to(device)


    posw_e = compute_pos_weight(tr_loader, head_idx=1, n_classes=len(effect_cols)).to(device)
    posw_c = compute_pos_weight(tr_loader, head_idx=3, n_classes=len(calam_cols)).to(device)
    crit_e = nn.BCEWithLogitsLoss(pos_weight=posw_e, reduction="none")
    crit_c = nn.BCEWithLogitsLoss(pos_weight=posw_c, reduction="none")

    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)

    Path(out_dir).mkdir(parents=True, exist_ok=True)
    best_macro = -1.0
    best_path = str(Path(out_dir)/"best.pt")

    for ep in range(1, epochs+1):

        model.train(); run_loss=0.0; seen=0
        for x, yE, mE, yC, mC in tqdm(tr_loader, desc=f"Epoch {ep}/{epochs}"):
            x, yE, mE, yC, mC = x.to(device), yE.to(device), mE.to(device), yC.to(device), mC.to(device)
            opt.zero_grad()
            oE, oC = model(x)
            lossE = (crit_e(oE, yE) * mE).sum() / torch.clamp(mE.sum(), min=1.0)
            lossC = (crit_c(oC, yC) * mC).sum() / torch.clamp(mC.sum(), min=1.0)
            loss = alpha*lossE + beta*lossC
            loss.backward(); opt.step()
            run_loss += loss.item() * x.size(0); seen += x.size(0)
        sch.step()
        train_loss = run_loss / max(seen,1)


        model.eval()
        with torch.no_grad():
            yE_true=[]; yE_pred=[]
            yC_true=[]; yC_pred=[]
            for x, yE, mE, yC, mC in va_loader:
                x = x.to(device)
                oE, oC = model(x)
                pE = torch.sigmoid(oE).cpu().numpy()
                pC = torch.sigmoid(oC).cpu().numpy()
                yE_true.append((yE.numpy(), mE.numpy())); yE_pred.append(pE)
                yC_true.append((yC.numpy(), mC.numpy())); yC_pred.append(pC)


            yE_t = np.concatenate([a for a,_ in yE_true], axis=0)
            mE_t = np.concatenate([b for _,b in yE_true], axis=0).astype(bool)
            pE_t = (np.concatenate(yE_pred, axis=0) >= 0.5).astype(int)
            f1_e = []
            for j in range(yE_t.shape[1]):
                mj = mE_t[:,j]
                if mj.sum()==0: continue
                f1_e.append(f1_score(yE_t[mj,j], pE_t[mj,j]))
            macro_e = float(np.mean(f1_e)) if f1_e else 0.0


            yC_t = np.concatenate([a for a,_ in yC_true], axis=0)
            mC_t = np.concatenate([b for _,b in yC_true], axis=0).astype(bool)
            pC_t = (np.concatenate(yC_pred, axis=0) >= 0.5).astype(int)
            f1_c = []
            for j in range(yC_t.shape[1]):
                mj = mC_t[:,j]
                if mj.sum()==0: continue
                f1_c.append(f1_score(yC_t[mj,j], pC_t[mj,j]))
            macro_c = float(np.mean(f1_c)) if f1_c else 0.0

            macro_avg = (macro_e + macro_c) / 2.0

        print(f"Epoch {ep}: train_loss={train_loss:.4f} | val_macroE={macro_e:.3f} val_macroC={macro_c:.3f} macro_avg={macro_avg:.3f}")

        if macro_avg > best_macro:
            best_macro = macro_avg
            torch.save({
                "state_dict": model.state_dict(),
                "effect_classes": effect_cols,
                "calamity_classes": calam_cols,
                "img_size": img_size
            }, best_path)
            with open(Path(out_dir)/"labels.json","w") as f:
                json.dump({"effects":effect_cols, "calamities":calam_cols}, f, indent=2)
            print(" Saved best to", best_path)


    if te_loader is not None:
        print("\n== TEST ==")
        ckpt = torch.load(best_path, map_location=device)
        model.load_state_dict(ckpt["state_dict"]); model.eval()
        yE_true=[]; yE_pred=[]; yC_true=[]; yC_pred=[]
        with torch.no_grad():
            for x, yE, mE, yC, mC in te_loader:
                x = x.to(device)
                oE, oC = model(x)
                pE = torch.sigmoid(oE).cpu().numpy()
                pC = torch.sigmoid(oC).cpu().numpy()
                yE_true.append((yE.numpy(), mE.numpy())); yE_pred.append(pE)
                yC_true.append((yC.numpy(), mC.numpy())); yC_pred.append(pC)


        yE_t = np.concatenate([a for a,_ in yE_true], axis=0)
        mE_t = np.concatenate([b for _,b in yE_true], axis=0).astype(bool)
        pE_t = (np.concatenate(yE_pred, axis=0) >= 0.5).astype(int)
        f1_e = [f1_score(yE_t[mE_t[:,j],j], pE_t[mE_t[:,j],j]) for j in range(yE_t.shape[1]) if mE_t[:,j].sum()>0]
        print("Effects per-class F1:", [round(x,3) for x in f1_e], "Macro:", round(float(np.mean(f1_e)) if f1_e else 0.0,3))

        yC_t = np.concatenate([a for a,_ in yC_true], axis=0)
        mC_t = np.concatenate([b for _,b in yC_true], axis=0).astype(bool)
        pC_t = (np.concatenate(yC_pred, axis=0) >= 0.5).astype(int)
        f1_c = [f1_score(yC_t[mC_t[:,j],j], pC_t[mC_t[:,j],j]) for j in range(yC_t.shape[1]) if mC_t[:,j].sum()>0]
        print("Calamities per-class F1:", [round(x,3) for x in f1_c], "Macro:", round(float(np.mean(f1_c)) if f1_c else 0.0,3))

    print("\nDone. Artifacts in:", out_dir)
    return best_path


@torch.no_grad()
def predict_image(img_path: str, ckpt_path: str):
    ckpt = torch.load(ckpt_path, map_location=device)
    eff = ckpt["effect_classes"]; cal = ckpt["calamity_classes"]; size = ckpt["img_size"]
    model = MultiTaskEffB0(n_effects=len(eff), n_calam=len(cal)).to(device)
    model.load_state_dict(ckpt["state_dict"]); model.eval()

    tf = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
    ])
    im = Image.open(img_path).convert("RGB")
    x = tf(im).unsqueeze(0).to(device)
    oE, oC = model(x)
    pE = torch.sigmoid(oE).squeeze(0).cpu().numpy()
    pC = torch.sigmoid(oC).squeeze(0).cpu().numpy()
    return dict(effects=list(zip(eff, pE.tolist())), calamities=list(zip(cal, pC.tolist())))


Device: cuda


In [6]:
best_ckpt = train_multitask(
    images_dir="/content/drive/MyDrive/blue beacon dataset/images1",
    effects_csv="/content/drive/MyDrive/blue beacon dataset/annotations/Copy of effects.csv - Sheet1 (1).csv",
    calam_csv="/content/drive/MyDrive/blue beacon dataset/annotations/Copy of calamity.csv - Sheet1.csv",
    out_dir="out_multitask",
    img_size=256,
    batch_size=16,
    epochs=20,
    lr=3e-4,
    val_split=0.10,
    test_split=0.00,
    alpha=1.0,
    beta=1.0,
    seed=42
)


Split -> train:13 val:2 test:0
Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth


100%|██████████| 20.5M/20.5M [00:00<00:00, 215MB/s]
Epoch 1/20: 100%|██████████| 1/1 [00:01<00:00,  1.98s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 1: train_loss=1.8239 | val_macroE=0.222 val_macroC=0.333 macro_avg=0.278
✅ Saved best to out_multitask/best.pt


Epoch 2/20: 100%|██████████| 1/1 [00:00<00:00,  1.10it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 2: train_loss=1.6962 | val_macroE=0.222 val_macroC=0.333 macro_avg=0.278


Epoch 3/20: 100%|██████████| 1/1 [00:00<00:00,  1.10it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 3: train_loss=1.5634 | val_macroE=0.444 val_macroC=0.333 macro_avg=0.389
✅ Saved best to out_multitask/best.pt


Epoch 4/20: 100%|██████████| 1/1 [00:00<00:00,  1.58it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 4: train_loss=1.4029 | val_macroE=0.556 val_macroC=0.667 macro_avg=0.611
✅ Saved best to out_multitask/best.pt


Epoch 5/20: 100%|██████████| 1/1 [00:00<00:00,  1.57it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 5: train_loss=1.3287 | val_macroE=0.556 val_macroC=0.667 macro_avg=0.611


Epoch 6/20: 100%|██████████| 1/1 [00:00<00:00,  1.40it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 6: train_loss=1.1892 | val_macroE=0.556 val_macroC=0.556 macro_avg=0.556


Epoch 7/20: 100%|██████████| 1/1 [00:00<00:00,  1.53it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 7: train_loss=1.1103 | val_macroE=0.556 val_macroC=0.556 macro_avg=0.556


Epoch 8/20: 100%|██████████| 1/1 [00:00<00:00,  1.53it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 8: train_loss=1.0825 | val_macroE=0.556 val_macroC=0.556 macro_avg=0.556


Epoch 9/20: 100%|██████████| 1/1 [00:00<00:00,  1.55it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 9: train_loss=0.9442 | val_macroE=0.556 val_macroC=0.556 macro_avg=0.556


Epoch 10/20: 100%|██████████| 1/1 [00:00<00:00,  1.49it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 10: train_loss=0.8758 | val_macroE=0.556 val_macroC=0.556 macro_avg=0.556


Epoch 11/20: 100%|██████████| 1/1 [00:00<00:00,  1.55it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 11: train_loss=0.8722 | val_macroE=0.556 val_macroC=0.556 macro_avg=0.556


Epoch 12/20: 100%|██████████| 1/1 [00:00<00:00,  1.52it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 12: train_loss=0.8444 | val_macroE=0.556 val_macroC=0.556 macro_avg=0.556


Epoch 13/20: 100%|██████████| 1/1 [00:00<00:00,  1.54it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 13: train_loss=0.7834 | val_macroE=0.556 val_macroC=0.556 macro_avg=0.556


Epoch 14/20: 100%|██████████| 1/1 [00:00<00:00,  1.58it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 14: train_loss=0.7358 | val_macroE=0.556 val_macroC=0.556 macro_avg=0.556


Epoch 15/20: 100%|██████████| 1/1 [00:00<00:00,  1.56it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 15: train_loss=0.7316 | val_macroE=0.556 val_macroC=0.556 macro_avg=0.556


Epoch 16/20: 100%|██████████| 1/1 [00:00<00:00,  1.15it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 16: train_loss=0.6928 | val_macroE=0.556 val_macroC=0.556 macro_avg=0.556


Epoch 17/20: 100%|██████████| 1/1 [00:00<00:00,  1.14it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 17: train_loss=0.6681 | val_macroE=0.667 val_macroC=0.556 macro_avg=0.611


Epoch 18/20: 100%|██████████| 1/1 [00:00<00:00,  1.12it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 18: train_loss=0.6410 | val_macroE=0.667 val_macroC=0.556 macro_avg=0.611


Epoch 19/20: 100%|██████████| 1/1 [00:00<00:00,  1.54it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 19: train_loss=0.6596 | val_macroE=0.667 val_macroC=0.556 macro_avg=0.611


Epoch 20/20: 100%|██████████| 1/1 [00:00<00:00,  1.61it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 20: train_loss=0.6701 | val_macroE=0.667 val_macroC=0.667 macro_avg=0.667
✅ Saved best to out_multitask/best.pt

Done. Artifacts in: out_multitask


In [7]:

import torch, json
from PIL import Image
from torchvision import transforms, models
import torch.nn as nn

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

CKPT_PATH = "/content/out_multitask/best.pt"
IMG_PATH  = "/content/drive/MyDrive/blue beacon dataset/images1/oilspill_1.jpg"


class MultiTaskEffB0(nn.Module):
    def __init__(self, n_effects: int, n_calam: int):
        super().__init__()
        try:
            weights = models.EfficientNet_B0_Weights.IMAGENET1K_V1
        except:
            weights = None
        self.backbone = models.efficientnet_b0(weights=weights)
        in_feat = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Identity()
        self.pool = nn.AdaptiveAvgPool2d((1,1))
        self.dropout = nn.Dropout(0.2)
        self.head_effects = nn.Linear(in_feat, n_effects)
        self.head_calam   = nn.Linear(in_feat, n_calam)

    def forward(self, x):
        feats = self.backbone.features(x)
        feats = self.pool(feats).flatten(1)
        feats = self.dropout(feats)
        return self.head_effects(feats), self.head_calam(feats)


ckpt = torch.load(CKPT_PATH, map_location=device)
eff_labels = ckpt["effect_classes"]
cal_labels = ckpt["calamity_classes"]
img_size   = ckpt.get("img_size", 256)

model = MultiTaskEffB0(n_effects=len(eff_labels), n_calam=len(cal_labels)).to(device)
model.load_state_dict(ckpt["state_dict"])
model.eval()


tf = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

im = Image.open(IMG_PATH).convert("RGB")
x = tf(im).unsqueeze(0).to(device)

with torch.no_grad():
    outE, outC = model(x)
    pE = torch.sigmoid(outE).squeeze(0).cpu().numpy()
    pC = torch.sigmoid(outC).squeeze(0).cpu().numpy()


def topk(names, probs, k=5):
    order = sorted(range(len(names)), key=lambda i: probs[i], reverse=True)[:k]
    return [(names[i], float(probs[i])) for i in order]

print("\nTop Effects:")
for name, p in topk(eff_labels, pE, k=min(5, len(eff_labels))):
    print(f"  {name:24s}  {p:.3f}")

print("\nTop Calamities:")
for name, p in topk(cal_labels, pC, k=min(5, len(cal_labels))):
    print(f"  {name:24s}  {p:.3f}")


thresholds = {lbl: 0.50 for lbl in cal_labels}
for k in thresholds:
    if k.lower() in ["tsunami"]: thresholds[k] = 0.80
    if k.lower() in ["oil_spill","ship_boat_wreckage"]: thresholds[k] = 0.70

accepted_calams = [(cal_labels[i], float(pC[i])) for i in range(len(cal_labels)) if pC[i] >= thresholds[cal_labels[i]]]
abstain = (len(accepted_calams) == 0)

result = {
    "effects": [{"label": eff_labels[i], "p": float(pE[i])} for i in range(len(eff_labels))],
    "calamities": [{"label": cal_labels[i], "p": float(pC[i])} for i in range(len(cal_labels))],
    "accepted_calamities": sorted(accepted_calams, key=lambda x: x[1], reverse=True),
    "abstain": abstain
}

print("\nAccepted calamities (thresholded):", result["accepted_calamities"])
print("Abstain on cause?:", result["abstain"])


Device: cuda

Top Effects:
  oil_sheen                 0.860
  debris_on_beach           0.353
  flooded_area              0.292

Top Calamities:
  oil_spill                 0.847
  debris                    0.306
  floods                    0.198

Accepted calamities (thresholded): [('oil_spill', 0.8465505242347717)]
Abstain on cause?: False
