# SINet COD10K Detection
Questo notebook implementa SINet per il rilevamento di oggetti mimetizzati.

In [1]:
import os
import glob
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt


In [None]:
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(torch.cuda.current_device()))
tensor = torch.randn((2,3), device = torch.device('cuda:0'))
print(tensor.device)
device = torch.device('cuda:0')
torch.backends.cudnn.benchmark = True

True
NVIDIA GeForce RTX 3050 Laptop GPU
cuda:0


### Dataset Class

In [3]:
# class CODDataset(Dataset):
#     def __init__(self, img_dir, mask_dir, transform=None):

#         super(CODDataset, self).__init__()
#         self.img_paths = sorted(glob.glob(os.path.join(img_dir, '*.jpg')))
#         self.mask_paths = sorted(glob.glob(os.path.join(mask_dir, '*.png')))
#         self.transform = transform
        
#         assert len(self.img_paths) == len(self.mask_paths), \
#             "Numero di immagini e maschere non coincide!"

#     def __len__(self):
#         return len(self.img_paths)

#     def __getitem__(self, idx):
#         img_path = self.img_paths[idx]
#         mask_path = self.mask_paths[idx]

#         image = Image.open(img_path).convert('RGB')
#         mask = Image.open(mask_path).convert('L') 

#         if self.transform:
#             image = self.transform(image)
#             mask = self.transform(mask)
#         else:
#             image = transforms.ToTensor()(image)
#             mask = transforms.ToTensor()(mask)

#         mask = (mask > 0.5).float()

#         return image, mask
class CODDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        super(CODDataset, self).__init__()
        self.img_paths = sorted(glob.glob(os.path.join(img_dir, '*.jpg')))
        self.mask_paths = sorted(glob.glob(os.path.join(mask_dir, '*.png')))
        self.transform = transform

        assert len(self.img_paths) == len(self.mask_paths), \
            "Numero di immagini e maschere non coincide!"

        # Definiamo una trasformazione di base per il ridimensionamento
        self.resize = transforms.Resize((512, 512))  # Imposta la dimensione a 512x512

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        mask_path = self.mask_paths[idx]

        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')  # Converte la maschera in scala di grigi

        # Applichiamo il ridimensionamento
        image = self.resize(image)
        mask = self.resize(mask)

        # Se ci sono altre trasformazioni, le applichiamo
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        else:
            image = transforms.ToTensor()(image)
            mask = transforms.ToTensor()(mask)

        # Convertiamo la maschera in binaria
        mask = (mask > 0.5).float()

        return image, mask


### Backbone feature extraction

In [4]:
class ResNetBackbone(nn.Module):
    def __init__(self, pretrained=True):
        super(ResNetBackbone, self).__init__()
        resnet = torchvision.models.resnet50(pretrained=pretrained)

        self.stage1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu) 
        self.pool = resnet.maxpool 
        self.stage2 = resnet.layer1  
        self.stage3 = resnet.layer2  
        self.stage4 = resnet.layer3 
        self.stage5 = resnet.layer4  

    def forward(self, x):
        x1 = self.stage1(x)   
        x1p = self.pool(x1)     
        x2 = self.stage2(x1p)  
        x3 = self.stage3(x2)    
        x4 = self.stage4(x3) 
        x5 = self.stage5(x4)  
        return x1, x2, x3, x4, x5

### Search Class

In [5]:
class SearchModule(nn.Module):
    def __init__(self, in_channels_list=[256, 512, 1024]):
        super(SearchModule, self).__init__()
        self.conv_list = nn.ModuleList([
            nn.Conv2d(in_ch, 256, kernel_size=3, padding=1) 
            for in_ch in in_channels_list
        ])
        self.out_conv = nn.Conv2d(256, 1, kernel_size=1)

    def forward(self, x2, x3, x4):
        
        x2_ = self.conv_list[0](x2)            
        x3_ = F.interpolate(self.conv_list[1](x3),
                            size=x2_.shape[2:], mode='bilinear', align_corners=False)
        x4_ = F.interpolate(self.conv_list[2](x4),
                            size=x2_.shape[2:], mode='bilinear', align_corners=False)
        
        fused = x2_ + x3_ + x4_
        coarse_map = self.out_conv(fused)
        coarse_map = torch.sigmoid(coarse_map)
        return coarse_map

### Identification Class

In [6]:
class IdentificationModule(nn.Module):
    def __init__(self, in_channels=2048):
        super(IdentificationModule, self).__init__()
        self.conv_deep = nn.Conv2d(in_channels, 256, kernel_size=3, padding=1)
        self.refine_conv = nn.Conv2d(256+1, 256, kernel_size=3, padding=1) 
        self.out_conv = nn.Conv2d(256, 1, kernel_size=1)

    def forward(self, x5, coarse_map):

        x5_ = self.conv_deep(x5)   
        x5_up = F.interpolate(x5_, scale_factor=8, mode='bilinear', align_corners=False)

        refine_input = torch.cat([x5_up, coarse_map], dim=1) 

        refine_feat = self.refine_conv(refine_input)         

        out_map = self.out_conv(refine_feat)                  
        out_map = torch.sigmoid(out_map)

        return out_map  

### SINet Class

In [7]:
class SINet(nn.Module):
    def __init__(self, backbone_pretrained=True):
        super(SINet, self).__init__()
        self.backbone = ResNetBackbone(pretrained=backbone_pretrained)
   
        self.search = SearchModule(in_channels_list=[256, 512, 1024])
     
        self.identify = IdentificationModule(in_channels=2048)


    def forward(self, x):

        x1, x2, x3, x4, x5 = self.backbone(x)

        coarse_map = self.search(x2, x3, x4)   

        refine_map = self.identify(x5, coarse_map)  

        out_final = F.interpolate(refine_map, scale_factor=4, mode='bilinear', align_corners=False)

        return out_final, coarse_map

### Evaluation Methods

In [8]:
def compute_batch_metrics(pred, target, threshold=0.5):

    pred_bin = (pred >= threshold).float()

    eps = 1e-7
    batch_size = pred.shape[0]

    acc_list, prec_list, rec_list, f1_list, iou_list = [], [], [], [], []

    for i in range(batch_size):
        p = pred_bin[i].view(-1)   
        t = target[i].view(-1)    

        TP = (p * t).sum().item()
        FP = (p * (1 - t)).sum().item()
        FN = ((1 - p) * t).sum().item()
        TN = ((1 - p) * (1 - t)).sum().item()


        acc = (TP + TN) / (TP + TN + FP + FN + eps)
       
        prec = TP / (TP + FP + eps)

        rec = TP / (TP + FN + eps)

        f1 = 2 * prec * rec / (prec + rec + eps)
  
        union = TP + FP + FN
        iou = TP / (union + eps)

        acc_list.append(acc)
        prec_list.append(prec)
        rec_list.append(rec)
        f1_list.append(f1)
        iou_list.append(iou)


    metrics = {
        'accuracy': np.mean(acc_list),
        'precision': np.mean(prec_list),
        'recall': np.mean(rec_list),
        'f1': np.mean(f1_list),
        'iou': np.mean(iou_list)
    }
    return metrics

### Dice Loss

In [9]:
def dice_loss(pred, target, smooth=1.0):

    pred = pred.view(-1)
    target = target.view(-1)
    intersection = (pred * target).sum()
    return 1 - ((2.0 * intersection + smooth) / (pred.sum() + target.sum() + smooth))

### Train Method

In [10]:
def train_one_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0.0
    for images, masks in dataloader:
        images = images.to(device)
        masks = masks.to(device)
        #print('prova', images.device, masks.device)

        optimizer.zero_grad()
        out_final, out_coarse = model(images)

        loss_final = dice_loss(out_final, masks) + F.binary_cross_entropy(out_final, masks)

        loss_coarse = dice_loss(out_coarse, F.interpolate(masks, size=out_coarse.shape[2:], mode='nearest')) \
                      + F.binary_cross_entropy(out_coarse, F.interpolate(masks, size=out_coarse.shape[2:], mode='nearest'))

        loss = loss_final + 0.5 * loss_coarse

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)

def validate_one_epoch(model, dataloader, device):
    model.eval()
    val_loss = 0.0
    all_acc, all_prec, all_rec, all_f1, all_iou = [], [], [], [], []

    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device)

            out_final, out_coarse = model(images)

            loss_final = dice_loss(out_final, masks) + F.binary_cross_entropy(out_final, masks)
            loss_coarse = dice_loss(out_coarse, F.interpolate(masks, size=out_coarse.shape[2:], mode='nearest')) \
                          + F.binary_cross_entropy(out_coarse, F.interpolate(masks, size=out_coarse.shape[2:], mode='nearest'))
            loss = loss_final + 0.5 * loss_coarse
            val_loss += loss.item()
            
            batch_metrics = compute_batch_metrics(out_final, masks, threshold=0.5)
            all_acc.append(batch_metrics['accuracy'])
            all_prec.append(batch_metrics['precision'])
            all_rec.append(batch_metrics['recall'])
            all_f1.append(batch_metrics['f1'])
            all_iou.append(batch_metrics['iou'])

    avg_loss = val_loss / len(dataloader)
    avg_metrics = {
        'accuracy': np.mean(all_acc),
        'precision': np.mean(all_prec),
        'recall': np.mean(all_rec),
        'f1': np.mean(all_f1),
        'iou': np.mean(all_iou)
    }
    return avg_loss, avg_metrics

### Test Method

In [11]:
# def test_model(model, dataloader, device, threshold=0.5):
#     model.eval()
#     all_acc, all_prec, all_rec, all_f1, all_iou = [], [], [], [], []

#     with torch.no_grad():
#         print(len(dataloader))
#         for i, (images, masks) in enumerate(dataloader):
#             print(f"Processing batch {i+1}/{len(dataloader)}...")  # DEBUG
#             images = images.to(device)
#             masks = masks.to(device)

#             out_final, out_coarse = model(images)
#             print("Model inference done.")  # DEBUG

#             batch_metrics = compute_batch_metrics(out_final, masks, threshold=threshold)
#             print("Metrics computed.")  # DEBUG

#             all_acc.append(batch_metrics['accuracy'])
#             all_prec.append(batch_metrics['precision'])
#             all_rec.append(batch_metrics['recall'])
#             all_f1.append(batch_metrics['f1'])
#             all_iou.append(batch_metrics['iou'])

#     avg_metrics = {
#         'accuracy': np.mean(all_acc),
#         'precision': np.mean(all_prec),
#         'recall': np.mean(all_rec),
#         'f1': np.mean(all_f1),
#         'iou': np.mean(all_iou)
#     }
#     return avg_metrics

def test_model(model, dataloader, device, threshold=0.5, num_samples=5):
    model.eval()
    all_acc, all_prec, all_rec, all_f1, all_iou = [], [], [], [], []

    with torch.no_grad():
        print(len(dataloader))
        for i, (images, masks) in enumerate(dataloader):
            print(f"Processing batch {i+1}/{len(dataloader)}...")  # DEBUG
            images = images.to(device)
            masks = masks.to(device)

            out_final, out_coarse = model(images)
            print("Model inference done.")  # DEBUG

            batch_metrics = compute_batch_metrics(out_final, masks, threshold=threshold)
            print("Metrics computed.")  # DEBUG

            all_acc.append(batch_metrics['accuracy'])
            all_prec.append(batch_metrics['precision'])
            all_rec.append(batch_metrics['recall'])
            all_f1.append(batch_metrics['f1'])
            all_iou.append(batch_metrics['iou'])

            # Plot delle prime immagini
            if i < num_samples:
                plot_results(images.cpu(), masks.cpu(), out_final.cpu(), i)

    avg_metrics = {
        'accuracy': np.mean(all_acc),
        'precision': np.mean(all_prec),
        'recall': np.mean(all_rec),
        'f1': np.mean(all_f1),
        'iou': np.mean(all_iou)
    }
    return avg_metrics

def plot_results(images, masks, predictions, batch_idx):
    """
    Mostra le immagini originali, le maschere reali e le previsioni del modello.
    """
    batch_size = images.shape[0]
    fig, axes = plt.subplots(batch_size, 3, figsize=(10, batch_size * 3))

    for i in range(batch_size):
        img = images[i].permute(1, 2, 0).numpy()  # Converti da tensor a numpy
        mask = masks[i].squeeze().numpy()
        pred = (predictions[i].squeeze().numpy() > 0.5).astype(np.uint8)  # Applica threshold

        if batch_size == 1:
            axes[0].imshow(img)
            axes[0].set_title("Immagine originale")
            axes[1].imshow(mask, cmap="gray")
            axes[1].set_title("Maschera reale")
            axes[2].imshow(pred, cmap="gray")
            axes[2].set_title("Previsione modello")
        else:
            axes[i, 0].imshow(img)
            axes[i, 0].set_title(f"Immagine {batch_idx * batch_size + i}")
            axes[i, 1].imshow(mask, cmap="gray")
            axes[i, 1].set_title("Maschera reale")
            axes[i, 2].imshow(pred, cmap="gray")
            axes[i, 2].set_title("Previsione modello")

        for ax in axes[i] if batch_size > 1 else axes:
            ax.axis("off")

    plt.tight_layout()
    plt.show()

### Main

In [None]:
def main():
    batch_size = 16
    num_epochs = 10
    lr = 1e-4

    train_img_dir = 'C:/Users/Nicholas/Desktop/COD10K-v3/Train/Image'
    train_mask_dir = 'C:/Users/Nicholas/Desktop/COD10K-v3/Train/GT_Object'
    val_img_dir = 'C:/Users/Nicholas/Desktop/COD10K-v3/Train/Image'
    val_mask_dir = 'C:/Users/Nicholas/Desktop/COD10K-v3/Train/GT_Object'
    test_img_dir = 'C:/Users/Nicholas/Desktop/COD10K-v3/Test/Image'
    test_mask_dir = 'C:/Users/Nicholas/Desktop/COD10K-v3/Test/GT_Object'

    print(f"Train images: {len(os.listdir(train_img_dir))}")
    print(f"Train masks: {len(os.listdir(train_mask_dir))}")

    train_dataset = CODDataset(train_img_dir, train_mask_dir, transform=None)
    sample = train_dataset[0]
    print(f"Sample keys: {sample.keys()}" if isinstance(sample, dict) else "Sample loaded")
    val_dataset = CODDataset(val_img_dir, val_mask_dir, transform=None)
    test_dataset = CODDataset(test_img_dir, test_mask_dir, transform=None)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

    model = SINet(backbone_pretrained=True).to(device)
    #print(next(model.parameters()).device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        train_loss = train_one_epoch(model, train_loader, optimizer, device)
        val_loss, val_metrics = validate_one_epoch(model, val_loader, device)  # Scompatta la tupla
        print(f"Epoch [{epoch+1}/{num_epochs}] - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    torch.save(model.state_dict(), "sinet_camouflage.pth")
    print("Training completato e modello salvato.")

    test_metrics = test_model(model, test_loader, device, threshold=0.5)
    print("RISULTATI TEST FINALI:")
    print(f"  Accuracy = {test_metrics['accuracy']:.3f}")
    print(f"  Precision = {test_metrics['precision']:.3f}")
    print(f"  Recall = {test_metrics['recall']:.3f}")
    print(f"  F1-score = {test_metrics['f1']:.3f}")
    print(f"  IoU = {test_metrics['iou']:.3f}")

if __name__ == '__main__':
    main()

Train images: 6000
Train masks: 6000
Sample loaded


