In [4]:
import torch
import torch.nn as nn

# 损失函数

<img src="imgs/loss.png" width="900" height="400" align="bottom">

In [2]:
class DefaultConfig():
    # backbone
    backbone = "darknet19"
    pretrained = True

    # fpn
    fpn_out_channels = 256
    use_p5 = True
    
    # head
    class_num = 20
    use_GN_head = False
    prior = 0.01
    add_centerness = True
    cnt_on_reg = True

    # training
    strides = [8, 16, 32, 64, 128]
    limit_range = [[-1, 64], [64, 128], [128, 256], [256, 512], [512, 999999]]

    # inference
    score_threshold = 0.3
    nms_iou_threshold = 0.2
    max_detection_boxes_num = 150

In [4]:
class LOSS(nn.Module):
    def __init__(self, config=None):
        super().__init__()
        if config is None:
            self.config = DefaultConfig
        else:
            self.config = config

    def forward(self, inputs):
        """
        inputs list
        [0]preds:  ....
        [1]targets : list contains three elements [[batch_size,sum(_h*_w),1],[batch_size,sum(_h*_w),1],[batch_size,sum(_h*_w),4]]
        """
        preds, targets = inputs
        cls_logits, cnt_logits, reg_preds = preds
        cls_targets, cnt_targets, reg_targets = targets
        # mask用来区分正负样本，负样本不计算cnt和reg损失
        mask_pos = (cnt_targets > -1).squeeze(dim=-1)  # [batch_size,sum(_h*_w)]
        cls_loss = compute_cls_loss(cls_logits, cls_targets, mask_pos).mean()  # []
        cnt_loss = compute_cnt_loss(cnt_logits, cnt_targets, mask_pos).mean()
        reg_loss = compute_reg_loss(reg_preds, reg_targets, mask_pos).mean()
        
        if self.config.add_centerness:
            total_loss = cls_loss + cnt_loss + reg_loss
            return cls_loss, cnt_loss, reg_loss, total_loss
        else:
            total_loss = cls_loss + reg_loss + cnt_loss * 0.0
            return cls_loss, cnt_loss, reg_loss, total_loss

## 分类损失

In [7]:
def focal_loss_from_logits(preds, targets, gamma=2.0, alpha=0.25):
    '''
    Args:
    preds: [n,class_num] 
    targets: [n,class_num]
    '''
    # FocalLoss的定义分正负样本，但是这个函数仅使用一个公式便定义了两种情况
    preds = preds.sigmoid()
    pt =  preds * targets + (1.0 - preds) * (1.0 -targets)
    # 这里的w其实就是focal loss里的alpha的两种情况写到了一个表达式里
    w = alpha * targets + (1.0 - alpha) * (1.0 - targets)
    loss = -w * torch.pow((1.0 - pt), gamma) * pt.log()
    return loss.sum()

In [8]:
def compute_cls_loss(preds, targets, mask):
    '''
    Args  
    preds: list contains five level pred [batch_size,class_num,_h,_w]
    targets: [batch_size,sum(_h*_w),1]
    mask: [batch_size,sum(_h*_w)]
    '''
    batch_size = targets.shape[0]
    
    preds_reshape = []
    class_num = preds[0].shape[1]
    mask = mask.unsqueeze(dim=-1)  # unsqueeze() 添加列维度，变成[batch_size,sum(_h*_w),1]
    
    # 计算正样本个数，最后的损失需要除以这个数字
    num_pos = torch.sum(mask, dim=[1, 2]).clamp_(min=1).float()  # [batch_size,]
    
    for pred in preds:
        # 将pred从[batch_size, class_num,_h,_w]形状转换成[batch_size, -1, class_num]
        pred = pred.permute(0, 2, 3, 1)
        pred = torch.reshape(pred, [batch_size, -1, class_num])
        preds_reshape.append(pred)
    preds = torch.cat(preds_reshape, dim=1)  # [batch_size,sum(_h*_w),class_num]
    
    # 此时preds的形状是[batch_size,sum(_h*_w),class_num]，我们需要把target也转换成这个形状
    # 才可以让它们计算损失
    assert preds.shape[:2] == targets.shape[:2]
    loss = []
    for batch_index in range(batch_size):
        pred_pos = preds[batch_index]  # [sum(_h*_w),class_num]
        target_pos = targets[batch_index]  # [sum(_h*_w),1]
        target_pos = (torch.arange(1, class_num + 1, device=target_pos.device)[None,:]
                      .type(torch.float32) == target_pos).float()  # sparse-->onehot
        loss.append(focal_loss_from_logits(pred_pos, target_pos).view(1))
    return torch.cat(loss, dim=0) / num_pos  # [batch_size,]

In [9]:
preds = [torch.ones([2, 5, 4, 4])] * 5  
targets = torch.ones([2, 80, 1])
mask = torch.ones([2, 80], dtype=torch.uint8)
compute_cls_loss(preds, targets, mask)

tensor([2.1113, 2.1113])

In [10]:
batch_size = targets.shape[0]
class_num = preds[0].shape[1]  # = 2 

In [11]:
mask = torch.ones([2, 80], dtype=torch.uint8)
mask

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.uint8)

In [12]:
mask = mask.unsqueeze(dim=-1)
mask.shape

torch.Size([2, 80, 1])

In [13]:
num_pos = torch.sum(mask, dim=[1, 2]).clamp_(min=1).float()  # [batch_size,]  # _ 表示 inplace操作
num_pos

tensor([80., 80.])

In [14]:
preds_reshape = []
for pred in preds:
    pred = pred.permute(0, 2, 3, 1)
    print(pred.shape)
    pred = torch.reshape(pred, [batch_size, -1, class_num])
    print(pred.shape)
    preds_reshape.append(pred)
        
preds = torch.cat(preds_reshape, dim=1)  # [batch_size,sum(_h*_w),class_num]
preds.shape

torch.Size([2, 4, 4, 5])
torch.Size([2, 16, 5])
torch.Size([2, 4, 4, 5])
torch.Size([2, 16, 5])
torch.Size([2, 4, 4, 5])
torch.Size([2, 16, 5])
torch.Size([2, 4, 4, 5])
torch.Size([2, 16, 5])
torch.Size([2, 4, 4, 5])
torch.Size([2, 16, 5])


torch.Size([2, 80, 5])

In [15]:
assert preds.shape[:2] == targets.shape[:2]

In [16]:
pred_pos = preds[0]  # [sum(_h*_w),class_num]
target_pos = targets[0]  # [sum(_h*_w),1]
print(target_pos.device, target_pos.dtype)

cpu torch.float32


In [17]:
tensor = torch.arange(1, class_num + 1, device=target_pos.device)
tensor

tensor([1, 2, 3, 4, 5])

In [18]:
tensor[None,:]

tensor([[1, 2, 3, 4, 5]])

In [19]:
print(tensor[None,:].dtype)

torch.int64


In [20]:
target_pos = (tensor[None,:].type(torch.float32) == target_pos).float()  # sparse-->onehot
target_pos.shape

torch.Size([80, 5])

In [21]:
def focal_loss_from_logits(preds, targets, gamma=2.0, alpha=0.25):
    '''
    Args:
    preds: [n,class_num] 
    targets: [n,class_num]
    '''
    preds = preds.sigmoid()
    pt = preds * targets + (1.0 - preds) * (1.0 - targets)
    w = alpha * targets + (1.0 - alpha) * (1.0 - targets)
    loss = -w * torch.pow((1.0 - pt), gamma) * pt.log()
    return loss.sum()

In [22]:
focal_loss_from_logits(pred_pos, target_pos).view(1)

tensor([168.9016])

In [23]:
torch.cat(loss, dim=0) / num_pos

NameError: name 'loss' is not defined

## 中心度损失
* 与分类损失基本相同，区别：分类损失--多分类；中心度损失--二分类。

In [21]:
def compute_cnt_loss(preds, targets, mask):
    '''
    Args  
    preds: list contains five level pred [batch_size,1,_h,_w]
    targets: [batch_size,sum(_h*_w),1]
    mask: [batch_size,sum(_h*_w)]
    '''
    batch_size = targets.shape[0]
    c = targets.shape[-1]
    preds_reshape = []
    mask = mask.unsqueeze(dim=-1) #[batch_size,sum(_h*_w),1]
    
    # 同样计算正样本个数，最后损失除以这个数字
    num_pos = torch.sum(mask, dim=[1, 2]).clamp_(min=1).float()  # [batch_size,]
    for pred in preds:
        pred = pred.permute(0, 2, 3, 1)
        pred = torch.reshape(pred, [batch_size, -1, c])
        preds_reshape.append(pred)
    preds = torch.cat(preds_reshape, dim=1)
    assert preds.shape == targets.shape  # [batch_size,sum(_h*_w),1]
    loss = []
    for batch_index in range(batch_size):
        pred_pos = preds[batch_index][mask[batch_index]]  # [num_pos_b,]
        target_pos = targets[batch_index][mask[batch_index]]  # [num_pos_b,]
        assert len(pred_pos.shape) == 1
        loss.append(
            nn.functional.binary_cross_entropy_with_logits(input=pred_pos, target=target_pos, reduction='sum').view(1))
    return torch.cat(loss, dim=0) / num_pos  # [batch_size,]

In [22]:
loss = compute_cnt_loss([torch.ones([2, 1, 4, 4])] * 5, torch.ones([2, 80, 1]),
                        torch.ones([2, 80], dtype=torch.uint8))
print(loss)

tensor([0.3133, 0.3133])




## 定位损失

In [24]:
def iou_loss(preds, targets):
    '''
    Args:
    preds: [n,4] ltrb
    targets: [n,4]
    '''
    lt = torch.min(preds[:, :2], targets[:, :2])
    rb = torch.min(preds[:, 2:], targets[:, 2:])
    wh = (rb + lt).clamp(min=0)
    
    overlap = wh[:, 0] * wh[:, 1]  # [n]
    area1 = (preds[:, 2] + preds[:, 0]) * (preds[:, 3] + preds[:, 1])
    area2 = (targets[:, 2] + targets[:, 0]) * (targets[:, 3] + targets[:, 1])
    iou = overlap / (area1 + area2 - overlap)
    
    loss = -iou.clamp(min=1e-6).log()
    return loss.sum()

In [27]:
def compute_reg_loss(preds, targets, mask, mode='iou'):
    '''
    Args  
    preds: list contains five level pred [batch_size,4,_h,_w]
    targets: [batch_size,sum(_h*_w),4]
    mask: [batch_size,sum(_h*_w)]
    '''
    batch_size = targets.shape[0]
    c = targets.shape[-1]
    preds_reshape = []
    # mask=targets>-1#[batch_size,sum(_h*_w),4]
    num_pos = torch.sum(mask, dim=1).clamp_(min=1).float()  # [batch_size,]
    for pred in preds:
        pred = pred.permute(0, 2, 3, 1)
        pred = torch.reshape(pred, [batch_size, -1, c])
        preds_reshape.append(pred)
    preds = torch.cat(preds_reshape, dim=1)
    assert preds.shape == targets.shape  # [batch_size,sum(_h*_w),4]
    loss = []
    for batch_index in range(batch_size):
        pred_pos = preds[batch_index][mask[batch_index]]  # [num_pos_b,4]
        target_pos = targets[batch_index][mask[batch_index]]  # [num_pos_b,4]
        assert len(pred_pos.shape) == 2
        if mode == 'iou':
            loss.append(iou_loss(pred_pos, target_pos).view(1))
        elif mode == 'giou':
            loss.append(giou_loss(pred_pos, target_pos).view(1))
        else:
            raise NotImplementedError("reg loss only implemented ['iou','giou']")
    return torch.cat(loss, dim=0) / num_pos  # [batch_size,]

In [28]:
loss = compute_reg_loss([torch.ones([2, 4, 4, 4])] * 5, torch.ones([2, 80, 4]),
                        torch.ones([2, 80], dtype=torch.uint8))
print(loss)

tensor([0., 0.])


