<a href="https://colab.research.google.com/github/rslab-ntua/MSc_GBDA/blob/master/2023/Lab4_vit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/rslab-ntua/MSc_GBDA/blob/master/2023/Lab4_vit.ipynb)

In [None]:
!wget http://madm.dfki.de/files/sentinel/EuroSATallBands.zip
!unzip EuroSATallBands.zip
!rm EuroSATallBands.zip

In [None]:
!pip install rasterio
!pip install lightning
!pip install patchify

%load_ext tensorboard

In [None]:
# Read data
from torch.utils.data import Dataset, DataLoader, random_split
from glob import glob
import os
import rasterio
from typing import Callable, List
import numpy as np
from patchify import patchify

DATA_ROOT = "ds/images/remote_sensing/otherDatasets/sentinel_2/tif/"

TransformFun = Callable[[dict], dict]

class EuroSAT(Dataset):
    def __init__(self, data_root, transforms: List[TransformFun] = []):
        super().__init__()
        self._build_db(data_root)
        self.transforms = transforms
        
    def _build_db(self, data_root) -> None:
        sample_urls = sorted(glob(os.path.join(data_root, "**/*.tif"), recursive=True))
        
        def parse_category(url):
            return os.path.basename(os.path.dirname(url))
        
        # Get unique category names in alphabetical order
        categories = sorted(list(set([parse_category(url) for url in sample_urls])))
        self.categories = {c_name: idx for idx, c_name in enumerate(categories)}
        
        self.db = []
        for s_url in sample_urls:
            self.db.append({
                "url": s_url,
                "category_name": parse_category(s_url),
                "category_id": self.categories[parse_category(s_url)]
            })
    
    @property
    def num_categories(self):
        return(len(self.categories))
    
    def __getitem__(self, index):
        sample =  self.db[index]
        
        for T in self.transforms:
            sample = T(sample)
            
        return sample
    
    def __len__(self):
        return len(self.db)
    
        

def load_data():
    def apply(x:dict) -> dict:
        assert "url" in x
        with rasterio.open(x["url"]) as dataset:
            x.update({"data": dataset.read()[3:0:-1]})
        return x
    return apply

def normalize(factor=10000):
    def apply(x:dict) -> dict:
        assert "data" in x
        x["data"] = x["data"].astype(np.float32) / factor
        return x
    return apply

def patchify_transform(n_patches=4):
    def apply(x:dict) -> dict:
        assert "data" in x
        channels, height, width = x["data"].shape
        assert height == width
        assert height % n_patches == 0
        patch_size = height // n_patches
        sample = patchify(x["data"], (channels,patch_size,patch_size), step=patch_size)
        cut_size = sample.shape[1]
        x["data"] = x["data"].reshape(1,cut_size**2,channels*patch_size**2)[0]
        return x
    return apply



In [None]:
dset = EuroSAT(data_root=DATA_ROOT, transforms=[load_data(), normalize()])
itdata = iter(dset)

In [None]:
import matplotlib.pyplot as plt

n_patches = 8

sample = next(itdata)
fig_im, ax_im = plt.subplots()
ax_im.imshow(sample["data"].transpose(1,2,0))
ax_im.set_axis_off()
chw = sample["data"].shape

patch_side = chw[-1] // n_patches
patches = patchify(sample["data"], (chw[0],patch_side,patch_side), step=patch_side)
cut_size = patches.shape[1:3]
fig, axes = plt.subplots(*cut_size)
for id, ax in enumerate(axes.flat):
  (i,j) = np.unravel_index(id, cut_size)
  patch = patches[0,i,j,...].transpose(1,2,0)
  ax.imshow(patch)
  ax.set_axis_off()

In [None]:
# Split train/val/test set
dset = EuroSAT(data_root=DATA_ROOT, transforms=[load_data(), normalize(), patchify_transform(8)])

train_dset, val_dset, test_dset = random_split(dset, 
        lengths=[0.7,0.2,0.1]
    )

train_dloader =DataLoader(train_dset, batch_size=128, shuffle=True, num_workers=2)
val_dloader =DataLoader(val_dset, batch_size=128, shuffle=False, num_workers=2)
test_dloader =DataLoader(test_dset, batch_size=128, shuffle=False, num_workers=2)

In [None]:
import math

from torch import nn
import torch
import torch.nn.functional as F
import lightning.pytorch as pl
from torchmetrics import Accuracy, ConfusionMatrix
from torchsummary import summary

class PositionalEncoding(nn.Module):
    def __init__(self, seq_len: int, emb_len: int, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, emb_len, 2) * (-math.log(10000.0) / emb_len))
        pe = torch.zeros(1,seq_len, emb_len)
        pe[0,:,0::2] = torch.sin(position * div_term)
        pe[0,:,1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        """
        x = x + self.pe
        return self.dropout(x)

class MultiheadSelfAttention(nn.Module):
    #  Implementation based on: BrianPulfer/PapersReimplementations/vit
    def __init__(self, dim, n_heads=2):
        super(MultiheadSelfAttention, self).__init__()
        self.dim = dim
        self.n_heads = n_heads

        assert dim % n_heads == 0, f"Can't divide dimension {dim} into {n_heads} heads"

        d_head = int(dim / n_heads)
        self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        # Sequences has shape (N, seq_length, token_dim)
        # We go into shape    (N, seq_length, n_heads, token_dim / n_heads)
        # And come back to    (N, seq_length, item_dim)  (through concatenation)
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        result = torch.cat([torch.unsqueeze(r, dim=0) for r in result])
        return result

class ViTLayer(nn.Module):
    def __init__(self, hidden_d, n_heads, mlp_ratio=4):
        super(ViTLayer, self).__init__()
        self.hidden_d = hidden_d
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(hidden_d)
        self.mhsa = MultiheadSelfAttention(hidden_d, n_heads)
        self.norm2 = nn.LayerNorm(hidden_d)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d, mlp_ratio * hidden_d),
            nn.GELU(),
            nn.Linear(mlp_ratio * hidden_d, hidden_d)
        )

    def forward(self, x):
        out = x + self.mhsa(self.norm1(x))
        out = out + self.mlp(self.norm2(out))
        return out

class ViT(pl.LightningModule):
    def __init__(self, input_d: int, out_d: int, n_patches: int = 8, 
                      hidden_d: int = 8, n_layers: int = 2, 
                      n_heads: int = 2, mlp_ratio: int = 4, lr=5e-3):
      # Super constructor
        super(ViT, self).__init__()
        
        self.lr = lr

        # 1) Linear mapper
        self.embedding = nn.Linear(input_d, hidden_d)

        # 2) Learnable classifiation token
        self.class_token = nn.Parameter(torch.rand(1, hidden_d))

        # 3) Positional encoding
        self.pos_embed = PositionalEncoding(n_patches ** 2 + 1, hidden_d)

        # 4) Encoder
        self.layers = nn.ModuleList([ViTLayer(hidden_d, n_heads, mlp_ratio) for _ in range(n_layers)])
        # encoder_layers = nn.TransformerEncoderLayer(hidden_d, n_heads, 
        #                                             mlp_ratio*hidden_d, 
        #                                             batch_first=True)
        # self.transformer_encoder = nn.TransformerEncoder(encoder_layers, n_layers)
        self.classifier = nn.Linear(hidden_d, out_d)

        self.train_accuracy = Accuracy(task="multiclass", num_classes=out_d)

        self.val_accuracy = Accuracy(task="multiclass", num_classes=out_d)
        self.val_confusion_matrix = ConfusionMatrix(task="multiclass", num_classes=out_d)
        
        self.save_hyperparameters()

    def forward(self, x: torch.Tensor):
        tokens = self.embedding(x)

        # Adding classification token to the tokens
        tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])
        
        # Adding positional embedding
        out = self.pos_embed(tokens)

        for layer in self.layers:
            out = layer(out)

        out = self.classifier(out[:,0])
        return out
      
    def training_step(self, batch, batch_idx):
        X = batch["data"]
        y = batch["category_id"]

        logits = self(X)

        loss = F.nll_loss(torch.log_softmax(logits, dim=-1), y)
        self.log("loss/train", loss, prog_bar=True, on_epoch=True, on_step=False)

        self.train_accuracy(logits, y)
        self.log("accuracy/train", self.train_accuracy, prog_bar=True, on_epoch=True, on_step=False)

        return loss
    
    def validation_step(self, batch, batch_idx):
        X = batch["data"]
        y = batch["category_id"]

        logits = self(X)

        loss = F.nll_loss(torch.log_softmax(logits, dim=-1), y)
        self.log("loss/val", loss, prog_bar=True, on_epoch=True, on_step=False)

        self.val_accuracy(logits, y)
        self.log("accuracy/val", self.val_accuracy, prog_bar=True, on_epoch=True, on_step=False)

        self.val_confusion_matrix(logits, y)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)


In [None]:
pe = PositionalEncoding(seq_len=100, emb_len=300)
plt.imshow(pe.pe[0])

In [None]:
model = ViT(192,10,8)
x = torch.randn(3,64,192)
y = model(x)
print(y.shape)

In [None]:
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint

callbacks = [
    EarlyStopping(monitor="accuracy/val", mode="max", patience=5),
    ModelCheckpoint(monitor="accuracy/val", mode="max", save_last=True)
]

model = ViT(input_d=192, out_d=dset.num_categories, n_patches=8,
            hidden_d=16, n_heads=2, n_layers=2)
trainer = pl.Trainer(
    accelerator="cpu", 
    devices=1,
    max_epochs=20,
    callbacks=callbacks,
    default_root_dir="simple_vit"
)

trainer.fit(model, train_dataloaders=train_dloader, val_dataloaders=val_dloader)


In [None]:
%tensorboard --logdir simple_vit/lightning_logs

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay
from matplotlib import pyplot as plt

best_model = ViT.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

trainer.validate(model, dataloaders=test_dloader)

cm = model.val_confusion_matrix.compute().cpu().numpy()

disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                               display_labels=dset.categories.keys())
plt.figure(figsize=(20,20), dpi=100)
ax = plt.axes()

disp.plot(ax=ax)
plt.xticks(rotation=90)
plt.show()