In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import math

from tqdm import tqdm, trange


from torchvision.transforms import ToTensor
from torchvision.datasets.mnist import MNIST

from ViT import ViT

In [2]:
LR = 5e-5
NUM_EPOCHS = 40
CONVERGENCE_THRESH = 5
ACC_THRESH = 1

In [3]:
const_epochs = 0
max_acc = 0
last_acc = 0

In [4]:
if __name__ == "__main__":
    transform = ToTensor()
    train_set = MNIST(root='./datasets', train=True, download=True, transform=transform)
    test_set = MNIST(root='./datasets', train=False, download=True, transform=transform)
    train_loader = DataLoader(train_set, shuffle=True, batch_size=50)
    test_loader = DataLoader(test_set, shuffle=False, batch_size=50)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"using {device}")
    # device = torch.device("cpu")
    model = ViT((28, 28), device, in_channels=1, n_encoders=3, hidden_dim=512, n_heads=8, patch_dim=(7, 7)).to(device)
    optimizer = opt.Adam(model.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(NUM_EPOCHS):
        model.train()
        total_loss = 0.0
        correct = 0
        total = 0
        # Create a tqdm progress bar for the training batches
        with tqdm(train_loader, unit="batch") as t_bar:
            for x, y in t_bar:
                x, y = x.to(device), y.to(device)
                outputs = model(x)
                loss = criterion(outputs, y)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                total_loss += loss.item()
                predicted = torch.argmax(outputs, dim=-1)
                total += y.size(0)
                correct += (predicted == y).sum().item()

                # Update tqdm progress bar description
                t_bar.set_description(f"Epoch {epoch+1}/{NUM_EPOCHS}")
                t_bar.set_postfix(loss=total_loss / (total + 1e-8), accuracy=100 * correct / total)
        acc = 100.0 * correct / total
        if acc > max_acc:
            torch.save(model.state_dict(), f"./{acc}.pt")
        if last_acc - acc > ACC_THRESH:
            const_epochs = 0
        else:
            const_epochs += 1
        if const_epochs == CONVERGENCE_THRESH:
            break
        last_acc = acc
                
        # Print epoch-level information
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Loss: {total_loss / (total + 1e-8):.4f}, Accuracy: {100 * correct / total:.2f}%")

using cuda


Epoch 1/40: 100%|██████████| 1200/1200 [00:27<00:00, 43.79batch/s, accuracy=91.3, loss=0.00567]


Epoch 1/40 - Loss: 0.0057, Accuracy: 91.33%


Epoch 2/40:  35%|███▌      | 425/1200 [00:09<00:17, 44.85batch/s, accuracy=97.2, loss=0.00179]


KeyboardInterrupt: 

In [None]:
print(model)

In [None]:
with torch.no_grad():
    with tqdm(test_loader, unit="batch") as t_bar:
        for x, y in t_bar:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            predicted = torch.argmax(outputs, dim=-1)
            total += y.size(0)
            correct += (predicted == y).sum().item()

            # Update tqdm progress bar description
            t_bar.set_description(f"Testing...")
            t_bar.set_postfix(loss=total_loss / (total + 1e-8), accuracy=100 * correct / total)
    acc = 100.0 * correct / total
    print(f"Accuracy: {100.0 * correct / total:.2f}%")