# Minibatch training

## Data

In [3]:
# Download the MNIST dataset
import gzip
import pickle
import matplotlib as mpl
import matplotlib.pyplot as plt
from torch import tensor
from pathlib import Path
from urllib.request import urlretrieve

MNIST_URL = "https://github.com/mnielsen/neural-networks-and-deep-learning/blob/master/data/mnist.pkl.gz?raw=true"

data_path = Path("data/mnist")
data_path.mkdir(exist_ok=True)
gz_path = data_path / "mnist.pkl.gz"

mpl.rcParams["image.cmap"] = "gray"

if not gz_path.exists():
    urlretrieve(MNIST_URL, gz_path)

# File contains a tuple of tuples for the x and y, train and validation data
# Images are 28x28
with gzip.open(gz_path, "rb") as file:
    ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(file, encoding="latin-1")

# Put into tensors
x_train, y_train, x_valid, y_valid = map(tensor, [x_train, y_train, x_valid, y_valid])

## Model

In [4]:
# Lets define some constants that describe the data
n_examples, n_pixels = x_train.shape
possible_values = y_train.max() + 1

# How many nodes/activations/line thingys
n_hidden = 50

n_examples, n_pixels, possible_values

(50000, 784, tensor(10))

In [5]:
from torch import nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self, n_in, n_hidden, n_out):
        super().__init__()
        self.layers = [nn.Linear(n_in, n_hidden), nn.ReLU(), nn.Linear(n_hidden, n_out)]

    def __call__(self, inp):
        res = inp
        for layer in self.layers:
            res = layer(res)

        return res

In [6]:
model = Model(n_pixels, n_hidden, 10)
pred = model(x_train)
pred.shape

torch.Size([50000, 10])

## Cross entropy loss

We need to improve our loss function as MSE doesn't make sense, the distance between incorrecd predictions doesn't indicate how good/bad the prediction is. For example if the target is 2, then 3 isnt a better guess than 9.

We are outputting a pred for each possible number. As only one is possible at a time our targets are a one hot encoded matrix. If our targets are going to sum to 1, then it makes sense that our preds do too. We can calculate the softmax for each pred. Then for our loss func we compare the softmaxes to the 1 hot encoded targets. As we are 1 hot encoded we can ignore the other targets and just take the log.

In [7]:
def log_softmax(x):
    return (x.exp() / (x.exp().sum(-1, keepdim=True))).log()


log_softmax(pred)

tensor([[-2.2991, -2.4952, -2.2559,  ..., -2.3145, -2.4055, -2.0550],
        [-2.2196, -2.3798, -2.1516,  ..., -2.4936, -2.2933, -2.1178],
        [-2.2478, -2.3300, -2.3284,  ..., -2.4150, -2.3338, -2.1713],
        ...,
        [-2.2851, -2.4577, -2.2501,  ..., -2.3418, -2.3174, -2.1544],
        [-2.2231, -2.4281, -2.3472,  ..., -2.3487, -2.3482, -2.1276],
        [-2.2634, -2.4369, -2.1781,  ..., -2.3236, -2.4562, -2.1039]],
       grad_fn=<LogBackward0>)

In [8]:
# As log(a/b) = log(a) - log(b), we can simplify things to:


def log_softmax(x):
    return x - x.exp().sum(-1, keepdim=True).log()


log_softmax(pred)

tensor([[-2.2991, -2.4952, -2.2559,  ..., -2.3145, -2.4055, -2.0550],
        [-2.2196, -2.3798, -2.1516,  ..., -2.4936, -2.2933, -2.1178],
        [-2.2478, -2.3300, -2.3284,  ..., -2.4150, -2.3338, -2.1713],
        ...,
        [-2.2851, -2.4577, -2.2501,  ..., -2.3418, -2.3174, -2.1544],
        [-2.2231, -2.4281, -2.3472,  ..., -2.3487, -2.3482, -2.1276],
        [-2.2634, -2.4369, -2.1781,  ..., -2.3236, -2.4562, -2.1039]],
       grad_fn=<SubBackward0>)

In [9]:
# Its possible for the sum of the exponentials of big activations to overflow
# pytorch uses some tricks to solve this for use in one function
def log_softmax(x):
    return x - x.logsumexp(-1, keepdim=True)


sm_pred = log_softmax(pred)
pred

tensor([[ 0.0589, -0.1371,  0.1022,  ...,  0.0436, -0.0474,  0.3030],
        [ 0.1145, -0.0457,  0.1824,  ..., -0.1596,  0.0408,  0.2162],
        [ 0.0590, -0.0232, -0.0216,  ..., -0.1082, -0.0269,  0.1356],
        ...,
        [ 0.0561, -0.1165,  0.0911,  ..., -0.0005,  0.0239,  0.1868],
        [ 0.1206, -0.0844, -0.0035,  ..., -0.0050, -0.0045,  0.2161],
        [ 0.0676, -0.1058,  0.1530,  ...,  0.0074, -0.1251,  0.2272]],
       grad_fn=<AddmmBackward0>)

In [10]:
# How do we index into the preds we want
# The y values tell us targets, which are also the index
y_train[:5]

tensor([5, 0, 4, 1, 9])

In [11]:
# So for each i we want row i and col y_train[i]. eg for pred at row 0
sm_pred[0, y_train[0]]

tensor(-2.4828, grad_fn=<SelectBackward0>)

In [21]:
# We can do this for all of them like this
sm_pred[range(y_train.shape[0]), y_train]

tensor([-2.4828, -2.2196, -2.1388,  ..., -2.3174, -2.1172, -2.4562],
       grad_fn=<IndexBackward0>)

In [17]:
# So our cross entropy loss function looks like this
def nll(inp, target):
    return -inp[range(target.shape[0]), target].mean()


loss = nll(sm_pred, y_train)
loss

tensor(2.3149, grad_fn=<NegBackward0>)

In [22]:
# pytorch gives us both of these functions, NLL = negative log likelihood
F.nll_loss(F.log_softmax(pred, -1), y_train)

tensor(2.3149, grad_fn=<NllLossBackward0>)

In [23]:
# And a cross entropy function to do it all in one
F.cross_entropy(pred, y_train)

tensor(2.3149, grad_fn=<NllLossBackward0>)

## Training loop

In [27]:
loss_func = F.cross_entropy
batch_size = 64

# First we would run a minibatch
mini_batch = x_train[0:batch_size]
mini_batch_y = y_train[0:batch_size]

preds = model(mini_batch)
preds[0], preds.shape

(tensor([ 0.0589, -0.1371,  0.1022,  0.1158,  0.2218, -0.1247, -0.0787,  0.0436,
         -0.0474,  0.3030], grad_fn=<SelectBackward0>),
 torch.Size([64, 10]))

In [28]:
# Then we would calculate the loss
loss_func(preds, mini_batch_y)

tensor(2.2940, grad_fn=<NllLossBackward0>)

In [31]:
# For each of our predictions what did we predict, we need the highest number from each set of preds and get its index
preds.argmax(dim=1)

tensor([9, 9, 4, 4, 9, 4, 4, 9, 3, 9, 9, 9, 9, 9, 4, 2, 9, 4, 9, 4, 3, 2, 4, 4,
        9, 9, 4, 9, 0, 3, 3, 2, 9, 4, 9, 4, 9, 9, 9, 9, 9, 9, 9, 4, 3, 3, 4, 9,
        9, 9, 9, 3, 4, 4, 9, 9, 9, 9, 9, 9, 3, 4, 9, 4])

In [34]:
# We can see which we got correct
preds.argmax(dim=1) == mini_batch_y

tensor([False, False,  True, False,  True, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False,  True, False, False, False,
         True, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True, False, False, False,  True, False,
        False, False, False,  True,  True, False, False,  True, False, False,
        False,  True, False, False])

In [37]:
# We can use it to calculate accuracy, this isnt needed for the NN but helps us understand whats going on
def accuracy(out, targets):
    return (out.argmax(dim=1) == targets).float().mean()


accuracy(preds, mini_batch_y)

tensor(0.1562)

In [54]:
# Now we want to do all of the batches for a bunch of epochs, updating the weights each time
import torch

lr = 0.5
epochs = 5

for epoch in range(epochs):
    for batch_start in range(0, n_examples, batch_size):
        # Get the slice
        sl = slice(batch_start, min(n_examples, batch_start + batch_size))

        batch_inp = x_train[sl]
        batch_targets = y_train[sl]

        # Run the model
        preds = model(batch_inp)
        loss = loss_func(preds, batch_targets)

        if batch_start == 0:
            print("loss:", loss.item(), "acc: ", accuracy(preds, batch_targets).item())

        loss.backward()

        # Update the weights
        with torch.no_grad():
            for layer in model.layers:
                if hasattr(layer, "weight"):
                    layer.weight -= layer.weight.grad * lr
                    layer.bias -= layer.bias.grad * lr
                    layer.weight.grad.zero_()
                    layer.bias.grad.zero_()

loss: 1.6791051626205444 acc:  0.46875
loss: 0.13027404248714447 acc:  0.953125
loss: 0.10441920906305313 acc:  0.921875
loss: 0.0580277256667614 acc:  0.96875
loss: 0.07098586112260818 acc:  0.984375
