In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
import wandb

import lovely_tensors
lovely_tensors.monkey_patch()

from vit import ViT

In [5]:
vit = ViT(
    img_size=32, 
    patch_size=4,    
    in_channels=3, 
    embed_size=256, 
    num_heads=8, 
    depth=6,
    n_classes=10
)

input_tensor = torch.randn(1, 3, 32, 32)
out = vit(input_tensor)
out

tensor[1, 10] x∈[-0.584, 0.611] μ=-0.083 σ=0.406 grad AddmmBackward0 [[-0.445, 0.457, -0.584, -0.556, -0.276, 0.100, 0.101, -0.105, -0.137, 0.611]]

In [6]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop(size=32, scale=(0.8, 1.0), ratio=(0.9, 1.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
wandb.init(project="MRO-VIT")
model = ViT(
    img_size=32, 
    patch_size=4,    
    in_channels=3, 
    embed_size=256, 
    num_heads=8, 
    depth=6,
    n_classes=10
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.002)
num_epochs = 165
lr_drop_epochs = [100, 150]
best_test_loss = float('inf')

progressbar = tqdm(total=num_epochs)
for epoch in range(num_epochs):
    model.train()
    train_loss, train_correct = 0.0, 0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        train_correct += (predicted == labels).sum().item()

        progressbar.set_description(f"train loss: {loss.item():.4f}")

    model.eval()
    test_loss, test_correct = 0.0, 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)

            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            test_correct += (predicted == labels).sum().item()

            progressbar.set_description(f"test loss: {loss.item():.4f}")

    if epoch in lr_drop_epochs:
        print("LR drop")
        for param_group in optimizer.param_groups:
            param_group['lr'] /= 10

    logs = {
        "epoch": epoch,
        "lr": optimizer.param_groups[0]['lr'], 
        "loss/train": train_loss / len(train_loader),
        "loss/test": test_loss / len(test_loader),
        "acc/train": train_correct / len(train_dataset),
        "acc/test": test_correct / len(test_dataset),
    }
    wandb.log(logs)

    if logs["loss/test"] < best_test_loss:
        best_test_loss = logs["loss/test"]
        model_path = f"best_model_epoch_{epoch:03d}_test_loss{best_test_loss:.1e}.pth"
        torch.save(model.state_dict(), model_path)
        wandb.save(model_path)

        
    progressbar.update(1)
    print(", ".join([f"{k}: {v:}" for k, v in logs.items()]))


wandb.finish()


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mwoj-jasinski[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/165 [00:00<?, ?it/s]

epoch: 0, lr: 0.002, loss/train: 1.8000186492719918, loss/test: 1.5120775171473055, acc/train: 0.3376, acc/test: 0.4495
epoch: 1, lr: 0.002, loss/train: 1.4737884120258224, loss/test: 1.3767229635504228, acc/train: 0.46202, acc/test: 0.5016
epoch: 2, lr: 0.002, loss/train: 1.3660748270161622, loss/test: 1.3069208920756472, acc/train: 0.50728, acc/test: 0.5294
epoch: 3, lr: 0.002, loss/train: 1.3007862763026792, loss/test: 1.2506943111178241, acc/train: 0.53264, acc/test: 0.5469
epoch: 4, lr: 0.002, loss/train: 1.2493867419869698, loss/test: 1.2306218207636965, acc/train: 0.55136, acc/test: 0.5581
epoch: 5, lr: 0.002, loss/train: 1.200251328060999, loss/test: 1.1777015237868587, acc/train: 0.56978, acc/test: 0.5824
epoch: 6, lr: 0.002, loss/train: 1.1635121986689165, loss/test: 1.1462715366218663, acc/train: 0.5854, acc/test: 0.5878
epoch: 7, lr: 0.002, loss/train: 1.133148251896929, loss/test: 1.116226367558105, acc/train: 0.59752, acc/test: 0.6015
epoch: 8, lr: 0.002, loss/train: 1.11

VBox(children=(Label(value='100.871 MB of 100.871 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0,…

0,1
acc/test,▁▃▄▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇███████████████
acc/train,▁▃▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▆▆▆▇▇▇▇██████████████
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss/test,█▆▅▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/train,█▆▅▅▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,█████████████████████████▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁

0,1
acc/test,0.7812
acc/train,0.86518
epoch,164.0
loss/test,0.69445
loss/train,0.38207
lr,2e-05
