In [1]:
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


# Test weighted criterions

## Cross-entropy

In [2]:
def to_one_hot(y, n_class=2):
    oh = np.zeros((y.shape[0], n_class), np.float32)
    oh[np.arange(y.shape[0]), y] = 1
    return oh


In [3]:
class WeightedCrossEntropyLoss(nn.Module):
    def forward(self, input, one_hot, weight):
        log_soft = - F.log_softmax(input, dim=1)
        element_loss = torch.sum(log_soft * one_hot, 1) * weight
        loss = torch.mean(element_loss)
        return loss


In [4]:
class WeightedCrossEntropyLoss2(nn.Module):
    def forward(self, input, target, weight):
        element_loss = F.cross_entropy(input, target, reduce=False)
        loss = torch.mean(element_loss * weight)
        return loss


In [5]:
n_samples = 4096
n_class = 5
inpt = Variable(torch.randn(n_samples, n_class), requires_grad=True)
target = Variable(torch.LongTensor(n_samples).random_(n_class))
one_hot = Variable( torch.from_numpy( to_one_hot(target.data.numpy(), n_class=n_class) ) )
W = Variable(torch.ones(n_samples))

In [6]:
loss_0 = nn.CrossEntropyLoss(reduce=True)
loss_1 = WeightedCrossEntropyLoss()
loss_2 = WeightedCrossEntropyLoss2()


In [7]:
output_0 = loss_0(inpt, target)
output_1 = loss_1(inpt, one_hot, W)
output_2 = loss_2(inpt, target, W)


In [8]:
print(output_0)
print(output_1)
print(output_2)


Variable containing:
 1.9797
[torch.FloatTensor of size 1]

Variable containing:
 1.9797
[torch.FloatTensor of size 1]

Variable containing:
 1.9797
[torch.FloatTensor of size 1]



In [9]:
%timeit loss_0(inpt, target)

%timeit loss_1(inpt, one_hot, W)

%timeit loss_2(inpt, target, W)

436 µs ± 55.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
460 µs ± 31.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
394 µs ± 25 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


## Binary Cross-entropy

In [10]:
class WeightedBCEWithLogitsLoss(nn.Module):
    def forward(self, input, target, weight):
        input = input.view(target.size())
        log_sigm = F.logsigmoid(input) * target + (1-target) * (-input - F.softplus(-input))
        element_loss = -log_sigm * weight
        loss = torch.mean(element_loss)
        return loss


In [11]:
class WeightedBCEWithLogitsLoss2(nn.Module):
    def forward(self, input, target, weight):
        element_loss = F.binary_cross_entropy(F.sigmoid(input), target, reduce=False)
        loss = torch.mean(element_loss * weight)
        return loss


In [36]:
n_samples = 4096
n_class = 1
inpt = Variable(torch.randn(n_samples, n_class), requires_grad=True)
target = Variable(torch.LongTensor(n_samples).random_(n_class))
one_hot = Variable( torch.from_numpy( to_one_hot(target.data.numpy(), n_class=n_class) ) )
W = Variable(torch.ones(n_samples))/2

In [37]:
loss_0 = nn.BCEWithLogitsLoss()
loss_1 = WeightedBCEWithLogitsLoss()
loss_2 = WeightedBCEWithLogitsLoss2()


In [38]:
output_0 = loss_0(inpt, one_hot)
output_1 = loss_1(inpt, one_hot, W)


In [39]:
print(output_0)
print(output_1)


Variable containing:
 0.7967
[torch.FloatTensor of size 1]

Variable containing:
 0.3983
[torch.FloatTensor of size 1]



In [40]:
%timeit loss_0(inpt, one_hot)

%timeit loss_1(inpt, one_hot, W)

# %timeit loss_2(inpt, target, W)

213 µs ± 5.89 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
71.3 ms ± 6.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


## MSE

In [41]:
class WeightedMSELoss(nn.Module):
    def forward(self, input, target, weight):
        weight = weight.view( (-1,) + (1,) * (target.dim() - weight.dim() ) )
        loss = (input - target)**2
        element_loss = loss * weight
        loss = torch.mean(element_loss)
        return loss

In [42]:
class WeightedMSELoss2(nn.Module):
    def forward(self, input, target, weight):
        weight = weight.view( (-1,) + (1,) * (target.dim() - weight.dim() ) )
        element_loss = F.mse_loss(input, target, reduce=False)
        loss = torch.mean(element_loss * weight)
        return loss


In [43]:
n_samples = 4096
n_dim = 5
inpt = Variable(torch.randn(n_samples, n_dim), requires_grad=True)
target = Variable(torch.randn(n_samples, n_dim))
W = Variable(torch.ones(n_samples))/2

In [44]:
loss_0 = nn.MSELoss(reduce=True)
loss_1 = WeightedMSELoss()
loss_2 = WeightedMSELoss2()


In [45]:
output_0 = loss_0(inpt, target)
output_1 = loss_1(inpt, target, W)
output_2 = loss_2(inpt, target, W)


In [46]:
print(output_0)
print(output_1)
print(output_2)


Variable containing:
 2.0083
[torch.FloatTensor of size 1]

Variable containing:
 1.0041
[torch.FloatTensor of size 1]

Variable containing:
 1.0041
[torch.FloatTensor of size 1]



In [49]:
%timeit loss_0(inpt, target)

%timeit loss_1(inpt, target, W)

%timeit loss_2(inpt, target, W)

37.3 µs ± 4.26 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
115 µs ± 9.35 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
125 µs ± 4.58 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
