In [1]:
!git clone https://github.com/vishal-burman/PyTorch-Architectures.git
%cd PyTorch-Architectures/modeling_ViT

Cloning into 'PyTorch-Architectures'...
remote: Enumerating objects: 249, done.[K
remote: Counting objects: 100% (249/249), done.[K
remote: Compressing objects: 100% (157/157), done.[K
remote: Total 680 (delta 140), reused 174 (delta 79), pack-reused 431[K
Receiving objects: 100% (680/680), 8.39 MiB | 5.15 MiB/s, done.
Resolving deltas: 100% (416/416), done.
/kaggle/working/PyTorch-Architectures/modeling_ViT


In [3]:
import os
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import Subset

from torchvision import datasets, transforms

from model import ViT

In [7]:
# SETTINGS

# Model Settings
learning_rate = 3e-4
batch_size = 32
num_epochs = 5

# Architecture
num_classes = 10

# Other
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
model = ViT(
    image_size=64,
    patch_size=8,
    num_classes=10,
    dim=1024,
    depth=6,
    heads=8,
    mlp_dim=2048,
    dropout=0.1,
    emb_dropout=0.1,
).to(device)

total_params = sum(p.numel() for p in model.parameters())
print("Total Parameters = ", total_params)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

Total Parameters =  54833320


In [9]:
# Dataset
train_indices = torch.arange(0, 48000)
valid_indices = torch.arange(48000, 50000)

train_transform = transforms.Compose([
    transforms.Resize((70, 70)),
    transforms.RandomCrop((64, 64)),
    transforms.ToTensor(),
])
test_transform = transforms.Compose([
    transforms.Resize((70, 70)),
    transforms.RandomCrop((64, 64)),
    transforms.ToTensor(),
])

train_and_valid = datasets.CIFAR10(root="data", train=True, transform=train_transform, download=True)

train_dataset = Subset(train_and_valid, train_indices)
valid_dataset = Subset(train_and_valid, valid_indices)
test_dataset = datasets.CIFAR10(root="data", train=False, transform=test_transform, download=False)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/cifar-10-python.tar.gz to data


In [10]:
# Check the train_loader
for images, labels in train_loader:
    print("Image Dimensions: ", images.shape)
    print("Label Dimensions: ", labels.shape)
    break
print("Length of Train DataLoader: ", len(train_loader))
print("Length of Valid DataLoader: ", len(valid_loader))

Image Dimensions:  torch.Size([32, 3, 64, 64])
Label Dimensions:  torch.Size([32])
Length of Train DataLoader:  1500
Length of Valid DataLoader:  63


In [None]:
def compute_accuracy(model, data_loader, device):
    pass

start_time = time.time()
for epoch in range(num_epochs):
    model.train()
    for batch_idx, (features, labels) in train_loader:
        
        features = features.to(device)
        labels = labels.to(device)
        logits = model(features)
        cost = F.cross_entropy(logits, labels)
        
        optimizer.zero_grad()
        cost.backward()
        optimizer.step()
        
        # LOGGING
        if batch_idx % 200:
            print("Batch: %04d/%04d || Epoch: %04d/%04d" % (batch_idx, len(train_loader), epoch+1, num_epochs))
    model.eval()
    with torch.set_grad_enabled(False):
        train_acc = compute_accuracy(model, train_loader, device)
        valid_acc = compute_accuracy(model, valid_loader, device)
        print("Train Accuracy: ", train_acc)
        print("Valid Accuracy: ", valid_acc)
    elapsed_time = (time.time() - start_time) / 60
    print("Epoch Elapsed Time: ", elapsed_time)
elapsed_time = (time.time() - start_time) / 60
print("Total Training Time: ", elapsed_time)