In [1]:
import torch
from torch import nn
import torchvision
from PIL import Image

device = torch.device('mps')

full_train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=torchvision.transforms.ToTensor(), download=True)

train_size = int(len(full_train_dataset) * 0.9)
val_size = len(full_train_dataset) - train_size


train_dataset, val_dataset = torch.utils.data.random_split(full_train_dataset, [train_size, val_size])
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True)

print(len(train_dataset), len(val_dataset), len(test_dataset))

%load_ext autoreload
%autoreload 2

45000 5000 10000


In [2]:
batch_size = 32
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=2)

# for batch, labels in train_loader:
#     # print(batch)
#     print(labels)
#     # print(batch.shape)
#     print(labels.shape)
#     break

In [4]:
from vision_transformer import VisionTransformer

train_iters = 2
eval_interval = 1
eval_iters = 2
device = torch.device('mps')

model = VisionTransformer().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

@torch.no_grad()
def eval(iteration):

    train_losses = torch.zeros(eval_iters)
    for i, (batch, labels) in enumerate(train_loader):
        if i == eval_iters:
            break

        x = batch.to(device)
        targets = labels.to(device)
        out = model(x)
        loss = criterion(out, targets)
        train_losses[i] = loss.item()

    avg_train_loss = train_losses.mean(dim=0)

    val_losses = torch.zeros(eval_iters)
    for i, (batch, labels) in enumerate(val_loader):
        if i == eval_iters:
            break

        x = batch.to(device)
        targets = labels.to(device)
        out = model(x)
        loss = criterion(out, targets)
        val_losses[i] = loss.item()

    avg_val_loss = val_losses.mean(dim=0)

    print("Iteration " + str(iteration) + " - Train loss: " + str(avg_train_loss.item()) + ", Val loss: " + str(avg_val_loss.item()))


for i in range(train_iters):

    if i % eval_interval == 0:
        eval(i)

    for batch, labels in train_loader:
        
        x = batch.to(device)
        targets = labels.to(device)

        out = model(x)

        loss = criterion(out, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


# x = torch.randn(size=(2, 3, 32, 32), device=device)

# out = model(x)

# print(out)
# print(out.shape)

Iteration 0 - Train loss: 2.308490514755249, Val loss: 2.3256478309631348
Iteration 1 - Train loss: 2.3022377490997314, Val loss: 2.3034329414367676
Iteration 0 - Train loss: 2.330874443054199, Val loss: 2.3251724243164062
Iteration 1 - Train loss: 2.3024582862854004, Val loss: 2.303649425506592
