## Vision Transformer

In [None]:
import torch
import torchvision

print(torchvision.__version__)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
train_dir = './data/train'
test_dir = './data/test'

In [None]:
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms

# Create the Data Loader

In [None]:
BATCH_SIZE = 16

def create_dataloader(train_dir = train_dir, test_dir=test_dir, transform=None, batch_size=BATCH_SIZE):
    train_data = ImageFolder(root=train_dir, transform=transform, target_transform=None)
    test_data = ImageFolder(root=test_dir, transform=transform, target_transform=None)

    train_dataloader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, drop_last=True)
    test_dataloader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False, drop_last=True)

    return train_dataloader, test_dataloader


In [None]:
IMG_SIZE = (224, 224)

man_transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor()
])

In [None]:
train_dataloader, test_dataloader = create_dataloader(transform=man_transform)

print(len(train_dataloader))
print(len(test_dataloader))

In [None]:
from torch import nn
PATCH_SIZE = (16,16)

class patched_embeddings(nn.Module):
    def __init__(self, embedding_size = 768, stride_length=PATCH_SIZE, kernel_size=PATCH_SIZE, batch_size = BATCH_SIZE):
        super().__init__()
        
        self.patch_layer = nn.Conv2d(in_channels=3, out_channels=embedding_size, kernel_size=kernel_size, stride=stride_length)
        self.flatten_layer = nn.Flatten(start_dim=2, end_dim=3)
        self.class_token = nn.Parameter(torch.randn(batch_size, 1, embedding_size), requires_grad=True)
        self.positional_encoding = nn.Parameter(torch.randn(self.class_token.shape), requires_grad=True)

    def forward(self, x):
        x = self.flatten_layer(self.patch_layer(x)).permute(0,2,1)
        x = torch.cat((self.class_token, x), 1)
        x = x + self.positional_encoding
        return x



In [None]:
class MSABlock(nn.Module):
    def __init__(self, embedding_dim = 768, num_heads=12, attn_dropout=0):
        super().__init__()

        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim, device=device)
        self.multiheaded_attn = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_heads, dropout=attn_dropout, device=device, batch_first=True)

    def forward(self, x):
        x = self.layer_norm(x)
        attn_output, _ = self.multiheaded_attn(query=x, key=x, value=x, need_weights=False)

        return attn_output


In [None]:
class MLPBlock(nn.Module):
    def __init__(self, embedding_dim = 768, mlp_dropout=0.2, mlp_size=3072):
        super().__init__()

        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim, device=device)
        self.MLP = nn.Sequential(
            nn.Linear(in_features=embedding_dim, out_features=mlp_size, device=device),
            nn.GELU(),
            nn.Dropout(p=mlp_dropout),
            nn.Linear(in_features=mlp_size, out_features=embedding_dim, device=device)
        )

    def forward(self, x):
        x = self.layer_norm(x)
        return self.MLP(x)

In [None]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, embedding_dim=768, num_heads=12, attn_dropout = 0, mlp_dropout=0.2):
        super().__init__()

        self.msa_block = MSABlock(embedding_dim=embedding_dim, num_heads=num_heads, attn_dropout=attn_dropout)
        self.mlp_block = MLPBlock(embedding_dim=embedding_dim, mlp_dropout=mlp_dropout)

    def forward(self, x):
        x = self.msa_block(x) + x 
        x = self.mlp_block(x) + x

        return x 

In [None]:
class SimpleViT(nn.Module):
    def __init__(self, embedding_size=768, stride_length=PATCH_SIZE, kernel_size=PATCH_SIZE, batch_size = BATCH_SIZE, num_heads=12, attn_dropout=0, mlp_dropout=0.2, num_encoders=12, out_classes = 3, hidden_size = 1024):
        super().__init__()

        self.patcher = patched_embeddings(embedding_size=embedding_size, stride_length=stride_length, kernel_size=kernel_size, batch_size=batch_size)
        self.encoder_blocks = nn.Sequential(*[TransformerEncoderBlock(embedding_dim=embedding_size, num_heads=num_heads, attn_dropout=attn_dropout, mlp_dropout=mlp_dropout) for _ in range(num_encoders)])
        self.mlp_head = nn.Sequential(
            nn.Linear(in_features=embedding_size, out_features=hidden_size*2),
            nn.GELU(),
            nn.Dropout(mlp_dropout),
            nn.Linear(in_features=hidden_size*2, out_features=hidden_size),
            nn.GELU(),
            nn.Dropout(mlp_dropout),
            nn.Linear(in_features=hidden_size, out_features=out_classes),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        x = self.patcher(x)

        x = self.encoder_blocks(x)



        return self.mlp_head(x[:,0])

In [None]:
from torchinfo import summary

model_0 = SimpleViT().to(device)

summary(model_0, input_size=(BATCH_SIZE, 3, 224,224))

In [None]:
from tqdm import tqdm

In [None]:
X, y = next(iter(train_dataloader))
y

In [None]:
from torchmetrics import Accuracy

accuracy_fn = Accuracy(task='multiclass', num_classes=3).to(device)