<a href="https://www.kaggle.com/code/anuhskaa/camouflage-improvement-research-2?scriptVersionId=271776260" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
import os, random, math, time
from pathlib import Path
from tqdm import tqdm
import numpy as np
from PIL import Image

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms, models
from torchvision.transforms import RandAugment
import timm

In [3]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
from collections import Counter

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if device == "cuda": torch.cuda.manual_seed_all(SEED)

Device: cuda


In [5]:
IMG_SIZE = 224
BATCH_SIZE = 8          # adjust if OOM
EPOCHS = 20
NUM_WORKERS = 0         # set 0 if worker issues on Kaggle
LR = 3e-4
LABEL_SMOOTH = 0.1
SAVE_PATH = "best_model.pth"
USE_SEGMENTATION = True

# Loss weights from PDF suggestion
ALPHA_DOM = 0.5
BETA_SUPCON = 0.2
ETA_CONS = 0.1

# Mixup/CutMix probabilities and alphas
PROB_MIXUP = 0.5
PROB_CUTMIX = 0.5
MIXUP_ALPHA = 0.2
CUTMIX_ALPHA = 1.0

# warmup epochs
WARMUP_EPOCHS = 5

# early stopping
EARLY_STOPPING_PATIENCE = 8
FREEZE_EPOCHS = 10

In [6]:
info_dir  = "/kaggle/input/cod10k/COD10K-v3/Info"
train_dir_cod = "/kaggle/input/cod10k/COD10K-v3/Train/Image"
test_dir_cod  = "/kaggle/input/cod10k/COD10K-v3/Test/Image"

# these exist in Info/
combined_train_cam  = os.path.join(info_dir, "CAM_train.txt")
combined_train_noncam  = os.path.join(info_dir, "NonCAM_train.txt")
combined_test_cam = os.path.join(info_dir, "CAM_test.txt")
combined_test_noncam = os.path.join(info_dir, "NonCAM_test.txt")

In [7]:
info_dir2 ="/kaggle/input/camo-coco/CAMO_COCO/Info"
train_cam_txt2 = os.path.join(info_dir2, "camo_train.txt")
train_noncam_txt2 = os.path.join(info_dir2, "non_camo_train.txt")
test_cam_txt2 = os.path.join(info_dir2, "camo_test.txt")
test_noncam_txt2 = os.path.join(info_dir2, "non_camo_test.txt")

In [8]:
train_dir_camo = {
    "cam": "/kaggle/input/camo-coco/CAMO_COCO/Camouflage/Images/Train",
    "noncam": "/kaggle/input/camo-coco/CAMO_COCO/Non_Camouflage/Images/Train"
}
test_dir_camo = {
    "cam": "/kaggle/input/camo-coco/CAMO_COCO/Camouflage/Images/Test",
    "noncam": "/kaggle/input/camo-coco/CAMO_COCO/Non_Camouflage/Images/Test"
}

In [9]:
train_dir = [train_dir_cod, train_dir_camo["cam"], train_dir_camo["noncam"]]
test_dir  = [test_dir_cod, test_dir_camo["cam"], test_dir_camo["noncam"]]


In [10]:
# Combine train files
with open(combined_train_cam) as f1, open(train_cam_txt2) as f2:
    train_cam_txt = f1.read().splitlines() + f2.read().splitlines()

with open(combined_train_noncam) as f1, open(train_noncam_txt2) as f2:
    train_noncam_txt = f1.read().splitlines() + f2.read().splitlines()

# Combine test files
with open(combined_test_cam) as f1, open(test_cam_txt2) as f2:
    test_cam_txt = f1.read().splitlines() + f2.read().splitlines()

with open(combined_test_noncam) as f1, open(test_noncam_txt2) as f2:
    test_noncam_txt= f1.read().splitlines() + f2.read().splitlines()


Noise + Transform

In [11]:
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=0.05):
        self.mean = mean
        self.std = std
    def __call__(self, tensor):
        noise = torch.randn(tensor.size()) * self.std + self.mean
        noisy_tensor = tensor + noise
        return torch.clamp(noisy_tensor, 0., 1.)
    def __repr__(self):
        return self.__class__.__name__ + f'(mean={self.mean}, std={self.std})'

weak_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    AddGaussianNoise(0., 0.02),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
strong_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ColorJitter(0.4,0.4,0.4,0.1),
    RandAugment(num_ops=2, magnitude=9),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.ToTensor(),
    AddGaussianNoise(0., 0.05),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

val_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])


In [12]:
import os
from collections import Counter
from torch.utils.data import Dataset, WeightedRandomSampler
from PIL import Image
import torch
import numpy as np
from torchvision import transforms

# NOTE: IMG_SIZE should be defined globally (e.g., 224 for Swin compatibility)
# IMG_SIZE = 352 # Use your defined IMG_SIZE

class COD10KDataset(Dataset):
    """
    Dataset supporting COD10K + other datasets (like CAMO).
    Each line in txt_file(s) must contain:
        <filename> <label>
    or just <filename> (then label inferred by name or folder).
    """

    def __init__(self, root_dirs, txt_files, weak_transform=None, strong_transform=None, use_masks=True):
        # root_dir can be a list of directories
        if isinstance(root_dirs, str):
            root_dirs = [root_dirs]
        self.root_dirs = root_dirs
        self.weak_transform = weak_transform
        self.strong_transform = strong_transform
        self.use_masks = use_masks
        self.samples = []
        
        # txt_file can also be a list
        if isinstance(txt_files, str):
            txt_files = [txt_files]

        all_lines = []
        for t in txt_files:
            if not os.path.exists(t):
                raise RuntimeError(f"TXT file not found: {t}")
            with open(t, "r") as f:
                lines = f.readlines()
                all_lines.extend([(line.strip(), t) for line in lines if line.strip()])

        for line, src_txt in all_lines:
            parts = line.split()
            if len(parts) == 0:
                continue

            fname = parts[0]
            if len(parts) >= 2:
                try:
                    lbl = int(parts[1])
                except:
                    lbl = 1 if "CAM" in fname or "cam" in fname else 0
            else:
                lbl = 1 if "CAM" in fname or "cam" in fname else 0

            # Find which dataset this file belongs to
            found = False
            # Search in the immediate root, or common Image subfolders, or deep nested folders (CAMO-COCO style)
            search_subs = [
                "",  # If image is directly in root_dir (less common)
                "Image", "Imgs", "images", "JPEGImages", "img", # Common image folders (COD10K style)
                # Deep CAMO-COCO style paths (Images/Train/ or Images/Test/ subfolders)
                "Images/Train", "Images/Test",
            ]
            
            # Since fnames might contain sub-directories, we only want the filename itself for checking
            base_fname = os.path.basename(fname) 

            for rdir in self.root_dirs:
                for sub in search_subs:
                    img_path = os.path.join(rdir, sub, base_fname)
                    # For CAMO-COCO Non-Camouflage roots, the image structure is often Images/Train/[fname] 
                    # but the noncam root is already specific (Non_Camouflage). The robust search handles this.
                    
                    if os.path.exists(img_path):
                        self.samples.append((img_path, lbl, rdir))
                        found = True
                        break
                if found:
                    break

            if not found:
                print(f"[WARN] File not found in any root: {base_fname} (Searched in {self.root_dirs})")

        if len(self.samples) == 0:
            raise RuntimeError(f"No valid samples found from {txt_files}")

        print(f"✅ Loaded {len(self.samples)} samples from {len(self.root_dirs)} root directories.")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        # We need IMG_SIZE defined here or passed in during object creation. 
        # Since it is a global variable in the notebook, we'll assume it's available.
        global IMG_SIZE 
        
        img_path, lbl, rdir = self.samples[idx]
        img = Image.open(img_path).convert("RGB")

        if self.weak_transform:
            weak = self.weak_transform(img)
        else:
            weak = transforms.ToTensor()(img)
        if self.strong_transform:
            strong = self.strong_transform(img)
        else:
            strong = weak.clone()

        mask = None
        if self.use_masks:
            mask_name = os.path.splitext(os.path.basename(img_path))[0] + ".png"
            
            # Try multiple common mask dirs
            found_mask = False
            for mask_dir in ["GT_Object", "GT", "masks", "Mask"]:
                mask_path = os.path.join(rdir, mask_dir, mask_name)
                
                # Special handling for CAMO-COCO Non-Camouflage which is missing mask subfolders
                if "Non_Camouflage" in rdir and mask_dir in ["GT", "GT_Object"]:
                    # Non-Camouflage generally doesn't have GTs, skip searching for them explicitly
                    # This check is redundant if the file doesn't exist, but serves as a quick exit
                    continue 
                
                if os.path.exists(mask_path):
                    m = Image.open(mask_path).convert("L").resize((IMG_SIZE, IMG_SIZE))
                    m = np.array(m).astype(np.float32) / 255.0
                    mask = torch.from_numpy((m > 0.5).astype(np.float32)).unsqueeze(0)
                    found_mask = True
                    break

            # FIX: If mask is not found but use_masks is True, return a zero mask Tensor
            if mask is None:
                # Create a zero mask of size [1, IMG_SIZE, IMG_SIZE] (1 channel, H, W)
                # This ensures the batch collator only deals with Tensors, resolving the error.
                mask = torch.zeros((1, IMG_SIZE, IMG_SIZE), dtype=torch.float32)
                # print(f"[DEBUG] No mask found for {mask_name}, returning zero mask.") # Debugging line
                
        return weak, strong, lbl, mask


def build_weighted_sampler(dataset):
    # This must be updated to use the new sample structure (img_path, lbl, rdir)
    labels = [lbl for (_, lbl, _) in dataset.samples] 
    counts = Counter(labels)
    total = len(labels)
    
    # Ensure there are at least two classes to calculate class_weights
    if len(counts) <= 1:
        print(f"[WARN] Only {len(counts)} class(es) found. Using equal weights.")
        weights = [1.0] * total
    else:
        class_weights = {c: total / (counts[c] * len(counts)) for c in counts}
        weights = [class_weights[lbl] for lbl in labels]
        
    return WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)


In [13]:
# Assuming ALL_ROOT_DIRS, ALL_TRAIN_TXTS, and ALL_VAL_TXTS have been defined correctly
# in the cells preceding this one.

# --- PATH AND VARIABLE DEFINITIONS (Consolidated based on user's history and images) ---

# [9] COD10K PATHS (Based on image_22ec1d.png)
info_dir  = "/kaggle/input/cod10k/COD10K-v3/Info"
train_dir_cod = "/kaggle/input/cod10k/COD10K-v3/Train" # Renamed from original train_dir
test_dir_cod  = "/kaggle/input/cod10k/COD10K-v3/Test"  # Renamed from original test_dir

# COD10K Info files (filenames lists) - Using info_dir
train_cam_txt = os.path.join(info_dir, "CAM_train.txt")
train_noncam_txt = os.path.join(info_dir, "NonCAM_train.txt")
test_cam_txt = os.path.join(info_dir, "CAM_test.txt")
test_noncam_txt = os.path.join(info_dir, "NonCAM_test.txt")

# [10] CAMO-COCO PATHS (Based on image_22ec1d.png)
info_dir2 = "/kaggle/input/camo-coco/CAMO_COCO/Info"

# CAMO-COCO Info files (filenames lists) - Using info_dir2
train_cam_txt2 = os.path.join(info_dir2, "camo_train.txt")
train_noncam_txt2 = os.path.join(info_dir2, "non_camo_train.txt")
test_cam_txt2 = os.path.join(info_dir2, "camo_test.txt")
test_noncam_txt2 = os.path.join(info_dir2, "non_camo_test.txt")

# [11] CAMO-COCO ROOT DIRECTORIES (rdir passed to Dataset, contains Images/ and GT/)
# These are the actual root folders containing the image subdirectories (Images/Train, Images/Test)
train_dir_camo_cam = "/kaggle/input/camo-coco/CAMO_COCO/Camouflage"
train_dir_camo_noncam = "/kaggle/input/camo-coco/CAMO_COCO/Non_Camouflage"
test_dir_camo_cam = "/kaggle/input/camo-coco/CAMO_COCO/Camouflage" # Same root as train for cam
test_dir_camo_noncam = "/kaggle/input/camo-coco/CAMO_COCO/Non_Camouflage" # Same root as train for noncam


# CONSOLIDATED LISTS FOR DATASET INITIALIZATION

# 1. All Root Directories where the images/masks are physically stored:
ALL_ROOT_DIRS = [
    train_dir_cod,          # COD10K Train/
    test_dir_cod,           # COD10K Test/
    train_dir_camo_cam,     # CAMO-COCO Camouflage/
    train_dir_camo_noncam   # CAMO-COCO Non_Camouflage/
]

# 2. All Train TXT files (containing filenames):
ALL_TRAIN_TXTS = [
    train_cam_txt, 
    train_noncam_txt,
    train_cam_txt2,
    train_noncam_txt2
]

# 3. All Test/Validation TXT files (containing filenames):
ALL_VAL_TXTS = [
    test_cam_txt,
    test_noncam_txt,
    test_cam_txt2,
    test_noncam_txt2
]

# --- END PATH DEFINITIONS ---

# Create the final unified datasets
train_ds = COD10KDataset(
    root_dirs=ALL_ROOT_DIRS, 
    txt_files=ALL_TRAIN_TXTS, 
    weak_transform=weak_tf, 
    strong_transform=strong_tf, 
    use_masks=USE_SEGMENTATION
)

val_ds = COD10KDataset(
    root_dirs=ALL_ROOT_DIRS,  
    txt_files=ALL_VAL_TXTS,  
    weak_transform=val_tf, 
    strong_transform=None, 
    use_masks=USE_SEGMENTATION
)

# Build Sampler and DataLoaders
train_sampler = build_weighted_sampler(train_ds)

# Make sure DataLoader is imported (it should be in the initial imports of your notebook)
# BATCH_SIZE and NUM_WORKERS should also be defined previously

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=train_sampler, num_workers=NUM_WORKERS)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# Corrected print statement: len(val_ds) is an integer, so you cannot call len() on it.
print("Total Train samples:", len(train_ds), "Total Val samples:", len(val_ds))


✅ Loaded 7999 samples from 4 root directories.
✅ Loaded 4500 samples from 4 root directories.
Total Train samples: 7999 Total Val samples: 4500


## Backbones

In [14]:
class DenseNetExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = models.densenet201(pretrained=pretrained).features
    def forward(self, x):
        feats = []
        for name, layer in self.features._modules.items():
            x = layer(x)
            if name in ["denseblock1","denseblock2","denseblock3","denseblock4"]:
                feats.append(x)
        return feats


class MobileNetExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = models.mobilenet_v3_large(pretrained=pretrained).features
    def forward(self, x):
        feats = []
        out = x
        for i, layer in enumerate(self.features):
            out = layer(out)
            if i in (2,5,9,12):
                feats.append(out)
        if len(feats) < 4:
            feats.append(out)
        return feats

In [15]:
class SwinExtractor(nn.Module):
    def __init__(self, model_name="swin_tiny_patch4_window7_224", pretrained=True):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, features_only=True)
    def forward(self, x):
        return self.model(x)

In [16]:
class CBAMlite(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, max(channels//reduction,4), 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(max(channels//reduction,4), channels, 1),
            nn.Sigmoid()
        )
        self.spatial = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1, groups=channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, 1, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return x * self.se(x) * self.spatial(x)


In [17]:
class GatedFusion(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.g_fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(dim, max(dim//4, 4), 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(max(dim//4,4), dim, 1),
            nn.Sigmoid()
        )
    def forward(self, H, X):
        if H.shape[2:] != X.shape[2:]:
            X = F.interpolate(X, size=H.shape[2:], mode='bilinear', align_corners=False)
        g = self.g_fc(H)
        return g * H + (1 - g) * X

In [18]:
class CrossAttention(nn.Module):
    def __init__(self, d_cnn, d_swin, d_out):
        super().__init__()
        self.q = nn.Linear(d_cnn, d_out)
        self.k = nn.Linear(d_swin, d_out)
        self.v = nn.Linear(d_swin, d_out)
        self.scale = d_out ** -0.5
    def forward(self, feat_cnn, feat_swin):
        B, Cc, H, W = feat_cnn.shape
        q = feat_cnn.permute(0,2,3,1).reshape(B, H*W, Cc)
        if feat_swin.dim() == 4:
            Bs, Cs, Hs, Ws = feat_swin.shape
            kv = feat_swin.permute(0,2,3,1).reshape(Bs, Hs*Ws, Cs)
        else:
            kv = feat_swin
        K = self.k(kv)
        V = self.v(kv)
        Q = self.q(q)
        attn = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = torch.matmul(attn, V)
        out = out.reshape(B, H, W, -1).permute(0,3,1,2)
        return out


## Segmentation Decoder

In [19]:
class SegDecoder(nn.Module):
    def __init__(self, in_channels_list, mid_channels=128):
        super().__init__()
        self.projs = nn.ModuleList([nn.Conv2d(c, mid_channels, 1) for c in in_channels_list])
        self.conv = nn.Sequential(nn.Conv2d(mid_channels * len(in_channels_list), mid_channels, 3, padding=1), nn.ReLU(inplace=True))
        self.out = nn.Conv2d(mid_channels, 1, 1)
    def forward(self, feat_list):
        target_size = feat_list[0].shape[2:]
        ups = []
        for f, p in zip(feat_list, self.projs):
            x = p(f)
            if x.shape[2:] != target_size:
                x = F.interpolate(x, size=target_size, mode='bilinear', align_corners=False)
            ups.append(x)
        x = torch.cat(ups, dim=1)
        x = self.conv(x)
        x = self.out(x)
        return x

## Probing Backbones

In [20]:
dnet = DenseNetExtractor().to(device).eval()
mnet = MobileNetExtractor().to(device).eval()
snet = SwinExtractor().to(device).eval()
with torch.no_grad():
    dummy = torch.randn(1,3,IMG_SIZE,IMG_SIZE).to(device)
    featsA = dnet(dummy)
    featsB = mnet(dummy)
    featsS = snet(dummy)
chA = [f.shape[1] for f in featsA]
chB = [f.shape[1] for f in featsB]
chS = [f.shape[1] for f in featsS]
print("DenseNet channels:", chA)
print("MobileNet channels:", chB)
print("Swin channels:", chS)

Downloading: "https://download.pytorch.org/models/densenet201-c1103571.pth" to /root/.cache/torch/hub/checkpoints/densenet201-c1103571.pth
100%|██████████| 77.4M/77.4M [00:00<00:00, 208MB/s]
Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_large-8738ca79.pth
100%|██████████| 21.1M/21.1M [00:00<00:00, 144MB/s] 


model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

DenseNet channels: [256, 512, 1792, 1920]
MobileNet channels: [24, 40, 80, 112]
Swin channels: [56, 28, 14, 7]


# Fusion Model (DenseNet + MobileNet + Swin cross attention)

In [21]:
class FusionWithSwin(nn.Module):
    def __init__(self, dense_chs, mobile_chs, swin_chs, d=256, use_seg=True, num_classes=2):
        super().__init__()
        self.backA = DenseNetExtractor()
        self.backB = MobileNetExtractor()
        self.backS = SwinExtractor()
        L = min(len(dense_chs), len(mobile_chs), len(swin_chs))
        self.L = L
        self.d = d
        self.alignA = nn.ModuleList([nn.Conv2d(c, d, 1) for c in dense_chs[:L]])
        self.alignB = nn.ModuleList([nn.Conv2d(c, d, 1) for c in mobile_chs[:L]])
        self.cbamA = nn.ModuleList([CBAMlite(d) for _ in range(L)])
        self.cbamB = nn.ModuleList([CBAMlite(d) for _ in range(L)])
        self.gates = nn.ModuleList([GatedFusion(d) for _ in range(L)])
        self.cross_atts = nn.ModuleList([CrossAttention(d, swin_chs[i], d) for i in range(L)])
        self.reduce = nn.Conv2d(d * L, d, 1)
        self.classifier = nn.Sequential(
            nn.Linear(d, 512), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(512, 128), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )
        self.use_seg = use_seg
        if self.use_seg:
            self.segdecoder = SegDecoder([d] * L, mid_channels=128)

        # Domain head for DANN (simple MLP)
        self.domain_head = nn.Sequential(
            nn.Linear(d, 256), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(256, 2)
        )

    def forward(self, x, grl_lambda=0.0):
        fa = self.backA(x)
        fb = self.backB(x)
        fs = self.backS(x)
        fused_feats = []
        aligned_for_dec = []
        for i in range(self.L):
            a = self.alignA[i](fa[i])
            a = self.cbamA[i](a)
            b = self.alignB[i](fb[i])
            b = self.cbamB[i](b)
            if b.shape[2:] != a.shape[2:]:
                b = F.interpolate(b, size=a.shape[2:], mode='bilinear', align_corners=False)
            fused = self.gates[i](a, b)
            swin_feat = fs[i]
            swin_att = self.cross_atts[i](fused, swin_feat)
            if swin_att.shape[2:] != fused.shape[2:]:
                swin_att = F.interpolate(swin_att, size=fused.shape[2:], mode='bilinear', align_corners=False)
            fused = fused + swin_att
            fused_feats.append(fused)
            aligned_for_dec.append(fused)
        target = fused_feats[-1]
        upsampled = [F.interpolate(f, size=target.shape[2:], mode='bilinear', align_corners=False) if f.shape[2:] != target.shape[2:] else f for f in fused_feats]
        concat = torch.cat(upsampled, dim=1)
        fused = self.reduce(concat)
        z = F.adaptive_avg_pool2d(fused, (1,1)).view(fused.size(0), -1)
        logits = self.classifier(z)
        out = {"logits": logits, "feat": z}
        if self.use_seg:
            out["seg"] = self.segdecoder(aligned_for_dec)

        # Domain prediction with GRL effect applied by multiplying lambda and reversing sign in custom grad fn
        if grl_lambda > 0.0:
            # GRL implemented outside (we'll pass z through GRL function)
            pass
        out["domain_logits"] = self.domain_head(z)
        return out

# instantiate model
model = FusionWithSwin(dense_chs=chA, mobile_chs=chB, swin_chs=chS, d=256, use_seg=USE_SEGMENTATION, num_classes=2).to(device)
print("Model parameters (M):", sum(p.numel() for p in model.parameters())/1e6)

Model parameters (M): 51.586615


In [22]:
class LabelSmoothingCE(nn.Module):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.s = smoothing
    def forward(self, logits, target):
        c = logits.size(-1)
        logp = F.log_softmax(logits, dim=-1)
        with torch.no_grad():
            true_dist = torch.zeros_like(logp)
            true_dist.fill_(self.s / (c - 1))
            true_dist.scatter_(1, target.unsqueeze(1), 1.0 - self.s)
        return (-true_dist * logp).sum(dim=-1).mean()

class FocalLoss(nn.Module):
    def __init__(self, gamma=1.5):
        super().__init__()
        self.gamma = gamma
    def forward(self, logits, target):
        prob = F.softmax(logits, dim=1)
        pt = prob.gather(1, target.unsqueeze(1)).squeeze(1)
        ce = F.cross_entropy(logits, target, reduction='none')
        loss = ((1 - pt) ** self.gamma) * ce
        return loss.mean()

def dice_loss_logits(pred_logits, target):
    pred = torch.sigmoid(pred_logits)
    target = target.float()
    inter = (pred * target).sum(dim=(1,2,3))
    denom = pred.sum(dim=(1,2,3)) + target.sum(dim=(1,2,3))
    dice = (2 * inter + 1e-6) / (denom + 1e-6)
    return 1.0 - dice.mean()

clf_loss_ce = LabelSmoothingCE(LABEL_SMOOTH)
clf_loss_focal = FocalLoss(gamma=1.5)
seg_bce = nn.BCEWithLogitsLoss()

def dice_loss(pred, target, smooth=1.0):
    pred = torch.sigmoid(pred)
    num = 2 * (pred * target).sum() + smooth
    den = pred.sum() + target.sum() + smooth
    return 1 - (num / den)

def seg_loss_fn(pred, mask):
    if pred.shape[-2:] != mask.shape[-2:]:
        pred = F.interpolate(pred, size=mask.shape[-2:], mode="bilinear", align_corners=False)
    return F.binary_cross_entropy_with_logits(pred, mask) + dice_loss(pred, mask)


In [23]:
#Supervised contrastive Loss
class SupConLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.cos = nn.CosineSimilarity(dim=-1)
    def forward(self, features, labels):
        # features: [N, D], labels: [N]
        device = features.device
        f = F.normalize(features, dim=1)
        sim = torch.matmul(f, f.T) / self.temperature  # [N,N]
        labels = labels.contiguous().view(-1,1)
        mask = torch.eq(labels, labels.T).float().to(device)
        # remove diagonal
        logits_max, _ = torch.max(sim, dim=1, keepdim=True)
        logits = sim - logits_max.detach()
        exp_logits = torch.exp(logits) * (1 - torch.eye(len(features), device=device))
        denom = exp_logits.sum(1, keepdim=True)
        # for each i, positive samples are where mask==1 (excluding self)
        pos_mask = mask - torch.eye(len(features), device=device)
        pos_exp = (exp_logits * pos_mask).sum(1)
        # avoid divide by zero
        loss = -torch.log((pos_exp + 1e-8) / (denom + 1e-8) + 1e-12)
        # average only across anchors that have positives
        valid = (pos_mask.sum(1) > 0).float()
        loss = (loss * valid).sum() / (valid.sum() + 1e-8)
        return loss
supcon_loss_fn = SupConLoss(temperature=0.07)

In [24]:
# Domain Adversarial: Gradient Reversal Layer (GRL)

from torch.autograd import Function
class GradReverse(Function):
    @staticmethod
    def forward(ctx, x, l):
        ctx.l = l
        return x.view_as(x)
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.l, None

def grad_reverse(x, l=1.0):
    return GradReverse.apply(x, l)

In [25]:
# Optimizer + scheduler + mixed precision + clipping
# -----------------------------
# param groups: smaller LR for backbones, larger for heads
backbone_params = []
head_params = []
for name, param in model.named_parameters():
    if any(k in name for k in ['backA', 'backB', 'backS']):  # backbone names
        backbone_params.append(param)
    else:
        head_params.append(param)

opt = torch.optim.AdamW([
    {'params': backbone_params, 'lr': LR * 0.2},
    {'params': head_params, 'lr': LR}
], lr=LR, weight_decay=1e-4)

# warmup + cosine schedule
def get_cosine_with_warmup_scheduler(optimizer, warmup_epochs, total_epochs, last_epoch=-1):
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return float(epoch) / float(max(1.0, warmup_epochs))
        # cosine from warmup -> total
        t = (epoch - warmup_epochs) / float(max(1, total_epochs - warmup_epochs))
        return 0.5 * (1.0 + math.cos(math.pi * t))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)

scheduler = get_cosine_with_warmup_scheduler(opt, WARMUP_EPOCHS, EPOCHS)

scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))

# -----------------------------
# Mixup & CutMix helpers
# -----------------------------
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)   # use builtin int
    cut_h = int(H * cut_rat)   # use builtin int

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def apply_mixup(x, y, alpha=MIXUP_ALPHA):
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0))
    mixed_x = lam * x + (1 - lam) * x[idx]
    y_a, y_b = y, y[idx]
    return mixed_x, y_a, y_b, lam

def apply_cutmix(x, y, alpha=CUTMIX_ALPHA):
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0))
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    new_x = x.clone()
    new_x[:, :, bby1:bby2, bbx1:bbx2] = x[idx, :, bby1:bby2, bbx1:bbx2]
    lam_adjusted = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size(-1) * x.size(-2)))
    return new_x, y, y[idx], lam_adjusted


  scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))


## Training

In [26]:
best_vf1 = 0.0
best_epoch = 0
patience_count = 0

def compute_combined_clf_loss(logits, targets, mix_info=None, use_focal=False):
    # mix_info: (mode, y_a, y_b, lam) or None
    if mix_info is None:
        if use_focal:
            return clf_loss_focal(logits, targets)
        else:
            return clf_loss_ce(logits, targets)
    else:
        # mixup/cutmix: soft labels
        y_a, y_b, lam = mix_info
        if use_focal:
            # focal is not designed for soft labels; approximate by weighted CE
            loss = lam * F.cross_entropy(logits, y_a) + (1 - lam) * F.cross_entropy(logits, y_b)
        else:
            loss = lam * clf_loss_ce(logits, y_a) + (1 - lam) * clf_loss_ce(logits, y_b)
        return loss
for epoch in range(1, EPOCHS+1):
    # freeze/unfreeze strategy
    if epoch <= FREEZE_EPOCHS:
        # freeze early layers of backbones
        for name, p in model.named_parameters():
            if any(k in name for k in ['backA.features.conv0','backA.features.norm0','backA.features.denseblock1']):
                p.requires_grad = False
    else:
        for p in model.parameters():
            p.requires_grad = True


    model.train()
    running_loss = 0.0
    y_true, y_pred = [], []
    n_batches = 0

    for weak_imgs, strong_imgs, labels, masks in tqdm(train_loader, desc=f"Train {epoch}/{EPOCHS}"):
        weak_imgs = weak_imgs.to(device); strong_imgs = strong_imgs.to(device)
        labels = labels.to(device)
        if masks is not None:
            masks = masks.to(device)

        # combine weak and strong optionally for the classifier path; we'll feed weak to model for main forward
        imgs = weak_imgs

        # optionally apply mixup/cutmix on imgs (on weak view)
        mix_info = None
        rand = random.random()
        if rand < PROB_MIXUP:
            imgs, y_a, y_b, lam = apply_mixup(imgs, labels)
            mix_info = (y_a.to(device), y_b.to(device), lam)
        elif rand < PROB_MIXUP + PROB_CUTMIX:
            imgs, y_a, y_b, lam = apply_cutmix(imgs, labels)
            mix_info = (y_a.to(device), y_b.to(device), lam)

        with torch.cuda.amp.autocast(enabled=(device=="cuda")):
            out = model(imgs)  # returns logits, feat, seg, domain_logits
            logits = out["logits"]
            feat = out["feat"]
            seg_out = out.get("seg", None)
            domain_logits = out.get("domain_logits", None)

            # classification loss (label-smoothing or focal)
            clf_loss = compute_combined_clf_loss(logits, labels, mix_info=mix_info, use_focal=False)

            # segmentation loss if available & mask present
            seg_loss = 0.0
            if USE_SEGMENTATION and (masks is not None):
                seg_pred = out["seg"]
                seg_loss = seg_loss_fn(seg_pred, masks)
            # supcon loss on features (use features from weak)
            supcon_loss = supcon_loss_fn(feat, labels)

            # consistency: forward strong view and compare predictions
            out_strong = model(strong_imgs)
            logits_strong = out_strong["logits"]
            probs_weak = F.softmax(logits.detach(), dim=1)
            probs_strong = F.softmax(logits_strong, dim=1)
            # L2 between probability vectors (could be KL)
            cons_loss = F.mse_loss(probs_weak, probs_strong)
            # domain adversarial: need domain labels; for now assume source-only (skip) unless domain label available
            # To support domain adaptation, user should provide target dataloader and stack batches with domain labels
            dom_loss = 0.0
            # (If domain labels are provided, compute dom logits after GRL: domain_logits_grl = domain_head(grad_reverse(feat, l)))
            # then dom_loss = criterion(domain_logits_grl, domain_labels)

            total_loss = clf_loss + seg_loss + BETA_SUPCON * supcon_loss + ETA_CONS * cons_loss + ALPHA_DOM * dom_loss

        opt.zero_grad()
        scaler.scale(total_loss).backward()
        # gradient clipping
        scaler.unscale_(opt)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(opt)
        scaler.update()

        running_loss += total_loss.item()
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(logits.argmax(1).cpu().numpy())
        n_batches += 1

    scheduler.step()

    # metrics
    acc = accuracy_score(y_true, y_pred)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="macro", zero_division=0)
    print(f"[Epoch {epoch}] Train Loss: {running_loss/max(1,n_batches):.4f} Acc: {acc:.4f} Prec: {prec:.4f} Rec: {rec:.4f} F1: {f1:.4f}")

    # -------------------
    # VALIDATION
    # -------------------
    model.eval()
    val_y_true, val_y_pred = [], []
    val_loss = 0.0
    with torch.no_grad():
        for weak_imgs, _, labels, masks in val_loader:
            imgs = weak_imgs.to(device)
            labels = labels.to(device)
            if masks is not None:
                masks = masks.to(device)

            out = model(imgs)
            logits = out["logits"]
            feat = out["feat"]
            seg_out = out.get("seg", None)
            loss = compute_combined_clf_loss(logits, labels, mix_info=None, use_focal=False)
            if USE_SEGMENTATION and (masks is not None):
                loss += seg_loss_fn(seg_out, masks)
            val_loss += loss.item()

            val_y_true.extend(labels.cpu().numpy())
            val_y_pred.extend(logits.argmax(1).cpu().numpy())

    vacc = accuracy_score(val_y_true, val_y_pred)
    vprec, vrec, vf1, _ = precision_recall_fscore_support(val_y_true, val_y_pred, average="macro", zero_division=0)
    print(f"[Epoch {epoch}] Val Loss: {val_loss/max(1,len(val_loader)):.4f} Acc: {vacc:.4f} Prec: {vprec:.4f} Rec: {vrec:.4f} F1: {vf1:.4f}")

    # early stopping & save best
    if vf1 > best_vf1:
        best_vf1 = vf1
        best_epoch = epoch
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "opt_state": opt.state_dict(),
            "best_vf1": best_vf1
        }, SAVE_PATH)
        patience_count = 0
        print(f"Saved best model at epoch {epoch} (F1 {best_vf1:.4f})")
    else:
        patience_count += 1
        if patience_count >= EARLY_STOPPING_PATIENCE:
            print("Early stopping triggered.")
            break

print("Training finished. Best val F1:", best_vf1, "at epoch", best_epoch)


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 1/20: 100%|██████████| 1000/1000 [14:04<00:00,  1.18it/s]


[Epoch 1] Train Loss: 4.1229 Acc: 0.5057 Prec: 0.5048 Rec: 0.5045 F1: 0.4961
[Epoch 1] Val Loss: 2.4864 Acc: 0.5838 Prec: 0.5766 Rec: 0.5437 F1: 0.5079
Saved best model at epoch 1 (F1 0.5079)


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 2/20: 100%|██████████| 1000/1000 [13:18<00:00,  1.25it/s]


[Epoch 2] Train Loss: 2.8040 Acc: 0.6960 Prec: 0.6960 Rec: 0.6961 F1: 0.6960
[Epoch 2] Val Loss: 2.0169 Acc: 0.8076 Prec: 0.8227 Rec: 0.7915 F1: 0.7970
Saved best model at epoch 2 (F1 0.7970)


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 3/20: 100%|██████████| 1000/1000 [12:58<00:00,  1.28it/s]


[Epoch 3] Train Loss: 2.4700 Acc: 0.7731 Prec: 0.7733 Rec: 0.7733 F1: 0.7731
[Epoch 3] Val Loss: 1.4709 Acc: 0.9029 Prec: 0.9010 Rec: 0.9023 F1: 0.9016
Saved best model at epoch 3 (F1 0.9016)


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 4/20: 100%|██████████| 1000/1000 [12:47<00:00,  1.30it/s]


[Epoch 4] Train Loss: 2.3447 Acc: 0.7886 Prec: 0.7886 Rec: 0.7886 F1: 0.7886
[Epoch 4] Val Loss: 1.5159 Acc: 0.8982 Prec: 0.8979 Rec: 0.8950 F1: 0.8963


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 5/20: 100%|██████████| 1000/1000 [12:49<00:00,  1.30it/s]


[Epoch 5] Train Loss: 2.3592 Acc: 0.7820 Prec: 0.7820 Rec: 0.7820 F1: 0.7820
[Epoch 5] Val Loss: 1.4274 Acc: 0.8771 Prec: 0.8940 Rec: 0.8642 F1: 0.8714


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 6/20: 100%|██████████| 1000/1000 [12:51<00:00,  1.30it/s]


[Epoch 6] Train Loss: 2.3152 Acc: 0.7892 Prec: 0.7894 Rec: 0.7893 F1: 0.7892
[Epoch 6] Val Loss: 1.5198 Acc: 0.9042 Prec: 0.9063 Rec: 0.8992 F1: 0.9020
Saved best model at epoch 6 (F1 0.9020)


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 7/20: 100%|██████████| 1000/1000 [12:50<00:00,  1.30it/s]


[Epoch 7] Train Loss: 2.3145 Acc: 0.7952 Prec: 0.7954 Rec: 0.7955 F1: 0.7952
[Epoch 7] Val Loss: 1.4672 Acc: 0.8447 Prec: 0.8792 Rec: 0.8254 F1: 0.8335


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 8/20: 100%|██████████| 1000/1000 [12:45<00:00,  1.31it/s]


[Epoch 8] Train Loss: 2.2546 Acc: 0.8025 Prec: 0.8026 Rec: 0.8024 F1: 0.8024
[Epoch 8] Val Loss: 1.4292 Acc: 0.8896 Prec: 0.8987 Rec: 0.8801 F1: 0.8856


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 9/20: 100%|██████████| 1000/1000 [12:46<00:00,  1.30it/s]


[Epoch 9] Train Loss: 2.2150 Acc: 0.8060 Prec: 0.8062 Rec: 0.8059 F1: 0.8059
[Epoch 9] Val Loss: 1.3897 Acc: 0.9098 Prec: 0.9077 Rec: 0.9097 F1: 0.9086
Saved best model at epoch 9 (F1 0.9086)


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 10/20: 100%|██████████| 1000/1000 [12:46<00:00,  1.31it/s]


[Epoch 10] Train Loss: 2.1977 Acc: 0.8115 Prec: 0.8115 Rec: 0.8115 F1: 0.8115
[Epoch 10] Val Loss: 1.3956 Acc: 0.9078 Prec: 0.9080 Rec: 0.9044 F1: 0.9060


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 11/20: 100%|██████████| 1000/1000 [13:00<00:00,  1.28it/s]


[Epoch 11] Train Loss: 2.1152 Acc: 0.8240 Prec: 0.8243 Rec: 0.8237 F1: 0.8238
[Epoch 11] Val Loss: 1.3947 Acc: 0.9184 Prec: 0.9168 Rec: 0.9179 F1: 0.9173
Saved best model at epoch 11 (F1 0.9173)


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 12/20: 100%|██████████| 1000/1000 [12:58<00:00,  1.29it/s]


[Epoch 12] Train Loss: 2.0701 Acc: 0.8265 Prec: 0.8265 Rec: 0.8265 F1: 0.8265
[Epoch 12] Val Loss: 1.3643 Acc: 0.9180 Prec: 0.9213 Rec: 0.9127 F1: 0.9160


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 13/20: 100%|██████████| 1000/1000 [13:00<00:00,  1.28it/s]


[Epoch 13] Train Loss: 2.0014 Acc: 0.8449 Prec: 0.8449 Rec: 0.8448 F1: 0.8449
[Epoch 13] Val Loss: 1.3647 Acc: 0.9100 Prec: 0.9176 Rec: 0.9021 F1: 0.9071


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 14/20: 100%|██████████| 1000/1000 [12:57<00:00,  1.29it/s]


[Epoch 14] Train Loss: 1.9578 Acc: 0.8387 Prec: 0.8387 Rec: 0.8387 F1: 0.8387
[Epoch 14] Val Loss: 1.3658 Acc: 0.9129 Prec: 0.9139 Rec: 0.9090 F1: 0.9111


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 15/20: 100%|██████████| 1000/1000 [13:01<00:00,  1.28it/s]


[Epoch 15] Train Loss: 1.9128 Acc: 0.8504 Prec: 0.8504 Rec: 0.8504 F1: 0.8504
[Epoch 15] Val Loss: 1.3242 Acc: 0.9293 Prec: 0.9290 Rec: 0.9273 F1: 0.9281
Saved best model at epoch 15 (F1 0.9281)


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 16/20: 100%|██████████| 1000/1000 [13:01<00:00,  1.28it/s]


[Epoch 16] Train Loss: 1.8889 Acc: 0.8501 Prec: 0.8501 Rec: 0.8501 F1: 0.8501
[Epoch 16] Val Loss: 1.3285 Acc: 0.9200 Prec: 0.9208 Rec: 0.9166 F1: 0.9184


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 17/20: 100%|██████████| 1000/1000 [13:03<00:00,  1.28it/s]


[Epoch 17] Train Loss: 1.8370 Acc: 0.8626 Prec: 0.8627 Rec: 0.8627 F1: 0.8626
[Epoch 17] Val Loss: 1.3108 Acc: 0.9273 Prec: 0.9279 Rec: 0.9244 F1: 0.9259


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 18/20: 100%|██████████| 1000/1000 [13:13<00:00,  1.26it/s]


[Epoch 18] Train Loss: 1.8399 Acc: 0.8602 Prec: 0.8602 Rec: 0.8602 F1: 0.8602
[Epoch 18] Val Loss: 1.3038 Acc: 0.9273 Prec: 0.9296 Rec: 0.9231 F1: 0.9257


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 19/20: 100%|██████████| 1000/1000 [13:13<00:00,  1.26it/s]


[Epoch 19] Train Loss: 1.8066 Acc: 0.8625 Prec: 0.8625 Rec: 0.8625 F1: 0.8625
[Epoch 19] Val Loss: 1.3017 Acc: 0.9296 Prec: 0.9298 Rec: 0.9270 F1: 0.9283
Saved best model at epoch 19 (F1 0.9283)


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Train 20/20: 100%|██████████| 1000/1000 [13:14<00:00,  1.26it/s]


[Epoch 20] Train Loss: 1.8155 Acc: 0.8497 Prec: 0.8498 Rec: 0.8497 F1: 0.8497
[Epoch 20] Val Loss: 1.3001 Acc: 0.9298 Prec: 0.9301 Rec: 0.9271 F1: 0.9285
Saved best model at epoch 20 (F1 0.9285)
Training finished. Best val F1: 0.9284544716860835 at epoch 20


In [27]:
# Test-time augmentation (TTA) helper
# -----------------------------
def tta_predict(model, img_pil, device=device, scales=[224, 288, 320], flip=True):
    model.eval()
    logits_accum = None
    with torch.no_grad():
        for s in scales:
            tf = transforms.Compose([
                transforms.Resize((s, s)),
                transforms.ToTensor(),
                transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
            ])
            x = tf(img_pil).unsqueeze(0).to(device)
            out = model(x)
            logits = out["logits"]
            if flip:
                x_f = torch.flip(x, dims=[3])
                logits_f = model(x_f)["logits"]
                logits = (logits + logits_f) / 2.0
            if logits_accum is None:
                logits_accum = logits
            else:
                logits_accum += logits
    logits_accum /= len(scales)
    return logits_accum

In [28]:
# Grad-CAM helper (very simple)
# -----------------------------
def get_gradcam_heatmap(model, input_tensor, target_class=None, layer_name='backA.features.denseblock4'):
    """
    Very light Grad-CAM: find a conv layer by name, register hook, compute gradients wrt target logit.
    Returns upsampled heatmap (H,W) normalized in [0,1].
    """
    model.eval()
    # find layer
    target_module = None
    for name, module in model.named_modules():
        if name == layer_name:
            target_module = module
            break
    if target_module is None:
        raise RuntimeError("Layer not found for Grad-CAM: " + layer_name)

    activations = []
    gradients = []

    def forward_hook(module, input, output):
        activations.append(output.detach())
    def backward_hook(module, grad_in, grad_out):
        gradients.append(grad_out[0].detach())

    h1 = target_module.register_forward_hook(forward_hook)
    h2 = target_module.register_full_backward_hook(backward_hook)

    out = model(input_tensor)
    logits = out["logits"]
    if target_class is None:
        target_class = logits.argmax(1).item()
    loss = logits[:, target_class].sum()
    model.zero_grad()
    loss.backward(retain_graph=True)

    act = activations[0]  # [B,C,H,W]
    grad = gradients[0]   # [B,C,H,W]
    weights = grad.mean(dim=(2,3), keepdim=True)  # [B,C,1,1]
    cam = (weights * act).sum(dim=1, keepdim=True)  # [B,1,H,W]
    cam = F.relu(cam)
    cam = F.interpolate(cam, size=(input_tensor.size(2), input_tensor.size(3)), mode='bilinear', align_corners=False)
    cam = cam.squeeze().cpu().numpy()
    cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
    h1.remove(); h2.remove()
    return cam