In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [32]:
N = 16
C = 5

In [33]:
pred = torch.rand(N,C)
pred

tensor([[0.2535, 0.5691, 0.5391, 0.4035, 0.9875],
        [0.9943, 0.6627, 0.6124, 0.3754, 0.1337],
        [0.6902, 0.9178, 0.7491, 0.6236, 0.6748],
        [0.3196, 0.9011, 0.0685, 0.7702, 0.8129],
        [0.9854, 0.6996, 0.9426, 0.9156, 0.6943],
        [0.3129, 0.1314, 0.3183, 0.7093, 0.7391],
        [0.6043, 0.5104, 0.6273, 0.3179, 0.6112],
        [0.1817, 0.1793, 0.2405, 0.0728, 0.8850],
        [0.1469, 0.1037, 0.0426, 0.2138, 0.5344],
        [0.1279, 0.6379, 0.3452, 0.2874, 0.8062],
        [0.3860, 0.8059, 0.5053, 0.4707, 0.8393],
        [0.0996, 0.6392, 0.3099, 0.4323, 0.1879],
        [0.0845, 0.9669, 0.9097, 0.3467, 0.9879],
        [0.1883, 0.6860, 0.5933, 0.8607, 0.2761],
        [0.2217, 0.9066, 0.8370, 0.0201, 0.6622],
        [0.5075, 0.5571, 0.3007, 0.2858, 0.0426]])

In [35]:
target = torch.randint(0,C,(N,))
# target = F.one_hot(target,num_classes=5)
target

tensor([1, 1, 3, 2, 3, 0, 2, 1, 0, 0, 4, 0, 2, 2, 4, 3])

## self_defined focal loss(classification)

In [89]:
class focal_loss(nn.Module):
    def __init__(self, alpha:torch.Tensor, gamma:float,reduction='none'):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
        
    def forward(self,pred:torch.Tensor, target:torch.Tensor):
        log_softmax = nn.LogSoftmax(dim=1) # dim=-1 is the same
        log_softmax_pred = log_softmax(pred)
        pred_softmax = log_softmax_pred.exp()
        
        # weight的元素数目必须要和class_num的一致，不能只给出一个数字，torch在这里不会自动“广播”
        nll_loss = nn.NLLLoss(weight=self.alpha, reduction='none') 
        ce_loss = nll_loss(log_softmax_pred,target)
        
        rows = torch.arange(pred.shape[0]) 
        pt = pred_softmax[rows,target]
        focal_weight = (1-pt).pow(self.gamma)
        
        # NLL_Loss计算中已经带有负号，所以这里就不加负号了！
        # 见OneNote易懂。
        loss = focal_weight*ce_loss
        
        if self.reduction == 'mean': return loss.mean()
        elif self.reduction == 'sum': return loss.sum()
        else: return loss

In [29]:
fl = focal_loss(torch.tensor([0.2]*5),2)

In [36]:
my_result = fl(pred,target)
my_result

tensor([0.2091, 0.1910, 0.2324, 0.3391, 0.1922, 0.2431, 0.1862, 0.2485, 0.2240,
        0.2877, 0.1565, 0.2657, 0.1639, 0.1966, 0.1888, 0.2221])

## github focal loss(classification) <br>
https://github.com/AdeelH/pytorch-multi-class-focal-loss/blob/master/focal_loss.py#L36

In [95]:
class FocalLoss(nn.Module):
    """ Focal Loss, as described in https://arxiv.org/abs/1708.02002.
    It is essentially an enhancement to cross entropy loss and is
    useful for classification tasks when there is a large class imbalance.
    x is expected to contain raw, unnormalized scores for each class.
    y is expected to contain class labels.
    Shape:
        - x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0.
        - y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0.
    """

    def __init__(self,
                 alpha: None,
                 gamma: float = 0.,
                 reduction: str = 'mean',
                 ignore_index: int = -100):

        if reduction not in ('mean', 'sum', 'none'):
            raise ValueError(
                'Reduction must be one of: "mean", "sum", "none".')

        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.ignore_index = ignore_index
        self.reduction = reduction

        self.nll_loss = nn.NLLLoss(
            weight=alpha, reduction='none', ignore_index=ignore_index)

    def __repr__(self):
        arg_keys = ['alpha', 'gamma', 'ignore_index', 'reduction']
        arg_vals = [self.__dict__[k] for k in arg_keys]
        arg_strs = [f'{k}={v}' for k, v in zip(arg_keys, arg_vals)]
        arg_str = ', '.join(arg_strs)
        return f'{type(self).__name__}({arg_str})'

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        if x.ndim > 2:
            # x:(N, C, d1, d2, ..., dK) --> (N * d1 * ... * dK, C)
            c = x.shape[1]
            x = x.permute(0, *range(2, x.ndim), 1).reshape(-1, c)
            # y:(N, d1, d2, ..., dK) --> (N * d1 * ... * dK,)
            y = y.view(-1)

        # compute weighted cross entropy term: -alpha * log(pt)
        # (alpha is already part of self.nll_loss, if you provide the weight parameter of sefl.nll_loss)
        log_p = F.log_softmax(x, dim=-1)
        ce = self.nll_loss(log_p, y)

        # get true class column from each row (row stants for one sample in a batch)
        all_rows = torch.arange(len(x))
        log_pt = log_p[all_rows, y]  
        '''
            ①log_pt 是一个matrix(因为x permute之后变为2维，而log_p来自x，log_softmax和nll_loss都不改变输入的shape)
            ②multi-class ✔  multi-label ×
            ③从这里取log_pt的方式可以看出，从二分类扩展focal loss到多分类之后，*因为GT是one-hot编码的*，
            所以只是把原来二分类的从两个类别中取pt计算，变为了从多个类中取pt计算，具体而言就是pt矩阵的某一维从2变为C而已！
            所以多分类focal loss很容易理解！
        '''

        # compute focal term: (1 - pt)^gamma
        pt = log_pt.exp()
        '''
            不同于二分类，因为在多分类中，没有所谓的背景(除非显示定义)，所以直接1-pred就好了。
            而二分类则是(GT-pred)，因为二分类中,GT==1:pt=1 ; GT==0:pt=1-pred
        '''
        focal_term = (1 - pt)**self.gamma

        # the full loss: -alpha * ((1 - pt)^gamma) * log(pt)
        loss = focal_term * ce

        if self.reduction == 'mean':
            loss = loss.mean()
        elif self.reduction == 'sum':
            loss = loss.sum()

        return loss


In [38]:
fl_github = FocalLoss(torch.tensor([0.2]*5),2,reduction='none')

In [39]:
github_result = fl_github(pred,target)
github_result

tensor([0.2091, 0.1910, 0.2324, 0.3391, 0.1922, 0.2431, 0.1862, 0.2485, 0.2240,
        0.2877, 0.1565, 0.2657, 0.1639, 0.1966, 0.1888, 0.2221])

## 结果比较

In [40]:
my_result == github_result

tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True])

## seg-focal-loss

In [261]:
# 二分类
class Focal_loss(nn.Module):
    def __init__(self,pow=1):
        super(Focal_loss,self).__init__()
        self.power = pow

    def forward(self,pred,true):
        b,c,w,h= pred.size()
        p = pred.view(b, -1)
        t = true.view(b, -1)
        eps = 1e-6
        #原来的写法
        #loss = t*(1 - p).pow(self.power)*p.log()
        #改为下面这句，就和我的计算结果一样了(大部分小数一样，就认为是一样了 nm)。
        # 两个写法等价：
        #--------------------①
#         loss1 = t*(1-p).pow(self.power)*p.log()
#         loss2 = (1-t)*p.pow(self.power)*(1-p).log()  #p.pow == (1-(1-p)).pow
#         loss = (loss1+loss2).sum(1)/(w*h)
        #--------------------②
        loss = ((t-p).pow(2)*torch.log(((1-t)-p).abs()+eps)).sum(1)/(w*h)
        loss = -loss.mean(0)
        return loss
    
# 多分类
#TO check
class Focal_loss_multi_class(nn.Module):
    def __init__(self,device,pow=2):
        super(Focal_loss_multi_class,self).__init__()
        self.power = pow
        self.device = device

    def forward(self,probs,true):
        num_classes = probs.shape[1]
        #pred = probs > 0.5
        
        classes_points = []
        for i in range(num_classes):
            classes_points.append(probs[:,i]>0.5)

        loss_focal = torch.zeros((num_classes,),dtype=torch.float32).to(self.device)
        for i in range(num_classes):
            loss_focal[i] = (classes_points[i]*(1-probs[:,i]).pow(2)*probs[:,i].log()).mean()


        loss_focal = -loss_focal.sum()
        return loss_focal

修改前的Focal loss是github上找的。<br>
它的计算方式是错误的，因为它的计算方法为 t*(1-p).pow(self.power)，<br>
因为这个“t*”的存在，所以最终体现“难例挖掘”的"(1-p).pow(self.power)"只会含有正样本，<br>
而根据原论文的意思(由pt的定义易懂)，难例可以是正样本，也可以是负样本，所以这里的写法是错误的！<br>
<br>
并且，同时也要修改p.log()，因为此时对于正样本是p.log(),对于负样本则是(1-p).log()

In [253]:
out = torch.randn((8,1,512,512),dtype=torch.float)
probs = torch.sigmoid(out)
true_mask = torch.randint(0,2,(8,1,512,512))

In [254]:
true_mask[0,...,0,0],out[0,...,0,0],probs[0,...,0,0]

(tensor([0]), tensor([0.4765]), tensor([0.6169]))

In [255]:
lambdaa = 1.0
# loss_focal_batch = ((((true_mask-probs)**2)*torch.log(probs)).mean(axis=(1,2,3))) * lambdaa
loss1 = (true_mask*(1-probs)**2)*torch.log(probs)
loss2 = (1-true_mask)*(probs**2)*torch.log(1-probs)
loss_focal_batch = (loss1+loss2).mean(axis=(1,2,3)) * lambdaa
print(len(loss_focal_batch),loss_focal_batch)
loss_focal = -loss_focal_batch.mean()
loss_focal 

8 tensor([-0.3471, -0.3457, -0.3456, -0.3461, -0.3454, -0.3459, -0.3471, -0.3461])


tensor(0.3461)

上这个计算结果是GT

---------------------------------------

In [262]:
f1 = Focal_loss(2)

In [263]:
f1(probs,true_mask)

tensor(0.3461)

经过多次计算，发现两个focal_loss输出值的差值不会＞0.0004，所以可以认为两者一样！<br>
总结起来就是，二分类focal_loss的核心公式为：

In [119]:
loss1 = (true_mask*(1-probs)**2)*torch.log(probs)
loss2 = (1-true_mask)*(probs**2)*torch.log(1-probs)
loss_focal_batch = (loss1+loss2).mean(axis=(1,2,3)) * lambdaa

--------------------

TODO

多分类focal loss<br>
沿用上面的数据probs和true_mask

In [250]:
mlfl1 = Focal_loss_multi_class(device,2)
a = torch.sigmoid(torch.randint(-20,200,(3,4,8,8)).to(torch.float))
true = torch.randint(0,4,(3,4,8,8))
out1 = mlfl1(a,true)
out1

tensor(0.0002)