In [59]:
import torch
from torch import nn
from torch.nn import Softmax

In [86]:
y_data     = [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]]
y_hat_data = [[[1., 1.], [0., 0.]], [[1., 1.], [1., 1.]]]
y_pred = torch.Tensor(y_data)
y_hat = torch.Tensor(y_hat_data)

In [87]:
class WeightedTverskyLoss(nn.Module):
    """Tversky loss function from arXiv:1803.11078v1"""
    def __init__(self, weight : tuple=(0.5, 0.5)):
        super(WeightedTverskyLoss, self).__init__()
        self.alpha = weight[0]
        self.beta = weight[1]

    def forward(self, input, target):
        m = Softmax(dim=2)
        input = m(input)
        input = input.view(-1)
        target = target.view(-1)

        p0 = input #probability that voxel is a lacune
        p1 = 1 - input #probability that voxel is a non-lacune
        g0 = target #1 if voxel is a lacune, 0 if voxel is a non-lacune
        g1 = abs(target - 1) #0 if voxel is a lacune, 1 if voxel is a non-lacune
        
        loss = ((p0*g0).sum())/((p0*g0).sum()+self.alpha*((p0*g1).sum()) + self.beta*((p1*g0).sum()))

        return 1 - loss

In [88]:
criterion = WeightedTverskyLoss(weight=(0.3,0.7))
loss = criterion(y_pred, y_hat)

tensor(0.4444)
