In [None]:
import os
import time
import numpy as np
import time

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, ConcatDataset, SubsetRandomSampler, random_split
from torch.optim import lr_scheduler
from torchvision.datasets import ImageFolder
import torch.nn as nn
import torch.optim as optim

from sklearn.model_selection import KFold

import transformers

from tqdm import tqdm

import matplotlib.pyplot as plt

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

%matplotlib notebook

In [None]:
# plot function for plotting accuracies over epochs

def plot(val_accs, train_accs, save_folder="plots"):
    # Create the save folder if it doesn't exist
    os.makedirs(save_folder, exist_ok=True)

    fig, ax = plt.subplots(figsize=(10, 5))

    # Plot the accuracy scores for each epoch
    epochs = np.arange(1, len(train_accs) + 1)

    # Plot validation accuracy
    if len(val_accs) > 0:
        ax.plot(epochs, val_accs, label='Validation')

    # Plot training accuracy
    ax.plot(epochs, train_accs, label='Training', linestyle='--')

    ax.set_title('Accuracy Scores')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Accuracy')
    ax.legend()

    # Adjust layout and save the plot
    plt.tight_layout()
    save_path = os.path.join(save_folder, "accuracy_plot.png")
    plt.savefig(save_path)
    plt.show()

## Load and augment data

In [None]:
data_dir = "./dataset"

input_shape = (3, 224, 224)  # C,W,H

transformations = transforms.Compose(
    [
        transforms.Resize((input_shape[1], input_shape[2])),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomRotation(40),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]
        ),
    ]
)

train_data = ImageFolder(os.path.join(data_dir, 'train'), transform=transformations)
val_data = ImageFolder(os.path.join(data_dir, 'val'), transform=transformations)
test_data = ImageFolder(os.path.join(data_dir, 'test'), transform=transformations)

dataset = ConcatDataset([train_data, val_data])

In [None]:
# evaluation methods to calculate accuracy of a model
class AverageMeter(object):

    def __init__(self):
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.cnt = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.cnt += n
        self.avg = self.sum / self.cnt


def accuracy(logits, labels):
    preds = torch.argmax(logits, dim=1)
    return torch.sum(preds == labels) / len(labels)


def eval_fn(model, loader, device):

    score = AverageMeter()
    model.eval()

    t = tqdm(loader)
    with torch.no_grad():  # no gradient needed
        for images, labels in t:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            acc = accuracy(outputs, labels)
            score.update(acc.item(), images.size(0))

            t.set_description('(=> Test) Score: {:.4f}'.format(score.avg))

    return score.avg

# train function
def train_fn(model, optimizer, criterion, loader, device):

    time_begin = time.time()
    score = AverageMeter()
    losses = AverageMeter()
    model.train()
    time_train = 0

    t = tqdm(loader)
    for images, labels in t:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        acc = accuracy(logits, labels)
        n = images.size(0)
        losses.update(loss.item(), n)
        score.update(acc.item(), n)

        t.set_description('(=> Training) Loss: {:.4f}'.format(losses.avg))

    time_train += time.time() - time_begin
    print('training time: ' + str(time_train))
    return score.avg, losses.avg

## Define model 

In [None]:
class CustomViTForImageClassification(nn.Module):
    # initialize small Visual Transformer model from pretrained ImageNet dataset for 224x224 images
    # then all layers are frozen except the new classifier used for the finetuning to our custom dataset
    def __init__(self, num_classes, pretrained_model_name='WinKawaks/vit-small-patch16-224'):
        super(CustomViTForImageClassification, self).__init__()
        # Load pre-trained ViT model
        self.vit_model = transformers.ViTForImageClassification.from_pretrained(pretrained_model_name)
        for param in self.vit_model.parameters():
            param.requires_grad = False
        self.vit_model.classifier = nn.Linear(self.vit_model.config.hidden_size, num_classes)

    # change forward method to only return class probabilities as we don't need the labels in the output
    def forward(self, pixel_values, labels=None):
        outputs = self.vit_model(pixel_values, labels=labels)
        return outputs[0]

In [None]:
# num_classes = 10 in our dataset
model = CustomViTForImageClassification(10).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.007, momentum=0.7)
#scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=60)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max = 100)

param_count = sum(p.numel() for p in model.parameters())
print(F"Total parameters: {param_count}")

In [None]:
# Parameters for training

train_criterion=torch.nn.CrossEntropyLoss().to(device)
num_epochs = 10
batch_size = 16

# if len(split_ratio) == 3 we randomly split the dataset into train/val/test
# if it is 2 we split into train/test and don't do validation in training
split_ratio = [0.6, 0.2, 0.2]

if len(split_ratio) == 3:
    train_dataset, val_dataset, test_dataset = random_split(dataset, split_ratio)
    train_loader = DataLoader(train_dataset, batch_size=batch_size)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
if len(split_ratio) == 2:
    train_dataset, test_dataset = random_split(dataset, split_ratio)
    train_loader = DataLoader(train_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

## Train model

In [None]:
val_accs = []
train_accs = []
train_losss = []

for epoch in range(num_epochs):
    print('#' * 50)
    print(f"Epoch {epoch + 1}/{num_epochs}, lr: {scheduler.get_last_lr()}")

    train_score, train_loss = train_fn(model, optimizer, train_criterion, train_loader, device)
    print('Train accuracy: %f', train_score)

    if len(split_ratio) > 2:
        test_score = eval_fn(model, val_loader, device)
        print('Validation accuracy: %f', test_score)

    scheduler.step()

    val_accs.append(test_score)
    train_accs.append(train_score)
    train_losss.append(train_loss)

print('--------------------------------')
# Compare validation and train accuracy to find if we overfit
plot(val_accs, train_accs)

## Eval model

In [None]:
full_loader = DataLoader(dataset=ConcatDataset([train_data, val_data, test_data]), batch_size=128, shuffle=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)

# Find accuracy on test set, important to not use it anywhere in training
score = eval_fn(model, test_loader, device)
print('Avg accuracy test dataset:', str(score*100) + '%')

# For comparison check accuracy on full dataset which should be higher as we trained on most of the data
score = eval_fn(model, full_loader, device)
print('Avg accuracy full dataset:', str(score*100) + '%')