<a href="https://www.kaggle.com/code/poorvaahuja/camourflage-improvement-research?scriptVersionId=270513988" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
import os, random, math, time
from pathlib import Path
from tqdm import tqdm
import numpy as np
from PIL import Image

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms, models
from torchvision.transforms import RandAugment
import timm

In [4]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
from collections import Counter

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if device == "cuda": torch.cuda.manual_seed_all(SEED)

Device: cuda


In [6]:
IMG_SIZE = 224
BATCH_SIZE = 8          # adjust if OOM
EPOCHS = 20
NUM_WORKERS = 0         # set 0 if worker issues on Kaggle
LR = 3e-4
LABEL_SMOOTH = 0.1
SAVE_PATH = "best_model.pth"
USE_SEGMENTATION = True

# Loss weights from PDF suggestion
ALPHA_DOM = 0.5
BETA_SUPCON = 0.2
ETA_CONS = 0.1

# Mixup/CutMix probabilities and alphas
PROB_MIXUP = 0.5
PROB_CUTMIX = 0.5
MIXUP_ALPHA = 0.2
CUTMIX_ALPHA = 1.0

# warmup epochs
WARMUP_EPOCHS = 5

# early stopping
EARLY_STOPPING_PATIENCE = 8
FREEZE_EPOCHS = 10

In [7]:
info_dir  = "/kaggle/input/cod10k/COD10K-v3/Info"
train_dir = "/kaggle/input/cod10k/COD10K-v3/Train"
test_dir  = "/kaggle/input/cod10k/COD10K-v3/Test"

# these exist in Info/
train_cam_txt = os.path.join(info_dir, "CAM_train.txt")
train_noncam_txt = os.path.join(info_dir, "NonCAM_train.txt")
test_cam_txt = os.path.join(info_dir, "CAM_test.txt")
test_noncam_txt = os.path.join(info_dir, "NonCAM_test.txt")

In [8]:
# def merge_txts(txt_list, root_dir, transform, use_masks=True):
#     samples = []
#     for txt_file, label in txt_list:
#         with open(txt_file, "r") as f:
#             for line in f:
#                 fname = line.strip().split()[0]
#                 samples.append((fname, label))
#     return COD10KDataset(root_dir, txt_file=None, transform=transform, use_masks=use_masks)

Noise + Transform

In [9]:
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=0.05):
        self.mean = mean
        self.std = std
    def __call__(self, tensor):
        noise = torch.randn(tensor.size()) * self.std + self.mean
        noisy_tensor = tensor + noise
        return torch.clamp(noisy_tensor, 0., 1.)
    def __repr__(self):
        return self.__class__.__name__ + f'(mean={self.mean}, std={self.std})'

weak_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    AddGaussianNoise(0., 0.02),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
strong_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ColorJitter(0.4,0.4,0.4,0.1),
    RandAugment(num_ops=2, magnitude=9),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.ToTensor(),
    AddGaussianNoise(0., 0.05),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

val_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])


In [10]:
class COD10KDataset(Dataset):
    """
    Parses lines of the form:
      <filename> <label>
    Returns (weak_image, strong_image, label, mask_or_none)
    """
    def __init__(self, root_dir, txt_file, weak_transform=None, strong_transform=None, use_masks=True):
        self.root_dir = root_dir
        self.weak_transform = weak_transform
        self.strong_transform = strong_transform
        self.use_masks = use_masks
        self.samples = []

        if not os.path.exists(txt_file):
            raise RuntimeError(f"TXT file not found: {txt_file}")

        with open(txt_file, "r") as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) == 0:
                    continue
                if len(parts) >= 2:
                    fname = parts[0]
                    try:
                        lbl = int(parts[1])
                    except:
                        lbl = 1 if "CAM" in fname else 0
                else:
                    fname = parts[0]
                    lbl = 1 if "CAM" in parts[0] else 0
                img_path = os.path.join(self.root_dir, "Image", fname)
                if os.path.exists(img_path):
                    self.samples.append((fname, lbl))
                else:
                    pass

        if len(self.samples) == 0:
            raise RuntimeError(f"No samples found. Check {txt_file} and the Image folder under {root_dir}.")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        fname, lbl = self.samples[idx]
        img_path = os.path.join(self.root_dir, "Image", fname)
        img = Image.open(img_path).convert("RGB")

        if self.weak_transform:
            weak = self.weak_transform(img)
        else:
            weak = transforms.ToTensor()(img)
        if self.strong_transform:
            strong = self.strong_transform(img)
        else:
            strong = weak.clone()

        mask = None
        if self.use_masks:
            mask_name = os.path.splitext(fname)[0] + ".png"
            mask_path = os.path.join(self.root_dir, "GT_Object", mask_name)
            if os.path.exists(mask_path):
                m = Image.open(mask_path).convert("L").resize((IMG_SIZE, IMG_SIZE))
                m = np.array(m).astype(np.float32) / 255.0
                mask = torch.from_numpy((m > 0.5).astype(np.float32)).unsqueeze(0)

        return weak, strong, lbl, mask

# helper to build weighted sampler
def build_weighted_sampler(dataset):
    labels = [lbl for (_, lbl) in dataset.samples]
    counts = Counter(labels)
    total = len(labels)
    class_weights = {c: total / (counts[c] * len(counts)) for c in counts}
    weights = [class_weights[lbl] for lbl in labels]
    sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)
    return sampler

In [11]:
train_ds = COD10KDataset(train_dir, train_cam_txt, weak_transform=weak_tf, strong_transform=strong_tf, use_masks=USE_SEGMENTATION)
val_ds   = COD10KDataset(test_dir,  test_cam_txt,  weak_transform=val_tf, strong_transform=None, use_masks=USE_SEGMENTATION)

train_sampler = build_weighted_sampler(train_ds)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=train_sampler, num_workers=NUM_WORKERS)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
print("Train samples:", len(train_ds), "Val samples:", len(val_ds))

Train samples: 3038 Val samples: 2026


## Backbones

In [12]:
class DenseNetExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = models.densenet201(pretrained=pretrained).features
    def forward(self, x):
        feats = []
        for name, layer in self.features._modules.items():
            x = layer(x)
            if name in ["denseblock1","denseblock2","denseblock3","denseblock4"]:
                feats.append(x)
        return feats


class MobileNetExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = models.mobilenet_v3_large(pretrained=pretrained).features
    def forward(self, x):
        feats = []
        out = x
        for i, layer in enumerate(self.features):
            out = layer(out)
            if i in (2,5,9,12):
                feats.append(out)
        if len(feats) < 4:
            feats.append(out)
        return feats

In [13]:
class SwinExtractor(nn.Module):
    def __init__(self, model_name="swin_tiny_patch4_window7_224", pretrained=True):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, features_only=True)
    def forward(self, x):
        return self.model(x)

In [14]:
class CBAMlite(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, max(channels//reduction,4), 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(max(channels//reduction,4), channels, 1),
            nn.Sigmoid()
        )
        self.spatial = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1, groups=channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, 1, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return x * self.se(x) * self.spatial(x)


In [15]:
class GatedFusion(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.g_fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(dim, max(dim//4, 4), 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(max(dim//4,4), dim, 1),
            nn.Sigmoid()
        )
    def forward(self, H, X):
        if H.shape[2:] != X.shape[2:]:
            X = F.interpolate(X, size=H.shape[2:], mode='bilinear', align_corners=False)
        g = self.g_fc(H)
        return g * H + (1 - g) * X

In [16]:
class CrossAttention(nn.Module):
    def __init__(self, d_cnn, d_swin, d_out):
        super().__init__()
        self.q = nn.Linear(d_cnn, d_out)
        self.k = nn.Linear(d_swin, d_out)
        self.v = nn.Linear(d_swin, d_out)
        self.scale = d_out ** -0.5
    def forward(self, feat_cnn, feat_swin):
        B, Cc, H, W = feat_cnn.shape
        q = feat_cnn.permute(0,2,3,1).reshape(B, H*W, Cc)
        if feat_swin.dim() == 4:
            Bs, Cs, Hs, Ws = feat_swin.shape
            kv = feat_swin.permute(0,2,3,1).reshape(Bs, Hs*Ws, Cs)
        else:
            kv = feat_swin
        K = self.k(kv)
        V = self.v(kv)
        Q = self.q(q)
        attn = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = torch.matmul(attn, V)
        out = out.reshape(B, H, W, -1).permute(0,3,1,2)
        return out


## Segmentation Decoder

In [17]:
class SegDecoder(nn.Module):
    def __init__(self, in_channels_list, mid_channels=128):
        super().__init__()
        self.projs = nn.ModuleList([nn.Conv2d(c, mid_channels, 1) for c in in_channels_list])
        self.conv = nn.Sequential(nn.Conv2d(mid_channels * len(in_channels_list), mid_channels, 3, padding=1), nn.ReLU(inplace=True))
        self.out = nn.Conv2d(mid_channels, 1, 1)
    def forward(self, feat_list):
        target_size = feat_list[0].shape[2:]
        ups = []
        for f, p in zip(feat_list, self.projs):
            x = p(f)
            if x.shape[2:] != target_size:
                x = F.interpolate(x, size=target_size, mode='bilinear', align_corners=False)
            ups.append(x)
        x = torch.cat(ups, dim=1)
        x = self.conv(x)
        x = self.out(x)
        return x

## Probing Backbones

In [18]:
dnet = DenseNetExtractor().to(device).eval()
mnet = MobileNetExtractor().to(device).eval()
snet = SwinExtractor().to(device).eval()
with torch.no_grad():
    dummy = torch.randn(1,3,IMG_SIZE,IMG_SIZE).to(device)
    featsA = dnet(dummy)
    featsB = mnet(dummy)
    featsS = snet(dummy)
chA = [f.shape[1] for f in featsA]
chB = [f.shape[1] for f in featsB]
chS = [f.shape[1] for f in featsS]
print("DenseNet channels:", chA)
print("MobileNet channels:", chB)
print("Swin channels:", chS)

Downloading: "https://download.pytorch.org/models/densenet201-c1103571.pth" to /root/.cache/torch/hub/checkpoints/densenet201-c1103571.pth
100%|██████████| 77.4M/77.4M [00:00<00:00, 191MB/s] 
Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_large-8738ca79.pth
100%|██████████| 21.1M/21.1M [00:00<00:00, 134MB/s] 


model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

DenseNet channels: [256, 512, 1792, 1920]
MobileNet channels: [24, 40, 80, 112]
Swin channels: [56, 28, 14, 7]


# Fusion Model (DenseNet + MobileNet + Swin cross attention)

In [19]:
class FusionWithSwin(nn.Module):
    def __init__(self, dense_chs, mobile_chs, swin_chs, d=256, use_seg=True, num_classes=2):
        super().__init__()
        self.backA = DenseNetExtractor()
        self.backB = MobileNetExtractor()
        self.backS = SwinExtractor()
        L = min(len(dense_chs), len(mobile_chs), len(swin_chs))
        self.L = L
        self.d = d
        self.alignA = nn.ModuleList([nn.Conv2d(c, d, 1) for c in dense_chs[:L]])
        self.alignB = nn.ModuleList([nn.Conv2d(c, d, 1) for c in mobile_chs[:L]])
        self.cbamA = nn.ModuleList([CBAMlite(d) for _ in range(L)])
        self.cbamB = nn.ModuleList([CBAMlite(d) for _ in range(L)])
        self.gates = nn.ModuleList([GatedFusion(d) for _ in range(L)])
        self.cross_atts = nn.ModuleList([CrossAttention(d, swin_chs[i], d) for i in range(L)])
        self.reduce = nn.Conv2d(d * L, d, 1)
        self.classifier = nn.Sequential(
            nn.Linear(d, 512), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(512, 128), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )
        self.use_seg = use_seg
        if self.use_seg:
            self.segdecoder = SegDecoder([d] * L, mid_channels=128)

        # Domain head for DANN (simple MLP)
        self.domain_head = nn.Sequential(
            nn.Linear(d, 256), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(256, 2)
        )

    def forward(self, x, grl_lambda=0.0):
        fa = self.backA(x)
        fb = self.backB(x)
        fs = self.backS(x)
        fused_feats = []
        aligned_for_dec = []
        for i in range(self.L):
            a = self.alignA[i](fa[i])
            a = self.cbamA[i](a)
            b = self.alignB[i](fb[i])
            b = self.cbamB[i](b)
            if b.shape[2:] != a.shape[2:]:
                b = F.interpolate(b, size=a.shape[2:], mode='bilinear', align_corners=False)
            fused = self.gates[i](a, b)
            swin_feat = fs[i]
            swin_att = self.cross_atts[i](fused, swin_feat)
            if swin_att.shape[2:] != fused.shape[2:]:
                swin_att = F.interpolate(swin_att, size=fused.shape[2:], mode='bilinear', align_corners=False)
            fused = fused + swin_att
            fused_feats.append(fused)
            aligned_for_dec.append(fused)
        target = fused_feats[-1]
        upsampled = [F.interpolate(f, size=target.shape[2:], mode='bilinear', align_corners=False) if f.shape[2:] != target.shape[2:] else f for f in fused_feats]
        concat = torch.cat(upsampled, dim=1)
        fused = self.reduce(concat)
        z = F.adaptive_avg_pool2d(fused, (1,1)).view(fused.size(0), -1)
        logits = self.classifier(z)
        out = {"logits": logits, "feat": z}
        if self.use_seg:
            out["seg"] = self.segdecoder(aligned_for_dec)

        # Domain prediction with GRL effect applied by multiplying lambda and reversing sign in custom grad fn
        if grl_lambda > 0.0:
            # GRL implemented outside (we'll pass z through GRL function)
            pass
        out["domain_logits"] = self.domain_head(z)
        return out

# instantiate model
model = FusionWithSwin(dense_chs=chA, mobile_chs=chB, swin_chs=chS, d=256, use_seg=USE_SEGMENTATION, num_classes=2).to(device)
print("Model parameters (M):", sum(p.numel() for p in model.parameters())/1e6)

Model parameters (M): 51.586615


In [20]:
class LabelSmoothingCE(nn.Module):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.s = smoothing
    def forward(self, logits, target):
        c = logits.size(-1)
        logp = F.log_softmax(logits, dim=-1)
        with torch.no_grad():
            true_dist = torch.zeros_like(logp)
            true_dist.fill_(self.s / (c - 1))
            true_dist.scatter_(1, target.unsqueeze(1), 1.0 - self.s)
        return (-true_dist * logp).sum(dim=-1).mean()

class FocalLoss(nn.Module):
    def __init__(self, gamma=1.5):
        super().__init__()
        self.gamma = gamma
    def forward(self, logits, target):
        prob = F.softmax(logits, dim=1)
        pt = prob.gather(1, target.unsqueeze(1)).squeeze(1)
        ce = F.cross_entropy(logits, target, reduction='none')
        loss = ((1 - pt) ** self.gamma) * ce
        return loss.mean()

def dice_loss_logits(pred_logits, target):
    pred = torch.sigmoid(pred_logits)
    target = target.float()
    inter = (pred * target).sum(dim=(1,2,3))
    denom = pred.sum(dim=(1,2,3)) + target.sum(dim=(1,2,3))
    dice = (2 * inter + 1e-6) / (denom + 1e-6)
    return 1.0 - dice.mean()

clf_loss_ce = LabelSmoothingCE(LABEL_SMOOTH)
clf_loss_focal = FocalLoss(gamma=1.5)
seg_bce = nn.BCEWithLogitsLoss()

def dice_loss(pred, target, smooth=1.0):
    pred = torch.sigmoid(pred)
    num = 2 * (pred * target).sum() + smooth
    den = pred.sum() + target.sum() + smooth
    return 1 - (num / den)

def seg_loss_fn(pred, mask):
    if pred.shape[-2:] != mask.shape[-2:]:
        pred = F.interpolate(pred, size=mask.shape[-2:], mode="bilinear", align_corners=False)
    return F.binary_cross_entropy_with_logits(pred, mask) + dice_loss(pred, mask)


In [21]:
#Supervised contrastive Loss
class SupConLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.cos = nn.CosineSimilarity(dim=-1)
    def forward(self, features, labels):
        # features: [N, D], labels: [N]
        device = features.device
        f = F.normalize(features, dim=1)
        sim = torch.matmul(f, f.T) / self.temperature  # [N,N]
        labels = labels.contiguous().view(-1,1)
        mask = torch.eq(labels, labels.T).float().to(device)
        # remove diagonal
        logits_max, _ = torch.max(sim, dim=1, keepdim=True)
        logits = sim - logits_max.detach()
        exp_logits = torch.exp(logits) * (1 - torch.eye(len(features), device=device))
        denom = exp_logits.sum(1, keepdim=True)
        # for each i, positive samples are where mask==1 (excluding self)
        pos_mask = mask - torch.eye(len(features), device=device)
        pos_exp = (exp_logits * pos_mask).sum(1)
        # avoid divide by zero
        loss = -torch.log((pos_exp + 1e-8) / (denom + 1e-8) + 1e-12)
        # average only across anchors that have positives
        valid = (pos_mask.sum(1) > 0).float()
        loss = (loss * valid).sum() / (valid.sum() + 1e-8)
        return loss
supcon_loss_fn = SupConLoss(temperature=0.07)

In [22]:
# Domain Adversarial: Gradient Reversal Layer (GRL)

from torch.autograd import Function
class GradReverse(Function):
    @staticmethod
    def forward(ctx, x, l):
        ctx.l = l
        return x.view_as(x)
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.l, None

def grad_reverse(x, l=1.0):
    return GradReverse.apply(x, l)

In [23]:
# Optimizer + scheduler + mixed precision + clipping
# -----------------------------
# param groups: smaller LR for backbones, larger for heads
backbone_params = []
head_params = []
for name, param in model.named_parameters():
    if any(k in name for k in ['backA', 'backB', 'backS']):  # backbone names
        backbone_params.append(param)
    else:
        head_params.append(param)

opt = torch.optim.AdamW([
    {'params': backbone_params, 'lr': LR * 0.2},
    {'params': head_params, 'lr': LR}
], lr=LR, weight_decay=1e-4)

# warmup + cosine schedule
def get_cosine_with_warmup_scheduler(optimizer, warmup_epochs, total_epochs, last_epoch=-1):
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return float(epoch) / float(max(1.0, warmup_epochs))
        # cosine from warmup -> total
        t = (epoch - warmup_epochs) / float(max(1, total_epochs - warmup_epochs))
        return 0.5 * (1.0 + math.cos(math.pi * t))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)

scheduler = get_cosine_with_warmup_scheduler(opt, WARMUP_EPOCHS, EPOCHS)

scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))

# -----------------------------
# Mixup & CutMix helpers
# -----------------------------
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)   # use builtin int
    cut_h = int(H * cut_rat)   # use builtin int

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def apply_mixup(x, y, alpha=MIXUP_ALPHA):
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0))
    mixed_x = lam * x + (1 - lam) * x[idx]
    y_a, y_b = y, y[idx]
    return mixed_x, y_a, y_b, lam

def apply_cutmix(x, y, alpha=CUTMIX_ALPHA):
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0))
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    new_x = x.clone()
    new_x[:, :, bby1:bby2, bbx1:bbx2] = x[idx, :, bby1:bby2, bbx1:bbx2]
    lam_adjusted = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size(-1) * x.size(-2)))
    return new_x, y, y[idx], lam_adjusted


  scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))


## Training

In [24]:
best_vf1 = 0.0
best_epoch = 0
patience_count = 0

def compute_combined_clf_loss(logits, targets, mix_info=None, use_focal=False):
    # mix_info: (mode, y_a, y_b, lam) or None
    if mix_info is None:
        if use_focal:
            return clf_loss_focal(logits, targets)
        else:
            return clf_loss_ce(logits, targets)
    else:
        # mixup/cutmix: soft labels
        y_a, y_b, lam = mix_info
        if use_focal:
            # focal is not designed for soft labels; approximate by weighted CE
            loss = lam * F.cross_entropy(logits, y_a) + (1 - lam) * F.cross_entropy(logits, y_b)
        else:
            loss = lam * clf_loss_ce(logits, y_a) + (1 - lam) * clf_loss_ce(logits, y_b)
        return loss
for epoch in range(1, EPOCHS+1):
    # freeze/unfreeze strategy
    if epoch <= FREEZE_EPOCHS:
        # freeze early layers of backbones
        for name, p in model.named_parameters():
            if any(k in name for k in ['backA.features.conv0','backA.features.norm0','backA.features.denseblock1']):
                p.requires_grad = False
    else:
        for p in model.parameters():
            p.requires_grad = True


    model.train()
    running_loss = 0.0
    y_true, y_pred = [], []
    n_batches = 0

    for weak_imgs, strong_imgs, labels, masks in tqdm(train_loader, desc=f"Train {epoch}/{EPOCHS}"):
        weak_imgs = weak_imgs.to(device); strong_imgs = strong_imgs.to(device)
        labels = labels.to(device)
        if masks is not None:
            masks = masks.to(device)

        # combine weak and strong optionally for the classifier path; we'll feed weak to model for main forward
        imgs = weak_imgs

        # optionally apply mixup/cutmix on imgs (on weak view)
        mix_info = None
        rand = random.random()
        if rand < PROB_MIXUP:
            imgs, y_a, y_b, lam = apply_mixup(imgs, labels)
            mix_info = (y_a.to(device), y_b.to(device), lam)
        elif rand < PROB_MIXUP + PROB_CUTMIX:
            imgs, y_a, y_b, lam = apply_cutmix(imgs, labels)
            mix_info = (y_a.to(device), y_b.to(device), lam)

        with torch.cuda.amp.autocast(enabled=(device=="cuda")):
            out = model(imgs)  # returns logits, feat, seg, domain_logits
            logits = out["logits"]
            feat = out["feat"]
            seg_out = out.get("seg", None)
            domain_logits = out.get("domain_logits", None)

            # classification loss (label-smoothing or focal)
            clf_loss = compute_combined_clf_loss(logits, labels, mix_info=mix_info, use_focal=False)

            # segmentation loss if available & mask present
            seg_loss = 0.0
            if USE_SEGMENTATION and (masks is not None):
                seg_pred = out["seg"]
                seg_loss = seg_loss_fn(seg_pred, masks)
            # supcon loss on features (use features from weak)
            supcon_loss = supcon_loss_fn(feat, labels)

            # consistency: forward strong view and compare predictions
            out_strong = model(strong_imgs)
            logits_strong = out_strong["logits"]
            probs_weak = F.softmax(logits.detach(), dim=1)
            probs_strong = F.softmax(logits_strong, dim=1)
            # L2 between probability vectors (could be KL)
            cons_loss = F.mse_loss(probs_weak, probs_strong)
            # domain adversarial: need domain labels; for now assume source-only (skip) unless domain label available
            # To support domain adaptation, user should provide target dataloader and stack batches with domain labels
            dom_loss = 0.0
            # (If domain labels are provided, compute dom logits after GRL: domain_logits_grl = domain_head(grad_reverse(feat, l)))
            # then dom_loss = criterion(domain_logits_grl, domain_labels)

            total_loss = clf_loss + seg_loss + BETA_SUPCON * supcon_loss + ETA_CONS * cons_loss + ALPHA_DOM * dom_loss

        opt.zero_grad()
        scaler.scale(total_loss).backward()
        # gradient clipping
        scaler.unscale_(opt)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(opt)
        scaler.update()

        running_loss += total_loss.item()
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(logits.argmax(1).cpu().numpy())
        n_batches += 1

    scheduler.step()

    # metrics
    acc = accuracy_score(y_true, y_pred)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="macro", zero_division=0)
    print(f"[Epoch {epoch}] Train Loss: {running_loss/max(1,n_batches):.4f} Acc: {acc:.4f} Prec: {prec:.4f} Rec: {rec:.4f} F1: {f1:.4f}")

    # -------------------
    # VALIDATION
    # -------------------
    model.eval()
    val_y_true, val_y_pred = [], []
    val_loss = 0.0
    with torch.no_grad():
        for weak_imgs, _, labels, masks in val_loader:
            imgs = weak_imgs.to(device)
            labels = labels.to(device)
            if masks is not None:
                masks = masks.to(device)

            out = model(imgs)
            logits = out["logits"]
            feat = out["feat"]
            seg_out = out.get("seg", None)
            loss = compute_combined_clf_loss(logits, labels, mix_info=None, use_focal=False)
            if USE_SEGMENTATION and (masks is not None):
                loss += seg_loss_fn(seg_out, masks)
            val_loss += loss.item()

            val_y_true.extend(labels.cpu().numpy())
            val_y_pred.extend(logits.argmax(1).cpu().numpy())

    vacc = accuracy_score(val_y_true, val_y_pred)
    vprec, vrec, vf1, _ = precision_recall_fscore_support(val_y_true, val_y_pred, average="macro", zero_division=0)
    print(f"[Epoch {epoch}] Val Loss: {val_loss/max(1,len(val_loader)):.4f} Acc: {vacc:.4f} Prec: {vprec:.4f} Rec: {vrec:.4f} F1: {vf1:.4f}")

    # early stopping & save best
    if vf1 > best_vf1:
        best_vf1 = vf1
        best_epoch = epoch
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "opt_state": opt.state_dict(),
            "best_vf1": best_vf1
        }, SAVE_PATH)
        patience_count = 0
        print(f"Saved best model at epoch {epoch} (F1 {best_vf1:.4f})")
    else:
        patience_count += 1
        if patience_count >= EARLY_STOPPING_PATIENCE:
            print("Early stopping triggered.")
            break

print("Training finished. Best val F1:", best_vf1, "at epoch", best_epoch)


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 1/20: 100%|██████████| 380/380 [05:25<00:00,  1.17it/s]


[Epoch 1] Train Loss: 2.1491 Acc: 0.6359 Prec: 0.5000 Rec: 0.3180 F1: 0.3887
[Epoch 1] Val Loss: 2.2695 Acc: 0.8850 Prec: 0.5000 Rec: 0.4425 F1: 0.4695
Saved best model at epoch 1 (F1 0.4695)


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 2/20: 100%|██████████| 380/380 [05:17<00:00,  1.20it/s]


[Epoch 2] Train Loss: 1.3793 Acc: 0.9967 Prec: 0.5000 Rec: 0.4984 F1: 0.4992
[Epoch 2] Val Loss: 1.1828 Acc: 1.0000 Prec: 1.0000 Rec: 1.0000 F1: 1.0000
Saved best model at epoch 2 (F1 1.0000)


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 3/20: 100%|██████████| 380/380 [05:11<00:00,  1.22it/s]


[Epoch 3] Train Loss: 1.2713 Acc: 1.0000 Prec: 1.0000 Rec: 1.0000 F1: 1.0000
[Epoch 3] Val Loss: 1.1991 Acc: 1.0000 Prec: 1.0000 Rec: 1.0000 F1: 1.0000


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 4/20: 100%|██████████| 380/380 [05:09<00:00,  1.23it/s]


[Epoch 4] Train Loss: 1.2525 Acc: 1.0000 Prec: 1.0000 Rec: 1.0000 F1: 1.0000
[Epoch 4] Val Loss: 1.1289 Acc: 1.0000 Prec: 1.0000 Rec: 1.0000 F1: 1.0000


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 5/20:  81%|████████  | 307/380 [04:11<00:59,  1.22it/s]


KeyboardInterrupt: 

In [25]:
# Test-time augmentation (TTA) helper
# -----------------------------
def tta_predict(model, img_pil, device=device, scales=[224, 288, 320], flip=True):
    model.eval()
    logits_accum = None
    with torch.no_grad():
        for s in scales:
            tf = transforms.Compose([
                transforms.Resize((s, s)),
                transforms.ToTensor(),
                transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
            ])
            x = tf(img_pil).unsqueeze(0).to(device)
            out = model(x)
            logits = out["logits"]
            if flip:
                x_f = torch.flip(x, dims=[3])
                logits_f = model(x_f)["logits"]
                logits = (logits + logits_f) / 2.0
            if logits_accum is None:
                logits_accum = logits
            else:
                logits_accum += logits
    logits_accum /= len(scales)
    return logits_accum

In [None]:
# Grad-CAM helper (very simple)
# -----------------------------
def get_gradcam_heatmap(model, input_tensor, target_class=None, layer_name='backA.features.denseblock4'):
    """
    Very light Grad-CAM: find a conv layer by name, register hook, compute gradients wrt target logit.
    Returns upsampled heatmap (H,W) normalized in [0,1].
    """
    model.eval()
    # find layer
    target_module = None
    for name, module in model.named_modules():
        if name == layer_name:
            target_module = module
            break
    if target_module is None:
        raise RuntimeError("Layer not found for Grad-CAM: " + layer_name)

    activations = []
    gradients = []

    def forward_hook(module, input, output):
        activations.append(output.detach())
    def backward_hook(module, grad_in, grad_out):
        gradients.append(grad_out[0].detach())

    h1 = target_module.register_forward_hook(forward_hook)
    h2 = target_module.register_full_backward_hook(backward_hook)

    out = model(input_tensor)
    logits = out["logits"]
    if target_class is None:
        target_class = logits.argmax(1).item()
    loss = logits[:, target_class].sum()
    model.zero_grad()
    loss.backward(retain_graph=True)

    act = activations[0]  # [B,C,H,W]
    grad = gradients[0]   # [B,C,H,W]
    weights = grad.mean(dim=(2,3), keepdim=True)  # [B,C,1,1]
    cam = (weights * act).sum(dim=1, keepdim=True)  # [B,1,H,W]
    cam = F.relu(cam)
    cam = F.interpolate(cam, size=(input_tensor.size(2), input_tensor.size(3)), mode='bilinear', align_corners=False)
    cam = cam.squeeze().cpu().numpy()
    cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
    h1.remove(); h2.remove()
    return cam