The code below is my extended implementation of sigmoid focal loss.

`torchvision`'s (v0.11.3) `sigmoid_focal_loss` implementation takes a simple average/sum across pixels. This is not ideal with severe foreground/background imbalance. There IS an `alpha` parameter that modifies foreground and background pixel losses. But this approach requires very high `alpha` values to combat imbalance. The loss would be driven by a few foreground pixels where the number of foreground pixels varies (think going from 20 objects to 5 objects). This results in noisy loss values.

Applying a class weight AFTER averaging pixel losses by class should provide more effective and stable loss values than applying a class weight BEFORE averaging.

Other loss functions may also benefit from similar treatment.

In [1]:
import torch
from torch import Tensor
from torchvision.ops import sigmoid_focal_loss


def balanced_focal_loss(inputs: Tensor,
                        targets: Tensor,
                        alpha: float = 0.5,
                        beta: float = 0.75,
                        gamma: float = 2) -> Tensor:
    """
    Compute per-sample binary focal loss as
    the weighted sum of the average foreground and average background loss.

    This function wraps torchvision.ops.sigmoid_focal_loss.

    Args:
        inputs:  float tensor of arbitrary shape,
                 the predictions for each example
        targets: float tensor with the same shape as inputs,
                 the binary classification label for each element in inputs,
                 (0 for the negative class and 1 for the positive class)
        alpha:   weighting factor in range (0,1) to balance
                 positive vs negative pixels
        beta:    weighting factor in range (0,1) to balance
                 positive vs negative pixel averages
        gamma:   exponent of the modulating factor (1 - p_t)
                 to balance easy vs hard examples.
    """
    pixel_loss = sigmoid_focal_loss(
        inputs=inputs,
        targets=targets,
        alpha=alpha,
        gamma=gamma,
        reduction='none')

    # nansum after mean turns nan to 0 when indexing gives no elements back
    fg_loss = torch.stack(
        [loss[t == 1].mean().nansum() for loss, t in zip(pixel_loss, targets)])
    bg_loss = torch.stack(
        [loss[t == 0].mean().nansum() for loss, t in zip(pixel_loss, targets)])

    return beta * fg_loss + (1 - beta) * bg_loss