In [1]:
# I am trying to run the first model in https://pytorch.org/tutorials/beginner/nn_tutorial.html using GPU

In [2]:
from pathlib import Path
import requests
import math

DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"

PATH.mkdir(parents=True, exist_ok=True)

URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"

if not (PATH / FILENAME).exists():
        content = requests.get(URL + FILENAME).content
        (PATH / FILENAME).open("wb").write(content)

In [3]:
import pickle
import gzip

with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

In [4]:
from matplotlib import pyplot
import numpy as np

pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape)

(50000, 784)


In [5]:
import torch

x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)
n, c = x_train.shape
x_train, x_train.shape, y_train.min(), y_train.max()
print(x_train, y_train)
print(x_train.shape)
print(y_train.min(), y_train.max())

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]) tensor([5, 0, 4,  ..., 8, 4, 8])
torch.Size([50000, 784])
tensor(0) tensor(9)


In [6]:
# dev = torch.device("cpu") # works
dev = torch.device("cuda") # Does not work

dev

device(type='cuda')

In [7]:
from torch import nn

class Mnist_Logistic(nn.Module):
    def __init__(self):
        super().__init__()
        self.weights = nn.Parameter(torch.randn(784, 10) / math.sqrt(784))
        self.bias = nn.Parameter(torch.zeros(10))

    def forward(self, xb):
        return xb @ self.weights + self.bias

In [8]:
def accuracy(out, yb):
    preds = torch.argmax(out, dim=1)
    return (preds == yb).float().mean()

In [10]:
import torch.nn.functional as F

loss_func = F.cross_entropy

In [21]:
lr = 0.5  # learning rate
bs = 64  # batch size
epochs = 2  # how many epochs to train for
model = Mnist_Logistic().to(dev)

In [22]:
# xb = x_train[0:bs].to(dev)  # a mini-batch from x
# yb = y_train[0:bs].to(dev)
xt = x_train.to(dev)  # a mini-batch from x
yt = y_train.to(dev)

def print_loss_accuracy(msg, x, y):
    preds = model(x)  # predictions
    # preds[0], preds.shape
    # print(preds[0], preds.shape)
    print(f"{msg} loss {loss_func(preds, y)} accuracy {accuracy(preds, y)}")

print_loss_accuracy("Initial", xt, yt)

Initial loss 2.3585128784179688 accuracy 0.1046999990940094


In [23]:
def fit():
    for epoch in range(epochs):
        for i in range((n - 1) // bs + 1):
            start_i = i * bs
            end_i = start_i + bs
            xb = x_train[start_i:end_i].to(dev)
            yb = y_train[start_i:end_i].to(dev)
            pred = model(xb)
            loss = loss_func(pred, yb)

            loss.backward()
            with torch.no_grad():
                for p in model.parameters():
                    p -= p.grad * lr
                model.zero_grad()

        print_loss_accuracy(f"After epoch {epoch}", xt, yt)

fit()

After epoch 0 loss 0.3332004249095917 accuracy 0.905299961566925
After epoch 1 loss 0.30698782205581665 accuracy 0.9134599566459656


In [24]:
preds = model(xb)  # predictions
print(loss_func(preds, yb))
print(accuracy(preds, yb))

tensor(0.3070, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(0.9135, device='cuda:0')


In [25]:
fit()

After epoch 0 loss 0.2958129644393921 accuracy 0.9167799949645996
After epoch 1 loss 0.2891461253166199 accuracy 0.9183200001716614


In [26]:
fit()

After epoch 0 loss 0.28456413745880127 accuracy 0.9194999933242798
After epoch 1 loss 0.2811531722545624 accuracy 0.9204599857330322


In [27]:
fit()

After epoch 0 loss 0.27847820520401 accuracy 0.9213399887084961
After epoch 1 loss 0.27630120515823364 accuracy 0.9219399690628052


In [28]:
fit()

After epoch 0 loss 0.2744787335395813 accuracy 0.9225399494171143
After epoch 1 loss 0.2729182839393616 accuracy 0.9230999946594238
