This notebook breaks down how `binary_cross_entropy_with_logits` function (corresponding to `BCEWithLogitsLoss` used for multilabel classification) is implemented in pytorch, and how it is related to `sigmoid` and `binary_cross_entropy`

In [82]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [83]:
batch_size, n_classes = 10, 4
x = torch.randn(batch_size, n_classes)
x.shape

torch.Size([10, 4])

In [84]:
x

tensor([[ 2.3611, -0.8813, -0.5006, -0.2178],
        [ 0.0419,  0.0763, -1.0457, -1.6692],
        [-1.0494,  0.8111,  1.5723,  1.2315],
        [ 1.3081,  0.6641,  1.1802, -0.2547],
        [ 0.5292,  0.7636,  0.3692, -0.8318],
        [ 0.5100,  0.9849, -1.2905,  0.2821],
        [ 1.4662,  0.4550,  0.9875,  0.3143],
        [-1.2121,  0.1262,  0.0598, -1.6363],
        [ 0.3214, -0.8689,  0.0689, -2.5094],
        [ 1.1320, -0.6824,  0.1657, -0.0687]])

In [85]:
target = torch.randint(n_classes, size=(batch_size,), dtype=torch.long)
target

tensor([1, 1, 3, 0, 2, 0, 2, 2, 1, 2])

In [86]:
y = torch.zeros(batch_size, n_classes)
y[range(y.shape[0]), target]=1
y

tensor([[0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.]])

### `sigmoid` + `binary_cross_entropy`

In [87]:
def sigmoid(x): return (1 + (-x).exp()).reciprocal()
def binary_cross_entropy(input, y): return -(pred.log()*y + (1-y)*(1-pred).log()).mean()

pred = sigmoid(x)
loss = binary_cross_entropy(pred, y)
loss

tensor(0.7739)

### `F.sigmoid` + `F.binary_cross_entropy`

The above but in pytorch.

In [88]:
pred = torch.sigmoid(x)
loss = F.binary_cross_entropy(pred, y)
loss

tensor(0.7739)

### `F.binary_cross_entropy_with_logits`

Pytorch's single `binary_cross_entropy_with_logits` function.

In [89]:
F.binary_cross_entropy_with_logits(x, y)

tensor(0.7739)