<a href="https://colab.research.google.com/github/unofficial-Jona/thesis/blob/main/Transformer_tryout.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transformer tryout

## imports

In [None]:
%%capture
! pip install einops
! pip install positional-encodings[pytorch]

In [None]:
import numpy as np
import math
import pdb
from tqdm import tqdm


import torch 
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import DataLoader
from torch import Tensor
from torchvision.transforms import Compose, Resize, ToTensor
from torch.optim import Adam
from torch.nn import CrossEntropyLoss

from torchvision.datasets.mnist import MNIST

from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary
from positional_encodings.torch_encodings import PositionalEncodingPermute1D, PositionalEncoding1D

# dataset
from torchvision.datasets import CIFAR10




## utility functions 
- patching
- pos_encoding

In [None]:
def patchify(images, n_patches):
    n, c, h, w = images.shape

    assert h == w, "Patchify method is implemented for square images only"

    patches = torch.zeros(n, n_patches ** 2, h * w // n_patches ** 2)
    patch_size = h // n_patches

    for idx, image in enumerate(images):
        for i in range(n_patches):
            for j in range(n_patches):
                patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
                patches[idx, i * n_patches + j] = patch.flatten()
    return patches

def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result


## self coded MSA module

In [None]:
class MyMSA(nn.Module):
    def __init__(self, d, n_heads=2):
        super(MyMSA, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"

        d_head = int(d / n_heads)
        self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        # Sequences has shape (N, seq_length, token_dim)
        # We go into shape    (N, seq_length, n_heads, token_dim / n_heads)
        # And come back to    (N, seq_length, item_dim)  (through concatenation)
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])


In [None]:
class MyViTBlock(nn.Module):
    def __init__(self, hidden_d, n_heads, mlp_ratio=4):
        super(MyViTBlock, self).__init__()
        self.hidden_d = hidden_d
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(hidden_d)
        self.mhsa = MyMSA(hidden_d, n_heads)
        self.norm2 = nn.LayerNorm(hidden_d)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d, mlp_ratio * hidden_d),
            nn.GELU(),
            nn.Linear(mlp_ratio * hidden_d, hidden_d)
        )

    def forward(self, x):
        out = x + self.mhsa(self.norm1(x))
        out = out + self.mlp(self.norm2(out))
        return out


## model composition

In [None]:
class MLP(nn.Module):
    """
    2 Layer MLP with GELU non-linearity
    in_features, hidden_features, out_features
    
    """
    def __init__(self, in_features, hidden_features, out_features, activation = nn.Gelu, drop=0):
        super().__init__()
        self.in_features = in_features
        self.hidden_features = hidden_features
        self.out_features = out_features
        self.activation = activation
        self.drop = drop

        # layers
        self.fc1 = nn.Linear(in_featues, hidden_features)
        self.act = self.activation()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
        self.softmax = nn.Softmax()
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        x = self.drop(x)
        return self.softmax(x)

class 

        

class MyViT(nn.Module):
    def __init__(self, chw, n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10):
        # Super constructor
        super(MyViT, self).__init__()
        
        # Attributes
        self.chw = chw # ( C , H , W )
        self.n_patches = n_patches
        self.n_blocks = n_blocks
        self.n_heads = n_heads
        self.hidden_d = hidden_d
        
        # Input and patches sizes
        assert chw[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        assert chw[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)

        # 1) Linear mapper
        self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)
        
        # 2) Learnable classification token
        self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))
        
        # 3) Positional embedding
        self.pos_embed = nn.Parameter(get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d).clone())
        self.pos_embed.requires_grad = False
        
        # 4) Transformer encoder blocks
        self.blocks = nn.ModuleList([MyViTBlock(hidden_d, n_heads) for _ in range(n_blocks)])
        
        # 5) Classification MLPk
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_d, out_d),
            
            nn.Linear(self.hidden_d, out_d),
            nn.Softmax(dim=-1)
        )

        # 6) pytorch transformer encoder layer
        self.encoder_layer = TransformerEncoderLayer(8, 4)
        self.encoder_blocks = TransformerEncoder(self.encoder_layer, 4)

    def forward(self, images):
        # Dividing images into patches
        n, c, h, w = images.shape
        patches = patchify(images, self.n_patches).to(self.pos_embed.device)
        
        # Running linear layer tokenization
        # Map the vector corresponding to each patch to the hidden size dimension
        tokens = self.linear_mapper(patches)
        
        # Adding classification token to the tokens
        tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])
        
        # Adding positional embedding
        pos_embed = self.pos_embed.repeat(n, 1, 1)
        out = tokens + pos_embed
        
        '''
        # Transformer Blocks
        for block in self.blocks:
            out = block(out)
        '''

        out = self.encoder_blocks(out)


        # Getting the classification token only
        out = out[:, 0]
        
        return self.mlp(out) # Map to output dimension, output category distribution
    

## main function
- data download
- training loop
- testing loop

In [None]:
def main():
    # Loading data
    transform = ToTensor()

    train_set = MNIST(root='./../datasets', train=True, download=True, transform=transform)
    test_set = MNIST(root='./../datasets', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
    test_loader = DataLoader(test_set, shuffle=False, batch_size=128)

    # Defining model and training options
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")
    model = MyViT((1, 28, 28), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10).to(device)
    N_EPOCHS = 5
    LR = 0.005

    # Training loop
    optimizer = Adam(model.parameters(), lr=LR)
    criterion = CrossEntropyLoss()
    for epoch in tqdm(range(N_EPOCHS), desc="Training"):
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)

            train_loss += loss.detach().cpu().item() / len(train_loader)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")

    # Test loop
    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="Testing"):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)

            correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
            total += len(x)
        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {correct / total * 100:.2f}%")

In [None]:
main()

Using device:  cpu 


Training:   0%|          | 0/5 [00:00<?, ?it/s]
Epoch 1 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 1 in training:   0%|          | 1/469 [00:00<06:04,  1.28it/s][A
Epoch 1 in training:   0%|          | 2/469 [00:01<06:06,  1.27it/s][A
Epoch 1 in training:   1%|          | 3/469 [00:02<06:06,  1.27it/s][A
Epoch 1 in training:   1%|          | 4/469 [00:03<05:57,  1.30it/s][A
Epoch 1 in training:   1%|          | 5/469 [00:03<05:31,  1.40it/s][A
Epoch 1 in training:   1%|▏         | 6/469 [00:04<05:23,  1.43it/s][A
Epoch 1 in training:   1%|▏         | 7/469 [00:05<05:29,  1.40it/s][A
Epoch 1 in training:   2%|▏         | 8/469 [00:05<05:33,  1.38it/s][A
Epoch 1 in training:   2%|▏         | 9/469 [00:06<05:38,  1.36it/s][A
Epoch 1 in training:   2%|▏         | 10/469 [00:07<05:33,  1.38it/s][A
Epoch 1 in training:   2%|▏         | 11/469 [00:07<05:16,  1.45it/s][A
Epoch 1 in training:   3%|▎         | 12/469 [00:08<05:20,  1.43it/s][A
Epoch 1 in training: 

Epoch 1/5 loss: 2.30



Epoch 2 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 2 in training:   0%|          | 1/469 [00:00<06:01,  1.30it/s][A
Epoch 2 in training:   0%|          | 2/469 [00:01<05:23,  1.44it/s][A
Epoch 2 in training:   1%|          | 3/469 [00:01<05:00,  1.55it/s][A
Epoch 2 in training:   1%|          | 4/469 [00:02<04:49,  1.61it/s][A
Epoch 2 in training:   1%|          | 5/469 [00:03<04:42,  1.64it/s][A
Epoch 2 in training:   1%|▏         | 6/469 [00:03<04:38,  1.66it/s][A
Epoch 2 in training:   1%|▏         | 7/469 [00:04<04:47,  1.61it/s][A
Epoch 2 in training:   2%|▏         | 8/469 [00:05<05:03,  1.52it/s][A
Epoch 2 in training:   2%|▏         | 9/469 [00:05<04:55,  1.56it/s][A
Epoch 2 in training:   2%|▏         | 10/469 [00:06<04:46,  1.60it/s][A
Epoch 2 in training:   2%|▏         | 11/469 [00:06<04:39,  1.64it/s][A
Epoch 2 in training:   3%|▎         | 12/469 [00:07<04:33,  1.67it/s][A
Epoch 2 in training:   3%|▎         | 13/469 [00:08<04:32,  1.68it/s

Epoch 2/5 loss: 2.30



Epoch 3 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 3 in training:   0%|          | 1/469 [00:00<04:53,  1.59it/s][A
Epoch 3 in training:   0%|          | 2/469 [00:01<05:04,  1.54it/s][A
Epoch 3 in training:   1%|          | 3/469 [00:02<05:24,  1.43it/s][A
Epoch 3 in training:   1%|          | 4/469 [00:02<05:33,  1.39it/s][A
Epoch 3 in training:   1%|          | 5/469 [00:03<05:41,  1.36it/s][A
Epoch 3 in training:   1%|▏         | 6/469 [00:04<05:42,  1.35it/s][A
Epoch 3 in training:   1%|▏         | 7/469 [00:05<05:43,  1.35it/s][A
Epoch 3 in training:   2%|▏         | 8/469 [00:05<05:44,  1.34it/s][A
Epoch 3 in training:   2%|▏         | 9/469 [00:06<05:45,  1.33it/s][A
Epoch 3 in training:   2%|▏         | 10/469 [00:07<05:43,  1.33it/s][A
Epoch 3 in training:   2%|▏         | 11/469 [00:08<05:48,  1.32it/s][A
Epoch 3 in training:   3%|▎         | 12/469 [00:08<05:45,  1.32it/s][A
Epoch 3 in training:   3%|▎         | 13/469 [00:09<05:44,  1.32it/s

Epoch 3/5 loss: 2.30



Epoch 4 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 4 in training:   0%|          | 1/469 [00:00<04:51,  1.60it/s][A
Epoch 4 in training:   0%|          | 2/469 [00:01<04:44,  1.64it/s][A
Epoch 4 in training:   1%|          | 3/469 [00:01<04:44,  1.64it/s][A
Epoch 4 in training:   1%|          | 4/469 [00:02<04:41,  1.65it/s][A
Epoch 4 in training:   1%|          | 5/469 [00:03<04:41,  1.65it/s][A
Epoch 4 in training:   1%|▏         | 6/469 [00:03<04:37,  1.67it/s][A
Epoch 4 in training:   1%|▏         | 7/469 [00:04<04:37,  1.66it/s][A
Epoch 4 in training:   2%|▏         | 8/469 [00:04<04:46,  1.61it/s][A
Epoch 4 in training:   2%|▏         | 9/469 [00:05<05:04,  1.51it/s][A
Epoch 4 in training:   2%|▏         | 10/469 [00:06<04:59,  1.53it/s][A
Epoch 4 in training:   2%|▏         | 11/469 [00:07<05:09,  1.48it/s][A
Epoch 4 in training:   3%|▎         | 12/469 [00:07<05:16,  1.44it/s][A
Epoch 4 in training:   3%|▎         | 13/469 [00:08<05:05,  1.49it/s

Epoch 4/5 loss: 2.30



Epoch 5 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 5 in training:   0%|          | 1/469 [00:00<04:49,  1.62it/s][A
Epoch 5 in training:   0%|          | 2/469 [00:01<04:42,  1.65it/s][A
Epoch 5 in training:   1%|          | 3/469 [00:01<04:40,  1.66it/s][A
Epoch 5 in training:   1%|          | 4/469 [00:02<04:39,  1.66it/s][A
Epoch 5 in training:   1%|          | 5/469 [00:03<04:37,  1.67it/s][A
Epoch 5 in training:   1%|▏         | 6/469 [00:03<04:37,  1.67it/s][A
Epoch 5 in training:   1%|▏         | 7/469 [00:04<04:36,  1.67it/s][A
Epoch 5 in training:   2%|▏         | 8/469 [00:04<04:36,  1.67it/s][A
Epoch 5 in training:   2%|▏         | 9/469 [00:05<04:35,  1.67it/s][A
Epoch 5 in training:   2%|▏         | 10/469 [00:05<04:33,  1.68it/s][A
Epoch 5 in training:   2%|▏         | 11/469 [00:06<04:34,  1.67it/s][A
Epoch 5 in training:   3%|▎         | 12/469 [00:07<04:33,  1.67it/s][A
Epoch 5 in training:   3%|▎         | 13/469 [00:07<04:32,  1.67it/s

Epoch 5/5 loss: 2.30


Testing: 100%|██████████| 79/79 [00:30<00:00,  2.60it/s]

Test loss: 2.30
Test accuracy: 11.35%



