# Vision Transformer

## Imports

In [1]:
import os
from PIL import Image
from tqdm import tqdm
from transformers import ViTFeatureExtractor, ViTForImageClassification, AdamW
import pandas as pd
from torch.utils.data import DataLoader
import torch
import pytorch_lightning as pl
import torch.nn as nn
from torchvision.transforms import (Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomVerticalFlip,
                                    GaussianBlur,
                                    RandomRotation, #  rotation
                                    Resize, 
                                    ToTensor)

NUM_WORKERS = 7

## Data Preparation

In [2]:
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch32-384")
labels = pd.read_csv("data/img_info.csv", index_col=0)
labels.sort_values("shuf_file_number", inplace=True)

labels


Unnamed: 0,orig_file,new_file,delta_x,delta_y,orig_crop_side,side,scaling,orig_file_number,shuf_file_number,labels_string,labels_int
8848,DEV13781.jpg,SHUF00000,176,-144,2272,2272,1.0,13781,0,NRG,0.0
1097,DEV12858.jpg,SHUF00001,-89,-904,3150,3150,1.0,12858,1,NRG,0.0
13185,DEV02899.jpg,SHUF00002,-59,-685,2486,2486,1.0,2899,2,NRG,0.0
6229,DEV12129.jpg,SHUF00003,168,-152,2256,2256,1.0,12129,3,NRG,0.0
12359,DEV04490.jpg,SHUF00004,-66,-886,3166,3166,1.0,4490,4,NRG,0.0
...,...,...,...,...,...,...,...,...,...,...,...
11147,DEV04506.jpg,SHUF14995,-44,-905,3175,3175,1.0,4506,14995,NRG,0.0
5655,DEV12149.jpg,SHUF14996,-45,-857,3207,3207,1.0,12149,14996,RG,1.0
8920,DEV00409.jpg,SHUF14997,-95,-870,2987,2987,1.0,409,14997,RG,1.0
4030,DEV01824.jpg,SHUF14998,197,-127,2338,2338,1.0,1824,14998,NRG,0.0


In [3]:
images = []
for filename in tqdm(os.listdir("data/shuffled_square_75")[:300]):
    file = f"data/shuffled_square_75/{filename}"
    img = Image.open(file)
    images.append(img)

images_train = images[:100]
images_val = images[100:200]
images_test = images[200:]

labels_list = labels.labels_int.to_list()
labels_train = labels_list[:100]
labels_val = labels_list[100:200]
labels_test = labels_list[200:]

100%|██████████| 300/300 [00:00<00:00, 474.40it/s]


In [4]:
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)

_train_transforms = Compose(
        [
            Resize(feature_extractor.size),
            RandomHorizontalFlip(),
            RandomVerticalFlip(),
            RandomRotation(45),
            GaussianBlur(5, (0.1, 0.2)),
            ToTensor(),
            normalize,
        ]
    )

_val_transforms = Compose(
        [
            Resize(feature_extractor.size),
            ToTensor(),
            normalize,
        ]
    )

def train_transforms(images):
    return [_train_transforms(img) for img in tqdm(images)]

def val_transforms(images):
    return [_val_transforms(img) for img in tqdm(images)]

In [5]:
train_transformed = train_transforms(images_train)
val_transformed = val_transforms(images_val)
test_transformed = val_transforms(images_test)

100%|██████████| 100/100 [00:33<00:00,  3.00it/s]
100%|██████████| 100/100 [01:28<00:00,  1.12it/s]
100%|██████████| 100/100 [00:57<00:00,  1.73it/s]


In [6]:
train_ds = [{"pixel_values": train_transformed[i], "label": int(labels_train[i])} for i in range(100)]
val_ds = [{"pixel_values": val_transformed[i], "label": int(labels_val[i])} for i in range(100)]
test_ds = [{"pixel_values": test_transformed[i], "label": int(labels_test[i])} for i in range(100)]
test_ds[0]

{'pixel_values': tensor([[[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]],
 
         [[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]],
 
         [[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]]]),
 'label': 0}

In [7]:
def collate_fn(images):
    pixel_values = torch.stack([img["pixel_values"] for img in images])
    labels = torch.LongTensor([img["label"] for img in images])
    return {"pixel_values": pixel_values, "labels": labels}


train_batch_size = 2
eval_batch_size = 2

train_dataloader = DataLoader(train_ds, shuffle=True, collate_fn=collate_fn, batch_size=train_batch_size, num_workers = NUM_WORKERS)
val_dataloader = DataLoader(val_ds, collate_fn=collate_fn, batch_size=eval_batch_size)
test_dataloader = DataLoader(test_ds, collate_fn=collate_fn, batch_size=eval_batch_size)

In [9]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v, torch.Tensor):
    print(k, v.shape)

RuntimeError: DataLoader worker (pid(s) 14596, 17012, 13608, 9468, 10076, 11592, 19944) exited unexpectedly

## Model Definition

In [21]:
id2label = {0: "NRG", 1:"RG"}
label2id = {"NRG": 0, "RG": 1}

class ViTLightningModule(pl.LightningModule):
    def __init__(self, num_labels=10):
        super(ViTLightningModule, self).__init__()
        self.vit = ViTForImageClassification.from_pretrained('google/vit-base-patch32-384',
                                                              num_labels=2,
                                                              id2label=id2label,
                                                              label2id=label2id,
                                                              ignore_mismatched_sizes=True)

    def forward(self, pixel_values):
        outputs = self.vit(pixel_values=pixel_values)
        return outputs.logits
        
    def common_step(self, batch, batch_idx):
        pixel_values = batch['pixel_values']
        labels = batch['labels']
        logits = self(pixel_values)

        criterion = nn.CrossEntropyLoss()
        loss = criterion(logits, labels)
        predictions = logits.argmax(-1)
        correct = (predictions == labels).sum().item()
        accuracy = correct/pixel_values.shape[0]

        return loss, accuracy
      
    def training_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)
        self.log("training_accuracy", accuracy)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        self.log("validation_loss", loss, on_epoch=True)
        self.log("validation_accuracy", accuracy, on_epoch=True)

        return loss

    def test_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     

        return loss

    def configure_optimizers(self):
        # We could make the optimizer more fancy by adding a scheduler and specifying which parameters do
        # not require weight_decay but just using AdamW out-of-the-box works fine
        return AdamW(self.parameters(), lr=5e-5)

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return val_dataloader

    def test_dataloader(self):
        return test_dataloader

## Training

In [22]:
# Start tensorboard.
%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [23]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping

# for early stopping, see https://pytorch-lightning.readthedocs.io/en/1.0.0/early_stopping.html?highlight=early%20stopping
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=3,
    strict=False,
    verbose=True,
    mode='min'
)

model = ViTLightningModule()
trainer = Trainer(gpus=1, callbacks=[EarlyStopping(monitor='validation_loss')], max_epochs=30)
trainer.fit(model)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch32-384 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
Missing logger folder: c:\Users\Valen\Documents\Master AI VU\Medical Imaging\AIROGSLite-AI4MI-VU-2022\lightning_logs

  | Name | Type                      | Params
---------------------------------------------------
0 | vit  | ViTForImageClassification | 87.5 M
------------------------------------------

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

## Evaluation

In [None]:
trainer.test()

In [None]:
%tensorboard --logdir lightning_logs/