# Concise Implementation of Softmax Regression

Just as PyTOrch made it much easier to implement linear regression in notebook ``concise_implementation_of_linear_regression``, we will find it similarly convenient for implementing classification models.

Let's import tha necessary packages.

In [None]:
import torch
import random
from torch import nn, optim
import torchvision
from torchvision import datasets, transforms
from IPython import display
import d2l
import matplotlib.pyplot as plt

Let’s stick with the Fashion-MNIST dataset and keep the batch size at 256.

In [None]:
def load_data_fashion_mnist(batch_size, resize=None, num_workers=0):
    tranform_list = []
    if resize:
        tranform_list.append(torchvision.transforms.Resize(resize))
    tranform_list.append(transforms.ToTensor())
    transform = transforms.Compose(tranform_list)
    mnist_train = datasets.FashionMNIST('~/datasets/F_MNIST/',
                                 download=True,
                                 train=True,
                                 transform=transform)

    mnist_test = datasets.FashionMNIST('~/datasets/F_MNIST/',
                                     download=True,
                                     train=False,
                                    transform=transform)
    
    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, 
                                          shuffle=True, num_workers=num_workers)
    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, 
                                              shuffle=False)
    return train_iter, test_iter

In [None]:
batch_size = 256

train_iter, test_iter = # insert your code here

## Define Model

In [None]:
input_dim = # insert your code here
output_dim = # insert your code here
net = # insert your code here

## The Softmax

In notebook ``implementation_of_softmax_from_scratch``, we calculated our model’s output and then ran this output through the cross-entropy
loss. At its heart it uses ``torch.mean(-torch.log(y_hat.gather(1, y.view(-1, 1))))``. Mathematically, that’s a perfectly reasonable thing
to do. However, computationally, things can get hairy when dealing with exponentiation due to numerical
stability issues (e.g. in Section 4.5). Recall that the softmax function calculates $\hat{y}_j = \frac{e^{z_j}}{\sum_{i=1}^n e^{z_i}}$, where $\hat{y}_j$ is the j-th element
of $\hat{y}$ and $z_j$ is the j-th element of the input ``y_linear`` variable, as computed by the softmax.

Our salvation is that even though we’re computing these exponential functions, we ultimately plan to take
their log in the cross-entropy functions. It turns out that by combining these two operators softmax and
cross_entropy together, we can escape the numerical stability issues that might otherwise plague us during
backpropagation. As shown in the equation below, we avoided calculating $e^{z_j}$ but directly using $z_j$ due to
$log(exp())$.

$$\log (\hat{y}_j) = \log \left(\frac{e^{z_j}}{\sum_{i=1}^n e^{z_i}} \right) = z_j - \log \left(\sum_{i=1}^n e^{z_i} \right)$$.

PyTorch already implements this trick in its [``CrossEntropyLoss``](https://pytorch.org/docs/0.3.1/nn.html?highlight=crossentropyloss#torch.nn.CrossEntropyLoss).

In [None]:
loss = # insert your code here

## Optimization Algorithm

We use the mini-batch random gradient descent with a learning rate of 0.1 as the optimization algorithm.
Note that this is the same choice as for linear regression and it illustrates the general applicability of the
optimizers.

In [None]:
lr = 0.1
optimizer = # insert your code here

## Training

Next, we use the training functions defined in notebook ``implementation_of_softmax_regression_from_scratch`` to train a model. Note that we made a slight modification in ``train_epoch`` and ``evaluate_accuracy`` functions.

In [None]:
def accuracy(y_hat, y):
    return (y_hat.argmax(dim=1) == y).sum()

class Accumulator(object):
    """Sum a list of numbers over time"""
    def __init__(self, n):
        self.data = [0.0] * n
    def add(self, *args):
        self.data = [a+b for a, b in zip(self.data, args)]
    def reset(self):
        self.data = [0] * len(self.data)
    def __getitem__(self, i):
        return self.data[i]
    
def evaluate_accuracy(net, data_iter):
    metric = Accumulator(2)
    for X, y in data_iter:
        metric.add(accuracy(net(X.view(-1, X.shape[2]*X.shape[3])), y), len(y))
    return float(metric[0])/metric[1]

def train_epoch(net, train_iter, loss, optimizer):
    metric = Accumulator(3) # train_loss_sum, train_acc_sum, num_examples
    for X, y in train_iter:
        y_hat = net(X.view(-1, X.shape[2]*X.shape[3]))
        # compute gradients and update parameters
        l = loss(y_hat, y)
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        l_sum = float(l)*len(y)
        metric.add(l_sum, float(accuracy(y_hat, y)), len(y))
    return metric[0]/metric[2], metric[1]/metric[2]

class Animator(object):
    def __init__(self, xlabel=None, ylabel=None, legend=[], xlim=None,
                 ylim=None, xscale='linear', yscale='linear', fmts=None,
                 nrows=1, ncols=1, figsize=(3.5, 2.5)):
        """Incrementally plot multiple lines."""
        d2l.use_svg_display()
        self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)
        if nrows * ncols == 1: self.axes = [self.axes,]
        # use a lambda to capture arguments
        self.config_axes = lambda : d2l.set_axes(self.axes[0], xlabel, ylabel,
                                                 xlim, ylim, xscale, yscale, legend)
        self.X, self.Y, self.fmts = None, None, fmts
    def add(self, x, y):
        """Add multiple data points into the figure."""
        if not hasattr(y, "__len__"): y = [y]
        n = len(y)
        if not hasattr(x, "__len__"): x = [x] * n
        if not self.X: self.X = [[] for _ in range(n)]
        if not self.Y: self.Y = [[] for _ in range(n)]
        if not self.fmts: self.fmts = ['-'] * n
        for i, (a, b) in enumerate(zip(x, y)):
            if a is not None and b is not None:
                self.X[i].append(a)
                self.Y[i].append(b)
        self.axes[0].cla()
        for x, y, fmt in zip(self.X, self.Y, self.fmts):
            self.axes[0].plot(x, y, fmt)
        self.config_axes()
        display.display(self.fig)
        display.clear_output(wait=True)
        
def train(net, train_iter, test_iter, loss, num_epochs, updater):
    trains, test_accs = [], []
    animator = Animator(xlabel='epoch', xlim=[1, num_epochs], 
                        ylim=[0.3, 0.9], legend=['train loss', 'train acc', 'test acc'])
    for epoch in range(num_epochs):
        train_metrics = train_epoch(net, train_iter, loss, updater)
        test_acc = evaluate_accuracy(net, test_iter)
        animator.add(epoch+1, train_metrics+(test_acc,))

In [None]:
num_epochs = 10
# insert your code here