# Binary and multiclass loss functions in PyTorch


In [1]:
import torch
from torch import nn

# Binary classification

Assume we have 4 samples of classes 0, 1, 1, 0, respectively

In [2]:
y_true = torch.tensor([0, 1, 1, 0], dtype=float)
y_true

tensor([0., 1., 1., 0.], dtype=torch.float64)

## Sigmoid activation

Suppose the model outputs (before sigmoid activation function) are as follows

In [3]:
logits = torch.tensor([-0.1, 2.0, 3.0, 0.5], dtype=float)

So the probabilities of being class `1` are as follows

In [4]:
y_pred = torch.sigmoid(logits)
y_pred

tensor([0.4750, 0.8808, 0.9526, 0.6225], dtype=torch.float64)

**Question** What is the value of *cross entropy loss*?

## Binary cross entropy (BCE) loss

We should calculate the *distance* between two distributions:
* `[0, 1, 1, 0]` and
* `[0.4750, 0.8808, 0.9526, 0.6225]`

## BCE ... using definition

In [5]:
torch.log(y_pred)

tensor([-0.7444, -0.1269, -0.0486, -0.4741], dtype=torch.float64)

Let's calculcate the loss for each sample

In [6]:
- ( y_true*torch.log(y_pred) + (1-y_true)*torch.log(1.0-y_pred) )

tensor([0.6444, 0.1269, 0.0486, 0.9741], dtype=torch.float64)

Now we calculate the average loss value

In [7]:
torch.mean( - ( y_true*torch.log(y_pred) + (1-y_true)*torch.log(1.0-y_pred) ) )

tensor(0.4485, dtype=torch.float64)

One can check that if `y_pred=[0.01, 0.9, 0.99, 0.1]` then the loss value is much smaller

In [8]:
y_pred_new = torch.tensor([0.01, 0.9, 0.99, 0.1])
torch.mean( - ( y_true*torch.log(y_pred_new) + (1-y_true)*torch.log(1.0-y_pred_new) ) )

tensor(0.0577, dtype=torch.float64)

## using `BCELoss` or `BCEWithLogitsLoss`
see https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html#torch.nn.BCEWithLogitsLoss

One can use the *modular* or *functional* API from PyTorch

### Modular API

In [9]:
loss_fn = torch.nn.BCELoss()
loss_fn_with_logits = torch.nn.BCEWithLogitsLoss()

In [10]:
loss_fn(y_pred, y_true)

tensor(0.4485, dtype=torch.float64)

In [11]:
loss_fn_with_logits(logits, y_true)

tensor(0.4485, dtype=torch.float64)

### Functional API

In [12]:
from torch.nn import functional as F

In [13]:
F.binary_cross_entropy(y_pred, y_true)

tensor(0.4485, dtype=torch.float64)

In [14]:
F.binary_cross_entropy_with_logits(logits, y_true)

tensor(0.4485, dtype=torch.float64)

# Multiclass classification

Consider we have 5 classes.
Assume we have 4 samples of classes 0, 2, 3, 1, respectively

In [15]:
num_classes = 5

In [16]:
y_true = torch.tensor([0, 2, 3, 1])

We can encode the same information using *one-hot* encoding

In [17]:
y_true_one_hot = F.one_hot(y_true, num_classes=num_classes)
y_true_one_hot

tensor([[1, 0, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 1, 0, 0, 0]])

In the case of *multiclass* (categorical) classification we create the model having `num_classes` neurons in the last layer. Assume the model outputs are are follows:

In [18]:
logits = torch.tensor(
    [[-5.0, 10.0, 1.2,  3.0, -4.0],
     [ 5.0, -4.0, 0.1,  1.0,  0.1],
     [ 1.0,  3.0, 2.5, -0.1,  2.5],
     [-7.0,  2.5, 0.1,  3.1,  0.2]]
)
logits

tensor([[-5.0000, 10.0000,  1.2000,  3.0000, -4.0000],
        [ 5.0000, -4.0000,  0.1000,  1.0000,  0.1000],
        [ 1.0000,  3.0000,  2.5000, -0.1000,  2.5000],
        [-7.0000,  2.5000,  0.1000,  3.1000,  0.2000]])

## Softmax activation

In *multiclass* classification we use **Softmax** activation function, which:
* give the nice probability interpretation, i.e. the output for each sample is the probability distribution between `num_classes` classes (values are positive and their sum is 1.0)
* the *max* is being kept.

In [19]:
numerator = torch.exp(logits)
numerator

tensor([[6.7379e-03, 2.2026e+04, 3.3201e+00, 2.0086e+01, 1.8316e-02],
        [1.4841e+02, 1.8316e-02, 1.1052e+00, 2.7183e+00, 1.1052e+00],
        [2.7183e+00, 2.0086e+01, 1.2182e+01, 9.0484e-01, 1.2182e+01],
        [9.1188e-04, 1.2182e+01, 1.1052e+00, 2.2198e+01, 1.2214e+00]])

In [20]:
numerator.shape

torch.Size([4, 5])

In [21]:
denominator = torch.sum( torch.exp(logits), dim=1 )
denominator

tensor([22049.8965,   153.3601,    48.0736,    36.7079])

In [22]:
denominator.shape

torch.Size([4])

In fact, it is better to keep the sum in the following shape:

In [23]:
denominator = torch.sum( torch.exp(logits), dim=1, keepdim=True )
denominator

tensor([[22049.8965],
        [  153.3601],
        [   48.0736],
        [   36.7079]])

In [24]:
denominator.shape

torch.Size([4, 1])

In [25]:
y_pred = numerator / denominator
y_pred

tensor([[3.0558e-07, 9.9894e-01, 1.5057e-04, 9.1091e-04, 8.3065e-07],
        [9.6774e-01, 1.1943e-04, 7.2064e-03, 1.7725e-02, 7.2064e-03],
        [5.6544e-02, 4.1781e-01, 2.5341e-01, 1.8822e-02, 2.5341e-01],
        [2.4842e-05, 3.3188e-01, 3.0107e-02, 6.0472e-01, 3.3274e-02]])

In [26]:
torch.sum(y_pred, dim=1)

tensor([1.0000, 1.0000, 1.0000, 1.0000])

## `torch.softmax`, `torch.nn.Softmax` and `F.softmax`

In [27]:
softmax = nn.Softmax(dim=1)
softmax(logits)

tensor([[3.0558e-07, 9.9894e-01, 1.5057e-04, 9.1091e-04, 8.3065e-07],
        [9.6774e-01, 1.1943e-04, 7.2064e-03, 1.7725e-02, 7.2064e-03],
        [5.6544e-02, 4.1781e-01, 2.5341e-01, 1.8822e-02, 2.5341e-01],
        [2.4842e-05, 3.3188e-01, 3.0107e-02, 6.0472e-01, 3.3274e-02]])

In [28]:
F.softmax(logits, dim=1)

tensor([[3.0558e-07, 9.9894e-01, 1.5057e-04, 9.1091e-04, 8.3065e-07],
        [9.6774e-01, 1.1943e-04, 7.2064e-03, 1.7725e-02, 7.2064e-03],
        [5.6544e-02, 4.1781e-01, 2.5341e-01, 1.8822e-02, 2.5341e-01],
        [2.4842e-05, 3.3188e-01, 3.0107e-02, 6.0472e-01, 3.3274e-02]])

In [29]:
torch.softmax(logits, dim=1)

tensor([[3.0558e-07, 9.9894e-01, 1.5057e-04, 9.1091e-04, 8.3065e-07],
        [9.6774e-01, 1.1943e-04, 7.2064e-03, 1.7725e-02, 7.2064e-03],
        [5.6544e-02, 4.1781e-01, 2.5341e-01, 1.8822e-02, 2.5341e-01],
        [2.4842e-05, 3.3188e-01, 3.0107e-02, 6.0472e-01, 3.3274e-02]])

## `argmax` = `predict_class`

In [30]:
y_pred = torch.softmax(logits, dim=1)
torch.argmax(y_pred, dim=1)

tensor([1, 0, 1, 3])

## Cross entropy loss

In [31]:
y_pred

tensor([[3.0558e-07, 9.9894e-01, 1.5057e-04, 9.1091e-04, 8.3065e-07],
        [9.6774e-01, 1.1943e-04, 7.2064e-03, 1.7725e-02, 7.2064e-03],
        [5.6544e-02, 4.1781e-01, 2.5341e-01, 1.8822e-02, 2.5341e-01],
        [2.4842e-05, 3.3188e-01, 3.0107e-02, 6.0472e-01, 3.3274e-02]])

In [32]:
y_true_one_hot

tensor([[1, 0, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 1, 0, 0, 0]])

In [33]:
y_pred * y_true_one_hot

tensor([[3.0558e-07, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 7.2064e-03, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 1.8822e-02, 0.0000e+00],
        [0.0000e+00, 3.3188e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00]])

In [34]:
- torch.log( y_pred * y_true_one_hot )

tensor([[15.0011,     inf,     inf,     inf,     inf],
        [    inf,     inf,  4.9328,     inf,     inf],
        [    inf,     inf,     inf,  3.9727,     inf],
        [    inf,  1.1030,     inf,     inf,     inf]])

We should sum only the non-inf numbers.

In [35]:
y_pred[torch.where(y_true_one_hot==1)]

tensor([3.0558e-07, 7.2064e-03, 1.8822e-02, 3.3188e-01])

In [36]:
- torch.log( y_pred[torch.where(y_true_one_hot==1)] )

tensor([15.0011,  4.9328,  3.9727,  1.1030])

In [37]:
torch.mean( - torch.log( y_pred[torch.where(y_true_one_hot==1)] ) )

tensor(6.2524)

## `logits` + `torch.nn.CrossEntropy`

PyTorch provides very effective implementation of *cross entropy* loss.
Remeber to not compute the `Softmax` activation and provide directly the `logits` only.

In [38]:
loss_fn = nn.CrossEntropyLoss()

In [39]:
loss_fn(logits, y_true)

tensor(6.2524)

In [40]:
F.cross_entropy(logits, y_true)

tensor(6.2524)

## `logits` + `log_softmax` + `NLLLoss`

Alternatively, you can use `torch.nn.LogSoftmax` and `torch.nn.NLLLoss`.

In [41]:
log_likelihood = torch.log_softmax(logits, dim=1)

In [42]:
torch.softmax(logits, dim=1)

tensor([[3.0558e-07, 9.9894e-01, 1.5057e-04, 9.1091e-04, 8.3065e-07],
        [9.6774e-01, 1.1943e-04, 7.2064e-03, 1.7725e-02, 7.2064e-03],
        [5.6544e-02, 4.1781e-01, 2.5341e-01, 1.8822e-02, 2.5341e-01],
        [2.4842e-05, 3.3188e-01, 3.0107e-02, 6.0472e-01, 3.3274e-02]])

In [43]:
log_likelihood

tensor([[-1.5001e+01, -1.0631e-03, -8.8011e+00, -7.0011e+00, -1.4001e+01],
        [-3.2789e-02, -9.0328e+00, -4.9328e+00, -4.0328e+00, -4.9328e+00],
        [-2.8727e+00, -8.7273e-01, -1.3727e+00, -3.9727e+00, -1.3727e+00],
        [-1.0603e+01, -1.1030e+00, -3.5030e+00, -5.0299e-01, -3.4030e+00]])

In [44]:
assert log_likelihood.isclose( torch.log(torch.softmax(logits, dim=1))).all()

In [45]:
nll_loss = torch.nn.NLLLoss()
nll_loss(log_likelihood, y_true)

tensor(6.2524)

In [46]:
F.nll_loss(log_likelihood, y_true)

tensor(6.2524)