# Imports

In [1]:
import torch
from torch import nn
from torch.nn.functional import cross_entropy, softmax, binary_cross_entropy_with_logits

# Debug WBCE-by-disease

In [13]:
%run ./wbce.py

In [14]:
loss = WeigthedBCEByDiseaseLoss()

In [20]:
bs = 5
out = torch.randn(bs, 4)
target = (torch.rand(bs, 4) > 0.75).long()
out, torch.sigmoid(out), target

(tensor([[ 2.1791, -1.3991,  1.4183, -0.5403],
         [-0.8847, -0.5324, -0.2359,  0.7764],
         [-0.0602,  0.5243, -0.3063,  0.0197],
         [ 0.3352, -2.6139, -0.4627, -1.4814],
         [-0.6498,  0.6365, -0.1319,  0.2359]]),
 tensor([[0.8984, 0.1980, 0.8051, 0.3681],
         [0.2922, 0.3700, 0.4413, 0.6849],
         [0.4850, 0.6282, 0.4240, 0.5049],
         [0.5830, 0.0683, 0.3863, 0.1852],
         [0.3430, 0.6540, 0.4671, 0.5587]]),
 tensor([[0, 1, 1, 0],
         [0, 0, 1, 0],
         [0, 1, 0, 0],
         [0, 0, 1, 0],
         [1, 0, 0, 1]]))

In [21]:
loss(out, target)

(tensor([[2.8578, 4.0492, 0.3614, 0.5738],
         [0.4320, 0.7700, 1.3634, 1.4436],
         [0.8294, 1.1624, 1.3792, 0.8788],
         [1.0934, 0.1178, 1.5851, 0.2560],
         [5.3496, 1.7687, 1.5734, 2.9107]]),
 tensor([5.0000, 2.5000, 1.6667, 5.0000]),
 tensor([1.2500, 1.6667, 2.5000, 1.2500]))

## BCE with weights

In [30]:
bs = 3
out = torch.randn(bs, 4)
target = (torch.rand(bs, 4) > 0.75).long()
out, torch.sigmoid(out), target

(tensor([[ 1.0918, -0.6534, -0.9670, -0.4186],
         [ 0.0961, -0.0819,  2.0618, -1.0035],
         [ 0.1126, -1.3838, -0.2212,  1.1069]]),
 tensor([[0.7487, 0.3422, 0.2755, 0.3969],
         [0.5240, 0.4795, 0.8871, 0.2682],
         [0.5281, 0.2004, 0.4449, 0.7515]]),
 tensor([[1, 1, 0, 0],
         [1, 0, 1, 1],
         [1, 1, 0, 0]]))

In [33]:
loss = nn.BCEWithLogitsLoss(reduction='none')
loss(out, target.float())

tensor([[0.2894, 1.0723, 0.3222, 0.5056],
        [0.6462, 0.6530, 0.1198, 1.3158],
        [0.6384, 1.6074, 0.5886, 1.3925]])

In [38]:
loss.to('cuda')

BCEWithLogitsLoss()

In [37]:
loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1, 2, 3, 4]), reduction='none')
loss(out, target.float())

tensor([[0.2894, 2.1446, 0.3222, 0.5056],
        [0.6462, 0.6530, 0.3593, 5.2634],
        [0.6384, 3.2148, 0.5886, 1.3925]])

# Debug Focal loss

In [2]:
%run focal.py

In [3]:
alpha = 0.75
gamma = 2

## Multilabel case

In [108]:
outputs = torch.tensor([[1, 1, 0, 0, 0, 1],
                        [0, 0, 1, 0, 0, 1],
                       ]).float()
targets = torch.tensor([[1, 1, 1, 0, 1, 0],
                        [0, 1, 0, 0, 1, 0],
                       ]).float()
outputs

tensor([[1., 1., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0., 1.]])

In [116]:
focal = FocalLoss(alpha=0.87, reduction='none')
focal(outputs, targets)

tensor([[0.8700, 0.8700, 0.8700, 0.1300, 0.8700, 0.1300],
        [0.1300, 0.8700, 0.1300, 0.1300, 0.8700, 0.1300]])
tensor([[0.0723, 0.0723, 0.2500, 0.2500, 0.2500, 0.5344],
        [0.2500, 0.2500, 0.5344, 0.2500, 0.2500, 0.5344]])
tensor([[0.3133, 0.3133, 0.6931, 0.6931, 0.6931, 1.3133],
        [0.6931, 0.6931, 1.3133, 0.6931, 0.6931, 1.3133]])


tensor([[0.0197, 0.0197, 0.1508, 0.0225, 0.1508, 0.0912],
        [0.0225, 0.1508, 0.0912, 0.0225, 0.1508, 0.0912]])

## Multiclass case

In [38]:
outputs = torch.tensor([[0, 1, 0],
                        [-3, 1, 2],
                        [1, 11.1, 0.5],
                       ]).float()
targets = torch.tensor([1, 2, 0]).long()

In [39]:
focal = FocalLoss(alpha=0.9, gamma=2, multilabel=False, reduction='none')
focal(outputs, targets)

tensor([0.0892, 0.0213, 9.0893])