In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

# import

In [2]:
# export
import torch

In [3]:
# export
from torch.nn import functional as F

In [4]:
# export
from torch import tensor

In [5]:
from matplotlib import pyplot as plt

In [6]:
from IPython.core import debugger as idb

In [7]:
# export
import numpy as np

# functions

## dice_loss

In [8]:
# export
def dice_coef(input, target):
    smooth = 1.
    
    pred = input.sigmoid()
    target = target.float()
    
    return ((2. * (pred * target).sum() + smooth) / (pred.sum() + target.sum() +smooth))

In [9]:
# export
def dice_loss(input, target):
    return 1 - dice_coef(input, target)

## balance_bce

In [10]:
# export
def weighted_bce(input, target, pos_weight=0):
    """
    pos_weight: positive weight relative to negative weight(which is 1)
    """
    mask = target.float()
    
    if pos_weight>0:
        weight = (mask*pos_weight + (1-mask))
        weight = weight/weight.sum()*mask.numel()
        return F.binary_cross_entropy_with_logits(input, mask, weight=weight)
    else:
        return F.binary_cross_entropy_with_logits(input, mask)

In [11]:
# export
def balance_bce(input, target, balance_ratio=0):
    """
    Auto adjust positive/negative ration as set by balance_ratio.
    """
    mask = target.float()
    if balance_ratio>0:
        posN = mask.sum().clamp(1)
        negN = (1-mask).sum().clamp(1)
        pos_weight = balance_ratio*negN/posN
        return weighted_bce(input, mask, pos_weight)
    else:
        return weighted_bce(input, mask)

## combo_loss

In [12]:
# export
def combo_loss(input,target,balance_ratio=0):
    return dice_loss(input,target)+balance_bce(input,target,balance_ratio)

## mask_iou

In [13]:
# export
def mask_iou(input, target):
    """
    iou for segmentation
    """
    pred_mask = input>0
    target_mask = target>0
    
    i = (pred_mask&target_mask).float().sum()
    u = (pred_mask|target_mask).float().sum()
    return i/u

# test

## dice_loss

In [14]:
x = torch.randn((4,1,512,512))
y = torch.randint_like(x,0,2)

dice_coef(x,y),dice_loss(x,y)

(tensor(0.5004), tensor(0.4996))

## balance_bce

In [15]:
x = torch.randn((4,1,512,512))+2 # 让x平均值为2，则概率sigmoid(x)的平均值>0.5，所以正例的损失小，负例的损失大
y = (torch.randint_like(x,0,3)>0).type(torch.float32) # 让正例占2/3，负例1/3

In [16]:
weighted_bce(x,y), balance_bce(x,y)

(tensor(0.8477), tensor(0.8477))

In [17]:
weighted_bce(x,y,0.5), balance_bce(x,y,1)

(tensor(1.1806), tensor(1.1823))

## combo_loss

In [18]:
x = torch.randn((4,1,512,512))
y = torch.randint_like(x,0,2)

combo_loss(x,y,2), dice_loss(x,y)+balance_bce(x,y,2)

(tensor(1.3072), tensor(1.3072))

## mask_iou

In [19]:
x = torch.randn((4,1,512,512))
y = torch.randint_like(x,0,2)
mask_iou(x,y)

tensor(0.3334)

x有50%正，y有50%正，二者交集为25%，二者并集为75%，所以交集/并集=1/3

# export

In [20]:
!python ../notebook2script.py --fname 'loss_metrics.ipynb' --outputDir '../exp/'

Converted loss_metrics.ipynb to exp/nb_loss_metrics.py
