In [6]:
import os, math, random
import numpy as np
import pandas as pd
from tqdm import tqdm

In [7]:
from PIL import Image

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision import datasets, transforms, models
import timm
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device:", device)

Device: cpu


In [10]:
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
seed_everything(42)

# Config/hyperparameters

In [11]:
IMG_SIZE = 224
BATCH_SIZE = 16
EPOCHS = 10

train_dir = "/kaggle/input/cod10k/COD10K-v3/Train"
test_dir = "/kaggle/input/cod10k/COD10K-v3/Test"

train_txt = os.path.join(train_dir, "CAM-NonCAM_Instance_Train.txt")
test_txt = os.path.join(test_dir, "CAM-NonCAM_Instance_Test.txt")

In [12]:
class COD10KDataset(Dataset):
    def __init__(self, root_dir, txt_files, transform=None):
        """
        txt_files: list of txt files (e.g. [CAM_train.txt, NonCAM_train.txt])
        """
        self.root_dir = root_dir
        self.img_dir = os.path.join(root_dir, "Image")
        self.transform = transform

        self.samples = []
        for txt_file, label_value in txt_files:
            with open(txt_file, "r") as f:
                for line in f:
                    img_name = line.strip().split()[0]   # first column = filename
                    img_path = os.path.join(self.img_dir, img_name)
                    if os.path.exists(img_path):
                        self.samples.append((img_name, label_value))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_name, label = self.samples[idx]
        img_path = os.path.join(self.img_dir, img_name)
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label


In [13]:
# =============================
# 3. Transforms
# =============================
train_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
val_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

In [14]:
class DenseNetExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        base = models.densenet201(pretrained=pretrained).features
        self.features = base
    def forward(self, x):
        feats = []
        for name, layer in self.features._modules.items():
            x = layer(x)
            if name in ["denseblock1","denseblock2","denseblock3","denseblock4"]:
                feats.append(x)
        return feats

In [15]:
class MobileNetExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        m = models.mobilenet_v3_large(pretrained=pretrained)
        self.features = m.features
    def forward(self, x):
        feats = []
        out = x
        for i, layer in enumerate(self.features):
            out = layer(out)
            if i in (2,5,9,12):  # 4 scales
                feats.append(out)
        if len(feats)<4: feats.append(out)
        return feats

# CBAM-lite (SE + light spatial)

In [16]:
class CBAMlite(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, max(channels//reduction, 4), 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(max(channels//reduction, 4), channels, 1),
            nn.Sigmoid()
        )
        self.spatial = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1, groups=channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, 1, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return x * self.se(x) * self.spatial(x)



In [17]:
class GatedFusion(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.g_fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(dim, max(dim//4,4), 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(max(dim//4,4), dim, 1),
            nn.Sigmoid()
        )
    def forward(self, H, X):
        g = self.g_fc(H)
        return g * H + (1-g) * X



## Backbone extractors

In [19]:
def probe_backbones(img_size=224):
    netA = DenseNetExtractor(True).to(device).eval()
    netB = MobileNetExtractor(True).to(device).eval()
    with torch.no_grad():
        dummy = torch.randn(1,3,img_size,img_size).to(device)
        featsA = netA(dummy)
        featsB = netB(dummy)
    return [f.shape[1] for f in featsA], [f.shape[1] for f in featsB]

dense_chs, mobilenet_chs = probe_backbones()
print("DenseNet channels:", dense_chs)
print("MobileNet channels:", mobilenet_chs)


Downloading: "https://download.pytorch.org/models/densenet201-c1103571.pth" to /root/.cache/torch/hub/checkpoints/densenet201-c1103571.pth
100%|██████████| 77.4M/77.4M [00:00<00:00, 175MB/s] 
Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_large-8738ca79.pth
100%|██████████| 21.1M/21.1M [00:00<00:00, 131MB/s] 


DenseNet channels: [256, 512, 1792, 1920]
MobileNet channels: [24, 40, 80, 112]


## Main AGHCT-DA model

In [20]:
class DynamicFusionModel(nn.Module):
    def __init__(self, num_classes, dense_chs, mobilenet_chs, d=256):
        super().__init__()
        self.backA = DenseNetExtractor(True)
        self.backB = MobileNetExtractor(True)
        L = min(len(dense_chs), len(mobilenet_chs))
        self.L = L

        self.alignA = nn.ModuleList([nn.Conv2d(in_c,d,1) for in_c in dense_chs[:L]])
        self.alignB = nn.ModuleList([nn.Conv2d(in_c,d,1) for in_c in mobilenet_chs[:L]])
        self.cbamA = nn.ModuleList([CBAMlite(d) for _ in range(L)])
        self.cbamB = nn.ModuleList([CBAMlite(d) for _ in range(L)])
        self.gates = nn.ModuleList([GatedFusion(d) for _ in range(L)])

        self.fusion_reduce = nn.Conv2d(d*L, d, 1)
        self.classifier = nn.Sequential(
            nn.Linear(d,512), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(512,128), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(128,num_classes)
        )

    def forward(self, x):
        featsA = self.backA(x)
        featsB = self.backB(x)
        fused_feats=[]
        for fA,convA,cbamA,fB,convB,cbamB,gate in zip(
            featsA[:self.L], self.alignA, self.cbamA,
            featsB[:self.L], self.alignB, self.cbamB,
            self.gates):
            a = cbamA(convA(fA))
            b = cbamB(convB(fB))
            if b.shape[2:] != a.shape[2:]:
                b = F.interpolate(b, size=a.shape[2:], mode='bilinear', align_corners=False)
            fused_feats.append(gate(a,b))
        target = fused_feats[-1]
        upsampled=[F.interpolate(f, size=target.shape[2:], mode='bilinear', align_corners=False)
                   if f.shape[2:]!=target.shape[2:] else f for f in fused_feats]
        concat = torch.cat(upsampled, dim=1)
        fused = self.fusion_reduce(concat)
        z = F.adaptive_avg_pool2d(fused,1).view(fused.shape[0],-1)
        return self.classifier(z)



# training

In [21]:
info_dir = "/kaggle/input/cod10k/COD10K-v3/Info"

train_cam_txt = os.path.join(info_dir, "CAM_train.txt")
train_noncam_txt = os.path.join(info_dir, "NonCAM_train.txt")

test_cam_txt = os.path.join(info_dir, "CAM_test.txt")
test_noncam_txt = os.path.join(info_dir, "NonCAM_test.txt")

# Create datasets
train_ds = COD10KDataset(train_dir, [(train_cam_txt, 1), (train_noncam_txt, 0)], transform=train_tf)
val_ds   = COD10KDataset(test_dir,  [(test_cam_txt, 1),  (test_noncam_txt, 0)],  transform=val_tf)

num_classes = 2
print("Train samples:", len(train_ds), " Test samples:", len(val_ds))

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)


Train samples: 5998  Test samples: 4000


In [None]:
model = DynamicFusionModel(num_classes=num_classes, dense_chs=dense_chs, mobilenet_chs=mobilenet_chs).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(EPOCHS):
    model.train(); total_loss=0
    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        imgs, labels = imgs.to(device), torch.tensor(labels).to(device)
        logits = model(imgs)
        loss = loss_fn(logits, labels)

        optimizer.zero_grad(); loss.backward(); optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1} Train Loss: {total_loss/len(train_loader):.4f}")

    # validation
    model.eval(); correct,total=0,0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), torch.tensor(labels).to(device)
            preds = model(imgs).argmax(dim=1)
            correct += (preds==labels).sum().item()
            total += labels.size(0)
    print(f"Validation Accuracy: {100*correct/total:.2f}%")



  imgs, labels = imgs.to(device), torch.tensor(labels).to(device)
Epoch 1/10: 100%|██████████| 375/375 [18:49<00:00,  3.01s/it]


Epoch 1 Train Loss: 0.4088


  imgs, labels = imgs.to(device), torch.tensor(labels).to(device)


Validation Accuracy: 88.42%


Epoch 2/10: 100%|██████████| 375/375 [19:32<00:00,  3.13s/it]


Epoch 2 Train Loss: 0.3072
Validation Accuracy: 88.92%


Epoch 3/10: 100%|██████████| 375/375 [19:30<00:00,  3.12s/it]


Epoch 3 Train Loss: 0.2623
Validation Accuracy: 89.75%


Epoch 4/10: 100%|██████████| 375/375 [20:19<00:00,  3.25s/it]


Epoch 4 Train Loss: 0.2435
Validation Accuracy: 89.60%


Epoch 5/10: 100%|██████████| 375/375 [20:57<00:00,  3.35s/it]


Epoch 5 Train Loss: 0.2138
Validation Accuracy: 89.38%


Epoch 6/10:   3%|▎         | 13/375 [01:32<21:04,  3.49s/it] 