In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def focal_loss(input: torch.Tensor, target: torch.Tensor, gamma: float =2.0, reduction: str='mean', eps:float=1e-6):
    log_sigmoids = F.logsigmoid(input)
    prob = torch.exp(log_sigmoids)
    
    pos_weight = torch.pow(1. - prob, gamma)
    neg_weight = torch.pow(prob, gamma)
    
    focal = -(target*torch.log(prob+eps)*pos_weight+(1-target)*torch.log(1-prob+eps)*neg_weight)
    
    if reduction == 'None':
        return focal
    elif reduction == 'mean':
        return torch.mean(focal)
    elif reduction == 'sum':
        return torch.sum(focal)
    else:
         raise NotImplementedError("Invalid reduction mode: {}".format(reduction))

In [3]:
class BinaryFocalLoss(nn.Module):
    def __init__(self, gamma: float = 2.0, reduction: str = 'mean') -> None:
        super(BinaryFocalLoss, self).__init__()
        self.gamma = gamma
        self.reduction = reduction
        self.eps = 1e-10
        self.__name__ = 'BinaryFocalLoss'

    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        return focal_loss(input, target, self.gamma, self.reduction, self.eps)

In [4]:
bfloss = BinaryFocalLoss()
bceloss = nn.BCEWithLogitsLoss()

In [5]:
target = torch.randint(0, 2, size=(4, 1, 224, 224)).type(torch.float32)
input = target*(-0.3)+(1-target)*(-0.8)

In [6]:
target

tensor([[[[1., 0., 1.,  ..., 0., 0., 1.],
          [1., 1., 1.,  ..., 0., 0., 0.],
          [0., 1., 1.,  ..., 1., 0., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 1.],
          [0., 0., 1.,  ..., 0., 0., 1.]]],


        [[[0., 0., 1.,  ..., 1., 1., 0.],
          [1., 1., 0.,  ..., 1., 1., 0.],
          [0., 1., 0.,  ..., 0., 0., 1.],
          ...,
          [0., 1., 1.,  ..., 0., 0., 1.],
          [0., 1., 1.,  ..., 0., 0., 0.],
          [0., 1., 1.,  ..., 1., 0., 1.]]],


        [[[1., 0., 0.,  ..., 1., 1., 1.],
          [1., 0., 0.,  ..., 1., 0., 0.],
          [1., 0., 0.,  ..., 0., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 0., 1., 0.],
          [1., 1., 1.,  ..., 1., 0., 1.],
          [0., 0., 1.,  ..., 0., 1., 0.]]],


        [[[1., 0., 0.,  ..., 1., 0., 0.],
          [1., 1., 1.,  ..., 0., 1., 1.],
          [1., 1., 1.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 1., 1., 0.],
    

In [7]:
torch.sigmoid(input)

tensor([[[[0.4256, 0.3100, 0.4256,  ..., 0.3100, 0.3100, 0.4256],
          [0.4256, 0.4256, 0.4256,  ..., 0.3100, 0.3100, 0.3100],
          [0.3100, 0.4256, 0.4256,  ..., 0.4256, 0.3100, 0.4256],
          ...,
          [0.4256, 0.4256, 0.4256,  ..., 0.4256, 0.3100, 0.3100],
          [0.3100, 0.3100, 0.3100,  ..., 0.3100, 0.3100, 0.4256],
          [0.3100, 0.3100, 0.4256,  ..., 0.3100, 0.3100, 0.4256]]],


        [[[0.3100, 0.3100, 0.4256,  ..., 0.4256, 0.4256, 0.3100],
          [0.4256, 0.4256, 0.3100,  ..., 0.4256, 0.4256, 0.3100],
          [0.3100, 0.4256, 0.3100,  ..., 0.3100, 0.3100, 0.4256],
          ...,
          [0.3100, 0.4256, 0.4256,  ..., 0.3100, 0.3100, 0.4256],
          [0.3100, 0.4256, 0.4256,  ..., 0.3100, 0.3100, 0.3100],
          [0.3100, 0.4256, 0.4256,  ..., 0.4256, 0.3100, 0.4256]]],


        [[[0.4256, 0.3100, 0.3100,  ..., 0.4256, 0.4256, 0.4256],
          [0.4256, 0.3100, 0.3100,  ..., 0.4256, 0.3100, 0.3100],
          [0.4256, 0.3100, 0.3100,  ..

In [8]:
bfloss(input, target)

tensor(0.1585)

In [9]:
bceloss(input, target)

tensor(0.6122)

In [10]:
input = target*(-0.3)+(1-target)*(-0.7)

In [11]:
bfloss(input, target)

tensor(0.1629)

In [12]:
bceloss(input, target)

tensor(0.6283)