# The Base Classification Model
:label:`sec_classification`

## Section Summary
This section discusses a base class for classification models and an accuracy function used to compute the accuracy of the model. The Classifier class has a validation_step function to report the loss and classification accuracy on a validation batch. The code defines a stochastic gradient descent optimizer and uses argmax to obtain the predicted class with the highest predicted probability. The accuracy is the fraction of all predictions that are correct. The text also provides exercises to test the understanding of the material.






In [1]:
import torch
from d2l import torch as d2l

## The `Classifier` Class


In [2]:
class Classifier(d2l.Module):  #@save
    """The base class of classification models."""
    def validation_step(self, batch):
        Y_hat = self(*batch[:-1])
        self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
        self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)

In [3]:
@d2l.add_to_class(d2l.Module)  #@save
def configure_optimizers(self):
    return torch.optim.SGD(self.parameters(), lr=self.lr)

## Accuracy


In [4]:
@d2l.add_to_class(Classifier)  #@save
def accuracy(self, Y_hat, Y, averaged=True):
    """Compute the number of correct predictions."""
    Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
    preds = Y_hat.argmax(axis=1).type(Y.dtype)
    compare = (preds == Y.reshape(-1)).type(torch.float32)
    return compare.mean() if averaged else compare