In [1]:
import os, random, math, time
from pathlib import Path
from tqdm import tqdm
import numpy as np
from PIL import Image

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.transforms import RandAugment
import timm

In [6]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if device == "cuda": torch.cuda.manual_seed_all(SEED)

Device: cuda


In [8]:
IMG_SIZE = 224
BATCH_SIZE = 8          
EPOCHS = 10
NUM_WORKERS = 0     
LR = 3e-4
LABEL_SMOOTH = 0.1
USE_SEGMENTATION = True

In [9]:
info_dir  = "/kaggle/input/cod10k/COD10K-v3/Info"
train_dir = "/kaggle/input/cod10k/COD10K-v3/Train"
test_dir  = "/kaggle/input/cod10k/COD10K-v3/Test"

# these exist in Info/
train_cam_txt = os.path.join(info_dir, "CAM_train.txt")
train_noncam_txt = os.path.join(info_dir, "NonCAM_train.txt")
test_cam_txt = os.path.join(info_dir, "CAM_test.txt")
test_noncam_txt = os.path.join(info_dir, "NonCAM_test.txt")

In [10]:
def merge_txts(txt_list, root_dir, transform, use_masks=True):
    samples = []
    for txt_file, label in txt_list:
        with open(txt_file, "r") as f:
            for line in f:
                fname = line.strip().split()[0]
                samples.append((fname, label))
    return COD10KDataset(root_dir, txt_file=None, transform=transform, use_masks=use_masks)

Noise + Transform

In [11]:
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=0.05):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        noise = torch.randn(tensor.size()) * self.std + self.mean
        noisy_tensor = tensor + noise
        return torch.clamp(noisy_tensor, 0., 1.)

    def __repr__(self):
        return self.__class__.__name__ + f'(mean={self.mean}, std={self.std})'

train_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.ToTensor(),
    AddGaussianNoise(0., 0.05),         # noise added after converting to tensor
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

val_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

strong_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ColorJitter(0.4,0.4,0.4,0.1),
    RandAugment(num_ops=2, magnitude=9),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])


In [12]:
class COD10KDataset(Dataset):
    """
    Parses lines of the form:
      <filename> <label>
    or where label missing, infers from filename containing 'CAM' or 'NonCAM'.
    Loads optional masks from GT_Object/<basename>.png if present.
    """
    def __init__(self, root_dir, txt_file, transform=None, use_masks=True):
        self.root_dir = root_dir
        self.transform = transform
        self.use_masks = use_masks
        self.samples = []

        if not os.path.exists(txt_file):
            raise RuntimeError(f"TXT file not found: {txt_file}")

        with open(txt_file, "r") as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) == 0:
                    continue
                if len(parts) >= 2:
                    fname = parts[0]
                    try:
                        lbl = int(parts[1])
                    except:
                        # fallback: try to parse last token as int
                        lbl = 1 if "CAM" in fname else 0
                else:
                    fname = parts[0]
                    # CORRECTED: use parts[0] to decide label, don't refer to undefined variable
                    lbl = 1 if "CAM" in parts[0] else 0
                img_path = os.path.join(self.root_dir, "Image", fname)
                if os.path.exists(img_path):
                    self.samples.append((fname, lbl))
                else:
                    # Skip missing images (print once or collect for user)
                    pass

        if len(self.samples) == 0:
            raise RuntimeError(f"No samples found. Check {txt_file} and the Image folder under {root_dir}.")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        fname, lbl = self.samples[idx]
        img_path = os.path.join(self.root_dir, "Image", fname)
        img = Image.open(img_path).convert("RGB")

        if self.transform:
            x = self.transform(img)
        else:
            x = transforms.ToTensor()(img)

        mask = None
        if self.use_masks:
            mask_name = os.path.splitext(fname)[0] + ".png"
            mask_path = os.path.join(self.root_dir, "GT_Object", mask_name)
            if os.path.exists(mask_path):
                m = Image.open(mask_path).convert("L").resize((IMG_SIZE, IMG_SIZE))
                m = np.array(m).astype(np.float32) / 255.0
                mask = torch.from_numpy((m > 0.5).astype(np.float32)).unsqueeze(0)

        return x, lbl, mask

In [13]:
train_ds = COD10KDataset(train_dir, train_cam_txt, transform=train_tf, use_masks=USE_SEGMENTATION)
val_ds   = COD10KDataset(test_dir,  test_cam_txt,  transform=val_tf,   use_masks=USE_SEGMENTATION)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
print("Train samples:", len(train_ds), "Val samples:", len(val_ds))

Train samples: 3038 Val samples: 2026


## Backbones

In [14]:
class DenseNetExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = models.densenet201(pretrained=pretrained).features
    def forward(self, x):
        feats = []
        for name, layer in self.features._modules.items():
            x = layer(x)
            if name in ["denseblock1","denseblock2","denseblock3","denseblock4"]:
                feats.append(x)
        return feats

class MobileNetExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = models.mobilenet_v3_large(pretrained=pretrained).features
    def forward(self, x):
        feats = []
        out = x
        for i, layer in enumerate(self.features):
            out = layer(out)
            if i in (2,5,9,12):
                feats.append(out)
        if len(feats) < 4:
            feats.append(out)
        return feats

In [15]:
class SwinExtractor(nn.Module):
    def __init__(self, model_name="swin_tiny_patch4_window7_224", pretrained=True):
        super().__init__()
        # timm features_only model returns list of feature maps
        self.model = timm.create_model(model_name, pretrained=pretrained, features_only=True)
    def forward(self, x):
        return self.model(x)

In [16]:
class CBAMlite(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, max(channels//reduction,4), 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(max(channels//reduction,4), channels, 1),
            nn.Sigmoid()
        )
        self.spatial = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1, groups=channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, 1, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return x * self.se(x) * self.spatial(x)

In [17]:
class GatedFusion(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.g_fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(dim, max(dim//4, 4), 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(max(dim//4,4), dim, 1),
            nn.Sigmoid()
        )
    def forward(self, H, X):
        # ensure same spatial
        if H.shape[2:] != X.shape[2:]:
            X = F.interpolate(X, size=H.shape[2:], mode='bilinear', align_corners=False)
        g = self.g_fc(H)
        return g * H + (1 - g) * X


In [18]:
class CrossAttention(nn.Module):
    def __init__(self, d_cnn, d_swin, d_out):
        super().__init__()
        self.q = nn.Linear(d_cnn, d_out)
        self.k = nn.Linear(d_swin, d_out)
        self.v = nn.Linear(d_swin, d_out)
        self.scale = d_out ** -0.5
    def forward(self, feat_cnn, feat_swin):
        # feat_cnn: [B, Cc, H, W] -> [B, Nq, Cc]
        # feat_swin: [B, Cs, Hs, Ws] -> [B, Ns, Cs]
        B, Cc, H, W = feat_cnn.shape
        q = feat_cnn.permute(0,2,3,1).reshape(B, H*W, Cc)  # [B, Nq, Cc]
        if feat_swin.dim() == 4:
            Bs, Cs, Hs, Ws = feat_swin.shape
            kv = feat_swin.permute(0,2,3,1).reshape(Bs, Hs*Ws, Cs)  # [B, Ns, Cs]
        else:
            kv = feat_swin
        K = self.k(kv)    # [B, Ns, d_out]
        V = self.v(kv)
        Q = self.q(q)     # [B, Nq, d_out]
        attn = torch.matmul(Q, K.transpose(-2, -1)) * self.scale  # [B, Nq, Ns]
        attn = attn.softmax(dim=-1)
        out = torch.matmul(attn, V)  # [B, Nq, d_out]
        out = out.reshape(B, H, W, -1).permute(0,3,1,2)  # [B, d_out, H, W]
        return out

## Segmentation Decoder

In [19]:
class SegDecoder(nn.Module):
    def __init__(self, in_channels_list, mid_channels=128):
        super().__init__()
        self.projs = nn.ModuleList([nn.Conv2d(c, mid_channels, 1) for c in in_channels_list])
        self.conv = nn.Sequential(nn.Conv2d(mid_channels * len(in_channels_list), mid_channels, 3, padding=1), nn.ReLU(inplace=True))
        self.out = nn.Conv2d(mid_channels, 1, 1)
    def forward(self, feat_list):
        target_size = feat_list[0].shape[2:]
        ups = []
        for f, p in zip(feat_list, self.projs):
            x = p(f)
            if x.shape[2:] != target_size:
                x = F.interpolate(x, size=target_size, mode='bilinear', align_corners=False)
            ups.append(x)
        x = torch.cat(ups, dim=1)
        x = self.conv(x)
        x = self.out(x)
        return x  # [B,1,H,W]

## Probing Backbones

In [20]:
dnet = DenseNetExtractor().to(device).eval()
mnet = MobileNetExtractor().to(device).eval()
snet = SwinExtractor().to(device).eval()
with torch.no_grad():
    dummy = torch.randn(1,3,IMG_SIZE,IMG_SIZE).to(device)
    featsA = dnet(dummy)
    featsB = mnet(dummy)
    featsS = snet(dummy)
chA = [f.shape[1] for f in featsA]
chB = [f.shape[1] for f in featsB]
chS = [f.shape[1] for f in featsS]
print("DenseNet channels:", chA)
print("MobileNet channels:", chB)
print("Swin channels:", chS)

Downloading: "https://download.pytorch.org/models/densenet201-c1103571.pth" to /root/.cache/torch/hub/checkpoints/densenet201-c1103571.pth
100%|██████████| 77.4M/77.4M [00:00<00:00, 207MB/s]
Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_large-8738ca79.pth
100%|██████████| 21.1M/21.1M [00:00<00:00, 133MB/s] 


model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

DenseNet channels: [256, 512, 1792, 1920]
MobileNet channels: [24, 40, 80, 112]
Swin channels: [56, 28, 14, 7]


# Fusion Model (DenseNet + MobileNet + Swin cross attention)

In [21]:
class FusionWithSwin(nn.Module):
    def __init__(self, dense_chs, mobile_chs, swin_chs, d=256, use_seg=True):
        super().__init__()
        self.backA = DenseNetExtractor()
        self.backB = MobileNetExtractor()
        self.backS = SwinExtractor()
        L = min(len(dense_chs), len(mobile_chs), len(swin_chs))
        self.L = L
        self.d = d
        self.alignA = nn.ModuleList([nn.Conv2d(c, d, 1) for c in dense_chs[:L]])
        self.alignB = nn.ModuleList([nn.Conv2d(c, d, 1) for c in mobile_chs[:L]])
        self.cbamA = nn.ModuleList([CBAMlite(d) for _ in range(L)])
        self.cbamB = nn.ModuleList([CBAMlite(d) for _ in range(L)])
        self.gates = nn.ModuleList([GatedFusion(d) for _ in range(L)])
        self.cross_atts = nn.ModuleList([CrossAttention(d, swin_chs[i], d) for i in range(L)])
        self.reduce = nn.Conv2d(d * L, d, 1)
        self.classifier = nn.Sequential(
            nn.Linear(d, 512), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(512, 128), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(128, 2)
        )
        self.use_seg = use_seg
        if self.use_seg:
            self.segdecoder = SegDecoder([d] * L, mid_channels=128)
    def forward(self, x):
        fa = self.backA(x)
        fb = self.backB(x)
        fs = self.backS(x)  # timm returns list [B, C, H, W]
        fused_feats = []
        aligned_for_dec = []
        for i in range(self.L):
            a = self.alignA[i](fa[i])
            a = self.cbamA[i](a)
            b = self.alignB[i](fb[i])
            b = self.cbamB[i](b)
            if b.shape[2:] != a.shape[2:]:
                b = F.interpolate(b, size=a.shape[2:], mode='bilinear', align_corners=False)
            fused = self.gates[i](a, b)
            swin_feat = fs[i]
            swin_att = self.cross_atts[i](fused, swin_feat)
            if swin_att.shape[2:] != fused.shape[2:]:
                swin_att = F.interpolate(swin_att, size=fused.shape[2:], mode='bilinear', align_corners=False)
            fused = fused + swin_att
            fused_feats.append(fused)
            aligned_for_dec.append(fused)
        # fuse multi-scale
        target = fused_feats[-1]
        upsampled = [F.interpolate(f, size=target.shape[2:], mode='bilinear', align_corners=False) if f.shape[2:] != target.shape[2:] else f for f in fused_feats]
        concat = torch.cat(upsampled, dim=1)
        fused = self.reduce(concat)
        z = F.adaptive_avg_pool2d(fused, (1,1)).view(fused.size(0), -1)
        logits = self.classifier(z)
        out = {"logits": logits, "feat": z}
        if self.use_seg:
            out["seg"] = self.segdecoder(aligned_for_dec)
        return out

# instantiate model
model = FusionWithSwin(dense_chs=chA, mobile_chs=chB, swin_chs=chS, d=256, use_seg=USE_SEGMENTATION).to(device)
print("Model parameters (M):", sum(p.numel() for p in model.parameters())/1e6)


Model parameters (M): 51.520309


In [22]:
class LabelSmoothingCE(nn.Module):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.s = smoothing
    def forward(self, logits, target):
        c = logits.size(-1)
        logp = F.log_softmax(logits, dim=-1)
        with torch.no_grad():
            true_dist = torch.zeros_like(logp)
            true_dist.fill_(self.s / (c - 1))
            true_dist.scatter_(1, target.unsqueeze(1), 1.0 - self.s)
        return (-true_dist * logp).sum(dim=-1).mean()

def dice_loss_logits(pred_logits, target):
    pred = torch.sigmoid(pred_logits)
    target = target.float()
    inter = (pred * target).sum(dim=(1,2,3))
    denom = pred.sum(dim=(1,2,3)) + target.sum(dim=(1,2,3))
    dice = (2 * inter + 1e-6) / (denom + 1e-6)
    return 1.0 - dice.mean()

clf_loss_fn = LabelSmoothingCE(LABEL_SMOOTH)
seg_bce = nn.BCEWithLogitsLoss()

# optimizer (single group for simplicity)
opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)

def dice_loss(pred, target, smooth=1.0):
    pred = torch.sigmoid(pred)
    num = 2 * (pred * target).sum() + smooth
    den = pred.sum() + target.sum() + smooth
    return 1 - (num / den)

def seg_loss_fn(pred, mask):
    # Upsample predictions to GT mask size (224×224)
    if pred.shape[-2:] != mask.shape[-2:]:
        pred = F.interpolate(pred, size=mask.shape[-2:], mode="bilinear", align_corners=False)
    return F.binary_cross_entropy_with_logits(pred, mask) + dice_loss(pred, mask)



## Training

In [None]:
best_f1 = 0.0
for epoch in range(1, EPOCHS+1):
    # -------------------
    # TRAIN
    # -------------------
    model.train()
    running_loss = 0
    y_true, y_pred = [], []

    for imgs, labels, masks in tqdm(train_loader, desc=f"Train {epoch}/{EPOCHS}"):
        imgs, labels = imgs.to(device), labels.to(device)

        out = model(imgs)
        loss = clf_loss_fn(out["logits"], labels)
        if masks is not None:
            masks = masks.to(device)
            loss += seg_loss_fn(out["seg"], masks)

        opt.zero_grad()
        loss.backward()
        opt.step()

        running_loss += loss.item()
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(out["logits"].argmax(1).cpu().numpy())
     # metrics
    acc = accuracy_score(y_true, y_pred)
    prec, rec, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average="macro", zero_division=0
    )

    print(f"[Epoch {epoch}] Train Loss: {running_loss/len(train_loader):.4f} "
      f"Acc: {acc:.4f} Prec: {prec:.4f} Rec: {rec:.4f} F1: {f1:.4f}")

    # VALIDATION
    model.eval()
    val_y_true, val_y_pred = [], []
    val_loss = 0
    with torch.no_grad():
        for imgs, labels, masks in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            out = model(imgs)

            loss = clf_loss_fn(out["logits"], labels)
            if masks is not None:
                masks = masks.to(device)
                loss += seg_loss_fn(out["seg"], masks)
            val_loss += loss.item()

            val_y_true.extend(labels.cpu().numpy())
            val_y_pred.extend(out["logits"].argmax(1).cpu().numpy())

    vacc = accuracy_score(val_y_true, val_y_pred)
    vprec, vrec, vf1, _ = precision_recall_fscore_support(val_y_true, val_y_pred, average="macro", zero_division=0)
    print(f"[Epoch {epoch}] Val Loss: {val_loss/len(val_loader):.4f} "
          f"Acc: {vacc:.4f} Prec: {vprec:.4f} Rec: {vrec:.4f} F1: {vf1:.4f}")

Train 1/10: 100%|██████████| 380/380 [05:58<00:00,  1.06it/s]


[Epoch 1] Train Loss: 1.3524 Acc: 0.9984 Prec: 0.5000 Rec: 0.4992 F1: 0.4996
[Epoch 1] Val Loss: 1.2867 Acc: 1.0000 Prec: 1.0000 Rec: 1.0000 F1: 1.0000


Train 2/10: 100%|██████████| 380/380 [05:04<00:00,  1.25it/s]


[Epoch 2] Train Loss: 1.2338 Acc: 1.0000 Prec: 1.0000 Rec: 1.0000 F1: 1.0000
[Epoch 2] Val Loss: 1.2286 Acc: 1.0000 Prec: 1.0000 Rec: 1.0000 F1: 1.0000


Train 3/10: 100%|██████████| 380/380 [05:04<00:00,  1.25it/s]


[Epoch 3] Train Loss: 1.2070 Acc: 1.0000 Prec: 1.0000 Rec: 1.0000 F1: 1.0000
[Epoch 3] Val Loss: 1.1754 Acc: 1.0000 Prec: 1.0000 Rec: 1.0000 F1: 1.0000


Train 4/10: 100%|██████████| 380/380 [05:02<00:00,  1.26it/s]


[Epoch 4] Train Loss: 1.1912 Acc: 1.0000 Prec: 1.0000 Rec: 1.0000 F1: 1.0000
[Epoch 4] Val Loss: 1.2142 Acc: 1.0000 Prec: 1.0000 Rec: 1.0000 F1: 1.0000


Train 5/10: 100%|██████████| 380/380 [05:02<00:00,  1.26it/s]


[Epoch 5] Train Loss: 1.1860 Acc: 1.0000 Prec: 1.0000 Rec: 1.0000 F1: 1.0000
[Epoch 5] Val Loss: 1.1775 Acc: 1.0000 Prec: 1.0000 Rec: 1.0000 F1: 1.0000


Train 6/10: 100%|██████████| 380/380 [06:03<00:00,  1.05it/s]


[Epoch 6] Train Loss: 1.1725 Acc: 1.0000 Prec: 1.0000 Rec: 1.0000 F1: 1.0000
[Epoch 6] Val Loss: 1.2277 Acc: 1.0000 Prec: 1.0000 Rec: 1.0000 F1: 1.0000


Train 7/10: 100%|██████████| 380/380 [05:03<00:00,  1.25it/s]


[Epoch 7] Train Loss: 1.1597 Acc: 1.0000 Prec: 1.0000 Rec: 1.0000 F1: 1.0000
