# The Base Classification Model

In [1]:
import torch
from d2l import torch as d2l

In [4]:
class Classifier(d2l.Module):
    """The base class of classification models"""

    def validation_step(self, batch):
        Y_hat = self(*batch[:-1])
        self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
        self.plot('accuracy', self.accuracy(Y_hat, batch[-1]), train=False)

    def configure_optimisers(self):
        return torch.optim.SGD(self.parameters(), lr=self.lr)

    def accuracy(self, Y_hat, Y, averaged=True):
        # Does this just flatten it? Or ensure that it has the right dimensionality?
        Y_hat = Y_hat.reshape(-1, Y_hat.shape([-1]))
        # Get the index with the maximum probability and set this to have the same datatype as y
        preds = Y_hat.argmax(axis=1).type(Y.dtype)
        compare = (preds == Y.reshape(-1)).type(torch.float32)
        return compare.mean() if averaged else compare
        

## Accuracy

Although a model may estiamte probabilities of classes internally, the final prediction must be a hard category, the _accuracy_ of the model is simply the fraction of classifications that the model gets correct. This is pretty challenging to actually optimize against, as it is non-differentiable, but it is often _the_ metric that we care deeply about for assessing the success of a model.