Classification: Binary classification

Now that we've gone over multiclass classification, we'll take a look at binary classification. This can be a bit more confusing since there is more than one way to accomplish this. We'll take a look at a few ways to calculate binary cross entropy.

The first option is to have 2 logits output by the model: 1 for the negative class, and one for the positive, shown below. This is the same workflow as the multiclass scenario from the previous lesson.

In [1]:
import numpy as np
import torch

What Are Logits? In machine learning, the term “logits” refers to the raw outputs of a model before they are transformed into probabilities. Specifically, logits are the unnormalized outputs of the last layer of a neural network.

In [2]:
def make_classification_logits(n_classes, n_samples, pct_correct, confidence=1):
    """
    This function returns simulated logits and classes.

    n_classes: nuber of classes
    n_samples: number of rows
    pct_correct: float between 0 and 1. The higher it is,
                 the higher the % of logits that will
                 generate the correct output.
    confidence: controls how confident our logits are.
                Closer to 0: less confident
                Larger: more confident
    """
    classes = list(range(n_classes))
    # Randomly make logits
    logits = np.random.uniform(-5., 5., (n_samples, n_classes))
    # Randomly make labels
    labels = np.random.choice(classes, size=(n_samples))
    # Find the max of each row in logits
    maxs = np.abs(logits).max(axis=1)       #In numpy.max axis=0 is column and axis=1 is row
    # For each row...
    for i in range(len(maxs)):
        # If we want the answer to be right...
        if np.random.random() <= pct_correct:   #np.random.random(), Return random floats in the half-open interval [0.0, 1.0). Alias for random_sample to ease forward-porting to the new random API.
            # Make the correct item the highest logit
            logits[i, labels[i]] = maxs[i] + np.random.random()*confidence
        # If we want it to be wrong...
        else:
            # Make the highest logit a different index
            _c = classes.copy()
            _c.remove(classes[labels[i]])
            _i = np.random.choice(_c)
            logits[i, _i] = maxs[i] + np.random.random()/10

    # Return logits and labels
    return torch.FloatTensor(logits), torch.tensor(labels)

In [3]:
logits, labels=make_classification_logits(2,10,1., confidence=1)

In [4]:
np.abs(logits).max(axis=1) #just to check it. It`s taking max values per row.

torch.return_types.max(
values=tensor([4.3765, 2.9773, 4.7288, 3.8919, 4.6187, 3.5896, 5.0186, 2.5354, 5.3304,
        3.9719]),
indices=tensor([0, 0, 0, 0, 0, 1, 0, 1, 1, 0]))

In [5]:
logits

tensor([[ 4.3765, -3.3314],
        [ 2.9773,  2.2280],
        [ 4.7288,  4.6093],
        [ 3.8919, -2.4916],
        [ 4.6187, -4.1310],
        [ 1.0369,  3.5896],
        [ 5.0186, -2.1783],
        [ 2.0699,  2.5354],
        [ 1.7808,  5.3304],
        [ 3.9719,  3.2910]])

In [7]:
logits.softmax(dim=1)

tensor([[9.9955e-01, 4.4908e-04],
        [6.7904e-01, 3.2096e-01],
        [5.2984e-01, 4.7016e-01],
        [9.9831e-01, 1.6864e-03],
        [9.9984e-01, 1.5849e-04],
        [7.2246e-02, 9.2775e-01],
        [9.9925e-01, 7.4831e-04],
        [3.8567e-01, 6.1433e-01],
        [2.7935e-02, 9.7206e-01],
        [6.6395e-01, 3.3605e-01]])

In [6]:
labels

tensor([0, 0, 0, 0, 0, 1, 0, 1, 1, 0])

In [None]:
import torch.nn.functional as F

In [None]:
F.cross_entropy(logits,labels)

tensor(0.1623)

The second option is to organize your model to have only 1 output unit. In that case, you can use binary cross entropy as your loss function. Let's see a small example. In the cell below, we will make some logits and some labels. In this case, each number in the logits and labels arrays represent a single item.

In [None]:
#Make some logits and labels, making sure not to get everything correct
logits=torch.tensor([2.8,-1.4,1.1,-.8])
labels=torch.tensor([1.,0.,0.,0.])

In [None]:
#view the logits
logits

tensor([ 2.8000, -1.4000,  1.1000, -0.8000])

Let's take a look at the normalized probability scores for each item. Since each value represents an individual item, we no longer need to pass a dim to the softmax method.

In [None]:
#Normalize to probabilities
logits.sigmoid()

tensor([0.9427, 0.1978, 0.7503, 0.3100])

Finally, let's use some torch methods to calculate the binary cross entropy loss.

In [None]:
#Calculate loss with the logits
F.binary_cross_entropy_with_logits(logits,labels)

tensor(0.5095)

In [None]:
from torch import nn

In [None]:
#Calculate the loss with the logits
nn.BCEWithLogitsLoss()(logits,labels)

tensor(0.5095)

We can also calculate the loss with the probabilities. But here we can see that we don't get exactly the same numbers. This illustrates why calcuilating cross entropy with logits is more numerically stable than with the normalized probabilities.

In [None]:
F.binary_cross_entropy(logits.sigmoid(),labels)

tensor(0.5095)

In [None]:
nn.BCELoss()(logits.sigmoid(),labels)

tensor(0.5095)

We can see that the values for the binary cross entropy loss using logits and labels are close using the sanity check below.

In [None]:
# Values are close...
torch.allclose(F.binary_cross_entropy_with_logits(logits, labels), F.binary_cross_entropy(logits.sigmoid(), labels))

True

However, the code below indicates that the loss calcualted with the logits and the loss calculated with the normalized probabilities are not in fact identical.

In [None]:
# ... but not the same
F.binary_cross_entropy_with_logits(logits, labels) == F.binary_cross_entropy(logits.sigmoid(), labels)

tensor(True)

Finally, let's calculate the difference between the outputs of the two loss functions.

In [None]:
F.binary_cross_entropy_with_logits(logits, labels) - F.binary_cross_entropy(logits.sigmoid(), labels)

tensor(0.)

In this lesson, we reviewed binary cross entropy loss in depth. We also saw that calculating the loss with logits can be more precise than calculating the loss with the normalized probabilities.