In [1]:
import torch
import torchvision
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from tqdm import tqdm
from torch.utils.data import DataLoader

from torchvision.datasets.mnist import MNIST
from torchvision.transforms import ToTensor

import numpy as np 

from ViT import ViT
import os




np.random.seed(420)
torch.manual_seed(420)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7fe4f0b2c330>

In [2]:
DOWNLOAD_PATH = 'data'

if not os.path.exists(DOWNLOAD_PATH):
    os.mkdir(DOWNLOAD_PATH)

transform = ToTensor()
train_set = MNIST(root=DOWNLOAD_PATH, train=True, download=True, transform=transform)
test_set = MNIST(root=DOWNLOAD_PATH, train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, shuffle=True, batch_size=16)
test_loader = DataLoader(test_set, shuffle=False, batch_size=16)




In [3]:
model = ViT((1,28,28), n_patches=7, hidden_d=8, n_heads=2, out_d=10)

N_EPOCHS = 20
LR = 0.01

# train loop
optimizer = Adam(model.parameters(), lr=LR)
criterion = CrossEntropyLoss()

for epoch in range(N_EPOCHS):
    train_loss = 0.0
    for _ ,batch in enumerate(tqdm(train_loader)):
        x,y = batch
        y_hat = model(x)
        loss = criterion(y_hat,y)/ len(x)

        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'EPOCH {epoch+1}/{N_EPOCHS} loss: {train_loss:.2f}')


100%|██████████| 3750/3750 [01:52<00:00, 33.43it/s]


EPOCH 1/20 loss: 403.91


100%|██████████| 3750/3750 [01:48<00:00, 34.63it/s]


EPOCH 2/20 loss: 371.46


100%|██████████| 3750/3750 [01:48<00:00, 34.41it/s]


EPOCH 3/20 loss: 368.45


100%|██████████| 3750/3750 [01:47<00:00, 34.77it/s]


EPOCH 4/20 loss: 367.36


100%|██████████| 3750/3750 [01:47<00:00, 34.87it/s]


EPOCH 5/20 loss: 366.19


100%|██████████| 3750/3750 [01:48<00:00, 34.65it/s]


EPOCH 6/20 loss: 365.52


100%|██████████| 3750/3750 [01:48<00:00, 34.47it/s]


EPOCH 7/20 loss: 364.85


100%|██████████| 3750/3750 [01:47<00:00, 34.83it/s]


EPOCH 8/20 loss: 364.10


100%|██████████| 3750/3750 [01:46<00:00, 35.10it/s]


EPOCH 9/20 loss: 363.70


100%|██████████| 3750/3750 [01:49<00:00, 34.24it/s]


EPOCH 10/20 loss: 363.24


100%|██████████| 3750/3750 [01:49<00:00, 34.22it/s]


EPOCH 11/20 loss: 363.18


100%|██████████| 3750/3750 [01:50<00:00, 34.01it/s]


EPOCH 12/20 loss: 362.97


100%|██████████| 3750/3750 [01:51<00:00, 33.64it/s]


EPOCH 13/20 loss: 362.99


100%|██████████| 3750/3750 [01:50<00:00, 33.90it/s]


EPOCH 14/20 loss: 363.20


100%|██████████| 3750/3750 [01:42<00:00, 36.70it/s]


EPOCH 15/20 loss: 362.57


100%|██████████| 3750/3750 [01:50<00:00, 33.79it/s]


EPOCH 16/20 loss: 362.45


100%|██████████| 3750/3750 [01:53<00:00, 33.18it/s]


EPOCH 17/20 loss: 362.50


100%|██████████| 3750/3750 [01:53<00:00, 32.91it/s]


EPOCH 18/20 loss: 362.48


100%|██████████| 3750/3750 [01:53<00:00, 33.17it/s]


EPOCH 19/20 loss: 362.40


100%|██████████| 3750/3750 [01:55<00:00, 32.44it/s]

EPOCH 20/20 loss: 362.55





In [4]:
# test loop
correct, total = 0, 0
test_loss = 0.0

for batch in test_loader:
    x,y = batch
    y_hat = model(x)
    loss = criterion(y_hat, y)
    test_loss += loss / len(x)

    correct += torch.sum(torch.argmax(y_hat, dim=1)== y).item()
    total+= len(x)
print(f'Test loss: {test_loss:.2f}')
print(f'Test accuracy: {correct / total * 100:.2f}%')

Test loss: 60.05
Test accuracy: 92.31%
