In [1]:
!pip install timm pycocotools torchmetrics



In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset
import timm
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.ops import MultiScaleRoIAlign
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from tqdm import tqdm
import math
import xml.etree.ElementTree as ET
from PIL import Image
import numpy as np

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')



In [3]:
def get_transforms(train):
    transforms = []
    transforms.append(T.Resize((600, 600)))
    transforms.append(T.PILToTensor())
    transforms.append(T.ConvertImageDtype(torch.float))
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

In [4]:
class VOCDataset(Dataset):
    def __init__(self, root, split, transforms=None):
        self.root = root
        self.split = split
        self.transforms = transforms
        
        self.img_dir = os.path.join(root, 'images', split)
        self.label_dir = os.path.join(root, 'labels', split)
        
        self.image_files = sorted([f for f in os.listdir(self.img_dir) if f.endswith('.jpg')])
        
        self.voc_classes = [
            "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", 
            "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", 
            "pottedplant", "sheep", "sofa", "train", "tvmonitor"
        ]
        
        self.target_classes = ["person", "dog", "cat", "tvmonitor", "bird"]
        
        self.class_to_idx = {cls: i + 1 for i, cls in enumerate(self.target_classes)}
        
        self.voc_to_model_map = {}
        for idx, cls_name in enumerate(self.voc_classes):
            if cls_name in self.target_classes:
                self.voc_to_model_map[idx] = self.class_to_idx[cls_name]

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.img_dir, img_name)
        label_path = os.path.join(self.label_dir, img_name.replace('.jpg', '.txt'))
        
        img = Image.open(img_path).convert("RGB")
        w, h = img.size
        
        boxes = []
        labels = []
        
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f:
                    data = line.strip().split()
                    if len(data) < 5: continue
                    
                    cls_idx = int(data[0])
                    
                    if cls_idx in self.voc_to_model_map:
                        cx, cy, bw, bh = map(float, data[1:5])
                        
                        xmin = (cx - bw/2) * w
                        ymin = (cy - bh/2) * h
                        xmax = (cx + bw/2) * w
                        ymax = (cy + bh/2) * h
                        
                        boxes.append([xmin, ymin, xmax, ymax])
                        labels.append(self.voc_to_model_map[cls_idx])

        if len(boxes) > 0:
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
            area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        else:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)
            area = torch.zeros((0,), dtype=torch.float32)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = torch.tensor([idx])
        target["area"] = area
        target["iscrowd"] = torch.zeros((len(labels),), dtype=torch.int64)

        if self.transforms is not None:
            img = self.transforms(img)

        return img, target

    def __len__(self):
        return len(self.image_files)

def collate_fn(batch):
    return tuple(zip(*batch))

In [5]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.stride = stride
        
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        
        if self.downsample is not None:
            identity = self.downsample(x)
            
        out += identity
        out = self.relu(out)
        return out

class CCTTBackbone(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.stem = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(96),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        self.c_stage1 = self._make_c_layer(96, 96, 2)
        self.c_stage2 = self._make_c_layer(96, 192, 2, stride=2)
        
        swin_config = timm.models.swin_transformer.SwinTransformer(
            img_size=224, patch_size=4, in_chans=3, num_classes=0,
            embed_dim=96, depths=[2, 2, 9, 3], num_heads=[3, 6, 12, 24],
            window_size=7
        )
        
        self.t_stage3 = swin_config.layers[2]
        self.t_stage4 = swin_config.layers[3]
        
        self.out_channels = 768

    def _make_c_layer(self, in_c, out_c, blocks, stride=1):
        layers = []
        layers.append(ConvBlock(in_c, out_c, stride))
        for _ in range(1, blocks):
            layers.append(ConvBlock(out_c, out_c))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.stem(x)        
        c1 = self.c_stage1(x)   
        c2 = self.c_stage2(c1)  
        
        x_in = c2.permute(0, 2, 3, 1) 
        
        t3 = self.t_stage3(x_in) 
        t4 = self.t_stage4(t3)   
        
        t4_out = t4.permute(0, 3, 1, 2)
        
        return {
            "0": t4_out 
        }

In [6]:
def get_cctt_model(num_classes):
    backbone = CCTTBackbone()
    backbone.out_channels = 768
    
    anchor_generator = AnchorGenerator(
        sizes=((32, 64, 128, 256, 512),),
        aspect_ratios=((0.5, 1.0, 2.0),)
    )
    
    roi_pooler = MultiScaleRoIAlign(
        featmap_names=['0'],
        output_size=7,
        sampling_ratio=2
    )
    
    model = FasterRCNN(
        backbone,
        num_classes=num_classes,
        rpn_anchor_generator=anchor_generator,
        box_roi_pool_init_fn=None,
        box_roi_pool=roi_pooler
    )
    return model

In [7]:
def gradient_calibration(model, dataloader, device, alpha=0.25):
    model.train()
    images, targets = next(iter(dataloader))
    images = list(image.to(device) for image in images)
    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

    model.zero_grad()
    
    loss_dict = model(images, targets)
    losses = sum(loss for loss in loss_dict.values())
    losses.backward()

    layer_norms = []
    layers = []
    
    for name, param in model.named_parameters():
        if param.grad is not None and param.requires_grad:
            norm = torch.norm(param.grad)
            layer_norms.append(norm.item())
            layers.append(param)
    
    if len(layer_norms) == 0:
        return

    c_tilde = np.exp(np.mean(np.log(np.array(layer_norms) + 1e-8)))
    
    with torch.no_grad():
        for param, norm in zip(layers, layer_norms):
            if norm > 0:
                r_k = (norm / c_tilde) ** alpha
                param.data.mul_(r_k)
    
    model.zero_grad()

In [8]:
def train_one_epoch(model, optimizer, data_loader, device, epoch):
    model.train()
    loss_total = 0
    
    pbar = tqdm(data_loader, desc=f"Epoch {epoch}")
    for images, targets in pbar:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        loss_total += losses.item()
        pbar.set_postfix({'Loss': losses.item()})
        
    return loss_total / len(data_loader)

@torch.no_grad()
def evaluate(model, data_loader, device):
    model.eval()
    metric = MeanAveragePrecision()
    
    for images, targets in tqdm(data_loader, desc="Evaluating"):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        outputs = model(images)
        metric.update(outputs, targets)
        
    return metric.compute()

In [9]:
root_dir = "/kaggle/input/pascal-voc-2012/VOC2012"

if not os.path.exists(root_dir):
    raise FileNotFoundError(f"Could not find dataset at {root_dir}")

dataset_train = VOCDataset(root_dir, "train", get_transforms(True))
dataset_val = VOCDataset(root_dir, "val", get_transforms(False))

data_loader_train = DataLoader(
    dataset_train, 
    batch_size=2, 
    shuffle=True, 
    num_workers=2, 
    collate_fn=collate_fn
)

data_loader_val = DataLoader(
    dataset_val, 
    batch_size=2, 
    shuffle=False, 
    num_workers=2, 
    collate_fn=collate_fn
)

print(f"Training samples: {len(dataset_train)}")
print(f"Validation samples: {len(dataset_val)}")

Training samples: 5717
Validation samples: 5823


In [10]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = get_cctt_model(num_classes=6)
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.AdamW(params, lr=0.0001, weight_decay=0.05)
scaler = torch.cuda.amp.GradScaler()

gradient_calibration(model, data_loader_train, device)

num_epochs = 12

for epoch in range(num_epochs):
    model.train()
    loss_total = 0
    pbar = tqdm(data_loader_train, desc=f"Epoch {epoch}")
    
    for images, targets in pbar:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        optimizer.zero_grad()
        
        with torch.cuda.amp.autocast():
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())

        scaler.scale(losses).backward()
        scaler.step(optimizer)
        scaler.update()

        loss_total += losses.item()
        pbar.set_postfix({'Loss': losses.item()})
        
    avg_loss = loss_total / len(data_loader_train)
    print(f"Epoch {epoch} Loss: {avg_loss:.4f}")
    
    checkpoint_path = f"cctt_voc_checkpoint_epoch_{epoch}.pth"
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss,
    }, checkpoint_path)
    
    if (epoch + 1) % 4 == 0:
        mAP = evaluate(model, data_loader_val, device)
        print(f"mAP: {mAP['map']:.4f}")

torch.save(model.state_dict(), "cctt_voc_final_model.pth")

  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
Epoch 0: 100%|██████████| 2859/2859 [11:32<00:00,  4.13it/s, Loss=1.21]


Epoch 0 Loss: 0.2890


Epoch 1: 100%|██████████| 2859/2859 [11:35<00:00,  4.11it/s, Loss=0.702]


Epoch 1 Loss: 0.2438


Epoch 2: 100%|██████████| 2859/2859 [11:35<00:00,  4.11it/s, Loss=0.0352]


Epoch 2 Loss: 0.2384


Epoch 3: 100%|██████████| 2859/2859 [11:35<00:00,  4.11it/s, Loss=0.5]


Epoch 3 Loss: 0.2406


Evaluating: 100%|██████████| 2912/2912 [07:38<00:00,  6.35it/s]


mAP: 0.0057


Epoch 4: 100%|██████████| 2859/2859 [11:36<00:00,  4.11it/s, Loss=0.37]


Epoch 4 Loss: 0.2369


Epoch 5: 100%|██████████| 2859/2859 [11:36<00:00,  4.11it/s, Loss=0.0882]


Epoch 5 Loss: 0.2358


Epoch 6: 100%|██████████| 2859/2859 [11:36<00:00,  4.11it/s, Loss=0.365]


Epoch 6 Loss: 0.2347


Epoch 7: 100%|██████████| 2859/2859 [11:37<00:00,  4.10it/s, Loss=0.0651]


Epoch 7 Loss: 0.2336


Evaluating: 100%|██████████| 2912/2912 [08:01<00:00,  6.04it/s]


mAP: 0.0125


Epoch 8: 100%|██████████| 2859/2859 [11:38<00:00,  4.09it/s, Loss=0.255]


Epoch 8 Loss: 0.2320


Epoch 9: 100%|██████████| 2859/2859 [11:38<00:00,  4.09it/s, Loss=0.414]


Epoch 9 Loss: 0.2269


Epoch 10: 100%|██████████| 2859/2859 [11:38<00:00,  4.09it/s, Loss=0.0578]


Epoch 10 Loss: 0.2244


Epoch 11: 100%|██████████| 2859/2859 [11:39<00:00,  4.09it/s, Loss=0.288]


Epoch 11 Loss: 0.2225


Evaluating: 100%|██████████| 2912/2912 [07:52<00:00,  6.16it/s]


mAP: 0.0183
