## Introduction

This notebook implements the [Vision Transformer](https://arxiv.org/abs/2010.11929) model in order to predict the classes of bounding boxes. The code for the model has been taken from [here](https://github.com/lucidrains/vit-pytorch) and customized. 

I cropped the train images with the ground truth bounding boxes and applied the model on 16 * 16 flattened 2D patches. In order not to run OOM, I capped the number of patches to 2048 (in this way very large bounding boxes are ignored). 

A challenge are the positional encodings with this setup. As bounding boxes are of very different shapes, the positions of the flattened patches does not convey much information. I experimented with 2D relative (euclidian) distance encodings and 2D sinusodal positional encodings. However, the results do not appear to be better than with simple learnable 1D positional embeddings.

The model is trained from scratch for 12 epochs (on 80% of all bounding boxes in the training data; no pretrained weights used).

Images in jpg-format have been kindly provided by [Md Awsafur Rahman](https://www.kaggle.com/awsaf49) ([dataset](https://www.kaggle.com/awsaf49/vinbigdata-original-image-dataset)) (Thanks for sharing!).

### We start by quickly importing some modules.

In [None]:
import numpy as np
import pandas as pd
import os
from glob import glob
import shutil, os
import matplotlib.pyplot as plt
from sklearn.model_selection import GroupKFold
from tqdm.notebook import tqdm
import seaborn as sns
import matplotlib.image as mpimg
from collections import defaultdict
import dill
import math

import torch
import torch.nn as nn
import torch.optim as optim 
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import torchvision
from torch.utils.data import DataLoader, Dataset
import torch.utils.data as utils
from torchvision import transforms

from sklearn.feature_extraction.image import extract_patches_2d
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split, GroupKFold
from PIL import Image
import cv2
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut

!pip install einops
from einops import rearrange, repeat

import warnings
warnings.filterwarnings("ignore")

### GPU use is of course recommended.

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available")
else:
    device = torch.device("cpu")
    print("GPU not available, CPU used")

### Loading the data and deleting "no finding" bounding boxes

In [None]:
train_df = pd.read_csv('../input/vinbigdata-chest-xray-abnormalities-detection/train.csv')

train_df = train_df[train_df['class_id']!=14].reset_index(drop=True)

train_df[['x_min','y_min','x_max','y_max']] = train_df[['x_min','y_min','x_max','y_max']].applymap(int)

### GroupKFold train-validation split in order to ensure that bounding boxes of the same image are not in training and validation data

In [None]:
gkf = GroupKFold(n_splits=5)

train_df['fold'] = -1
for fold, (train, val) in enumerate(gkf.split(train_df[['x_min','y_min','x_max','y_max']],train_df['class_id'],groups=train_df.image_id.to_list())): 
    train_df.loc[val,'fold']=fold
    
fold = 0
val_df = train_df.loc[train_df['fold']==fold]
train_df = train_df.loc[train_df['fold']!=fold]

train_df.reset_index(drop=True, inplace=True)
val_df.reset_index(drop=True, inplace=True)

## The model

In [None]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

In [None]:
class Attention(nn.Module):
    def __init__(self, dim, heads, dim_head, dropout):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout))

    def forward(self, x, mask = None):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        mask_value = -torch.finfo(dots.dtype).max
        
        if mask is not None:
            mask = mask[:, None, :] * mask[:, :, None]
            mask = mask.unsqueeze(1)
            dots.masked_fill_(mask, mask_value)
            del mask

        attn = dots.softmax(dim=-1)

        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out

In [None]:
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
                Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
            ]))
    def forward(self, x, mask = None):
        for attn, ff in self.layers:
            x = attn(x, mask = mask)
            x = ff(x)
        return x

In [None]:
class Vision_Transformer(nn.Module):
    def __init__(self, patch_size, max_patches, dim, depth, heads, mlp_dim, pool = 'cls', dim_head = 64, dropout = 0., emb_dropout = 0.):
        
        super().__init__()

        self.patch_size = patch_size
        self.pool = pool
        
        self.linear_projection = nn.Linear(patch_size**2, dim)
        self.pos_embedding = nn.Parameter(torch.randn(1, max_patches + 1, dim))
        self.linear_projection_emb = nn.Linear(dim//2, dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)
        
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.mlp_head = nn.Sequential(nn.LayerNorm(dim),nn.Linear(dim, 14))
        
    def forward(self, img):
        
        x = self.linear_projection(img)

        cls_emb = self.cls_token.repeat(x.size(0), 1, 1)   
        x = torch.cat((cls_emb, x), dim=1)

        x = x + self.pos_embedding[:,:x.size(1),:]

        x = self.dropout(x)
        
        mask = (x[:,:,0]==0)
        x = self.transformer(x, mask)
        
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0, :]
        
        x = self.mlp_head(x)
        
        return x

In [None]:
# Credits to: https://nlp.seas.harvard.edu/2018/04/03/attention.html

class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))

In [None]:
d_model = 256 #dim
batch_size = 16
patch_size = 16
max_patches = 2048
num_layers = 4 #depth
heads = 4
feed_forward = 512 #mlp_dim

transformer = Vision_Transformer(patch_size, max_patches, d_model, num_layers, heads, feed_forward)
transformer = transformer.to(device)

criterion = nn.CrossEntropyLoss()
criterion =criterion.to(device)

warm_up_steps = 4000
optimizer = NoamOpt(d_model, 1, warm_up_steps, optim.Adam(transformer.parameters(), lr=0))

### Preprocessing the data

In [None]:
class ImageData(Dataset):
    def __init__(self, labels_df, data_dir):
        super().__init__()
        self.labels_df = labels_df
        self.data_dir = data_dir

    def __len__(self):
        return len(self.labels_df)
    
    def __getitem__(self, index):       
        img_name = self.labels_df.image_id[index]
        label = self.labels_df.class_id[index]
        img_path = self.data_dir+img_name+".jpg"
        img = plt.imread(img_path)
        img = img[self.labels_df.y_min.values[index]:self.labels_df.y_max.values[index]+1,self.labels_df.x_min.values[index]:self.labels_df.x_max.values[index]+1]
        img = torch.tensor(img, dtype=torch.long)
        img = img.unsqueeze(0).unsqueeze(1)
        if img.size(3)%patch_size != 0:
            if img.size(3)%2 == 1:
                h_padding = (patch_size - (img.size(3)%patch_size)) // 2 + 1
            else:
                h_padding = (patch_size - (img.size(3)%patch_size)) // 2     
        else:
            h_padding = 0
        if img.size(2)%patch_size != 0:
            if img.size(2)%2 == 1:
                w_padding = (patch_size - (img.size(2)%patch_size)) // 2 + 1 
            else:
                w_padding = (patch_size - (img.size(2)%patch_size)) // 2       
        else:
            w_padding = 0
        img = F.pad(img, (h_padding, h_padding, w_padding, w_padding))
        img = img.squeeze(0).squeeze(0)
        img = img.unfold(0, patch_size, patch_size).unfold(1, patch_size, patch_size)
        img = img.reshape(-1, patch_size, patch_size)
        img = img.reshape(img.size(0),-1)
        return img, label

In [None]:
def my_collate(batch):
    data = [item[0] for item in batch if item[0].size(0) <= max_patches]  
    target = [item[1] for item in batch if item[0].size(0) <= max_patches]
    target = torch.tensor(target,dtype=torch.long)
    data = pad_sequence(data, batch_first=True, padding_value=0)
    data = torch.tensor(data, dtype=torch.float)
    return [data, target]

In [None]:
traindata = ImageData(train_df,'../input/vinbigdata-original-image-dataset/vinbigdata/train/')
trainset = DataLoader(dataset=traindata,batch_size=batch_size,collate_fn=my_collate, shuffle=True, num_workers=8, pin_memory=True)

valdata = ImageData(val_df,'../input/vinbigdata-original-image-dataset/vinbigdata/train/')
valset = DataLoader(dataset=valdata,batch_size=batch_size,collate_fn=my_collate,shuffle=False, num_workers=8, pin_memory=True)

### Training starts..

In [None]:
epochs = 12

train_loss = []
valid_loss = []
valid_acc = []

for epoch in tqdm(range(1, epochs+1)):
        
    train_loss_batches = []
    
    ############################### Training ##################################
    transformer.train()
    
    for i, (image_bboxes, target) in enumerate(tqdm(trainset)):
        target = target.to(device)
        image_bboxes = image_bboxes.to(device)
 
        optimizer.optimizer.zero_grad()
        output = transformer(image_bboxes)
        
        output = output.squeeze(1)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        acc = (output.argmax(dim=1) == target).float()
        acc_batch = acc.mean()

        train_loss_batches.append(loss.item())
        
        if i == 0:
            acc_train = acc
        else:
            acc_train = torch.cat((acc_train,acc))
    
    acc_train_epoch = acc_train.mean()
    loss_train = np.mean(train_loss_batches)
    train_loss.append(loss_train)
    
    ############################### Validation ################################
    
    transformer.eval()
    with torch.no_grad():
        valid_loss_batches = []
        for i, (image_bboxes, target) in enumerate(valset):      
            target = target.to(device)
            image_bboxes = image_bboxes.to(device)

            output = transformer(image_bboxes)
                
            output = output.squeeze(1)
            loss_val = criterion(output, target)
            acc = (output.argmax(dim=1) == target).float()
                
            valid_loss_batches.append(loss.item())
                
            if i == 0:
                acc_val = acc
            else:
                acc_val = torch.cat((acc_val,acc))
        
        acc_val_epoch = acc_val.mean()
        loss_val = np.mean(valid_loss_batches)
        
        valid_loss.append(loss_val)
        valid_acc.append(acc_val_epoch)
        
    print(f"Epoch: {epoch}............. Loss: {loss_train:.4f} - ACC: {acc_train_epoch:.4f} - Loss val: {loss_val:.4f} - ACC val: {acc_val_epoch:.4f}")

In [None]:
print(f"The highest Val ACC of {max(valid_acc):.4f} has been reached after {valid_acc.index(max(valid_acc))+1} epochs.")

In [None]:
torch.save(transformer, "/kaggle/working/transformer_model.pt")

In [None]:
# credits to: https://github.com/Bjarten/early-stopping-pytorch/blob/master/MNIST_Early_Stopping_example.ipynb

# visualize the loss as the network trained
fig = plt.figure(figsize=(10,8))
plt.plot(range(1,len(train_loss)+1),train_loss, label='Training Loss')
plt.plot(range(1,len(valid_loss)+1),valid_loss,label='Validation Loss')

# find position of lowest validation loss
minposs = valid_loss.index(min(valid_loss))+1 
plt.axvline(minposs, linestyle='--', color='r',label='Lowest validation loss')

plt.xlabel('epochs')
plt.ylabel('loss')
plt.ylim(0, 3) # consistent scale
plt.xlim(1, len(train_loss)) # consistent scale
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()
fig.savefig('loss_plot.png', bbox_inches='tight')