# SINet COD10K Detection
Questo notebook implementa SINet per il rilevamento di oggetti mimetizzati.

In [None]:
%pip install -qqq pytorch_wavelets

In [2]:
import os
import glob
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import shutil
from sklearn.metrics import jaccard_score, f1_score, precision_score, recall_score
from skimage.metrics import structural_similarity as ssim
import pywt
from pytorch_wavelets import DWTForward
import time
from datetime import timedelta


In [3]:
# print(torch.cuda.is_available())
# print(torch.cuda.get_device_name(torch.cuda.current_device()))
# tensor = torch.randn((2,3), device = torch.device('cuda:0'))
# print(tensor.device)
# device = torch.device('cuda:0')
# torch.backends.cudnn.benchmark = True

# device = torch.device('mps')

device = torch.device("mps") if torch.backends.mps.is_available() else \
         torch.device("cuda") if torch.cuda.is_available() else \
         torch.device("cpu")

### Dataset Class

In [4]:

class CODDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        super(CODDataset, self).__init__()
        self.img_paths = sorted(glob.glob(os.path.join(img_dir, '*.jpg')))
        self.mask_paths = sorted(glob.glob(os.path.join(mask_dir, '*.png')))
        self.transform = transform

        assert len(self.img_paths) == len(self.mask_paths), \
            "Numero di immagini e maschere non coincide!"

        # Definiamo una trasformazione di base per il ridimensionamento
        self.resize = transforms.Resize((224, 224)) 

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        mask_path = self.mask_paths[idx]

        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')  # Converte la maschera in scala di grigi

        # Applichiamo il ridimensionamento
        image = self.resize(image)
        mask = self.resize(mask)

        # Se ci sono altre trasformazioni, le applichiamo
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        else:
            image = transforms.ToTensor()(image)
            mask = transforms.ToTensor()(mask)

        # Convertiamo la maschera in binaria
        mask = (mask > 0.5).float()

        return image, mask


### Backbone feature extraction

In [5]:
class ResNetBackbone(nn.Module):
    def __init__(self, pretrained=True):
        super(ResNetBackbone, self).__init__()
        resnet = torchvision.models.resnet50(weights="ResNet50_Weights.DEFAULT")

        self.stage1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)
        self.pool = resnet.maxpool
        self.stage2 = resnet.layer1
        self.stage3 = resnet.layer2
        self.stage4 = resnet.layer3
        self.stage5 = resnet.layer4

        # Inizializzazione del DWT con wavelet Harr ('db2')
        self.dwt = DWTForward(J=1, mode='zero', wave='db2').to(device)

    def forward(self, x):
        x1 = self.stage1(x)
        x1p = self.pool(x1)
        x2 = self.stage2(x1p)
        x3 = self.stage3(x2)
        x4 = self.stage4(x3)
        x5 = self.stage5(x4)

        def apply_wavelet(feature_map, original_size):
            """Applica la trasformata wavelet e riporta la dimensione spaziale all'originale."""
            feature_map = feature_map.to(device)  # Assicura che sia sul giusto device
            Yl, Yh = self.dwt(feature_map)

            # Separiamo i coefficienti LH, HL, HH
            cH, cV, cD = torch.chunk(Yh[0], chunks=3, dim=2)  # LH, HL, HH

            # Assicuriamoci che Yl abbia la stessa dimensione di output
            Yl = F.interpolate(Yl, size=original_size, mode='bilinear', align_corners=False)
            
            # Upsample per mantenere la stessa dimensione dell'input originale
            cH = F.interpolate(cH.squeeze(2), size=original_size, mode='bilinear', align_corners=False)
            cV = F.interpolate(cV.squeeze(2), size=original_size, mode='bilinear', align_corners=False)
            cD = F.interpolate(cD.squeeze(2), size=original_size, mode='bilinear', align_corners=False)

            # Concatenazione di tutte le componenti wavelet
            wavelet_features = torch.cat([Yl, cH, cV, cD], dim=1)
            return wavelet_features

        # Applicazione della wavelet con ridimensionamento
        x2_freq = apply_wavelet(x2, x2.shape[2:])
        x3_freq = apply_wavelet(x3, x3.shape[2:])
        x4_freq = apply_wavelet(x4, x4.shape[2:])
        x5_freq = apply_wavelet(x5, x5.shape[2:])

        return x1, x2, x3, x4, x5, x2_freq, x3_freq, x4_freq, x5_freq

### Search Class

In [6]:
class SearchModule(nn.Module):
    def __init__(self, in_channels_list=[256, 512, 1024]):
        super(SearchModule, self).__init__()

        self.freq_compression = nn.ModuleList([
            nn.Conv2d(4 * in_ch, in_ch, kernel_size=1)  # Invece di ridurre a 4 canali, manteniamo più informazioni
            for in_ch in in_channels_list
        ])

        self.conv_list = nn.ModuleList([
            nn.Conv2d(2 * in_ch, 256, kernel_size=3, padding=1)  # Integriamo meglio le feature wavelet
            for in_ch in in_channels_list
        ])

        self.out_conv = nn.Conv2d(256, 1, kernel_size=1)

    def forward(self, x2, x3, x4, x2_freq, x3_freq, x4_freq):
        x2_freq = self.freq_compression[0](x2_freq)
        x3_freq = self.freq_compression[1](x3_freq)
        x4_freq = self.freq_compression[2](x4_freq)

        x2_ = self.conv_list[0](torch.cat([x2, x2_freq], dim=1))
        x3_ = F.interpolate(self.conv_list[1](torch.cat([x3, x3_freq], dim=1)),
                            size=x2.shape[2:], mode='bilinear', align_corners=False)
        x4_ = F.interpolate(self.conv_list[2](torch.cat([x4, x4_freq], dim=1)),
                            size=x2.shape[2:], mode='bilinear', align_corners=False)

        fused = x2_ + x3_ + x4_
        coarse_map = self.out_conv(fused)
        coarse_map = torch.sigmoid(coarse_map)

        return coarse_map


### Identification Class

In [7]:
class IdentificationModule(nn.Module):
    def __init__(self, in_channels=2048):
        super(IdentificationModule, self).__init__()

        self.freq_compression = nn.Conv2d(4 * in_channels, 4, kernel_size=1)
        self.conv_deep = nn.Conv2d(in_channels + 4, 256, kernel_size=3, padding=1)
        self.refine_conv = nn.Conv2d(256 + 256, 256, kernel_size=3, padding=1)
        self.out_conv = nn.Conv2d(256, 1, kernel_size=1)

        self.coarse_map_expand = nn.Conv2d(1, 256, kernel_size=1)

    def forward(self, x5, x5_freq, coarse_map):
        x5_freq = self.freq_compression(x5_freq)

        coarse_map = F.interpolate(coarse_map, size=x5.shape[2:], mode='bilinear', align_corners=False)
        coarse_map = self.coarse_map_expand(coarse_map)
        x5_ = self.conv_deep(torch.cat([x5, x5_freq], dim=1))
        x5_up = F.interpolate(x5_, scale_factor=8, mode='bilinear', align_corners=False)

        coarse_map = F.interpolate(coarse_map, size=x5_up.shape[2:], mode='bilinear', align_corners=False)

        refine_input = torch.cat([x5_up, coarse_map], dim=1)
        refine_feat = self.refine_conv(refine_input)
        out_map = self.out_conv(refine_feat)
        out_map = torch.sigmoid(out_map)

        return out_map


### SINet Class

In [8]:
class SINet(nn.Module):
    def __init__(self, backbone_pretrained=True):
        super(SINet, self).__init__()
        self.backbone = ResNetBackbone(pretrained=backbone_pretrained)
        self.search = SearchModule(in_channels_list=[256, 512, 1024])
        self.identify = IdentificationModule(in_channels=2048)

    def forward(self, x):
        x1, x2, x3, x4, x5, x2_freq, x3_freq, x4_freq, x5_freq = self.backbone(x)
        coarse_map = self.search(x2, x3, x4, x2_freq, x3_freq, x4_freq)
        refine_map = self.identify(x5, x5_freq, coarse_map)
        out_final = F.interpolate(refine_map, scale_factor=4, mode='bilinear', align_corners=False)

        return out_final, coarse_map


### Evaluation Methods

In [9]:
def compute_batch_metrics(pred, target, threshold=0.5):

    pred_bin = (pred >= threshold).float()

    eps = 1e-7
    batch_size = pred.shape[0]

    acc_list, prec_list, rec_list, f1_list, iou_list = [], [], [], [], []

    for i in range(batch_size):
        p = pred_bin[i].view(-1)   
        t = target[i].view(-1)    

        TP = (p * t).sum().item()
        FP = (p * (1 - t)).sum().item()
        FN = ((1 - p) * t).sum().item()
        TN = ((1 - p) * (1 - t)).sum().item()


        acc = (TP + TN) / (TP + TN + FP + FN + eps)
       
        prec = TP / (TP + FP + eps)

        rec = TP / (TP + FN + eps)

        f1 = 2 * prec * rec / (prec + rec + eps)
  
        union = TP + FP + FN
        iou = TP / (union + eps)

        acc_list.append(acc)
        prec_list.append(prec)
        rec_list.append(rec)
        f1_list.append(f1)
        iou_list.append(iou)


    metrics = {
        'accuracy': np.mean(acc_list),
        'precision': np.mean(prec_list),
        'recall': np.mean(rec_list),
        'f1': np.mean(f1_list),
        'iou': np.mean(iou_list)
    }
    return metrics

def S_measure(pred, gt):
    """
    Calcola la similarità strutturale tra la mappa predetta e la ground truth.
    """
    pred_np = pred.squeeze().cpu().numpy()
    gt_np = gt.squeeze().cpu().numpy()
    return ssim(pred_np, gt_np, data_range=1.0)

def E_measure(pred, gt):
    """
    Calcola la E-measure combinando precisione locale e statistiche globali.
    """
    pred = pred.squeeze().cpu().numpy()
    gt = gt.squeeze().cpu().numpy()
    pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
    
    # Similarità tra pixel con pesi
    diff = np.abs(pred - gt)
    return 1 - np.mean(diff)

def weighted_F_measure(pred, gt, beta=1):
    """
    Calcola il weighted F-measure per valutare la qualità della segmentazione.
    """
    pred = pred.squeeze().cpu().numpy() > 0.5
    gt = gt.squeeze().cpu().numpy() > 0.5
    TP = np.sum(pred * gt)
    FP = np.sum(pred * (1 - gt))
    FN = np.sum((1 - pred) * gt)
    
    precision = TP / (TP + FP + 1e-8)
    recall = TP / (TP + FN + 1e-8)
    
    F_beta = (1 + beta**2) * (precision * recall) / (beta**2 * precision + recall + 1e-8)
    return F_beta

def mean_absolute_error(pred, gt):
    """
    Calcola l'errore medio assoluto tra la mappa predetta e la ground truth.
    """
    return torch.abs(pred - gt).mean().item()

### Dice Loss

In [10]:
def dice_loss(pred, target, smooth=1.0):

    pred = pred.view(-1)
    target = target.view(-1)
    intersection = (pred * target).sum()
    return 1 - ((2.0 * intersection + smooth) / (pred.sum() + target.sum() + smooth))

def boundary_loss(pred, target):
    kernel = torch.tensor([[[[-1, 0, 1], 
                             [-2, 0, 2], 
                             [-1, 0, 1]]]], dtype=torch.float32, device=target.device)  # Forziamo float32
    
    edge_target = torch.abs(F.conv2d(target, kernel, padding=1))
    edge_pred = torch.abs(F.conv2d(pred, kernel, padding=1))

    return F.l1_loss(edge_pred, edge_target)


def combined_loss(pred, target):
    return dice_loss(pred, target) + F.binary_cross_entropy(pred, target) + boundary_loss(pred, target)


### Train Method

In [11]:
def train_one_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0.0
    for images, masks in dataloader:
        images = images.to(device)
        masks = masks.to(device)
        #print('prova', images.device, masks.device)

        optimizer.zero_grad()
        out_final, out_coarse = model(images)

        loss_final = dice_loss(out_final, masks) + F.binary_cross_entropy(out_final, masks)

        loss_coarse = dice_loss(out_coarse, F.interpolate(masks, size=out_coarse.shape[2:], mode='nearest')) \
                      + F.binary_cross_entropy(out_coarse, F.interpolate(masks, size=out_coarse.shape[2:], mode='nearest'))

        loss = loss_final + 0.5 * loss_coarse

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)

def validate_one_epoch(model, dataloader, device):
    model.eval()
    val_loss = 0.0
    all_acc, all_prec, all_rec, all_f1, all_iou = [], [], [], [], []

    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device)

            out_final, out_coarse = model(images)

            loss_final = dice_loss(out_final, masks) + F.binary_cross_entropy(out_final, masks)
            loss_coarse = dice_loss(out_coarse, F.interpolate(masks, size=out_coarse.shape[2:], mode='nearest')) \
                          + F.binary_cross_entropy(out_coarse, F.interpolate(masks, size=out_coarse.shape[2:], mode='nearest'))
            loss = loss_final + 0.5 * loss_coarse
            val_loss += loss.item()
            
            batch_metrics = compute_batch_metrics(out_final, masks, threshold=0.5)
            all_acc.append(batch_metrics['accuracy'])
            all_prec.append(batch_metrics['precision'])
            all_rec.append(batch_metrics['recall'])
            all_f1.append(batch_metrics['f1'])
            all_iou.append(batch_metrics['iou'])

    avg_loss = val_loss / len(dataloader)
    avg_metrics = {
        'accuracy': np.mean(all_acc),
        'precision': np.mean(all_prec),
        'recall': np.mean(all_rec),
        'f1': np.mean(all_f1),
        'iou': np.mean(all_iou)
    }
    return avg_loss, avg_metrics

### Test Method

In [12]:
def compute_iou(mask_np, pred_np):

    mask_np = mask_np.flatten()
    pred_np = pred_np.flatten()
    if mask_np.sum() == 0:
        return None
    return jaccard_score(mask_np, pred_np, zero_division=1)

def evaluate_segmentation(mask_np, pred_np):
    mask_np = mask_np.flatten()
    pred_np = pred_np.flatten()

    iou = jaccard_score(mask_np, pred_np, average="binary", zero_division=1)
    dice = f1_score(mask_np, pred_np, average="binary")
    precision = precision_score(mask_np, pred_np, average="binary")
    recall = recall_score(mask_np, pred_np, average="binary")

    return {
        "IoU": iou,
        "Dice": dice,
        "Precision": precision,
        "Recall": recall
    }

def test_model(model, dataloader, device, threshold=0.5, num_samples=20):
    model.eval()
    all_acc, all_prec, all_rec, all_f1, all_iou = [], [], [], [], []
    all_prec2, all_rec2, all_dice, all_iou2 = [], [], [], []  # forse da eliminare se non utile alla valutazione
    all_s_measure, all_e_measure, all_weighted_f_measure, all_mean_absolute_error = [], [], [], []  # forse da eliminare se non utile alla valutazione

    iou_scores = []

    with torch.no_grad():
        print(len(dataloader))
        for i, (images, masks) in enumerate(dataloader):
            print(f"Processing batch {i+1}/{len(dataloader)}...")  # DEBUG
            images = images.to(device)
            masks = masks.to(device)

            out_final, out_coarse = model(images)
            print("Model inference done.")  # DEBUG

            batch_metrics = compute_batch_metrics(out_final, masks, threshold=threshold)
            print("Metrics computed.")  # DEBUG

            s_measure = S_measure(out_final, masks)
            e_measure = E_measure(out_final, masks)
            weighted_f_measure = weighted_F_measure(out_final,masks)
            mean_Absolute_Error = mean_absolute_error(out_final, masks)

            all_s_measure.append(s_measure)     # forse da eliminare se non utile alla valutazione
            all_e_measure.append(e_measure)
            all_weighted_f_measure.append(weighted_f_measure)
            all_mean_absolute_error.append(mean_Absolute_Error)

            mask_np = (masks.squeeze().cpu().numpy() > 0.5).astype(int)
            pred_np = (out_final.squeeze().cpu().numpy() > 0.5).astype(int)
            iou = compute_iou(mask_np, pred_np)

            second_batch_metrics = evaluate_segmentation(mask_np, pred_np)   # forse da eliminare se non utile alla valutazione

            all_prec2.append(second_batch_metrics['Precision'])     # forse da eliminare se non utile alla valutazione
            all_rec2.append(second_batch_metrics['Recall'])
            all_dice.append(second_batch_metrics['Dice'])
            all_iou2.append(second_batch_metrics['IoU'])

            if iou is not None:
                iou_scores.append(iou)

            all_acc.append(batch_metrics['accuracy'])
            all_prec.append(batch_metrics['precision'])
            all_rec.append(batch_metrics['recall'])
            all_f1.append(batch_metrics['f1'])
            all_iou.append(batch_metrics['iou'])

            # Plot delle prime immagini
            if i < num_samples:
                plot_results(images.cpu(), masks.cpu(), out_final.cpu(), i)

    avg_metrics = {
        'accuracy': np.mean(all_acc),
        'precision': np.mean(all_prec),
        'recall': np.mean(all_rec),
        'f1': np.mean(all_f1),
        'iou': np.mean(all_iou)
    }

    avg_metrics2 = {                             # forse da eliminare se non utile alla valutazione
        'Precision': np.mean(all_prec2),
        'Recall': np.mean(all_rec2),
        'Dice': np.mean(all_dice),
        'IoU': np.mean(all_iou2)
    }

    avg_metrics3 = {                             # forse da eliminare se non utile alla valutazione
        'S-Measure': np.mean(all_s_measure),
        'E-Measure': np.mean(all_e_measure),
        'WFM': np.mean(all_weighted_f_measure),
        'MAE': np.mean(all_mean_absolute_error)
    }

    print(f"Mean IoU (filtered): {sum(iou_scores) / len(iou_scores):.2f}")
    return avg_metrics, avg_metrics2, avg_metrics3

def plot_results(images, masks, predictions, batch_idx):
    """
    Mostra le immagini originali, le maschere reali e le previsioni del modello.
    """
    batch_size = images.shape[0]
    fig, axes = plt.subplots(batch_size, 3, figsize=(10, batch_size * 3))

    for i in range(batch_size):
        img = images[i].permute(1, 2, 0).numpy()  # Converti da tensor a numpy
        mask = masks[i].squeeze().numpy()
        pred = (predictions[i].squeeze().numpy() > 0.5).astype(np.uint8)  # Applica threshold

        if batch_size == 1:
            axes[0].imshow(img)
            axes[0].set_title("Immagine originale")
            axes[1].imshow(mask, cmap="gray")
            axes[1].set_title("Maschera reale")
            axes[2].imshow(pred, cmap="gray")
            axes[2].set_title("Previsione modello")
        else:
            axes[i, 0].imshow(img)
            axes[i, 0].set_title(f"Immagine {batch_idx * batch_size + i}")
            axes[i, 1].imshow(mask, cmap="gray")
            axes[i, 1].set_title("Maschera reale")
            axes[i, 2].imshow(pred, cmap="gray")
            axes[i, 2].set_title("Previsione modello")

        for ax in axes[i] if batch_size > 1 else axes:
            ax.axis("off")

    plt.tight_layout()
    plt.show()


### Train split function

In [13]:
def split_train_dataset(train_dir, val_ratio=0.2):
    train_img_dir = os.path.join(train_dir, 'Image')
    train_mask_dir = os.path.join(train_dir, 'GT_Object')
    val_img_dir = os.path.join(train_dir, '../Validation/Image')
    val_mask_dir = os.path.join(train_dir, '../Validation/GT_Object')

    os.makedirs(val_img_dir, exist_ok=True)
    os.makedirs(val_mask_dir, exist_ok=True)

    image_files = sorted(glob.glob(os.path.join(train_img_dir, '*.jpg')))
    mask_files = sorted(glob.glob(os.path.join(train_mask_dir, '*.png')))

    assert len(image_files) == len(mask_files), "Il numero di immagini e maschere non coincide!"

    random.seed(42)
    indices = list(range(len(image_files)))
    random.shuffle(indices)

    split_idx = int(len(indices) * (1 - val_ratio))
    train_indices, val_indices = indices[:split_idx], indices[split_idx:]

    for idx in val_indices:
        shutil.move(image_files[idx], os.path.join(val_img_dir, os.path.basename(image_files[idx])))
        shutil.move(mask_files[idx], os.path.join(val_mask_dir, os.path.basename(mask_files[idx])))

    print(f"Train: {len(train_indices)}, Validation: {len(val_indices)}")
    return train_img_dir, train_mask_dir, val_img_dir, val_mask_dir

### Main

In [None]:
# Esempio di utilizzo
train_dataset_dir = 'COD10K-v3/Train'  # Sostituire con il percorso reale


def main():
    batch_size = 16
    num_epochs = 30
    lr = 1e-4


    # train_img_dir, train_mask_dir, val_img_dir, val_mask_dir = split_train_dataset(train_dataset_dir)
    train_img_dir = 'COD10K-v3/Train/Image'
    train_mask_dir = 'COD10K-v3/Train/GT_Object'
    val_img_dir = 'COD10K-v3/Validation/Image'
    val_mask_dir = 'COD10K-v3/Validation/GT_Object'
    
    
    test_img_dir = 'COD10K-v3/Test/Image'
    test_mask_dir = 'COD10K-v3/Test/GT_Object'

    os.makedirs(val_img_dir, exist_ok=True)
    os.makedirs(val_mask_dir, exist_ok=True)

# Ottieni tutti i file
    image_files = sorted(glob.glob(os.path.join(test_img_dir, '*.jpg')))
    mask_files = sorted(glob.glob(os.path.join(test_mask_dir, '*.png')))

# Assicurati che il numero di immagini e maschere corrisponda
    assert len(image_files) == len(mask_files), "Numero di immagini e maschere non coincide!"

# Mescola gli indici
    random.seed(42)  # Per la riproducibilità
    indices = list(range(len(image_files)))
    random.shuffle(indices)

# Definisci il rapporto di split (es. 80% test, 20% validation)
    split_idx = int(len(indices) * 0.8)
    test_indices, val_indices = indices[:split_idx], indices[split_idx:]

# Sposta i file nelle rispettive cartelle
    for idx in val_indices:
        shutil.move(image_files[idx], os.path.join(val_img_dir, os.path.basename(image_files[idx])))
        shutil.move(mask_files[idx], os.path.join(val_mask_dir, os.path.basename(mask_files[idx])))

    print(f"Train images: {len(os.listdir(train_img_dir))}")
    print(f"Train masks: {len(os.listdir(train_mask_dir))}")

    train_dataset = CODDataset(train_img_dir, train_mask_dir, transform=None)
    sample = train_dataset[0]
    print(f"Sample keys: {sample.keys()}" if isinstance(sample, dict) else "Sample loaded")
    val_dataset = CODDataset(val_img_dir, val_mask_dir, transform=None)
    test_dataset = CODDataset(test_img_dir, test_mask_dir, transform=None)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

    model = SINet(backbone_pretrained=True).to(device)
    #print(next(model.parameters()).device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    print("Starting training...")
    
    for epoch in range(num_epochs):
        start_time = time.time()
        train_loss = train_one_epoch(model, train_loader, optimizer, device)
        val_loss, val_metrics = validate_one_epoch(model, val_loader, device)  # Scompatta la tupla
        epoch_duration = timedelta(seconds=time.time() - start_time)
        print(f"Epoch [{epoch+1}/{num_epochs}] - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Epoch Duration: {epoch_duration}")

    torch.save(model.state_dict(), "sinet_camouflage.pth")
    print("Training completato e modello salvato.")

    test_metrics = test_model(model, test_loader, device, threshold=0.5)
    print("RISULTATI TEST FINALI:")
    print(f"  Accuracy = {test_metrics['accuracy']:.3f}")
    print(f"  Precision = {test_metrics['precision']:.3f}")
    print(f"  Recall = {test_metrics['recall']:.3f}")
    print(f"  F1-score = {test_metrics['f1']:.3f}")
    print(f"  IoU = {test_metrics['iou']:.3f}")

if __name__ == '__main__':
    main()

In [None]:
# # model = SINet(backbone_pretrained=True).to(device)
# # model.load_state_dict(torch.load('sinet_pretrained_120epoch.pth'))

# test_img_dir = 'COD10K-v3/Test/Image'
# test_mask_dir = 'COD10K-v3/Test/GT_Object'

# test_dataset = CODDataset(test_img_dir, test_mask_dir, transform=None)

# test_loader = DataLoader(test_dataset, batch_size=40, shuffle=False, num_workers=0, pin_memory=True)

# test_metrics, test_metrics2, test_metrics3 = test_model(model, test_loader, device, threshold=0.5)
# print("RISULTATI TEST FINALI:")
# print(f"  Accuracy = {test_metrics['accuracy']:.3f}")
# print(f"  Precision = {test_metrics['precision']:.3f}")
# print(f"  Recall = {test_metrics['recall']:.3f}")
# print(f"  F1-score = {test_metrics['f1']:.3f}")
# print(f"  IoU = {test_metrics['iou']:.3f}")

# print(f"  Precision = {test_metrics2['Precision']:.3f}")       # forse da eliminare se non utile alla valutazione
# print(f"  Recall = {test_metrics2['Recall']:.3f}")
# print(f"  Dice = {test_metrics2['Dice']:.3f}")
# print(f"  IoU = {test_metrics2['IoU']:.3f}")

# print(f"  S-Measure = {test_metrics3['S-Measure']:.3f}")       # forse da eliminare se non utile alla valutazione
# print(f"  E-Measure = {test_metrics3['E-Measure']:.3f}")
# print(f"  WFM = {test_metrics3['WFM']:.3f}")
# print(f"  MAE = {test_metrics3['MAE']:.3f}")

