## imports

In [25]:
import torch # '2.2.1+cu121'
import torch.nn.functional as F #' 2.2.1+cu121'
import torch.nn as nn #' 2.2.1+cu121'

## definitions


$$
focal\_loss = - (1-p_t)^{\gamma} * log(p_t)
$$

$$
focal\_loss = (1-p_t)^{\gamma} * CE(p_t)
$$


## operations

$$
(1-p_t)^{\gamma}
$$

In [18]:
# case 1

GAMMA = 2
focal_term = (1 - (-1))**GAMMA

print(f"focal term: {focal_term}")

CE_Loss = F.binary_cross_entropy_with_logits(torch.tensor([-1.0]), torch.tensor([1.0]))

print(f"CE loss: {CE_Loss}")

final_loss = focal_term * CE_Loss

print(f"final loss: {final_loss}")

focal term: 4
CE loss: 1.31326162815094
final loss: 5.25304651260376


In [19]:
#case 2

GAMMA = 2
focal_term = (1 - 0.2)**GAMMA

print(f"focal term: {focal_term}")

CE_Loss = F.binary_cross_entropy_with_logits(torch.tensor([0.2]), torch.tensor([1.0]))

print(f"CE loss: {CE_Loss}")

final_loss = focal_term * CE_Loss

print(f"final loss: {final_loss}")

focal term: 0.6400000000000001
CE loss: 0.5981389284133911
final loss: 0.3828088939189911


**this loss rewards the cases where it is more confident, and where it is wrong, it penalises it heavily.**

## understanding alpha

$$
\alpha(1-p_t)^{\gamma}
$$

In [30]:
# one hot encoding imitation
targets = torch.zeros(7)
targets[4] = 1
targets

tensor([0., 0., 0., 0., 1., 0., 0.])

In [31]:
alpha = 0.25
targets * alpha + (1 - alpha) * (1 - targets)

tensor([0.7500, 0.7500, 0.7500, 0.7500, 0.2500, 0.7500, 0.7500])

In [36]:
alpha = 0.8
targets * alpha + (1 - alpha) * (1 - targets)

tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.8000, 0.2000, 0.2000])

In [33]:
alpha = 0.5
targets * alpha + (1 - alpha) * (1 - targets)

tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000])

## torch implementation

In [48]:

class FocalLoss(nn.Module):
    '''
    Focal Loss - https://arxiv.org/abs/1708.02002
    '''

    def __init__(self, alpha=0.25, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, pred_logits, target):
        # compute CE
        ce = F.binary_cross_entropy_with_logits(pred_logits, target, reduction='none')
        
        #alpha = target * self.alpha + (1 - target) * (1 - self.alpha)

        # if selector
        pred = pred_logits.sigmoid() # activation
        pt = torch.where(target == 1, pred, 1 - pred) 

        return alpha * ( (1 - pt) ** self.gamma ) * ce
    

In [49]:
loss = FocalLoss(alpha=1, gamma=2)

result = loss(torch.tensor([0.2]), torch.tensor([1.0]))
result

tensor([0.0970])