In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [44]:
class FocalLoss(nn.Module):
    """ Focal Loss: https://arxiv.org/abs/1708.02002
    
    Parameters
    ----------
    gamma: int, float
        It modulates the loss for each classes.
    
    size_average: bool
        If True, it returns loss.mean(), else it returns loss.sum().
    
    Returns
    -------
    loss: torch.Tensor
        Calculated FocalLoss.
    """
    
    def __init__(self, gamma=2.0, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.size_average = size_average
    
    def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
        pt = F.log_softmax(y_pred, dim=1).exp()
        loss = -1 * ((1 - pt)**gamma * torch.log(pt))

        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()

In [45]:
criterion = FocalLoss()
criterion(y_true, y_pred)

tensor(0.2879)

In [46]:
y_true = torch.Tensor([0, 0, 1])
y_pred = torch.Tensor([
    [1.0, 0.0],
    [0.1, 0.9],
    [0.3, 0.7]
], requires_grad=True)

TypeError: new() received an invalid combination of arguments - got (list, requires_grad=bool), but expected one of:
 * (torch.device device)
 * (torch.Storage storage)
 * (Tensor other)
 * (tuple of ints size, torch.device device)
      didn't match because some of the keywords were incorrect: requires_grad
 * (object data, torch.device device)
      didn't match because some of the keywords were incorrect: requires_grad


In [41]:
pm = torch.nn.functional.log_softmax(y_pred, dim=1).exp()

In [42]:
gamma = 2.0

In [43]:
-1 * ((1 - pm)**gamma * torch.log(pm)).mean()

tensor(0.2879)