# Multi-class neural networks

Earlier, we encountered binary classification models that could pick between one of two possible choices, such as whether:

- A given email is spam or not spam.
- A given tumor is malignant or benign.

In this module, we'll investigate multi-class classification, which can pick from multiple possibilities. For example:

- Is this dog a beagle, a basset hound, or a bloodhound?
- Is this flower a Siberian Iris, Dutch Iris, Blue Flag Iris, or Dwarf Bearded Iris?
- Is that plane a Boeing 747, Airbus 320, Boeing 777, or Embraer 190?
- Is this an image of an apple, bear, candy, dog, or egg?

Some real-world multi-class problems entail choosing from millions of separate classes.

## One vs. All

One vs. all provides a way to leverage binary classification. Given a classification problem with $N$ possible solutions, a one-vs-all solution consists of $N$ separate binary classifiers — one binary classifier for each possible outcome. During training, the model runs through a sequence of binary classifiers, training each to answer a separate classification question. For example, given a picture of a dog, five different recognizers might be trained, four seeing the image as a negative example (not a dog but something else) and one seeing the image as a positive example (a dog).

This approach is fairly reasonable when the total number of classes is small, but becomes increasingly inefficient as the number of classes rises.

We can create a significantly more efficient one-vs-all model with a deep neural network in which each output node represents a different class.

## Softmax

Recall that logistic regression produces a decimal between 0 and 1.0. For example, a logistic regression output of 0.8 from an email classifier suggests an 80% chance of an email being spam and a 20% chance of it being not spam. Clearly, the sum of the probabilities of an email being either spam or not spam is 1.0.

`Softmax` extends this idea into a multi-class world. That is, Softmax assigns decimal probabilities to each class in a multi-class problem. Those decimal probabilities **must add up to 1.0**. This additional constraint helps training converge more quickly than it otherwise would.

Softmax is implemented through a neural network layer just before the output layer. The Softmax layer must have the same number of nodes as the output layer.

Softmax assumes that each example is a member of exactly one class, i.e. `single-label, multi-class classification`. Some examples, however, can simultaneously be a member of multiple classes. In the `multi-label classification` case, we can't use softmax and must rely on multiple logistic regressions.

We can use `candidate sampling` to improve the efficiency of softmax when there's a large number of classes. Candidate sampling means that Softmax calculates a probability for all the positive labels but **only for a random sample of negative labels**. For example, if we are interested in determining whether an input image is a beagle or a bloodhound, we don't have to provide probabilities for every non-doggy example.

# MNIST

MNIST is a classic dataset of 60000 hand-written digits. Each example contains

- Label: integer between 0 and 9.
- Image: a $28 \times 28$ map, where each pixel is an integer between 0 and 255. Pixel values are on a grayscale.

This is a multi-class classification problem with 10 classes.

In [1]:
%load_ext lab_black

In [2]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import pytorch_lightning.metrics.sklearns as plm

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split
from torchvision import transforms
from torchvision.datasets import MNIST

In [3]:
class MNISTModel(pl.LightningModule):
    def __init__(self, hparams, *args, **kwargs):
        super().__init__()
        self.hparams = hparams

        self.l1 = nn.Linear(28 * 28, 256)
        self.l2 = nn.Linear(256, 128)
        self.l_drop = nn.Dropout(p=self.hparams.dropout_rate)
        self.l3 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.l1(x.view(x.size(0), -1)))
        x = F.relu(self.l2(x))
        x = self.l_drop(x)
        x = self.l3(x)
        return x

    def prepare_data(self):
        # download only
        MNIST(os.getcwd(), train=True, download=True)
        MNIST(os.getcwd(), train=False, download=True)

    def configure_optimizers(self):
        return [torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)]

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)  # or just self(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)

        y_pred = F.softmax(y_hat, dim=1).argmax(dim=1)

        logs = {"train_loss": loss, "train_accuracy": plm.Accuracy()(y_pred, y)}
        return {"loss": loss, "log": logs}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        y_pred = F.softmax(y_hat, dim=1).argmax(dim=1)
        return {
            "val_loss": nn.CrossEntropyLoss()(y_hat, y),
            "val_accuracy": plm.Accuracy()(y_pred, y),
        }

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_accuracy = torch.stack([x["val_accuracy"] for x in outputs]).mean()
        logs = {"val_loss": avg_loss, "val_accuracy": avg_accuracy}
        return {"avg_val_loss": avg_loss, "log": logs}

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        y_pred = F.softmax(y_hat, dim=1).argmax(dim=1)
        return {
            "test_loss": nn.CrossEntropyLoss()(y_hat, y),
            "test_accuracy": plm.Accuracy()(y_pred, y),
        }

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
        avg_accuracy = torch.stack([x["test_accuracy"] for x in outputs]).mean()
        logs = {"test_loss": avg_loss, "test_accuracy": avg_accuracy}
        return {"avg_test_loss": avg_loss, "log": logs}

    def train_dataloader(self):
        # 0.1307 is the training set overall mean / 255
        # 0.3081 is the overall standard deviation / 255
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )
        mnist_train = MNIST(
            os.getcwd(), train=True, download=False, transform=transform
        )
        dat_train, _ = random_split(mnist_train, [55000, 5000])
        return DataLoader(dat_train, batch_size=self.hparams.batch_size, num_workers=8)

    def val_dataloader(self):
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )
        mnist_train = MNIST(
            os.getcwd(), train=True, download=False, transform=transform
        )
        _, dat_val = random_split(mnist_train, [55000, 5000])
        return DataLoader(dat_val, batch_size=self.hparams.batch_size, num_workers=4)

    def test_dataloader(self):
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )
        dat_test = MNIST(os.getcwd(), train=False, download=False, transform=transform)
        return DataLoader(dat_test, batch_size=self.hparams.batch_size, num_workers=8)

In [4]:
# Hyperparameters
hparams = {"learning_rate": 0.03, "batch_size": 4000, "dropout_rate": 0.2}
epochs = 50

# Train model
mnist_trainer = pl.Trainer(gpus=1, max_epochs=epochs)
mnist_model = MNISTModel(hparams)

mnist_trainer.fit(mnist_model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type    | Params
-----------------------------------
0 | l1     | Linear  | 200 K 
1 | l2     | Linear  | 32 K  
2 | l_drop | Dropout | 0     
3 | l3     | Linear  | 1 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

In [5]:
mnist_trainer.test()

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
TEST RESULTS
{'avg_test_loss': tensor(0.2007, device='cuda:0'),
 'test_accuracy': tensor(0.9637),
 'test_loss': tensor(0.2007, device='cuda:0')}
--------------------------------------------------------------------------------



{'avg_test_loss': 0.20070426166057587,
 'test_loss': 0.20070426166057587,
 'test_accuracy': 0.9636666774749756}