Install Vit with Linformer.

In [None]:
!pip -q install vit_pytorch linformer

Import the required Python dependencies.

In [None]:
from __future__ import print_function

import glob
from itertools import chain
import os
import random
import zipfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from linformer import Linformer
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

from vit_pytorch.efficient import ViT

## Data Download
The data set used for this PoC comes from a Kaggle competition: https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition
Before proceeding with the download, please follow these steps:


1.   Login to Kaggle and go to the **Account** section of your profile.
2.   Scroll to the **API** section and click on the **Create New API Token** button. This action will create a *kaggle.json* file and download it to your local disk.



Then install the Python *kaggle* package.

In [None]:
! pip install -q kaggle

Upload the generated *kaggle.json* file to Colab.

In [None]:
from google.colab import files

files.upload()

Make a directory named *kaggle* and copy the *kaggle.json* file there. Change the permissions of the file then.

In [None]:
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

Check that everything is OK.

In [None]:
 !kaggle datasets list

Download the data providing the name of the competition the data set belongs to.

In [None]:
!kaggle competitions download -c 'dogs-vs-cats-redux-kernels-edition'

Extract the train and test data.

In [None]:
!unzip train.zip

In [None]:
!unzip test.zip

## Data Loading

In [None]:
train_dir = './train'
test_dir = './test'

In [None]:
train_list = glob.glob(os.path.join(train_dir,'*.jpg'))
test_list = glob.glob(os.path.join(test_dir, '*.jpg'))

In [None]:
print(f"Train Data: {len(train_list)}")
print(f"Test Data: {len(test_list)}")

In [None]:
labels = [path.split('/')[-1].split('.')[0] for path in train_list]

Plot some random images.

In [None]:
random_idx = np.random.randint(1, len(train_list), size=9)
fig, axes = plt.subplots(3, 3, figsize=(16, 12))

for idx, ax in enumerate(axes.ravel()):
    img = Image.open(train_list[idx])
    ax.set_title(labels[idx])
    ax.imshow(img)

# Training

Set the seed.

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed = 42
seed_everything(seed)

Set training params.

In [None]:
batch_size = 64
epochs = 20
lr = 3e-5
gamma = 0.7

Set the device.

In [None]:
device = 'cuda'

Split the data for training and testing.

In [None]:
train_list, valid_list = train_test_split(train_list, 
                                          test_size=0.2,
                                          stratify=labels,
                                          random_state=seed)

In [None]:
print(f"Train Data: {len(train_list)}")
print(f"Validation Data: {len(valid_list)}")
print(f"Test Data: {len(test_list)}")

Define the image augmentation (resize to 224x224, random rezise crop and random horizontal flip) to use.

In [None]:
train_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)


test_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

Load the data sets.

In [None]:
class CatsDogsDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        img = Image.open(img_path)
        img_transformed = self.transform(img)

        label = img_path.split("/")[-1].split(".")[0]
        label = 1 if label == "dog" else 0

        return img_transformed, label

In [None]:
train_data = CatsDogsDataset(train_list, transform=train_transforms)
valid_data = CatsDogsDataset(valid_list, transform=test_transforms)
test_data = CatsDogsDataset(test_list, transform=test_transforms)

In [None]:
train_loader = DataLoader(dataset = train_data, batch_size=batch_size, shuffle=True )
valid_loader = DataLoader(dataset = valid_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset = test_data, batch_size=batch_size, shuffle=True)

In [None]:
print(len(train_data), len(train_loader))

In [None]:
print(len(valid_data), len(valid_loader))

Define the model.
Linformer:

In [None]:
efficient_transformer = Linformer(
    dim=128,
    seq_len=49+1,  # 7x7 patches + 1 cls-token
    depth=12,
    heads=8,
    k=64
)

Visual Transformer:

In [None]:
model = ViT(
    dim=128,
    image_size=224,
    patch_size=32,
    num_classes=2,
    transformer=efficient_transformer,
    channels=3,
).to(device)

Choose the loss function.

In [None]:
criterion = nn.CrossEntropyLoss()

Choose the optimizer.

In [None]:
optimizer = optim.Adam(model.parameters(), lr=lr)

Set the scheduler.

In [None]:
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

Start the training.

In [None]:
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

Save the model.

In [None]:
torch.save(model.state_dict(), './pretrained-vit-net.pt')

Do some prediction.

In [None]:
img = torch.randn(1, 3, 224, 224)

In [None]:
pred = model(img.to(device))

In [None]:
pred

In [None]:
prediction = pred.cpu().data.numpy().argmax()

In [None]:
prediction

In [None]:
labels[prediction]