In [1]:
# Standard libraries
import os
import shutil
import numpy as np
import random
import math
import json
from functools import partial
from PIL import Image

# Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

# Torchvision
import torchvision
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.datasets import DatasetFolder, ImageFolder
from torch.utils.data import DataLoader

# Imports for ROC AUC
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import roc_curve, roc_auc_score, auc
from itertools import cycle

# Imports for PyTorch Lightning
import pytorch_lightning as pl    
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
pl.seed_everything(42)

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

Global seed set to 42


Device: cuda:0


In [2]:
# Kaggle kernel
LENSES_DATASET_PATH = "../input/lensestargz/lenses" # Should point to the root of the datasets folder
BATCH_SIZE = 16
TEST_SPLIT=0.1
VAL_SPLIT=0.01
INPUT_HEIGHT = 150
INPUT_WIDTH = 150
TRAIN='train'
TEST='test'
VAL='val'

In [3]:
lenses_files = []
for folder in {'sub', 'no_sub'}:
    for file in os.listdir(os.path.join(LENSES_DATASET_PATH, folder)):
        if file.endswith(".jpg"):
            lenses_files.append(os.path.join(LENSES_DATASET_PATH, folder, file))

In [4]:
def copy_images(imagePaths, folder):
    if not os.path.exists(folder):
        os.makedirs(folder)
    for path in imagePaths:
        imageName = path.split(os.path.sep)[-1]
        label = path.split(os.path.sep)[-2]
        labelFolder = os.path.join(folder, label)
        if not os.path.exists(labelFolder):
            os.makedirs(labelFolder)
        destination = os.path.join(labelFolder, imageName)
        shutil.copy(path, destination)

In [5]:
np.random.shuffle(lenses_files)
valPathsLen = int(len(lenses_files) * VAL_SPLIT)
testPathsLen = int(len(lenses_files) * TEST_SPLIT)
trainPathsLen = len(lenses_files) - valPathsLen - testPathsLen
print(f"Train : {trainPathsLen}, Test: {testPathsLen}, Val:{valPathsLen}")

In [6]:
testInd = trainPathsLen + testPathsLen

In [7]:
trainPaths = lenses_files[:trainPathsLen]
testPaths = lenses_files[trainPathsLen:testInd]
valPaths = lenses_files[testInd:]

In [8]:
copy_images(trainPaths, TRAIN)
copy_images(testPaths, TEST)
copy_images(valPaths, VAL)

In [9]:
trainTransforms = transforms.Compose([
    transforms.Resize(size=(224, 224)),
    transforms.ToTensor()
])
testTransforms = transforms.Compose([
    transforms.Resize(size=(224, 224)),
    transforms.ToTensor()
])

In [10]:
trainDataset = ImageFolder(root=TRAIN, transform=trainTransforms)
testDataset = ImageFolder(root=TEST, transform=testTransforms)
valDataset = ImageFolder(root=VAL, transform=testTransforms)

print(f"[INFO] Training dataset contains {len(trainDataset)} samples.")
print(f"[INFO] Test dataset contains {len(testDataset)} samples.")
print(f"[INFO] Validation dataset contains {len(valDataset)} samples.")

In [11]:
trainDataLoader = DataLoader(trainDataset, 
        batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
testDataLoader = DataLoader(testDataset, shuffle=False, drop_last=False, batch_size=BATCH_SIZE, num_workers=2)
valDataLoader = DataLoader(valDataset, shuffle=False, drop_last=False, batch_size=BATCH_SIZE, num_workers=2)

In [12]:
def visualize_batch(batch, classes, dataset_type):
    fig = plt.figure("{} batch".format(dataset_type), figsize=(20, 5))
    for i in range(0, BATCH_SIZE):
        ax = plt.subplot(2, 8, i+1)
        image = batch[0][i].cpu().numpy()
        image = image.transpose((1, 2, 0))
        image = (image * 255.0).astype("uint8")
        idx = batch[1][i]
        label = classes[idx]
        plt.imshow(image)
        plt.title(label)
        plt.axis("off")
    plt.tight_layout()
    plt.show()

In [13]:
trainBatch = next(iter(trainDataLoader))
print(trainBatch[0].shape)
visualize_batch(trainBatch, trainDataset.classes, "train")

### Breaking image into multiple patches and the flattening them.

In [14]:
def img_to_patch(x, patch_size, flatten_channels=True):
    B, C, H, W = x.shape
    #x = x.reshape(B, C, H//patch_size, patch_size, W//patch_size, patch_size)
    x = x.reshape(B, C, torch.div(H, patch_size, rounding_mode='trunc'), patch_size, torch.div(W, patch_size, rounding_mode='trunc'), patch_size)
    
    x = x.permute(0, 2, 4, 1, 3, 5)
    x = x.flatten(1,2)
    if flatten_channels:
        x = x.flatten(2,4)
    return x

# Vision Transformer  
![Image ViT](https://production-media.paperswithcode.com/methods/Screen_Shot_2021-01-26_at_9.43.31_PM_uI4jjMq.png)

## Patch Embeddings

In [15]:
img_patches = img_to_patch(trainBatch[0], patch_size=16, flatten_channels=False)

fig, ax = plt.subplots(trainBatch[0].shape[0], 1, figsize=(40,10))

print('Display patch embedding result: ')
# 224x224 images, split as 16x16 patches so 196 patches in total, displaying 98 patches flattened as a single row

for i in range(trainBatch[0].shape[0]):
    img_grid = make_grid(img_patches[i], nrow=98, normalize=True, pad_value=0.8)
    img_grid = img_grid.permute(1, 2, 0)
    ax[i].imshow(img_grid)
    ax[i].axis('off')
plt.show()
plt.close()

## The Vision Transformer model

In [16]:
class AttentionBlock(nn.Module):
    """
        embed_dim - dims of the input and attn of feature vectors
        hidden_dim - dims of hidden layer in FFN i.e 2 x embed_dim
        num_heads - number of heads in MHA block
        dropout - amount of dropout to apply in FFN
    """
    def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
        super().__init__()

        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.linear = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )


    def forward(self, x):
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.linear(self.layer_norm_2(x))
        return x

In [17]:
class VisionTransformer(nn.Module):
    """
        embed_dim - dims of the input feature vectors
        hidden_dim - dims of the hidden layer in the FFN within transformer
        num_channels - num of channels of the input i.e. 3 in our case of RGB images
        num_heads - num of heads to use in the MHA block
        num_layers - num of layers in transformer
        num_classes - num of classes to predict
        patch_size - num of pixels that the patches have per dim
        num_patches - max number of patches of an image
        dropout - amount of dropout to apply in the FFN and on the input encoding
    """
    def __init__(self, embed_dim, hidden_dim, num_channels, num_heads, num_layers, num_classes, patch_size, num_patches, dropout=0.0):
        super().__init__()

        self.patch_size = patch_size
        
        self.input_layer = nn.Linear(num_channels*(patch_size**2), embed_dim)
        self.transformer = nn.Sequential(*[AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout) for _ in range(num_layers)])
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )
        self.dropout = nn.Dropout(dropout)

        self.cls_token = nn.Parameter(torch.randn(1,1,embed_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1,1+num_patches,embed_dim))


    def forward(self, x):
        x = img_to_patch(x, self.patch_size)
        B, T, _ = x.shape
        x = self.input_layer(x)

        cls_token = self.cls_token.repeat(B, 1, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.pos_embedding[:,:T+1]

        x = self.dropout(x)
        x = x.transpose(0, 1)
        x = self.transformer(x)

        cls = x[0]
        out = self.mlp_head(cls)
        return out

In [18]:
# We will be using PyTorch's Lightning module to organize our model code.

class ViT(pl.LightningModule):
    def __init__(self, model_kwargs, lr):
        super().__init__()
        self.save_hyperparameters()
        self.model = VisionTransformer(**model_kwargs)
        self.example_input_array = next(iter(trainDataLoader))[0]
        self.predictions = [] # Actual predictions for ROC AUC

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
        lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100,150], gamma=0.1)
        return [optimizer], [lr_scheduler]

    def _calculate_loss(self, batch, mode="train"):
        imgs, labels = batch
        preds = self.model(imgs)
        loss = F.cross_entropy(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()
        
        if mode == 'test':
            self.predictions.append(preds.argmax(dim=-1))
        
        self.log(f'{mode}_loss', loss, )
        self.log(f'{mode}_acc', acc)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._calculate_loss(batch, mode="train")
        return loss
    
    def validation_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="val")

    def test_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="test")

In [19]:
CHECKPOINT_PATH = "./saved_models/"

In [20]:
def train_model(**kwargs):
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, "ViT"),
                         gpus=1 if str(device)=="cuda:0" else 0,
                         max_epochs=180,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),
                                    LearningRateMonitor("epoch")],
                         enable_progress_bar=False,
                         log_every_n_steps=4)
    
    #trainer.logger._log_graph = True         # If True, we plot the computation graph in tensorboard
    #trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

    pl.seed_everything(42, workers=True) # To be reproducable
    model = ViT(**kwargs)
    trainer.fit(model=model, train_dataloaders=trainDataLoader, val_dataloaders=valDataLoader)
    
    # Using PyTorch Lightning we can load a revised model and use it as the best model on test set for best results, for a straighforward solution, we will use a new model.
    val_result = trainer.validate(model, dataloaders=valDataLoader, verbose=False)
    test_result = trainer.test(model, dataloaders=testDataLoader, verbose=False)
    
    result = {"test": test_result[0]["test_acc"], "val": val_result[0]["val_acc"]}
    
    return model, result

In [21]:
EMBED_DIM=256
HIDDEN_DIM=512
RESIZE_IMG=224
NUM_HEADS=8
NUM_LAYERS=6
PATCH_SIZE=16
NUM_CHANNELS=3
NUM_PATCHES=int((RESIZE_IMG*RESIZE_IMG) / (PATCH_SIZE*PATCH_SIZE))
NUM_CLASSES=2
DROPOUT=0.2

## Training the model

In [22]:
model, results = train_model(model_kwargs={
                                'embed_dim': EMBED_DIM,
                                'hidden_dim': HIDDEN_DIM,
                                'num_heads': NUM_HEADS,
                                'num_layers': NUM_LAYERS,
                                'patch_size': PATCH_SIZE,
                                'num_channels': NUM_CHANNELS,
                                'num_patches': NUM_PATCHES,
                                'num_classes': NUM_CLASSES,
                                'dropout': DROPOUT
                            }, lr=3e-4)
print("ViT results", results)

## ROC AUC Score

In [23]:
class LB(LabelBinarizer):
    def transform(self, y):
        Y = super().transform(y)
        if self.y_type_ == 'binary':
            return np.hstack((Y, 1-Y))
        else:
            return Y

    def inverse_transform(self, Y, threshold=None):
        if self.y_type_ == 'binary':
            return super().inverse_transform(Y[:, 0], threshold)
        else:
            return super().inverse_transform(Y, threshold)

In [24]:
y_score = torch.cat(model.predictions).cpu().detach().numpy()

y_test = []
for _, labels in testDataLoader:
    y_test.append(labels.cpu().detach().numpy())

y_test = np.concatenate(y_test)

In [25]:
lb = LB()
y_test = lb.fit_transform(y_test)
y_score = lb.fit_transform(y_score)

In [27]:
n_classes = y_test.shape[1]
fpr = dict()
tpr = dict()
roc_auc = dict()

for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])
colors = ['orange', 'green']

for i, color in zip(range(n_classes), colors):
    plt.plot(fpr[i], tpr[i], color=color, lw=2, label='ROC curve of class {0} (area = {1:0.5f})' ''.format(i, roc_auc[i]))
             
plt.plot([0, 1], [0, 1], 'k--', lw=2)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc="lower right")
plt.show()

### Save Model

In [28]:
torch.save(model.state_dict(), 'st5_model.pth')