In [None]:
# install packages
!pip install vit-pytorch
!pip install linformer


In [None]:
import os
import cv2

import pandas as pd
import numpy as np
from random import shuffle
import torch
from torch.utils.data import DataLoader, Dataset


import pandas as pd
import numpy as np
from random import shuffle
from sklearn.metrics import precision_recall_curve, roc_curve
import matplotlib.pyplot as plt
import wandb

# from vit_pytorch.efficient import ViT as ViT_eff
# from vit_pytorch import ViT
# from linformer import Linformer

# Dataset

Custom dataset class for Aptos 2019


### Data utils

In [None]:
class_list = ['No DR', 'DR']
def binarize(x):
    if x != 0:
        return 1
    else:
        return x
    
def balance(df, target='binary_target'):
    # We can balance the dataset
    df_0 = df[df[target] == 0]
    df_1 = df[df[target] == 1].sample(len(df_0), random_state=101)
    df_data = pd.concat([df_0, df_1], axis=0).reset_index(drop=True)
    df_data = shuffle(df_data)
    
    return df_data

In [None]:
def train_validate_test_split(df, train_percent=.6, validate_percent=.2, 
                              seed=None):
    

    np.random.seed(seed)
    perm = np.random.permutation(df.index)
    m = len(df.index)
    train_end = int(train_percent * m)
    validate_end = int(validate_percent * m) + train_end
    train = df.iloc[perm[:train_end]]
    validate = df.iloc[perm[train_end:validate_end]]
    test = df.iloc[perm[validate_end:]]

    return train, validate, test


def get_splits(csv_path,  target_col='diagnosis', balance=False):
    df = pd.read_csv(csv_path)
    # TODO: conditional statmenent if to binarize the labels
    
    df['binary_target'] = df[target_col].apply(binarize)
    
    if balance:
        df = balance(df)
    
    # Split the entire dataframe into train, val and test splits
    df_data = df.drop(target_col, axis=1).reset_index(drop=True)
    train_df, validate_df, test_df = train_validate_test_split(df_data)

    # reset_index for all
    train_df = train_df.reset_index(drop=True)
    validate_df = validate_df.reset_index(drop=True)
    test_df = test_df.reset_index(drop=True)

    return train_df, validate_df, test_df

In [None]:
class AptosDataset(Dataset):
    def __init__(self, df_data, data_dir, transform):
        super().__init__()
        self.df = df_data.values
        self.data_dir = data_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        img_name, label = self.df[index]
        img_path = os.path.join(self.data_dir, img_name+'.png')
        image = cv2.imread(img_path)
        image = self.transform(image)
            
        return image, label

## DataLoader

In [None]:
## Data Module
def collate_fn(examples):
    pixel_values = torch.stack([example[0] for example in examples])
    labels = torch.tensor([example[1] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [None]:
class AptosDataModule():
    def __init__(self, batch_size, train_df, valid_df, test_df, _train_transforms, _val_transforms, data_dir: str = './'):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.train_df = train_df
        self.valid_df = valid_df
        self.test_df = test_df
        self._train_transforms = _train_transforms
        self._val_transforms = _val_transforms

        self.train_set = AptosDataset(self.train_df, self.data_dir, transform=self._train_transforms)
        self.validate_set = AptosDataset(self.valid_df, self.data_dir, transform=self._val_transforms)
        self.test_set = AptosDataset(self.test_df, self.data_dir, transform=self._val_transforms)

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size,
                          shuffle=True, collate_fn=collate_fn, num_workers=2)

    def val_dataloader(self):
        return DataLoader(self.validate_set, batch_size=self.batch_size, 
                          shuffle=False, collate_fn=collate_fn,num_workers=2)
    
    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size,
                          shuffle=False, collate_fn=collate_fn, num_workers=2)




In [None]:
class AptosDataModule():
    def __init__(self, batch_size, train_df, valid_df, test_df, _train_transforms, _val_transforms, data_dir: str = './'):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.train_df = train_df
        self.valid_df = valid_df
        self.test_df = test_df
        self._train_transforms = _train_transforms
        self._val_transforms = _val_transforms

        self.train_set = AptosDataset(self.train_df, self.data_dir, transform=self._train_transforms)
        self.validate_set = AptosDataset(self.valid_df, self.data_dir, transform=self._val_transforms)
        self.test_set = AptosDataset(self.test_df, self.data_dir, transform=self._val_transforms)

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size,
                          shuffle=True, collate_fn=collate_fn, num_workers=2)

    def val_dataloader(self):
        return DataLoader(self.validate_set, batch_size=self.batch_size, 
                          shuffle=False, collate_fn=collate_fn,num_workers=2)
    
    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size,
                          shuffle=False, collate_fn=collate_fn, num_workers=2)



In [None]:
from transformers import ViTFeatureExtractor

In [None]:
from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize,
                                    ToPILImage,
                                    ToTensor,)

In [None]:
def transform(pre_train_model):

    feature_extractor = ViTFeatureExtractor.from_pretrained(pre_train_model)


    normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
    _train_transforms = Compose(
            [
                ToPILImage(mode='RGB'),
                RandomResizedCrop(feature_extractor.size),
                RandomHorizontalFlip(),
                ToTensor(),
                normalize,
            ]
        )

    _val_transforms = Compose(
            [
                ToPILImage(mode='RGB'),
                Resize(feature_extractor.size),
                CenterCrop(feature_extractor.size),
                ToTensor(),
                normalize,
            ]
        )

    return _train_transforms, _val_transforms

In [None]:
pre_train_model = "google/vit-base-patch16-224-in21k" #"google/vit-large-patch16-224-in21k"

In [None]:
_train_transforms, _val_transforms = transform(pre_train_model)

In [None]:

batch_size = 16
Data_Path = '../input/aptos2019-blindness-detection'
train_dir = Data_Path+ "/train_images"
csv_file = Data_Path + "/train.csv"
target_col = 'diagnosis'
target_names = ['No DR', 'DR']
num_epochs = 1
learning_rate=0.005
weight_decay=0.0002

train_df, validate_df, test_df  = get_splits(csv_path=csv_file, target_col=target_col)


In [None]:
data_module = AptosDataModule(batch_size,
                            train_df,
                            validate_df, 
                            test_df,
                            _train_transforms,
                            _val_transforms,
                            train_dir
                            )

In [None]:
train_data_loader = data_module.train_dataloader()
val_data_loader = data_module.val_dataloader()
test_data_loader = data_module.test_dataloader()

# Pretrain_ ViT Model

In [None]:
import torch.nn as nn
from transformers import ViTModel
from transformers import ViTConfig
from transformers import PretrainedConfig
from transformers import PreTrainedModel

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
vit = ViTModel.from_pretrained(pre_train_model,return_dict=True)
config = ViTConfig().from_pretrained(pre_train_model,
                                    attention_probs_dropout_prob = 0.2)
print(config)

In [None]:
# save_pretrained
config.save_pretrained(save_directory="./")

In [None]:
class ViTModuleImageClassification(PreTrainedModel):
    def __init__(self, pre_train_model,drop_out, num_classes, config):
        super(ViTModuleImageClassification, self).__init__(config)
        self.vit = ViTModel.from_pretrained(pre_train_model,return_dict=True) # made changes in the function not to return dict
        self.body = nn.Sequential(*list(self.vit.children())[:-2]) # return the model without the last two layers
        self.dropout = nn.Dropout(drop_out)
        self.num_classes = num_classes
        self.classifier = nn.Linear(self.vit.config.hidden_size, self.num_classes)
    

    def forward(self, pixel_values):
        outputs = self.body(pixel_values) # main model
        output = self.dropout(outputs.last_hidden_state[:,0])
        att_weights = outputs.attentions
        logits = self.classifier(output)
        
        return logits

In [None]:
NUM_CLASSES = 2
model = ViTModuleImageClassification(pre_train_model,config=config, drop_out=0.1, num_classes=NUM_CLASSES)
model = model.to(device)

## Train and evaluation FUnction

In [None]:
def train(model, criterion, device, train_data_loader, optimizer, epoch):
    model.train()
    loss_x = 0
    correct_train = 0

    for i, (batch) in enumerate(train_data_loader):
        pixel_values = batch['pixel_values'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        logits = model(pixel_values)
        _,predictions = torch.max(logits, 1)

        loss = criterion(logits, labels)
        loss_x += loss.item()
        correct_train += (predictions == labels).sum().item()

        loss.backward()
        optimizer.step()

    # scheduler.step()
    accuracy = 100*float(correct_train)/len(train_data_loader.dataset)
    loss_ = loss_x / len(train_data_loader.dataset)
    print("Training accuracy for epoch {} is {}".format(epoch, accuracy))
    print("Training Loss for epoch {} is {}".format(epoch, loss_))

def evaluate(model, criterion, device, val_data_loader, epoch):

    model.eval()
    loss_x = 0
    correct_test = 0

    with torch.no_grad():
        for i, (batch) in enumerate(val_data_loader):

            pixel_values = batch['pixel_values'].to(device)
            labels = batch['labels'].to(device)

            logits = model(pixel_values)
            _,predictions = torch.max(logits, 1)

            loss = criterion(logits, labels)

            loss_x += loss.item()
            
            correct_test += (predictions == labels).sum().item()

    acc = 100*float(correct_test)/len(val_data_loader.dataset)
    loss_ = loss_x / len(val_data_loader.dataset)

    print("Val accuracy for epoch{} is {}".format(epoch,acc))
    print("Val Loss for epoch {} is {}".format(epoch, loss_))


In [None]:
def test(test_data_loader, model):
    predictions, targets = [], []

    with torch.no_grad():
        model.eval()

        for i, (batch) in enumerate(test_data_loader):
            # evaluate the model on the test set
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['labels'].to(device)

            logits = model(pixel_values)  # [bs, n_class]
            _,pred = torch.max(logits, 1)  # [bs]


            targets.extend(labels.cpu().numpy())
            predictions.extend(pred.cpu().numpy())

    return predictions, targets


In [None]:
import torch.nn as nn

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 10, 15], gamma=0.1, verbose=True)
criterion = nn.CrossEntropyLoss().to(device)


In [None]:
for epoch in range(num_epochs):
    train(model, criterion, device, train_data_loader, optimizer, epoch)
    evaluate(model, criterion, device, val_data_loader, epoch)


In [None]:
model.save_pretrained(save_directory='./')

In [None]:
from transformers import ViTForImageClassification

In [None]:
# Load ViT
vit = ViTModel.from_pretrained("./config.json").to(
    device
)
vit.eval()

**Vision transformer Explainability**

[Reference Link](https://github.com/jacobgil/vit-explain)

In [None]:
# for name, module in model.named_modules():
#     print(name)

In [None]:


att_name = 'attention.output'
attentions = []
attention_gradients = []
def get_attention(module, input, output):
    attentions.append(output.cpu())
    
    
def get_attention_gradient(module, grad_input, grad_output):
    attention_gradients.append(grad_input[0].cpu())
    
for name, module in model.named_modules():
    if att_name in name:
        module.register_forward_hook(get_attention)
        module.register_backward_hook(get_attention_gradient)
        
        


In [None]:
attentions

In [None]:
import torch
from PIL import Image
import numpy
import sys
from torchvision import transforms
import numpy as np
import cv2

def grad_rollout(attentions, gradients, discard_ratio):
    result = torch.eye(attentions[0].size(-1))
    with torch.no_grad():
        for attention, grad in zip(attentions, gradients):                
            weights = grad
            attention_heads_fused = (attention*weights).mean(axis=1)
            attention_heads_fused[attention_heads_fused < 0] = 0

            # Drop the lowest attentions, but
            # don't drop the class token
            flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
            _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False)
            #indices = indices[indices != 0]
            flat[0, indices] = 0

            I = torch.eye(attention_heads_fused.size(-1))
            a = (attention_heads_fused + 1.0*I)/2
            a = a / a.sum(dim=-1)
            result = torch.matmul(a, result)
    
    # Look at the total attention between the class token,
    # and the image patches
    mask = result[0, 0 , 1 :]
    # In case of 224x224 image, this brings us from 196 to 14
    width = int(mask.size(-1)**0.5)
    mask = mask.reshape(width, width).numpy()
    mask = mask / np.max(mask)
    return mask    

class VITAttentionGradRollout:
    def __init__(self, model, attention_layer_name='vit.encoder.layer.11.output.dense',
        discard_ratio=0.9):
        self.model = model
        self.discard_ratio = discard_ratio
        for name, module in self.model.named_modules():
            if attention_layer_name in name:
                module.register_forward_hook(self.get_attention)
                module.register_backward_hook(self.get_attention_gradient)

        self.attentions = []
        self.attention_gradients = []

    def get_attention(self, module, input, output):
        self.attentions.append(output.cpu())

    def get_attention_gradient(self, module, grad_input, grad_output):
        self.attention_gradients.append(grad_input[0].cpu())

    def __call__(self, input_tensor, category_index):
        self.model.zero_grad()
        output = self.model(input_tensor)
        output = output['logits']
        category_mask = torch.zeros(output.size())
        category_mask[:, category_index] = 1
        loss = (output*category_mask).sum()
        loss.backward()

        return grad_rollout(self.attentions, self.attention_gradients,
            self.discard_ratio)

In [None]:
import torch
from PIL import Image
import numpy
import sys
from torchvision import transforms
import numpy as np
import cv2

def rollout(attentions, discard_ratio, head_fusion):
    result = torch.eye(attentions[0].size(-1))
    with torch.no_grad():
        for attention in attentions:
            if head_fusion == "mean":
                attention_heads_fused = attention.mean(axis=1)
            elif head_fusion == "max":
                attention_heads_fused = attention.max(axis=1)[0]
            elif head_fusion == "min":
                attention_heads_fused = attention.min(axis=1)[0]
            else:
                raise "Attention head fusion type Not supported"

            # Drop the lowest attentions, but
            # don't drop the class token
            flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
            _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False)
            indices = indices[indices != 0]
            flat[0, indices] = 0

            I = torch.eye(attention_heads_fused.size(-1))
            a = (attention_heads_fused + 1.0*I)/2
            a = a / a.sum(dim=-1)

            result = torch.matmul(a, result)
    
    # Look at the total attention between the class token,
    # and the image patches
    mask = result[0, 0 , 1 :]
    # In case of 224x224 image, this brings us from 196 to 14
    width = int(mask.size(-1)**0.5)
    mask = mask.reshape(width, width).numpy()
    mask = mask / np.max(mask)
    return mask    

class VITAttentionRollout:
    def __init__(self, model, attention_layer_name='attn_drop', head_fusion="mean",
        discard_ratio=0.9):
        self.model = model
        self.head_fusion = head_fusion
        self.discard_ratio = discard_ratio
        for name, module in self.model.named_modules():
            if attention_layer_name in name:
                module.register_forward_hook(self.get_attention)

        self.attentions = []

    def get_attention(self, module, input, output):
        self.attentions.append(output.cpu())

    def __call__(self, input_tensor):
        self.attentions = []
        with torch.no_grad():
            output = self.model(input_tensor)

        return rollout(self.attentions, self.discard_ratio, self.head_fusion)

In [None]:
model.eval()
x = x.to('cpu')

grad_rollout = VITAttentionGradRollout(model, discard_ratio=0.9)
mask = grad_rollout(x, category_index=0)

att_name = 'vit.embeddings.patch_embeddings'#vit.encoder.layer.11.output.dense'
attentions = []


for name, module in model.named_modules():
    print(name)
    
def get_attention(module, input, output):
    attentions.append(output.cpu())
for name, module in model.named_modules():
    if att_name in name:
        module.register_forward_hook(get_attention)
        
attentions

## trying from scratch

In [None]:
# from torch.autograd import Variable
# import torch

# """I cannot load the attention weights with this."""
# mod = ViTModule(pre_train_model, drop_out=0.1, num_classes=2)
# mod = mod.to(device)

# x=Variable(torch.FloatTensor(1, 3, 224,224))
# x = x.to(device)

# reuslts = mod(x)
# # reuslts['logits'].size()
# # reuslts['att_weights']

# from torch.autograd import Variable
# import torch

# """I cannot load the attention weights with this."""
# model = ViTModule(pre_train_model, drop_out=0.1, num_classes=2)
# model = model.to(device)

# x=Variable(torch.FloatTensor(16, 3, 224,224))
# x = x.to(device)

# logits, att_weights = model(x, output_attentions=True)
# att_weights


# """ From using ViTForImageCalssification
#     we can load the atention weights"""
# from transformers import ViTForImageClassification

# vit = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224").to(
#     device
# )

# result = vit(x, output_attentions=True)
# attention_probs = torch.stack(result[1]).squeeze(1)
# attention_probs.size()
# result[0].size()

In [None]:
vit = ViTModel.from_pretrained(pre_train_model,output_attentions=true, return_dict=True) # made changes in the function not to return dict
body = nn.Sequential(*list(vit.children())[:-2]) # return the model without the last two layers
dropout = nn.Dropout(drop_out)
classifier = nn.Linear(self.vit.config.hidden_size, self.num_classes)
    

    def forward(self, pixel_values, output_attentions=False):
        outputs = self.body(pixel_values) # main model
        output = self.dropout(outputs.last_hidden_state[:,0])
        att_weights = outputs.attentions
        logits = self.classifier(output)
        if output_attentions:
            return {'logits': logits,
                   'att_weights':att_weights}
        
        return logits

In [None]:
reuslts['att_weights'].size()

In [None]:
model.save_pretrained("weights_hf_test.pt")

In [None]:
# save model

path = "weights_test.pt"
print("Model saved at {}".format(path))
torch.save(model, path)




In [None]:
model=load_checkpoint('./weights_test.pt', model)
model.save_pretrained("my path")

In [None]:
vit = ViTForImageClassification.from_pretrained('')

In [None]:
batch = next(iter(test_data_loader))
x = batch['pixel_values']
y = batch['labels']
print(x.shape, y.shape)

In [None]:
from transformers import ViTForImageClassification

In [None]:
trained_model = torch.load('./weights_test.pt')

In [None]:
module= ViTModule(pre_train_model, drop_out=0.1, num_classes=2)

model.load_state_dict(torch.load('./weights_test.pt'))

In [None]:
vit = ViTForImageClassification.from_pretrained("google/vit-large-patch16-224-in21k",
                                               state_dict=trained_model).to(device)
vit.eval()

In [None]:
trained_model.eval()
result = trained_model(x.to(device), output_attentions=True)

In [None]:
att = result.last_hidden_state

In [None]:
class_id = result[0].argmax()

In [None]:
class_id

In [None]:

attention_probs = torch.stack(result[1]).squeeze(1)

# Average the attention at each layer over all heads
attention_probs = torch.mean(attention_probs, dim=1)
residual = torch.eye(attention_probs.size(-1)).to(device)
attention_probs = 0.5 * attention_probs + 0.5 * residual

## Testing Model

In [None]:
# Test Model Phase
predictions, targets = test(test_data_loader, model)

In [None]:
from sklearn.metrics import accuracy_score

In [None]:
accuracy = accuracy_score(targets, predictions)

In [None]:
accuracy

In [None]:
from sklearn.metrics import classification_report

In [None]:
# confusion matrix
print('This is the  classification report:...')
print(classification_report(targets, predictions, digits=3, target_names=target_names), '\n')



In [None]:
from sklearn.metrics import precision_recall_curve

In [None]:
# wandb.sklearn.plot_confusion_matrix(targets, predictions, target_names)

# Log plot of ROC

# precision recall curve
#calculate precision and recall
precision, recall, thresholds = precision_recall_curve(targets, predictions)

#create precision recall curve
fig, ax = plt.subplots()
ax.plot(recall, precision, color='purple')
#add axis labels to plot
ax.set_title('Precision-Recall Curve')
ax.set_ylabel('Precision')
ax.set_xlabel('Recall')

In [None]:
from sklearn.metrics import PrecisionRecallDisplay

In [None]:
display = PrecisionRecallDisplay.from_predictions(targets, predictions, name="LinearSVC")
_ = display.ax_.set_title("2-class Precision-Recall curve")

In [None]:
fpr, tpr, _ = roc_curve(targets, predictions)
plt.plot(fpr, tpr, lw=2)

plt.xlabel("false positive rate")
plt.ylabel("true positive rate")
plt.legend(loc="best")
plt.title("ROC curve")
plt.show()

In [None]:
!nvidia-smi
