## FocalLoss
- $\mathrm{FL}\left(p_{\mathrm{t}}\right)=-\left(1-p_{\mathrm{t}}\right)^{\gamma} \log \left(p_{\mathrm{t}}\right)$

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [30]:
class FocalLoss(nn.Module):

    def __init__(self, weight=None, reduction='mean', gamma=2, eps=1e-7):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.eps = eps
        self.ce = torch.nn.CrossEntropyLoss(weight=weight, reduction=reduction)

    def forward(self, input, target):
        logp = self.ce(input, target)
        p = torch.exp(-logp)
        loss = (1 - p) ** self.gamma * logp
        return loss.mean()

In [31]:
# 假设词汇表的大小为3， 语料包含两个单词"2 0"
y = [2, 0, 1, 1]
# 假设模型对两个单词预测时，产生的logit分别是[2.0, -1.0, 3.0]和[1.0, 0.0, -0.5]
y_logits = [[2.0, -1.0, 3.0], [1.0, 0.0, -0.5], [2.0, 1.0, -0.5], [1, 8, 2]]

In [32]:
F.cross_entropy(torch.tensor(y_logits), torch.tensor(y,dtype=torch.int64))

tensor(0.5415)

In [33]:
FocalLoss(weight=torch.Tensor([0.25, 0.75, 0.75]), gamma=2)(torch.tensor(y_logits), torch.tensor(y,dtype=torch.int64))

tensor(0.1015)

In [34]:
class focal_loss(nn.Module):
    """https://github.com/yatengLG/Focal-Loss-Pytorch/blob/master/Focal_Loss.py"""
    def __init__(self, alpha=0.25, gamma=2, num_classes=3, size_average=True):
        """
        focal_loss损失函数, -α(1-yi)**γ *ce_loss(xi,yi)
        步骤详细的实现了 focal_loss损失函数.
        :param alpha:  阿尔法α, 类别权重. 当α是列表时,为各类别权重,当α为常数时,类别权重为[α, 1-α, 1-α, ....],常用于 目标检测算法中抑制背景类 , retainnet中设置为0.25
        :param gamma:  伽马γ, 难易分样本调节参数. retainnet中设置为2
        :param num_classes:     类别数量
        :param size_average:    损失计算方式,默认取均值
        """
        super(focal_loss,self).__init__()
        self.size_average = size_average
        # α可以以list方式输入,size:[num_classes] 用于对不同类别精细地赋予权重
        if isinstance(alpha, list):
            assert len(alpha) == num_classes
            self.alpha = torch.Tensor(alpha)
        else:
            assert alpha < 1   # 如果α为一个常数,则降低第一类的影响
            self.alpha = torch.zeros(num_classes)
            self.alpha[0] += alpha
            # α 最终为 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes]
            self.alpha[1:] += (1-alpha) 
        self.gamma = gamma

    def forward(self, preds, labels):
        """
        focal_loss损失计算
        :param preds:   预测类别. size:[B,N,C] or [B,C] 分别对应与检测与分类任务, B 批次, N检测框数, C类别数
        :param labels:  实际类别. size:[B,N] or [B]
        :return:
        """
        # assert preds.dim()==2 and labels.dim()==1
        preds = preds.view(-1, preds.size(-1))
        self.alpha = self.alpha.to(preds.device)
        preds_logsoft = F.log_softmax(preds, dim=1) # log_softmax
        preds_softmax = torch.exp(preds_logsoft)    # softmax
        # 这部分实现nll_loss ( crossempty = log_softmax + nll )
        preds_softmax = preds_softmax.gather(1, labels.view(-1,1))   
        preds_logsoft = preds_logsoft.gather(1, labels.view(-1,1))
        self.alpha = self.alpha.gather(0, labels.view(-1))
        print("alpha", self.alpha)
        # torch.pow((1-preds_softmax), self.gamma) 为focal loss中 (1-pt)**γ
        loss = -torch.mul(torch.pow((1 - preds_softmax), self.gamma), preds_logsoft)  
        print("loss:", loss)
        loss = torch.mul(self.alpha, loss.t())
        if self.size_average:
            loss = loss.mean()
        else:
            loss = loss.sum()
        return loss

In [35]:
fl = focal_loss()

In [36]:
fl(torch.tensor(y_logits), torch.tensor(y,dtype=torch.int64))

alpha tensor([0.7500, 0.2500, 0.7500, 0.7500])
loss: tensor([[2.5347e-02],
        [6.4078e-02],
        [7.6386e-01],
        [3.8641e-08]])


tensor(0.1520)

In [37]:
labels = torch.tensor(y,dtype=torch.int64)
labels

tensor([2, 0, 1, 1])

In [38]:
a = torch.tensor([0.2500, 0.7500, 0.75])
a

tensor([0.2500, 0.7500, 0.7500])

In [39]:
a.gather(0, labels.view(-1))

tensor([0.7500, 0.2500, 0.7500, 0.7500])