Imports:

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

BCEWithLogitsLoss

In [2]:
bce_loss_fn = nn.BCEWithLogitsLoss()

Dice Loss:

The Dice coefficient measures overlap between prediction and ground truth:


In [3]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth

    def forward(self, logits, targets):
        """
        logits: raw output from model (no sigmoid) → shape (B,1,H,W)
        targets: ground truth mask → shape (B,1,H,W)
        """
        # Apply sigmoid to get probabilities
        probs = torch.sigmoid(logits)

        # Flatten
        probs = probs.view(-1)
        targets = targets.view(-1)

        # Intersection = sum(p * t)
        intersection = (probs * targets).sum()

        # Dice coefficient
        dice = (2 * intersection + self.smooth) / \
               (probs.sum() + targets.sum() + self.smooth)

        return 1 - dice

Combined Loss (BCE + Dice)

- BCE handles pixel-wise classification
- Dice handles region overlap
- Combined gives best IoU performance

In [4]:
class BCEDiceLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss()

    def forward(self, logits, targets):
        loss_bce = self.bce(logits, targets)
        loss_dice = self.dice(logits, targets)
        return loss_bce + loss_dice

Thresholding Utility:

Predictions are:

    logits → must apply sigmoid → threshold at 0.5

In [5]:
def apply_threshold(logits, threshold=0.5):
    """
    Convert raw logits → binary predictions {0,1}.
    """
    probs = torch.sigmoid(logits)
    return (probs > threshold).float()

IoU Metric (Intersection over Union):

In [6]:
def compute_iou(preds, targets, threshold=0.5, eps=1e-6):
    """
    preds: raw logits (B,1,H,W)
    targets: ground truth (B,1,H,W)
    """
    preds = apply_threshold(preds, threshold)

    preds = preds.view(-1)
    targets = targets.view(-1)

    intersection = (preds * targets).sum()
    union = preds.sum() + targets.sum() - intersection

    iou = (intersection + eps) / (union + eps)
    return iou.item()

Pixel Accuracy:

Measured as: (# of correctly classified pixels) / (total Pixels)

In [None]:
def compute_pixel_accuracy(preds, targets, threshold=0.5):
    preds = apply_threshold(preds, threshold)

    preds = preds.view(-1)
    targets = targets.view(-1)

    correct = (preds == targets).float().sum()
    total = preds.numel()

    return (correct / total).item()