# Visual Transformer with Linformer

Training Visual Transformer on *Mayo Clinit*

* Effecient Attention Implementation - https://github.com/lucidrains/vit-pytorch#efficient-attention

In [None]:
!pip -q install vit_pytorch linformer
!pip install nystrom-attention

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


## Import Libraries

In [None]:
from __future__ import print_function

import glob
from itertools import chain
import os
import random
import zipfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from linformer import Linformer
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

from vit_pytorch.efficient import ViT


In [None]:
print(f"Torch: {torch.__version__}")

Torch: 1.12.1+cu113


In [None]:
# Training settings
batch_size = 3
epochs = 10
lr = 3e-5
gamma = 0.7
seed = 42

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

In [None]:
device = 'cuda'
# !gdown 1L3bnEKVwq3g59GR6KW5EL1JratM3WRam
# !unzip data.zip

In [None]:
%pwd

'/content'

In [None]:
! pwd

/content


## Load Data

In [None]:
train_dir = '/content/data/train'
test_dir = '/content/data/test'

train_dirArg = '/content/gdrive/MyDrive/Tham/mayo_strip/augment_4096/train'
test_dir = '/content/data/test'

In [None]:
train_list = glob.glob(os.path.join(train_dir,'*.png'))
test_list = glob.glob(os.path.join(test_dir, '*.png'))
train_listArg = glob.glob(os.path.join(train_dirArg,'*.png'))
train_list = train_list + train_listArg 
len(train_list)


2296

In [None]:
print(f"Train Data: {len(train_list)}")

print(f"Test Data: {len(test_list)}")

Train Data: 2296
Test Data: 4


In [None]:
import pandas as pd 
train = pd.read_csv('/content/data/train.csv')
train

Unnamed: 0,image_id,center_id,patient_id,image_num,label
0,006388_0,11,006388,0,CE
1,008e5c_0,11,008e5c,0,CE
2,00c058_0,11,00c058,0,LAA
3,01adc5_0,11,01adc5,0,LAA
4,026c97_0,4,026c97,0,CE
...,...,...,...,...,...
749,fe9645_0,3,fe9645,0,CE
750,fe9bec_0,4,fe9bec,0,LAA
751,ff14e0_0,6,ff14e0,0,CE
752,ffec5c_0,7,ffec5c,0,LAA


In [None]:
labels = [path.split('/')[-1].split('.')[0] for path in train_list] 
Y = []
for label in labels: 
    label = label[:8]
    y = train.loc[train['image_id'] == label ]['label'].values[0]
    Y.append(1) if y == "CE" else Y.append(0)

labels = np.array(Y) 
labels.shape
    

(2296,)

## Random Plots

In [None]:
# random_idx = np.random.randint(1, len(train_list), size=9)
# fig, axes = plt.subplots(3, 3, figsize=(16, 12))

# for idx, ax in enumerate(axes.ravel()):
#     img = Image.open(train_list[idx])
#     ax.set_title(labels[idx])
#     ax.imshow(img)


## Split

In [None]:
train_list, valid_list = train_test_split(train_list, 
                                          test_size=0.2,
                                          stratify=labels,
                                          random_state=seed)

In [None]:
print(f"Train Data: {len(train_list)}")
print(f"Validation Data: {len(valid_list)}")
print(f"Test Data: {len(test_list)}")

Train Data: 1836
Validation Data: 460
Test Data: 4


## Image Augumentation

In [None]:
resize = 2048
train_transforms = transforms.Compose(
    [
        transforms.Resize((resize, resize)),
        transforms.RandomResizedCrop(resize),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Resize(resize),
        transforms.CenterCrop(resize),
        transforms.ToTensor(),
    ]
)


test_transforms = transforms.Compose(
    [
        transforms.Resize(resize),
        transforms.CenterCrop(resize),
        transforms.ToTensor(),
    ]
)


## Load Datasets

In [None]:
class MayoDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        img = Image.open(img_path)
        img_transformed = self.transform(img)

        label = img_path.split("/")[-1].split(".")[0]
        label = train.loc[train['image_id'] == label[:8]]['label'].values[0]
        label = 1 if label == "CE" else 0

        return img_transformed, label


In [None]:
train_data = MayoDataset(train_list, transform=train_transforms)
valid_data = MayoDataset(valid_list, transform=test_transforms)
test_data = MayoDataset(test_list, transform=test_transforms)

In [None]:
train_loader = DataLoader(dataset = train_data, batch_size=batch_size, shuffle=True )
valid_loader = DataLoader(dataset = valid_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset = test_data, batch_size=batch_size, shuffle=True)

In [None]:
print(len(train_data), len(train_loader))

1836 612


In [None]:
print(len(valid_data), len(valid_loader))

460 154


In [None]:
import torch
from vit_pytorch.efficient import ViT
from nystrom_attention import Nystromformer

efficient_transformer = Nystromformer(
    dim = 512,
    depth = 12,
    heads = 8,
    num_landmarks = 256
)

model = ViT(
    dim = 512,
    image_size = 2048,
    patch_size = 32,
    num_classes = 2,
    transformer = efficient_transformer
).to('cuda')
model.eval()

# img = torch.randn(1, 3, 2048, 2048).to("cuda") # your high resolution picture
# model(img) # (1, 2)
# model.eval()

ViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=32, p2=32)
    (1): Linear(in_features=3072, out_features=512, bias=True)
  )
  (transformer): Nystromformer(
    (layers): ModuleList(
      (0): ModuleList(
        (0): PreNorm(
          (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fn): NystromAttention(
            (to_qkv): Linear(in_features=512, out_features=1536, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=512, out_features=512, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
            (res_conv): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)
          )
        )
        (1): PreNorm(
          (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fn): FeedForward(
            (net): Sequential(
              (0): Linear(in_features=512, out_features=2048, bias=True)
 

### Training

In [None]:
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)


In [None]:
import torch
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [None]:
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)
        #print(data.shape)
        #print(label.shape)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)


    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)
    writer.add_scalar("Loss/train", epoch_loss, epoch+1)
    writer.add_scalar("Acc/train", epoch_accuracy, epoch+1)
    writer.add_scalar("Loss/Validation", epoch_val_loss, epoch+1)
    writer.add_scalar("Acc/Validation", epoch_val_accuracy, epoch+1)
    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )


  0%|          | 0/612 [00:00<?, ?it/s]

Epoch : 1 - loss : 0.5977 - acc: 0.7386 - val_loss : 0.6273 - val_acc: 0.7468



  0%|          | 0/612 [00:00<?, ?it/s]

Epoch : 2 - loss : 0.5794 - acc: 0.7489 - val_loss : 0.5633 - val_acc: 0.7511



  0%|          | 0/612 [00:00<?, ?it/s]

Epoch : 3 - loss : 0.5774 - acc: 0.7489 - val_loss : 0.5610 - val_acc: 0.7511



  0%|          | 0/612 [00:00<?, ?it/s]

Epoch : 4 - loss : 0.5737 - acc: 0.7489 - val_loss : 0.5693 - val_acc: 0.7511



  0%|          | 0/612 [00:00<?, ?it/s]

Epoch : 5 - loss : 0.5753 - acc: 0.7489 - val_loss : 0.5652 - val_acc: 0.7511



  0%|          | 0/612 [00:00<?, ?it/s]

Epoch : 6 - loss : 0.5725 - acc: 0.7489 - val_loss : 0.5923 - val_acc: 0.7511



  0%|          | 0/612 [00:00<?, ?it/s]

Epoch : 7 - loss : 0.5735 - acc: 0.7489 - val_loss : 0.5613 - val_acc: 0.7511



  0%|          | 0/612 [00:00<?, ?it/s]

Epoch : 8 - loss : 0.5711 - acc: 0.7489 - val_loss : 0.5618 - val_acc: 0.7511



  0%|          | 0/612 [00:00<?, ?it/s]

Epoch : 9 - loss : 0.5682 - acc: 0.7489 - val_loss : 0.5664 - val_acc: 0.7511



  0%|          | 0/612 [00:00<?, ?it/s]

Epoch : 10 - loss : 0.5692 - acc: 0.7489 - val_loss : 0.5598 - val_acc: 0.7511



In [None]:
writer.flush()
torch.save(model.state_dict(), "/content/gdrive/MyDrive/Ben/CVProject/VIT_2048_Agument.pt")
writer.close()

In [None]:
!tensorboard --logdir=runs 
