# Dice Loss

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet_DL(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(UNet_DL, self).__init__()
        
        # Encoder
        self.enc1 = self.conv_block(in_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        
        # Bridge
        self.bridge = self.conv_block(512, 1024)
        
        # Decoder
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = self.conv_block(1024, 512)
        
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = self.conv_block(512, 256)
        
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = self.conv_block(256, 128)
        
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = self.conv_block(128, 64)
        
        # Final output layer
        self.out = nn.Conv2d(64, out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # ---------------- Encoder ----------------
        enc1 = self.enc1(x)                      # (B,64,H,W)
        enc2 = self.enc2(F.max_pool2d(enc1, 2))   # (B,128,H/2,W/2)
        enc3 = self.enc3(F.max_pool2d(enc2, 2))   # (B,256,H/4,W/4)
        enc4 = self.enc4(F.max_pool2d(enc3, 2))   # (B,512,H/8,W/8)

        # ---------------- Bridge ----------------
        bridge = self.bridge(F.max_pool2d(enc4, 2))  # (B,1024,H/16,W/16)

        # ---------------- Decoder ----------------
        # Up4
        dec4 = self.up4(bridge)                    # (B,512,H/8,W/8)
        dec4 = torch.cat([enc4, dec4], dim=1)      # (B,1024,H/8,W/8)
        dec4 = self.dec4(dec4)                     # (B,512,H/8,W/8)

        # Up3
        dec3 = self.up3(dec4)                      # (B,256,H/4,W/4)
        dec3 = torch.cat([enc3, dec3], dim=1)      # (B,512,H/4,W/4)
        dec3 = self.dec3(dec3)                     # (B,256,H/4,W/4)

        # Up2
        dec2 = self.up2(dec3)                      # (B,128,H/2,W/2)
        dec2 = torch.cat([enc2, dec2], dim=1)      # (B,256,H/2,W/2)
        dec2 = self.dec2(dec2)                     # (B,128,H/2,W/2)

        # Up1
        dec1 = self.up1(dec2)                      # (B,64,H,W)
        dec1 = torch.cat([enc1, dec1], dim=1)      # (B,128,H,W)
        dec1 = self.dec1(dec1)                     # (B,64,H,W)

        # Output
        out = self.out(dec1)                       # (B,out_channels,H,W)
        return out

def dice_loss(pred, target, smooth=1.0):
    """
    Compute Dice Loss for multi-class segmentation.
    
    Parameters:
      - pred: logits tensor of shape (B, C, H, W)
      - target: tensor of shape (B, H, W) with class indices
      - smooth: smoothing factor to avoid division by zero
    
    Returns:
      - dice loss (1 - average dice score)
    """
    # Apply softmax to logits to get class probabilities
    pred_probs = F.softmax(pred, dim=1)  # (B, C, H, W)
    
    # One-hot encode the target
    target_one_hot = F.one_hot(target, num_classes=pred.shape[1])  # (B, H, W, C)
    target_one_hot = target_one_hot.permute(0, 3, 1, 2).float()      # (B, C, H, W)
    
    # Flatten the tensors
    pred_flat = pred_probs.contiguous().view(pred_probs.shape[0], pred_probs.shape[1], -1)
    target_flat = target_one_hot.contiguous().view(target_one_hot.shape[0], target_one_hot.shape[1], -1)
    
    intersection = (pred_flat * target_flat).sum(dim=2)
    denominator = pred_flat.sum(dim=2) + target_flat.sum(dim=2)
    
    dice_score = (2 * intersection + smooth) / (denominator + smooth)
    loss = 1 - dice_score.mean()
    return loss


# Jaccard Loss

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet_JL(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(UNet_JL, self).__init__()
        
        # Encoder
        self.enc1 = self.conv_block(in_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        
        # Bridge
        self.bridge = self.conv_block(512, 1024)
        
        # Decoder
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = self.conv_block(1024, 512)
        
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = self.conv_block(512, 256)
        
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = self.conv_block(256, 128)
        
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = self.conv_block(128, 64)
        
        # Final output layer
        self.out = nn.Conv2d(64, out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # ---------------- Encoder ----------------
        enc1 = self.enc1(x)                      # (B,64,H,W)
        enc2 = self.enc2(F.max_pool2d(enc1, 2))   # (B,128,H/2,W/2)
        enc3 = self.enc3(F.max_pool2d(enc2, 2))   # (B,256,H/4,W/4)
        enc4 = self.enc4(F.max_pool2d(enc3, 2))   # (B,512,H/8,W/8)

        # ---------------- Bridge ----------------
        bridge = self.bridge(F.max_pool2d(enc4, 2))  # (B,1024,H/16,W/16)

        # ---------------- Decoder ----------------
        # Up4
        dec4 = self.up4(bridge)                    # (B,512,H/8,W/8)
        dec4 = torch.cat([enc4, dec4], dim=1)      # (B,1024,H/8,W/8)
        dec4 = self.dec4(dec4)                     # (B,512,H/8,W/8)

        # Up3
        dec3 = self.up3(dec4)                      # (B,256,H/4,W/4)
        dec3 = torch.cat([enc3, dec3], dim=1)      # (B,512,H/4,W/4)
        dec3 = self.dec3(dec3)                     # (B,256,H/4,W/4)

        # Up2
        dec2 = self.up2(dec3)                      # (B,128,H/2,W/2)
        dec2 = torch.cat([enc2, dec2], dim=1)      # (B,256,H/2,W/2)
        dec2 = self.dec2(dec2)                     # (B,128,H/2,W/2)

        # Up1
        dec1 = self.up1(dec2)                      # (B,64,H,W)
        dec1 = torch.cat([enc1, dec1], dim=1)      # (B,128,H,W)
        dec1 = self.dec1(dec1)                     # (B,64,H,W)

        # Output
        out = self.out(dec1)                       # (B,out_channels,H,W)
        return out

def jaccard_loss(pred, target, smooth=1e-6):
    """
    Computes the Jaccard Loss (IoU loss) for multi-class segmentation.
    
    Parameters:
      - pred: logits tensor of shape (B, C, H, W)
      - target: tensor of shape (B, H, W) with class indices
      - smooth: smoothing constant to avoid division by zero
      
    Returns:
      - Jaccard loss (1 - average Jaccard index across classes)
    """
    # Convert logits to probabilities using softmax
    pred_probs = F.softmax(pred, dim=1)  # shape: (B, C, H, W)
    
    # One-hot encode the target tensor
    target_one_hot = F.one_hot(target, num_classes=pred.shape[1])  # (B, H, W, C)
    target_one_hot = target_one_hot.permute(0, 3, 1, 2).float()      # (B, C, H, W)
    
    # Flatten the tensors: shape (B, C, H*W)
    pred_flat = pred_probs.view(pred_probs.shape[0], pred_probs.shape[1], -1)
    target_flat = target_one_hot.view(target_one_hot.shape[0], target_one_hot.shape[1], -1)
    
    intersection = (pred_flat * target_flat).sum(dim=2)
    total = (pred_flat + target_flat).sum(dim=2)
    union = total - intersection
    
    jaccard_index = (intersection + smooth) / (union + smooth)
    loss = 1 - jaccard_index.mean()
    return loss


# Focal Loss

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet_FL(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(UNet_FL, self).__init__()
        
        # Encoder
        self.enc1 = self.conv_block(in_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        
        # Bridge
        self.bridge = self.conv_block(512, 1024)
        
        # Decoder
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = self.conv_block(1024, 512)
        
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = self.conv_block(512, 256)
        
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = self.conv_block(256, 128)
        
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = self.conv_block(128, 64)
        
        # Final output layer
        self.out = nn.Conv2d(64, out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # ---------------- Encoder ----------------
        enc1 = self.enc1(x)                      # (B,64,H,W)
        enc2 = self.enc2(F.max_pool2d(enc1, 2))   # (B,128,H/2,W/2)
        enc3 = self.enc3(F.max_pool2d(enc2, 2))   # (B,256,H/4,W/4)
        enc4 = self.enc4(F.max_pool2d(enc3, 2))   # (B,512,H/8,W/8)

        # ---------------- Bridge ----------------
        bridge = self.bridge(F.max_pool2d(enc4, 2))  # (B,1024,H/16,W/16)

        # ---------------- Decoder ----------------
        # Up4
        dec4 = self.up4(bridge)                    # (B,512,H/8,W/8)
        dec4 = torch.cat([enc4, dec4], dim=1)      # (B,1024,H/8,W/8)
        dec4 = self.dec4(dec4)                     # (B,512,H/8,W/8)

        # Up3
        dec3 = self.up3(dec4)                      # (B,256,H/4,W/4)
        dec3 = torch.cat([enc3, dec3], dim=1)      # (B,512,H/4,W/4)
        dec3 = self.dec3(dec3)                     # (B,256,H/4,W/4)

        # Up2
        dec2 = self.up2(dec3)                      # (B,128,H/2,W/2)
        dec2 = torch.cat([enc2, dec2], dim=1)      # (B,256,H/2,W/2)
        dec2 = self.dec2(dec2)                     # (B,128,H/2,W/2)

        # Up1
        dec1 = self.up1(dec2)                      # (B,64,H,W)
        dec1 = torch.cat([enc1, dec1], dim=1)      # (B,128,H,W)
        dec1 = self.dec1(dec1)                     # (B,64,H,W)

        # Output
        out = self.out(dec1)                       # (B,out_channels,H,W)
        return out

# ---------------- Focal Loss Implementation ----------------

class FocalLoss(nn.Module):
    """
    Focal Loss for multi-class classification.
    
    Args:
        alpha (float): Weighting factor for the rare class. Default: 0.25.
        gamma (float): Focusing parameter for modulating factor (1-p). Default: 2.0.
        reduction (str): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'.
    """
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        """
        Args:
            inputs (Tensor): Raw output logits from the model with shape (B, C, H, W).
            targets (Tensor): Ground truth labels with shape (B, H, W) where each value is 0 ≤ targets[i] ≤ C-1.
        """
        # Compute the standard cross entropy loss (without reduction)
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        # Compute the probability for the true class
        pt = torch.exp(-ce_loss)
        # Compute the focal loss term
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss



# Tversky Loss

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet_TL(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(UNet_TL, self).__init__()
        
        # Encoder
        self.enc1 = self.conv_block(in_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        
        # Bridge
        self.bridge = self.conv_block(512, 1024)
        
        # Decoder
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = self.conv_block(1024, 512)
        
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = self.conv_block(512, 256)
        
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = self.conv_block(256, 128)
        
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = self.conv_block(128, 64)
        
        # Final output layer
        self.out = nn.Conv2d(64, out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # ---------------- Encoder ----------------
        enc1 = self.enc1(x)                      # (B,64,H,W)
        enc2 = self.enc2(F.max_pool2d(enc1, 2))   # (B,128,H/2,W/2)
        enc3 = self.enc3(F.max_pool2d(enc2, 2))   # (B,256,H/4,W/4)
        enc4 = self.enc4(F.max_pool2d(enc3, 2))   # (B,512,H/8,W/8)

        # ---------------- Bridge ----------------
        bridge = self.bridge(F.max_pool2d(enc4, 2))  # (B,1024,H/16,W/16)

        # ---------------- Decoder ----------------
        # Up4
        dec4 = self.up4(bridge)                    # (B,512,H/8,W/8)
        dec4 = torch.cat([enc4, dec4], dim=1)      # (B,1024,H/8,W/8)
        dec4 = self.dec4(dec4)                     # (B,512,H/8,W/8)

        # Up3
        dec3 = self.up3(dec4)                      # (B,256,H/4,W/4)
        dec3 = torch.cat([enc3, dec3], dim=1)      # (B,512,H/4,W/4)
        dec3 = self.dec3(dec3)                     # (B,256,H/4,W/4)

        # Up2
        dec2 = self.up2(dec3)                      # (B,128,H/2,W/2)
        dec2 = torch.cat([enc2, dec2], dim=1)      # (B,256,H/2,W/2)
        dec2 = self.dec2(dec2)                     # (B,128,H/2,W/2)

        # Up1
        dec1 = self.up1(dec2)                      # (B,64,H,W)
        dec1 = torch.cat([enc1, dec1], dim=1)      # (B,128,H,W)
        dec1 = self.dec1(dec1)                     # (B,64,H,W)

        # Output
        out = self.out(dec1)                       # (B,out_channels,H,W)
        return out

class TverskyLoss(nn.Module):
    """
    Tversky Loss for multi-class segmentation.
    
    The Tversky index is defined as:
    
        TI = TP / (TP + α * FN + β * FP)
    
    and the Tversky loss is:
    
        Loss = 1 - TI
    
    where:
      - TP: True Positives
      - FN: False Negatives
      - FP: False Positives
      - α, β: weights that control the penalty for FN and FP respectively.
      
    Typically, α + β = 1.
    """
    def __init__(self, alpha=0.5, beta=0.5, smooth=1e-6):
        super(TverskyLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.smooth = smooth

    def forward(self, inputs, targets):
        """
        Args:
            inputs (Tensor): Raw output logits from the model with shape (B, C, H, W).
            targets (Tensor): Ground truth labels with shape (B, H, W) where each value is 0 ≤ targets[i] ≤ C-1.
        """
        # Number of classes inferred from the model output
        num_classes = inputs.size(1)
        
        # Convert targets to one-hot encoding with shape (B, C, H, W)
        targets_one_hot = F.one_hot(targets, num_classes).permute(0, 3, 1, 2).float()
        
        # Apply softmax to get class probabilities
        probs = F.softmax(inputs, dim=1)
        
        # Compute the true positives, false negatives, and false positives per class
        dims = (0, 2, 3)  # Sum over batch and spatial dimensions
        true_pos  = torch.sum(probs * targets_one_hot, dims)
        false_neg = torch.sum(targets_one_hot * (1 - probs), dims)
        false_pos = torch.sum((1 - targets_one_hot) * probs, dims)
        
        # Compute the Tversky index for each class
        tversky_index = (true_pos + self.smooth) / (true_pos + self.alpha * false_neg + self.beta * false_pos + self.smooth)
        
        # Tversky loss is 1 minus the Tversky index
        loss = 1 - tversky_index
        return loss.mean()

# ---------------- Example Usage ----------------


# Lovász-Softmax Loss

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# --------------------- U-Net Model ---------------------
class UNet_LSL(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(UNet_LSL, self).__init__()
        
        # Encoder
        self.enc1 = self.conv_block(in_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        
        # Bridge
        self.bridge = self.conv_block(512, 1024)
        
        # Decoder
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = self.conv_block(1024, 512)
        
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = self.conv_block(512, 256)
        
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = self.conv_block(256, 128)
        
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = self.conv_block(128, 64)
        
        # Final output layer
        self.out = nn.Conv2d(64, out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # ---------------- Encoder ----------------
        enc1 = self.enc1(x)                      # (B, 64, H, W)
        enc2 = self.enc2(F.max_pool2d(enc1, 2))   # (B, 128, H/2, W/2)
        enc3 = self.enc3(F.max_pool2d(enc2, 2))   # (B, 256, H/4, W/4)
        enc4 = self.enc4(F.max_pool2d(enc3, 2))   # (B, 512, H/8, W/8)

        # ---------------- Bridge ----------------
        bridge = self.bridge(F.max_pool2d(enc4, 2))  # (B, 1024, H/16, W/16)

        # ---------------- Decoder ----------------
        # Up4
        dec4 = self.up4(bridge)                    # (B, 512, H/8, W/8)
        dec4 = torch.cat([enc4, dec4], dim=1)      # (B, 1024, H/8, W/8)
        dec4 = self.dec4(dec4)                     # (B, 512, H/8, W/8)

        # Up3
        dec3 = self.up3(dec4)                      # (B, 256, H/4, W/4)
        dec3 = torch.cat([enc3, dec3], dim=1)      # (B, 512, H/4, W/4)
        dec3 = self.dec3(dec3)                     # (B, 256, H/4, W/4)

        # Up2
        dec2 = self.up2(dec3)                      # (B, 128, H/2, W/2)
        dec2 = torch.cat([enc2, dec2], dim=1)      # (B, 256, H/2, W/2)
        dec2 = self.dec2(dec2)                     # (B, 128, H/2, W/2)

        # Up1
        dec1 = self.up1(dec2)                      # (B, 64, H, W)
        dec1 = torch.cat([enc1, dec1], dim=1)      # (B, 128, H, W)
        dec1 = self.dec1(dec1)                     # (B, 64, H, W)

        # Output
        out = self.out(dec1)                       # (B, out_channels, H, W)
        return out

# ----------------- Lovász-Softmax Loss -----------------
# Helper function: compute gradient of the Lovász extension w.r.t sorted errors
def lovasz_grad(gt_sorted):
    """
    Computes gradient of the Lovász extension.
    
    Args:
        gt_sorted (Tensor): Ground truth labels sorted in descending order of prediction errors.
    
    Returns:
        Tensor: Gradients.
    """
    p = len(gt_sorted)
    gts = gt_sorted.sum()
    intersection = gts - gt_sorted.float().cumsum(0)
    union = gts + (1 - gt_sorted).float().cumsum(0)
    jaccard = 1. - intersection / union
    if p > 1:
        jaccard[1:p] = jaccard[1:p] - jaccard[:-1]
    return jaccard

def flatten_probas(probas, labels, ignore=None):
    """
    Flattens predictions in shape [B, C, H, W] and labels in shape [B, H, W]
    to [P, C] and [P] respectively.
    
    Args:
        probas (Tensor): Class probabilities (after softmax) with shape (B, C, H, W).
        labels (Tensor): Ground truth labels with shape (B, H, W).
        ignore (int, optional): Label to ignore.
    
    Returns:
        Tuple[Tensor, Tensor]: Flattened probabilities and labels.
    """
    if ignore is None:
        probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, probas.size(1))
        labels = labels.view(-1)
    else:
        mask = labels != ignore
        probas = probas.permute(0, 2, 3, 1)[mask].contiguous().view(-1, probas.size(1))
        labels = labels[mask]
    return probas, labels

def lovasz_softmax_flat(probas, labels, classes='present'):
    """
    Computes the Lovász-Softmax loss from flattened predictions and labels.
    
    Args:
        probas (Tensor): Flattened class probabilities, shape [P, C].
        labels (Tensor): Flattened ground truth labels, shape [P].
        classes (str or list): 'present' to compute loss only over classes present in labels,
                               or a list of classes to average over.
    
    Returns:
        Tensor: Lovász-Softmax loss.
    """
    if probas.numel() == 0:
        # Only void pixels, the loss is zero
        return probas * 0.
    C = probas.size(1)
    losses = []
    class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
    for c in class_to_sum:
        fg = (labels == c).float()  # foreground for class c
        if classes == 'present' and fg.sum() == 0:
            continue
        errors = torch.abs(fg - probas[:, c])
        errors_sorted, perm = torch.sort(errors, descending=True)
        fg_sorted = fg[perm]
        grad = lovasz_grad(fg_sorted)
        losses.append(torch.dot(errors_sorted, grad))
    if len(losses) == 0:
        # If no class is present, return zero
        return torch.tensor(0.).to(probas.device)
    return sum(losses) / len(losses)

def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):
    """
    Multi-class Lovász-Softmax loss.
    
    Args:
        probas (Tensor): Class probabilities at each pixel, shape [B, C, H, W].
        labels (Tensor): Ground truth labels, shape [B, H, W].
        classes (str or list): See lovasz_softmax_flat.
        per_image (bool): Compute the loss per image instead of per batch.
        ignore (int, optional): Label to ignore.
    
    Returns:
        Tensor: Lovász-Softmax loss.
    """
    if per_image:
        loss = torch.mean(
            torch.stack([
                lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)
                for prob, lab in zip(probas, labels)
            ])
        )
        return loss
    else:
        probas, labels = flatten_probas(probas, labels, ignore)
        return lovasz_softmax_flat(probas, labels, classes=classes)

class LovaszSoftmaxLoss(nn.Module):
    """
    Lovász-Softmax loss module.
    
    Args:
        per_image (bool): Whether to compute the loss per image.
        ignore_index (int, optional): Label to ignore.
        classes (str or list): Which classes to include in the loss computation.
    """
    def __init__(self, per_image=False, ignore_index=None, classes='present'):
        super(LovaszSoftmaxLoss, self).__init__()
        self.per_image = per_image
        self.ignore_index = ignore_index
        self.classes = classes

    def forward(self, logits, labels):
        # Compute class probabilities
        probas = F.softmax(logits, dim=1)
        loss = lovasz_softmax(probas, labels, classes=self.classes,
                              per_image=self.per_image, ignore=self.ignore_index)
        return loss

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import pandas as pd
from tqdm import tqdm
from PIL import Image
import gc  # For garbage collection

IMG_HEIGHT = 640
IMG_WIDTH = 640
BATCH_SIZE = 2
EPOCHS = 100
NUM_CLASSES = 3
LEARNING_RATE = 0.001
PATIENCE = 10  # Early stopping
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

IMAGE_DIR = 'CWD-3HSV/train/images'
MASK_DIR  = 'CWD-3HSV/train/Morphed_Images'
VALID_IMAGE_DIR = 'CWD-3HSV/valid/images'
VALID_MASK_DIR  = 'CWD-3HSV/valid/Morphed_Images'

class SegmentationDataset(Dataset):
    def __init__(self, image_files, mask_files, transform=None):
        self.image_files = image_files
        self.mask_files  = mask_files
        self.transform   = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path  = self.image_files[idx]
        mask_path = self.mask_files[idx]
        image = Image.open(img_path).convert('RGB')
        mask  = Image.open(mask_path).convert('L')

        # Resize
        image = image.resize((IMG_WIDTH, IMG_HEIGHT))
        mask  = mask.resize((IMG_WIDTH, IMG_HEIGHT))

        if self.transform:
            image = self.transform(image)
            mask  = torch.tensor(np.array(mask), dtype=torch.long)

        return image, mask

# Prepare training file paths
image_files = [f for f in os.listdir(IMAGE_DIR) if f.endswith('.jpg')]
mask_files  = [f.replace('.jpg', '_morphed.png') for f in image_files]

valid_image_files = []
valid_mask_files  = []
for img_file in image_files:
    mask_file = img_file.replace('.jpg', '_morphed.png')
    if mask_file in os.listdir(MASK_DIR):
        valid_image_files.append(os.path.join(IMAGE_DIR, img_file))
        valid_mask_files.append(os.path.join(MASK_DIR,  mask_file))

val_image_files = [os.path.join(VALID_IMAGE_DIR, f) for f in os.listdir(VALID_IMAGE_DIR) if f.endswith('.jpg')]
val_mask_files  = [os.path.join(VALID_MASK_DIR,  f.replace('.jpg', '_morphed.png'))
                   for f in os.listdir(VALID_IMAGE_DIR) if f.endswith('.jpg')]

transform = transforms.Compose([transforms.ToTensor()])

train_dataset = SegmentationDataset(valid_image_files, valid_mask_files, transform=transform)
val_dataset   = SegmentationDataset(val_image_files,   val_mask_files,   transform=transform)

train_loader  = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader    = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False)

def calculate_iou(outputs, masks, num_classes):
    # outputs: (B, num_classes, H, W)
    outputs = torch.argmax(outputs, dim=1)
    iou_per_class = []
    for cls in range(num_classes):
        intersection = ((outputs == cls) & (masks == cls)).sum().item()
        union        = ((outputs == cls) | (masks == cls)).sum().item()
        if union == 0:
            iou_per_class.append(float('nan'))
        else:
            iou_per_class.append(intersection / union)
    return np.nanmean(iou_per_class)

def calculate_iou_loss(outputs, masks, num_classes):
    # 1 - mean IoU
    iou = calculate_iou(outputs, masks, num_classes)
    return 1 - iou

def train_epoch(model, data_loader, optimizer, criterion):
    model.train()
    running_loss     = 0.0
    running_iou_loss = 0.0
    correct          = 0
    total            = 0
    iou_score        = 0

    for images, masks in tqdm(data_loader, desc="Training", leave=False):
        images, masks = images.to(DEVICE), masks.to(DEVICE)

        optimizer.zero_grad()
        outputs  = model(images)
        loss     = criterion(outputs, masks)
        iou_loss = calculate_iou_loss(outputs, masks, NUM_CLASSES)

        loss.backward()
        optimizer.step()

        running_loss     += loss.item()
        running_iou_loss += iou_loss
        _, predicted     = torch.max(outputs, 1)
        total           += masks.numel()
        correct         += (predicted == masks).sum().item()
        iou_score       += calculate_iou(outputs, masks, NUM_CLASSES)

    epoch_loss      = running_loss / len(data_loader)
    epoch_iou_loss  = running_iou_loss / len(data_loader)
    epoch_accuracy  = correct / total * 100
    epoch_iou       = iou_score / len(data_loader)
    return epoch_loss, epoch_accuracy, epoch_iou, epoch_iou_loss

def evaluate(model, data_loader, criterion):
    model.eval()
    running_loss     = 0.0
    running_iou_loss = 0.0
    correct          = 0
    total            = 0
    iou_score        = 0

    with torch.no_grad():
        for images, masks in tqdm(data_loader, desc="Validation", leave=False):
            images, masks = images.to(DEVICE), masks.to(DEVICE)

            outputs  = model(images)
            loss     = criterion(outputs, masks)
            iou_loss = calculate_iou_loss(outputs, masks, NUM_CLASSES)

            running_loss     += loss.item()
            running_iou_loss += iou_loss
            _, predicted     = torch.max(outputs, 1)
            total           += masks.numel()
            correct         += (predicted == masks).sum().item()
            iou_score       += calculate_iou(outputs, masks, NUM_CLASSES)

    epoch_loss      = running_loss / len(data_loader)
    epoch_iou_loss  = running_iou_loss / len(data_loader)
    epoch_accuracy  = correct / total * 100
    epoch_iou       = iou_score / len(data_loader)
    return epoch_loss, epoch_accuracy, epoch_iou, epoch_iou_loss

def train_and_evaluate_model(model_name, model_class, 
                             train_loader, val_loader,
                             epochs=EPOCHS, patience=PATIENCE):
    """
    Train a given model with the specified name/class 
    and store best model + metrics in a separate folder.
    """
    # 1) Create directory for this model
    model_dir = model_name
    os.makedirs(model_dir, exist_ok=True)

    # 2) Instantiate model + move to device
    model = model_class(in_channels=3, out_channels=NUM_CLASSES).to(DEVICE)

    # 3) Define optimizer + loss
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    best_val_loss   = float('inf')
    best_model_path = None
    patience_counter= 0
    records         = []

    for epoch in range(epochs):
        print(f"\n[{model_name}] Epoch {epoch+1}/{epochs}")

        train_loss, train_accuracy, train_iou, train_iou_loss = train_epoch(model, train_loader, optimizer, criterion)
        print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, "
              f"Train IoU: {train_iou:.4f}, Train IoU Loss: {train_iou_loss:.4f}")

        val_loss, val_accuracy, val_iou, val_iou_loss = evaluate(model, val_loader, criterion)
        print(f"Val Loss:   {val_loss:.4f}, Val Accuracy:   {val_accuracy:.2f}%, "
              f"Val IoU:   {val_iou:.4f}, Val IoU Loss:   {val_iou_loss:.4f}")

        records.append([
            epoch+1, 
            train_loss, 
            train_accuracy, 
            val_loss, 
            val_accuracy, 
            train_iou_loss, 
            train_iou, 
            val_iou_loss, 
            val_iou
        ])

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_path = os.path.join(model_dir, "unet_best_model.pth")
            torch.save(model, best_model_path)
            print(f"  [*] Best model saved at {best_model_path}")
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"  [!] Early stopping for {model_name}")
            break

    # 4) Save Training_Metrics.xlsx in model_dir
    excel_path = os.path.join(model_dir, "Training_Metrics.xlsx")
    columns = [
        "Epoch", 
        "Training Loss", 
        "Training Accuracy", 
        "Validation Loss", 
        "Validation Accuracy", 
        "Training IoU loss", 
        "Mean Training IoU", 
        "Validation IoU loss", 
        "Mean Validation IoU"
    ]
    df = pd.DataFrame(records, columns=columns)
    df.to_excel(excel_path, index=False)
    print(f"  Metrics saved to {excel_path}")

    print(f"Done training {model_name}.\n")

    # Free up GPU memory after training this model
    del model
    gc.collect()  # Force garbage collection
    torch.cuda.empty_cache()

if __name__ == "__main__":
    # List of models you want to train
    models_to_train = {
        "Unet-DL"   : UNet_DL,
        "Unet-JL"   : UNet_JL,
        "Unet-FL"   : UNet_FL,
        "Unet-TL"   : UNet_TL,
        "Unet-LSL"  : UNet_LSL,
    }

    for model_name, model_class in models_to_train.items():
        train_and_evaluate_model(model_name, model_class, 
                                 train_loader, val_loader,
                                 epochs=EPOCHS, patience=PATIENCE)
        # Extra precaution: free any residual CUDA memory after each model training
        gc.collect()
        torch.cuda.empty_cache()


# Saving Prediction Images Of Each Model

In [None]:
import os
import random
import numpy as np
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

import matplotlib
matplotlib.use('Agg')  # Turn off interactive backend (no pop-up windows)

TEST_IMAGE_FOLDER       = 'CWD-3HSV/test/images'
GROUND_TRUTH_MASK_FOLDER= 'CWD-3HSV/test/Morphed_Images'
IMG_HEIGHT              = 640
IMG_WIDTH               = 640
NUM_CLASSES             = 3
DEVICE                  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


transform = transforms.Compose([
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.ToTensor(),
])


def load_full_model(model_path):
    model = torch.load(model_path, map_location=DEVICE)
    model = model.to(DEVICE)
    model.eval()
    return model

def preprocess_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(DEVICE)
    return image


def load_ground_truth_mask(mask_path):
    mask = Image.open(mask_path).convert('L')
    mask = mask.resize((IMG_WIDTH, IMG_HEIGHT), Image.NEAREST)
    return np.array(mask)


def generate_segmentation_mask(model, image):
    with torch.no_grad():
        output = model(image)        # (B, NUM_CLASSES, H, W)
        pred   = torch.argmax(output, dim=1)
        return pred.squeeze().cpu().numpy()


def visualize_and_save_comparison(
    model, input_image, gt_mask, class_rgb_mapping, input_image_path,
    save_folder, save_predictions=True
):
    # Load the original image for visualization
    original_image = Image.open(input_image_path).convert('RGB')
    original_image = original_image.resize((IMG_WIDTH, IMG_HEIGHT))

    # Generate predicted mask
    pred_mask = generate_segmentation_mask(model, input_image)

    # Map predicted mask to RGB
    rgb_pred_mask = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3), dtype=np.uint8)
    for class_id, rgb_value in class_rgb_mapping.items():
        rgb_pred_mask[pred_mask == class_id] = rgb_value

    # Map ground truth to RGB
    rgb_gt_mask = np.zeros((gt_mask.shape[0], gt_mask.shape[1], 3), dtype=np.uint8)
    for class_id, rgb_value in class_rgb_mapping.items():
        rgb_gt_mask[gt_mask == class_id] = rgb_value

    # Plot side-by-side
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    axes[0].imshow(original_image)
    axes[0].set_title('Original')
    axes[0].axis('off')

    axes[1].imshow(rgb_gt_mask)
    axes[1].set_title('Ground Truth')
    axes[1].axis('off')

    axes[2].imshow(rgb_pred_mask)
    axes[2].set_title('Predicted')
    axes[2].axis('off')

    fig.tight_layout()

    # Save the figure (no plt.show())
    if save_predictions:
        # For the figure
        fig_filename = os.path.splitext(os.path.basename(input_image_path))[0] + '_compare.png'
        fig_save_path= os.path.join(save_folder, fig_filename)
        fig.savefig(fig_save_path, bbox_inches='tight')

        # For the predicted mask alone
        pred_mask_img = Image.fromarray(rgb_pred_mask)
        pred_filename = os.path.splitext(os.path.basename(input_image_path))[0] + '_predmask.png'
        pred_save_path= os.path.join(save_folder, pred_filename)
        pred_mask_img.save(pred_save_path)

    plt.close(fig)

# -------------------------------------------------------------------
# Class-to-RGB mapping
# -------------------------------------------------------------------
class_rgb_mapping = {
    0: (0, 0, 0),    # black
    1: (0, 255, 0),  # green
    2: (255, 0, 0),  # red
}

# -------------------------------------------------------------------
# MAIN: 
# 1) Find all "Unet-..." directories
# 2) For each, load "unet_best_model.pth"
# 3) Randomly pick 5 test images, generate predictions
# 4) Save side-by-side figure + predicted mask
# -------------------------------------------------------------------
if __name__ == "__main__":
    # Silence any console printing
    # (Here we can reassign print to a no-op if needed)
    def no_op(*args, **kwargs):
        pass
    print = no_op

    # 1) Find directories that start with "Unet-"
    all_dirs = [d for d in os.listdir('.') if os.path.isdir(d) and d.startswith("Unet")]

    # 2) For each directory, load unet_best_model.pth
    # and do random predictions
    test_image_files = [
        f for f in os.listdir(TEST_IMAGE_FOLDER)
        if f.endswith('.jpg') or f.endswith('.jpeg') or f.endswith('.png')
    ]
    # If <5 images exist, use them all
    if len(test_image_files) <= 5:
        selected_images = test_image_files
    else:
        selected_images = random.sample(test_image_files, 5)

    for model_dir in all_dirs:
        model_path = os.path.join(model_dir, "unet_best_model.pth")
        if not os.path.isfile(model_path):
            continue  # skip if no best model in that dir

        # Create a subfolder "Predictions" inside model_dir
        pred_save_folder = os.path.join(model_dir, "Predictions")
        os.makedirs(pred_save_folder, exist_ok=True)

        # Load model
        model = load_full_model(model_path)

        # For each selected image, compare
        for image_file in selected_images:
            input_image_path = os.path.join(TEST_IMAGE_FOLDER, image_file)
            # ground truth
            gt_mask_name = os.path.splitext(image_file)[0] + '_morphed.png'
            gt_mask_path = os.path.join(GROUND_TRUTH_MASK_FOLDER, gt_mask_name)

            input_image = preprocess_image(input_image_path)
            ground_truth_mask = load_ground_truth_mask(gt_mask_path)

            # Visualize and save
            visualize_and_save_comparison(
                model=model,
                input_image=input_image,
                gt_mask=ground_truth_mask,
                class_rgb_mapping=class_rgb_mapping,
                input_image_path=input_image_path,
                save_folder=pred_save_folder,
                save_predictions=True
            )


# Evaluating Each Model

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import pandas as pd
from tqdm import tqdm
from PIL import Image
import time

from sklearn.metrics import precision_score, recall_score, f1_score
import matplotlib
matplotlib.use('Agg')  # Turn off interactive display
import matplotlib.pyplot as plt
import seaborn as sns


TEST_IMAGES_DIR = 'CWD-3HSV/test/images'
TEST_MASKS_DIR  = 'CWD-3HSV/test/Morphed_Images'
IMG_HEIGHT      = 640
IMG_WIDTH       = 640
NUM_CLASSES     = 3
DEVICE          = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Suppress console output
def no_op(*args, **kwargs):
    pass
print = no_op  # Overwrite default print

transform = transforms.Compose([
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.ToTensor(),
])


def load_model(model_path):
    model = torch.load(model_path, map_location=DEVICE)
    model.to(DEVICE)
    model.eval()
    return model

def preprocess_image(image_path):
    image = Image.open(image_path).convert('RGB')
    return transform(image).unsqueeze(0).to(DEVICE)

def preprocess_mask(mask_path):
    mask = Image.open(mask_path).convert('L')
    mask = mask.resize((IMG_WIDTH, IMG_HEIGHT), Image.NEAREST)
    return np.array(mask)


def generate_predictions(model, image):
    with torch.no_grad():
        output = model(image)            # [B, NUM_CLASSES, H, W]
        pred   = torch.argmax(output, 1) # pick class with max logit
        return pred.squeeze().cpu().numpy()


def calculate_iou(pred_mask, gt_mask, num_classes):
    iou_per_class = []
    for cls in range(num_classes):
        intersection = np.logical_and(pred_mask == cls, gt_mask == cls).sum()
        union        = np.logical_or(pred_mask == cls, gt_mask == cls).sum()
        iou = intersection / union if union > 0 else 0
        iou_per_class.append(iou)
    return iou_per_class

def calculate_dice(pred_mask, gt_mask, num_classes):
    dice_per_class = []
    for cls in range(num_classes):
        intersection = np.logical_and(pred_mask == cls, gt_mask == cls).sum()
        denom        = (np.sum(pred_mask == cls) + np.sum(gt_mask == cls))
        dice         = 2.0 * intersection / denom if denom > 0 else 0
        dice_per_class.append(dice)
    return dice_per_class

def calculate_jaccard(pred_mask, gt_mask, num_classes):
    jaccard_per_class = []
    for cls in range(num_classes):
        intersection = np.logical_and(pred_mask == cls, gt_mask == cls).sum()
        union        = np.logical_or(pred_mask == cls, gt_mask == cls).sum()
        jaccard      = intersection / union if union > 0 else 0
        jaccard_per_class.append(jaccard)
    return jaccard_per_class


def plot_confusion_matrix(cm, class_names, save_path, title="Confusion Matrix", fmt='d'):
    """
    'fmt': 'd' for counts, '.2f' for percentages
    Saves the figure to 'save_path'.
    """
    fig, ax = plt.subplots(figsize=(6,6))
    sns.heatmap(cm, annot=True, fmt=fmt, cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    ax.set_xlabel("Predicted")
    ax.set_ylabel("Actual")
    ax.set_title(title)
    plt.tight_layout()
    fig.savefig(save_path, bbox_inches='tight')
    plt.close(fig)


def evaluate_model(model, test_images_dir, test_masks_dir, num_classes, results_dir):
    all_pred = []
    all_gt   = []

    iou_per_class       = np.zeros(num_classes)
    dice_per_class      = np.zeros(num_classes)
    jaccard_per_class   = np.zeros(num_classes)
    accuracy_per_class  = np.zeros(num_classes)
    precision_per_class = np.zeros(num_classes)
    recall_per_class    = np.zeros(num_classes)
    f1_per_class        = np.zeros(num_classes)

    image_files  = [img for img in os.listdir(test_images_dir) if img.endswith('.jpg')]
    total_samples= 0

    for img_name in image_files:
        image_path = os.path.join(test_images_dir, img_name)
        mask_name  = img_name.replace('.jpg', '_morphed.png')
        mask_path  = os.path.join(test_masks_dir, mask_name)

        image   = preprocess_image(image_path)
        gt_mask = preprocess_mask(mask_path)
        pred_mask= generate_predictions(model, image)

        # Flatten
        all_pred.extend(pred_mask.flatten())
        all_gt.extend(gt_mask.flatten())

        # iou/dice/jaccard
        iou_sample     = calculate_iou(pred_mask,     gt_mask, num_classes)
        dice_sample    = calculate_dice(pred_mask,    gt_mask, num_classes)
        jaccard_sample = calculate_jaccard(pred_mask, gt_mask, num_classes)

        iou_per_class     += np.array(iou_sample)
        dice_per_class    += np.array(dice_sample)
        jaccard_per_class += np.array(jaccard_sample)

        # per-class metrics
        for cls in range(num_classes):
            tp = np.sum((pred_mask == cls) & (gt_mask == cls))
            fp = np.sum((pred_mask == cls) & (gt_mask != cls))
            fn = np.sum((pred_mask != cls) & (gt_mask == cls))
            total_class_pixels= np.sum(gt_mask == cls)
            accuracy_per_class[cls] += tp / (total_class_pixels + 1e-6)

            precision_per_class[cls]+= tp / (tp + fp + 1e-6)
            recall_per_class[cls]   += tp / (tp + fn + 1e-6)
            f1_per_class[cls]       += 2*tp / (2*tp + fp + fn + 1e-6)

        total_samples += 1

    # Normalize
    accuracy_per_class  /= total_samples
    precision_per_class /= total_samples
    recall_per_class    /= total_samples
    f1_per_class        /= total_samples
    iou_per_class       /= total_samples
    dice_per_class      /= total_samples
    jaccard_per_class   /= total_samples

    mean_dice     = np.mean(dice_per_class)
    mean_jaccard  = np.mean(jaccard_per_class)

    # freq weighted iou
    all_gt_arr = np.array(all_gt)
    class_counts= np.bincount(all_gt_arr, minlength=num_classes)
    frequency_weighted_iou= np.average(iou_per_class, weights=class_counts)

    # overall metrics
    all_pred_arr = np.array(all_pred)
    accuracy = (all_pred_arr == all_gt_arr).sum()/len(all_gt_arr)
    precision= precision_score(all_gt_arr, all_pred_arr, average='weighted', zero_division=1)
    recall   = recall_score(all_gt_arr,    all_pred_arr, average='weighted', zero_division=1)
    f1       = f1_score(all_gt_arr,        all_pred_arr, average='weighted', zero_division=1)

    mean_iou    = np.mean(iou_per_class)
    weighted_iou= np.average(iou_per_class, weights=np.bincount(all_gt_arr))

    # confusion matrix
    confusion_matrix_counts= np.zeros((num_classes, num_classes), dtype=np.int64)
    for gt_val, pr_val in zip(all_gt_arr, all_pred_arr):
        confusion_matrix_counts[gt_val, pr_val]+=1

    confusion_matrix_percent= np.zeros_like(confusion_matrix_counts, dtype=float)
    for r in range(num_classes):
        row_sum = confusion_matrix_counts[r,:].sum()
        if row_sum > 0:
            confusion_matrix_percent[r,:] = (confusion_matrix_counts[r,:]/ row_sum)*100

    class_names= [f"Class {i}" for i in range(num_classes)]

    # Save confusion matrices as PNG
    cm_counts_path  = os.path.join(results_dir, "Confusion_Matrix_Counts.png")
    cm_percent_path = os.path.join(results_dir, "Confusion_Matrix_Percent.png")
    plot_confusion_matrix(confusion_matrix_counts,  class_names, save_path=cm_counts_path,  title="Confusion Matrix (Counts)", fmt='d')
    plot_confusion_matrix(confusion_matrix_percent, class_names, save_path=cm_percent_path, title="Confusion Matrix (Percent)", fmt='.2f')

    return (
        accuracy, accuracy_per_class, precision, precision_per_class,
        recall, recall_per_class, f1, f1_per_class,
        iou_per_class, mean_iou, weighted_iou, frequency_weighted_iou,
        dice_per_class, jaccard_per_class, mean_dice, mean_jaccard
    )

def save_results_to_excel(
    model_name,
    accuracy, accuracy_per_class,
    precision, precision_per_class,
    recall, recall_per_class,
    f1, f1_per_class,
    iou_per_class, mean_iou,
    weighted_iou, frequency_weighted_iou,
    dice_per_class, jaccard_per_class,
    mean_dice, mean_jaccard,
    save_directory
):
    # We'll store the final xlsx in the same directory as the model
    excel_path = os.path.join(save_directory, "Performance_Evaluation_Metrics.xlsx")

    columns = (
        ['Model Name', 'Accuracy']
        + [f'Accuracy Class {i}' for i in range(len(accuracy_per_class))]
        + ['Precision'] + [f'Precision Class {i}' for i in range(len(precision_per_class))]
        + ['Recall'] + [f'Recall Class {i}' for i in range(len(recall_per_class))]
        + ['F1 Score'] + [f'F1 Score Class {i}' for i in range(len(f1_per_class))]
        + [f'IoU Class {i}' for i in range(len(iou_per_class))]
        + ['Mean IoU', 'Weighted IoU', 'Frequency Weighted IoU']
        + [f'Dice Coefficient Class {i}' for i in range(len(dice_per_class))]
        + ['Mean Dice']
        + [f'Jaccard Index Class {i}' for i in range(len(jaccard_per_class))]
        + ['Mean Jaccard']
    )

    new_row = {
        'Model Name': model_name,
        'Accuracy': accuracy,
        **{f'Accuracy Class {i}': acc for i, acc in enumerate(accuracy_per_class)},
        'Precision': precision,
        **{f'Precision Class {i}': prec for i, prec in enumerate(precision_per_class)},
        'Recall': recall,
        **{f'Recall Class {i}': r for i, r in enumerate(recall_per_class)},
        'F1 Score': f1,
        **{f'F1 Score Class {i}': f1c for i, f1c in enumerate(f1_per_class)},
        **{f'IoU Class {i}': iou for i, iou in enumerate(iou_per_class)},
        'Mean IoU': mean_iou,
        'Weighted IoU': weighted_iou,
        'Frequency Weighted IoU': frequency_weighted_iou,
        **{f'Dice Coefficient Class {i}': d for i, d in enumerate(dice_per_class)},
        'Mean Dice': mean_dice,
        **{f'Jaccard Index Class {i}': j for i, j in enumerate(jaccard_per_class)},
        'Mean Jaccard': mean_jaccard,
    }

    new_data = pd.DataFrame([new_row], columns=columns)

    # Overwrite any existing file to store only the last row
    new_data.to_excel(excel_path, index=False, header=True)

if __name__ == "__main__":
    # Suppress console output
    def no_op(*args, **kwargs):
        pass
    print = no_op  # Overwrite default print

    # 1) Find all directories that match "Unet-<backbone>"
    all_dirs = [d for d in os.listdir('.') if os.path.isdir(d) and d.startswith("Unet")]

    for model_dir in all_dirs:
        # 2) Load "unet_best_model.pth" in that directory, if exists
        model_path = os.path.join(model_dir, "unet_best_model.pth")
        if not os.path.isfile(model_path):
            continue  # skip if no best model found

        model = load_model(model_path)

        # 3) Evaluate
        results = evaluate_model(
            model=model,
            test_images_dir=TEST_IMAGES_DIR,
            test_masks_dir=TEST_MASKS_DIR,
            num_classes=NUM_CLASSES,
            results_dir=model_dir  # store confusion matrix PNG in same directory
        )

        # 4) Unpack
        (
            accuracy, accuracy_per_class,
            precision, precision_per_class,
            recall, recall_per_class,
            f1, f1_per_class,
            iou_per_class, mean_iou,
            weighted_iou, frequency_weighted_iou,
            dice_per_class, jaccard_per_class,
            mean_dice, mean_jaccard
        ) = results

        # 5) Save row to "Performance_Evaluation_Metrics.xlsx" in model_dir
        save_results_to_excel(
            model_name=model_dir,
            accuracy=accuracy,
            accuracy_per_class=accuracy_per_class,
            precision=precision,
            precision_per_class=precision_per_class,
            recall=recall,
            recall_per_class=recall_per_class,
            f1=f1,
            f1_per_class=f1_per_class,
            iou_per_class=iou_per_class,
            mean_iou=mean_iou,
            weighted_iou=weighted_iou,
            frequency_weighted_iou=frequency_weighted_iou,
            dice_per_class=dice_per_class,
            jaccard_per_class=jaccard_per_class,
            mean_dice=mean_dice,
            mean_jaccard=mean_jaccard,
            save_directory=model_dir
        )


# Saving Training Curves Of Each Model

In [None]:
import os
import pandas as pd
import matplotlib
matplotlib.use('Agg')  # So figures don't pop up; they are just saved
import matplotlib.pyplot as plt

def plot_training_curves_for_model(excel_path, output_dir):
    """
    Reads Training_Metrics.xlsx from excel_path and saves four plots in output_dir.
    """
    # Read Excel
    df = pd.read_excel(excel_path)

    # Extract columns
    epochs            = df['Epoch']
    train_loss        = df['Training Loss']
    val_loss          = df['Validation Loss']
    train_acc         = df['Training Accuracy']
    val_acc           = df['Validation Accuracy']
    train_iou_loss    = df['Training IoU loss']
    val_iou_loss      = df['Validation IoU loss']
    mean_train_iou    = df['Mean Training IoU']
    mean_val_iou      = df['Mean Validation IoU']

    # 1) Training vs Validation Loss
    plt.figure(figsize=(8,6))
    plt.plot(epochs, train_loss, label='Training Loss', marker='o')
    plt.plot(epochs, val_loss,   label='Validation Loss', marker='s')
    plt.title('Training vs Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    loss_plot_path = os.path.join(output_dir, 'Train_vs_Val_Loss.png')
    plt.savefig(loss_plot_path, bbox_inches='tight')
    plt.close()

    # 2) Training vs Validation Accuracy
    plt.figure(figsize=(8,6))
    plt.plot(epochs, train_acc, label='Training Accuracy', marker='o')
    plt.plot(epochs, val_acc,   label='Validation Accuracy', marker='s')
    plt.title('Training vs Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.grid(True)

    acc_plot_path = os.path.join(output_dir, 'Train_vs_Val_Accuracy.png')
    plt.savefig(acc_plot_path, bbox_inches='tight')
    plt.close()

    # 3) Training IoU Loss vs Validation IoU Loss
    plt.figure(figsize=(8,6))
    plt.plot(epochs, train_iou_loss, label='Training IoU Loss', marker='o')
    plt.plot(epochs, val_iou_loss,   label='Validation IoU Loss', marker='s')
    plt.title('Training vs Validation IoU Loss')
    plt.xlabel('Epoch')
    plt.ylabel('IoU Loss')
    plt.legend()
    plt.grid(True)

    iou_loss_plot_path = os.path.join(output_dir, 'Train_vs_Val_IoU_Loss.png')
    plt.savefig(iou_loss_plot_path, bbox_inches='tight')
    plt.close()

    # 4) Mean Training IoU vs Mean Validation IoU
    plt.figure(figsize=(8,6))
    plt.plot(epochs, mean_train_iou, label='Mean Training IoU', marker='o')
    plt.plot(epochs, mean_val_iou,   label='Mean Validation IoU', marker='s')
    plt.title('Mean Training IoU vs Mean Validation IoU')
    plt.xlabel('Epoch')
    plt.ylabel('IoU')
    plt.legend()
    plt.grid(True)

    iou_plot_path = os.path.join(output_dir, 'Mean_Train_vs_Val_IoU.png')
    plt.savefig(iou_plot_path, bbox_inches='tight')
    plt.close()

def main():
    # 1) Find directories named "Unet-..."
    unet_dirs = [d for d in os.listdir('.') if os.path.isdir(d) and d.startswith('Unet')]

    for unet_dir in unet_dirs:
        # 2) The path to the Training_Metrics.xlsx
        excel_path = os.path.join(unet_dir, 'Training_Metrics.xlsx')
        if not os.path.isfile(excel_path):
            continue  # skip if no metrics file

        # 3) Create "Training_Curves" subdir
        curves_dir = os.path.join(unet_dir, 'Training_Curves')
        os.makedirs(curves_dir, exist_ok=True)

        # 4) Generate + save plots
        plot_training_curves_for_model(excel_path, curves_dir)

if __name__ == "__main__":
    main()


# All Models Performance Evaluation Sheet

In [None]:
import os
import pandas as pd

# 1) Automatically discover directories named "Unet-..."
model_directories = [
    d for d in os.listdir('.') 
    if os.path.isdir(d) and d.startswith("Unet")
]

# 2) Each directory's Excel file name
excel_filename = "Performance_Evaluation_Metrics.xlsx"

# 3) Where to save the merged file
results_folder = "Results"
os.makedirs(results_folder, exist_ok=True)
merged_excel_path = os.path.join(results_folder, "All_Models_Performance_Evaluation_Metrics.xlsx")

# 4) Create an empty list to hold the last rows from each subdirectory
merged_rows = []

# 5) Loop over each Unet-<backbone> directory
for model_dir in model_directories:
    excel_path = os.path.join(model_dir, excel_filename)

    # Check if the file exists
    if not os.path.isfile(excel_path):
        print(f"Warning: {excel_path} not found. Skipping.")
        continue

    # Read entire Excel file
    df = pd.read_excel(excel_path)

    if df.empty:
        print(f"Warning: {excel_path} is empty. Skipping.")
        continue

    # Get the last (bottom) row
    last_row = df.iloc[[-1]].copy()
    merged_rows.append(last_row)

# 6) If we have rows, concatenate them; otherwise create empty DataFrame
if len(merged_rows) > 0:
    merged_df = pd.concat(merged_rows, ignore_index=True)
else:
    merged_df = pd.DataFrame()

# 7) Overwrite the final Excel file with these rows
merged_df.to_excel(merged_excel_path, index=False)
print(f"Merged file saved to: {merged_excel_path}")



# Saving Predictions For All Models

In [None]:
import os
import random
import numpy as np
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

# ------------------------------------------------------------------------------
# 1) Automatically gather all directories starting with "Unet-"
# ------------------------------------------------------------------------------
all_dirs = [d for d in os.listdir('.') if os.path.isdir(d)]
model_dirs = [d for d in all_dirs if d.startswith("Unet-")]
models_info = [(d, os.path.join(d, "unet_best_model.pth")) for d in model_dirs]

# ------------------------------------------------------------------------------
# 2) Define directories / file paths and hyperparameters
# ------------------------------------------------------------------------------
TEST_IMAGE_FOLDER        = 'CWD-3HSV/test/images/'
GROUND_TRUTH_MASK_FOLDER = 'CWD-3HSV/test/Morphed_Images/'
PREDICTION_SAVE_FOLDER   = 'Predictions'
os.makedirs(PREDICTION_SAVE_FOLDER, exist_ok=True)

IMG_HEIGHT  = 640
IMG_WIDTH   = 640
NUM_CLASSES = 3
DEVICE      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ------------------------------------------------------------------------------
# 3) Define transformation (as used during training)
# ------------------------------------------------------------------------------
transform = transforms.Compose([
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.ToTensor(),
])

# ------------------------------------------------------------------------------
# 4) Helper functions
# ------------------------------------------------------------------------------
def load_full_model(model_path):
    model = torch.load(model_path, map_location=DEVICE)
    model = model.to(DEVICE)
    model.eval()
    return model

def preprocess_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(DEVICE)
    return image

def load_ground_truth_mask(mask_path):
    mask = Image.open(mask_path).convert('L')
    mask = mask.resize((IMG_WIDTH, IMG_HEIGHT), Image.NEAREST)
    return np.array(mask)

def generate_segmentation_mask(model, image_tensor):
    with torch.no_grad():
        output = model(image_tensor)  # shape: (B, NUM_CLASSES, H, W)
        pred = torch.argmax(output, dim=1)
        return pred.squeeze().cpu().numpy()  # shape: (H, W)

def mask_to_rgb(mask_array):
    h, w = mask_array.shape
    rgb_image = np.zeros((h, w, 3), dtype=np.uint8)
    # Define the mapping for 3 classes
    class_rgb_mapping = {
        0: (0, 0, 0),      # Black for background (or class 0)
        1: (0, 255, 0),    # Green for class 1
        2: (255, 0, 0)     # Red for class 2
    }
    for cls, color in class_rgb_mapping.items():
        rgb_image[mask_array == cls] = color
    return rgb_image

# ------------------------------------------------------------------------------
# 5) Main execution: load models, select images, and create merged figure
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    # Load all models
    loaded_models = []
    for model_name, model_path in models_info:
        if not os.path.isfile(model_path):
            continue
        model = load_full_model(model_path)
        loaded_models.append((model_name, model))
    if len(loaded_models) == 0:
        exit(0)
    
    # Gather test images
    test_image_files = [f for f in os.listdir(TEST_IMAGE_FOLDER) if f.lower().endswith(('.jpg','.jpeg','.png'))]
    if len(test_image_files) == 0:
        exit(0)
    
    # Randomly select up to 5 images
    selected_images = random.sample(test_image_files, 5) if len(test_image_files) >= 5 else test_image_files

    # Set up figure:
    n_rows = len(selected_images)
    n_cols = 2 + len(loaded_models)  # 1: Input, 1: GT, rest: each model's prediction
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4.2 * n_rows))
    if n_rows == 1:
        axes = [axes]  # ensure axes is a list of rows

    # Process each image:
    for row_idx, image_file in enumerate(selected_images):
        input_image_path = os.path.join(TEST_IMAGE_FOLDER, image_file)
        gt_mask_name = os.path.splitext(image_file)[0] + '_morphed.png'
        gt_mask_path = os.path.join(GROUND_TRUTH_MASK_FOLDER, gt_mask_name)

        # Load and resize input image and ground truth mask for display
        original_image = Image.open(input_image_path).resize((IMG_WIDTH, IMG_HEIGHT))
        gt_mask_np = load_ground_truth_mask(gt_mask_path)
        gt_rgb = mask_to_rgb(gt_mask_np)

        # Preprocess image for model inference
        input_tensor = preprocess_image(input_image_path)

        # Column 0: Input image
        axes[row_idx][0].imshow(original_image)
        if row_idx == 0:
            axes[row_idx][0].set_title("Input Image", fontsize=27)
        axes[row_idx][0].axis('off')

        # Column 1: Ground truth mask
        axes[row_idx][1].imshow(gt_rgb)
        if row_idx == 0:
            axes[row_idx][1].set_title("Ground Truth Mask", fontsize=27)
        axes[row_idx][1].axis('off')

        # Next columns: Predictions from each model
        for model_i, (model_name, model_obj) in enumerate(loaded_models):
            pred_mask = generate_segmentation_mask(model_obj, input_tensor)
            pred_rgb = mask_to_rgb(pred_mask)
            col_idx = 2 + model_i
            axes[row_idx][col_idx].imshow(pred_rgb)
            if row_idx == 0:
                axes[row_idx][col_idx].set_title(f"{model_name} \nPredicted Mask", fontsize=27)
            axes[row_idx][col_idx].axis('off')

    plt.tight_layout()
    merged_filename = os.path.join(PREDICTION_SAVE_FOLDER, "All_Models_Predictions.png")
    plt.savefig(merged_filename, bbox_inches='tight',dpi=200)
    # Uncomment the next line if you wish to display the figure interactively
    plt.show()


# Saving Predictions of Each Model

In [None]:
import os
import random
import numpy as np
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

# ------------------------------------------------------------------------------
# 1) Automatically gather all directories starting with "Unet-"
# ------------------------------------------------------------------------------
model_dirs = [d for d in os.listdir('.') if os.path.isdir(d) and d.startswith("Unet")]

# ------------------------------------------------------------------------------
# 2) Define directories and hyperparameters
# ------------------------------------------------------------------------------
TEST_IMAGE_FOLDER        = 'CWD-3HSV/test/images/'
GROUND_TRUTH_MASK_FOLDER = 'CWD-3HSV/test/Morphed_Images/'
# (The merged prediction figure for each model will be saved in a "Predictions" subfolder of that model directory.)
IMG_HEIGHT  = 640
IMG_WIDTH   = 640
NUM_CLASSES = 3
DEVICE      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ------------------------------------------------------------------------------
# 3) Define the transformation (as used during training)
# ------------------------------------------------------------------------------
transform = transforms.Compose([
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.ToTensor(),
])

# ------------------------------------------------------------------------------
# 4) Helper functions
# ------------------------------------------------------------------------------
def load_full_model(model_dir):
    """Load the full model from <model_dir>/unet_best_model.pth."""
    model_path = os.path.join(model_dir, "unet_best_model.pth")
    if not os.path.isfile(model_path):
        return None
    model = torch.load(model_path, map_location=DEVICE)
    model = model.to(DEVICE)
    model.eval()
    return model

def preprocess_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(DEVICE)
    return image

def load_ground_truth_mask(mask_path):
    mask = Image.open(mask_path).convert('L')
    mask = mask.resize((IMG_WIDTH, IMG_HEIGHT), Image.NEAREST)
    return np.array(mask)

def generate_segmentation_mask(model, image_tensor):
    with torch.no_grad():
        # model outputs logits of shape (B, NUM_CLASSES, H, W)
        output = model(image_tensor)
        pred = torch.argmax(output, dim=1)
        return pred.squeeze().cpu().numpy()  # shape: (H, W)

def mask_to_rgb(mask_array):
    h, w = mask_array.shape
    rgb_image = np.zeros((h, w, 3), dtype=np.uint8)
    # Fixed mapping for 3 classes
    class_rgb_mapping = {
        0: (0, 0, 0),      # Black
        1: (0, 255, 0),    # Green
        2: (255, 0, 0)     # Red
    }
    for cls, color in class_rgb_mapping.items():
        rgb_image[mask_array == cls] = color
    return rgb_image

# ------------------------------------------------------------------------------
# 5) Main execution: Process each model directory
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    # Gather test images (all jpg/jpeg/png)
    test_image_files = [f for f in os.listdir(TEST_IMAGE_FOLDER) if f.lower().endswith(('.jpg','.jpeg','.png'))]
    if len(test_image_files) == 0:
        exit(0)
    
    # Randomly select up to 5 images
    selected_images = random.sample(test_image_files, 5) if len(test_image_files) >= 5 else test_image_files

    # Process each model directory that starts with "Unet-"
    for model_dir in model_dirs:
        model = load_full_model(model_dir)
        if model is None:
            continue
        
        # Create a Predictions subfolder inside the model directory
        predictions_dir = os.path.join(model_dir, "Predictions")
        os.makedirs(predictions_dir, exist_ok=True)
        
        # Set up a figure with one row per image and 3 columns (Input, Ground Truth, Predicted)
        n_rows = len(selected_images)
        n_cols = 3
        fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(4 * n_cols, 4.2 * n_rows))
        if n_rows == 1:
            axes = [axes]  # Ensure axes is a list of rows

        for row_idx, image_file in enumerate(selected_images):
            input_image_path = os.path.join(TEST_IMAGE_FOLDER, image_file)
            gt_mask_name = os.path.splitext(image_file)[0] + '_morphed.png'
            gt_mask_path = os.path.join(GROUND_TRUTH_MASK_FOLDER, gt_mask_name)
            
            # Load original image and ground truth for display
            original_image = Image.open(input_image_path).resize((IMG_WIDTH, IMG_HEIGHT))
            gt_mask_np = load_ground_truth_mask(gt_mask_path)
            gt_rgb = mask_to_rgb(gt_mask_np)
            
            # Preprocess image for inference
            input_tensor = preprocess_image(input_image_path)
            
            # Column 0: Input image
            axes[row_idx][0].imshow(original_image)
            if row_idx == 0:
                axes[row_idx][0].set_title("Input Image", fontsize=27)
            axes[row_idx][0].axis('off')
            
            # Column 1: Ground truth mask
            axes[row_idx][1].imshow(gt_rgb)
            if row_idx == 0:
                axes[row_idx][1].set_title("Ground Truth Mask", fontsize=27)
            axes[row_idx][1].axis('off')
            
            # Column 2: Predicted mask from this model
            pred_mask = generate_segmentation_mask(model, input_tensor)
            pred_rgb = mask_to_rgb(pred_mask)
            axes[row_idx][2].imshow(pred_rgb)
            if row_idx == 0:
                axes[row_idx][2].set_title(f"{model_dir} \nPredicted Mask", fontsize=27)
            axes[row_idx][2].axis('off')
        
        plt.tight_layout()
        merged_filename = os.path.join(predictions_dir, "All_Models_Predictions.png")
        # Save the figure at high resolution
        plt.savefig(merged_filename, bbox_inches='tight', dpi=300)
        plt.close(fig)


# Saving Results of each model

In [None]:
import os
import pandas as pd

# ------------------------------------------------------------------------------
# 1) Automatically gather all directories starting with "Unet-"
# ------------------------------------------------------------------------------
model_directories = [d for d in os.listdir('.') if os.path.isdir(d) and d.startswith("Unet")]

# Name of the Excel file in each model directory
excel_filename = "Performance_Evaluation_Metrics.xlsx"

# Loop over each found model directory
for model_dir in model_directories:
    # Construct the full path to the Excel file
    excel_path = os.path.join(model_dir, excel_filename)
    
    # Skip this directory if the file does not exist
    if not os.path.isfile(excel_path):
        continue
    
    # Read the Excel file into a DataFrame
    df = pd.read_excel(excel_path)
    if df.empty:
        continue

    # Get the last (bottom) row from the DataFrame
    last_row = df.iloc[-1]
    
    # ------------------ Overall Metrics ------------------
    overall_metrics = {
        "Metric": [
            "Accuracy",
            "Precision",
            "Recall",
            "F1 Score",
            "Mean IoU",
            "Frequency Weighted IoU",
            "Mean Dice",
            "Mean Jaccard"
        ],
        "Value": [
            last_row["Accuracy"],
            last_row["Precision"],
            last_row["Recall"],
            last_row["F1 Score"],
            last_row["Mean IoU"],
            last_row["Frequency Weighted IoU"],
            last_row["Mean Dice"],
            last_row["Mean Jaccard"]
        ]
    }
    overall_df = pd.DataFrame(overall_metrics)
    
    # ------------------ Per-Class Metrics ------------------
    per_class_data = {
        "Class": ["Class 0", "Class 1", "Class 2"],
        "Accuracy": [
            last_row["Accuracy Class 0"],
            last_row["Accuracy Class 1"],
            last_row["Accuracy Class 2"]
        ],
        "Precision": [
            last_row["Precision Class 0"],
            last_row["Precision Class 1"],
            last_row["Precision Class 2"]
        ],
        "Recall": [
            last_row["Recall Class 0"],
            last_row["Recall Class 1"],
            last_row["Recall Class 2"]
        ],
        "F1 Score": [
            last_row["F1 Score Class 0"],
            last_row["F1 Score Class 1"],
            last_row["F1 Score Class 2"]
        ],
        "IoU": [
            last_row["IoU Class 0"],
            last_row["IoU Class 1"],
            last_row["IoU Class 2"]
        ],
        "Dice": [
            last_row["Dice Coefficient Class 0"],
            last_row["Dice Coefficient Class 1"],
            last_row["Dice Coefficient Class 2"]
        ],
        "Jaccard": [
            last_row["Jaccard Index Class 0"],
            last_row["Jaccard Index Class 1"],
            last_row["Jaccard Index Class 2"]
        ]
    }
    per_class_df = pd.DataFrame(per_class_data)
    
    # ------------------------------------------------------------------------------
    # 3) Save the new Excel files in a Results subfolder of the model directory
    # ------------------------------------------------------------------------------
    results_dir = os.path.join(model_dir, "Results")
    os.makedirs(results_dir, exist_ok=True)
    
    overall_excel_path = os.path.join(results_dir, "Overall_Metrics.xlsx")
    per_class_excel_path = os.path.join(results_dir, "Per_Class_Metrics.xlsx")
    
    overall_df.to_excel(overall_excel_path, index=False)
    per_class_df.to_excel(per_class_excel_path, index=False)


# Seperating All Models Perdormance Evaluation Sheet

In [None]:
import os
import pandas as pd

# ------------------------------------------------------------------------------
# 1) Path to the merged Excel file
# ------------------------------------------------------------------------------
merged_excel_path = os.path.join("Results", "All_Models_Performance_Evaluation_Metrics.xlsx")

# ------------------------------------------------------------------------------
# 2) Read the Excel file into a DataFrame
# ------------------------------------------------------------------------------
df = pd.read_excel(merged_excel_path)

# ------------------------------------------------------------------------------
# 3) Mapping to change model names
# ------------------------------------------------------------------------------
name_mapping = {
    "unet_best_model.pth": "Unet",
    "unetplusplus_best_model.pth": "Unet++",
    "manet_best_model.pth": "MAnet",
    "linknet_best_model.pth": "Linknet",
    "fpn_best_model.pth": "FPN",
    "pspnet_best_model.pth": "PSPNet",
    "pan_best_model.pth": "PAN",
    "deeplabv3_best_model.pth": "DeepLabV3",
    "deeplabv3plus_best_model.pth": "DeepLabV3+",
    "upernet_best_model.pth": "UPerNet",
    "segformer_best_model.pth": "Segformer"
}

# Update the "Model Name" column based on the mapping.
# If a model name is not found in the mapping, leave it unchanged.
df["Model Name"] = df["Model Name"].apply(lambda x: name_mapping.get(x, x))

# ------------------------------------------------------------------------------
# 4) Create Overall Metrics DataFrame
# ------------------------------------------------------------------------------
overall_columns = [
    "Model Name",
    "Accuracy",
    "Precision",
    "Recall",
    "F1 Score",
    "Mean IoU",
    "Weighted IoU",
    "Frequency Weighted IoU",
    "Mean Dice",
    "Mean Jaccard"
]
overall_df = df[overall_columns].copy()

# ------------------------------------------------------------------------------
# 5) Create Per-Class Metrics DataFrame
# ------------------------------------------------------------------------------
per_class_columns = [
    "Model Name",
    "Accuracy Class 0", "Accuracy Class 1", "Accuracy Class 2",
    "Precision Class 0", "Precision Class 1", "Precision Class 2",
    "Recall Class 0", "Recall Class 1", "Recall Class 2",
    "F1 Score Class 0", "F1 Score Class 1", "F1 Score Class 2",
    "IoU Class 0", "IoU Class 1", "IoU Class 2",
    "Dice Coefficient Class 0", "Dice Coefficient Class 1", "Dice Coefficient Class 2",
    "Jaccard Index Class 0", "Jaccard Index Class 1", "Jaccard Index Class 2"
]
per_class_df = df[per_class_columns].copy()

# ------------------------------------------------------------------------------
# 6) Save the two DataFrames to Excel (overwrite if re-executed)
# ------------------------------------------------------------------------------
overall_excel_path = os.path.join("Results", "Overall_Metrics.xlsx")
per_class_excel_path = os.path.join("Results", "Per_Class_Metrics.xlsx")

overall_df.to_excel(overall_excel_path, index=False)
per_class_df.to_excel(per_class_excel_path, index=False)

print(f"Saved overall metrics to: {overall_excel_path}")
print(f"Saved per-class metrics to: {per_class_excel_path}")
