In [None]:
from segmentation_models_pytorch.losses import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE
from time import time
from typing import Optional
import torch
import segmentation_models_pytorch

class FocalLossVectorised(segmentation_models_pytorch.losses.FocalLoss):
    def __init__(
        self,
        mode: str,
        alpha: Optional[float] = None,
        gamma: Optional[float] = 2.0,
        ignore_index: Optional[int] = None,
        reduction: Optional[str] = "mean",
        normalized: bool = False,
        reduced_threshold: Optional[float] = None,
    ):
        
        super().__init__(mode = mode,alpha = alpha,gamma = gamma, ignore_index = ignore_index,reduction = reduction,
                         normalized = normalized,reduced_threshold = reduced_threshold)
    
    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        if self.mode in {BINARY_MODE, MULTILABEL_MODE}:
            y_true = y_true.view(-1)
            y_pred = y_pred.view(-1)

            if self.ignore_index is not None:
                # Filter predictions with ignore label from loss computation
                not_ignored = y_true != self.ignore_index
                y_pred = y_pred[not_ignored]
                y_true = y_true[not_ignored]

            loss = self.focal_loss_fn(y_pred, y_true)

        elif self.mode == MULTICLASS_MODE:
            num_classes = y_pred.size(1)

            if self.ignore_index is not None:
                y_true[y_true == self.ignore_index] = num_classes
                y_true_one_hot = torch.nn.functional.one_hot(y_true,num_classes = num_classes + 1)
                y_true_one_hot = y_true_one_hot[ : , : , : , : -1]

            else:     
                y_true_one_hot = torch.nn.functional.one_hot(y_true,num_classes = num_classes)

            y_true_one_hot = torch.permute(y_true_one_hot,(0,3,1,2))
            loss = num_classes * self.focal_loss_fn(y_pred, y_true_one_hot)

        return loss

In [2]:
num_classes = 20
batch_size = 128
resolution = 512
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'

In [3]:
vectorised_loss_fn = FocalLossVectorised(mode = 'multiclass',ignore_index = num_classes)
loss_fn = segmentation_models_pytorch.losses.FocalLoss(mode = 'multiclass',ignore_index = num_classes)

In [4]:
predictions = torch.randn((batch_size,num_classes,resolution,resolution)).to(device = device)
labels = torch.randint(low = 0,high = num_classes+1,size = (batch_size,resolution,resolution)).to(device = device)

In [5]:
def benchmark(function,predictions,labels,benchmark_iterations = 100):
    start_time = time()

    for _ in range(benchmark_iterations):
        loss = function(predictions,labels)

    end_time = time()

    average_time_taken = (end_time - start_time) / (benchmark_iterations)

    print(f"Average time taken by function {function} is {average_time_taken} seconds")

In [6]:
benchmark(loss_fn,predictions,labels)

Average time taken by function FocalLoss() is 0.3390256547927856 seconds


In [7]:
benchmark(vectorised_loss_fn,predictions,labels)

Average time taken by function FocalLossVectorised() is 0.11771584510803222 seconds


##### CHECKING THAT OUTPUT OF NEW CLASS IS CONSISTENT WITH THE OLD ONE

In [8]:
output_from_vectorised_fn = vectorised_loss_fn(predictions,labels)
output_from_old_fn = loss_fn(predictions,labels)

assert torch.allclose(output_from_vectorised_fn,output_from_old_fn)