In [1]:
import torch
from torch import nn
import torchvision
from PIL import Image

device = torch.device('mps')

full_train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=torchvision.transforms.ToTensor(), download=True)

train_size = int(len(full_train_dataset) * 0.9)
val_size = len(full_train_dataset) - train_size


train_dataset, val_dataset = torch.utils.data.random_split(full_train_dataset, [train_size, val_size])
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True)

print(len(train_dataset), len(val_dataset), len(test_dataset))

%load_ext autoreload
%autoreload 2

45000 5000 10000


In [2]:
batch_size = 32
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=2)

# for batch, labels in train_loader:
#     # print(batch)
#     print(labels)
#     # print(batch.shape)
#     print(labels.shape)
#     break

In [3]:
from vision_transformer import VisionTransformer

train_iters = 5
eval_interval = 3
eval_iters = 2
device = torch.device('mps')

model = VisionTransformer().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable_params}")

@torch.no_grad()
def eval(iteration):

    model.eval()

    val_losses = torch.zeros(eval_iters)
    for i, (batch, labels) in enumerate(val_loader):
        if i == eval_iters:
            break

        x = batch.to(device)
        targets = labels.to(device)
        out = model(x)
        loss = criterion(out, targets)
        val_losses[i] = loss.item()

    avg_val_loss = val_losses.mean(dim=0)

    model.train()

    print("Epoch " + str(iteration) + " - Val loss: " + str(avg_val_loss.item()))


for i in range(train_iters):

    train_losses_sum = 0.0

    for batch_num, (batch, labels) in enumerate(train_loader):
        
        x = batch.to(device)
        targets = labels.to(device)

        out = model(x)

        loss = criterion(out, targets)
        train_losses_sum += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print("Epoch " + str(i) + " - Train loss: " + str(train_losses_sum / len(train_loader)))

    if i % eval_interval == 0 or i == train_iters - 1:
        eval(i)


Trainable parameters: 164746
Epoch 0 - Train loss: 1.9710850785959673
Epoch 0Val loss: 1.8361599445343018
Epoch 1 - Train loss: 1.6834926872958338


KeyboardInterrupt: 