## Vision Transformer

In [None]:
import torch
import torchvision

print(torchvision.__version__)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
train_dir = './data/train'
test_dir = './data/test'

In [None]:
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms

# Create the Data Loader

In [None]:
BATCH_SIZE = 16

def create_dataloader(train_dir = train_dir, test_dir=test_dir, transform=None, batch_size=BATCH_SIZE):
    train_data = ImageFolder(root=train_dir, transform=transform, target_transform=None)
    test_data = ImageFolder(root=test_dir, transform=transform, target_transform=None)

    train_dataloader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, drop_last=True)
    test_dataloader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False, drop_last=True)

    return train_dataloader, test_dataloader


In [None]:
IMG_SIZE = (224, 224)

man_transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor()
])

In [None]:
train_dataloader, test_dataloader = create_dataloader(transform=man_transform)

print(len(train_dataloader))
print(len(test_dataloader))

In [None]:
from torch import nn
PATCH_SIZE = (16,16)

class patched_embeddings(nn.Module):
    def __init__(self, embedding_size = 768, stride_length=PATCH_SIZE, kernel_size=PATCH_SIZE, batch_size = BATCH_SIZE):
        super().__init__()
        
        self.patch_layer = nn.Conv2d(in_channels=3, out_channels=embedding_size, kernel_size=kernel_size, stride=stride_length)
        self.flatten_layer = nn.Flatten(start_dim=2, end_dim=3)
        self.class_token = nn.Parameter(torch.randn(batch_size, 1, embedding_size), requires_grad=True)
        self.positional_encoding = nn.Parameter(torch.randn(self.class_token.shape), requires_grad=True)

    def forward(self, x):
        x = self.flatten_layer(self.patch_layer(x)).permute(0,2,1)
        x = torch.cat((self.class_token, x), 1)
        x = x + self.positional_encoding
        return x



In [None]:
class MSABlock(nn.Module):
    def __init__(self, embedding_dim = 768, num_heads=12, attn_dropout=0):
        super().__init__()

        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim, device=device)
        self.multiheaded_attn = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_heads, dropout=attn_dropout, device=device, batch_first=True)

    def forward(self, x):
        x = self.layer_norm(x)
        attn_output, _ = self.multiheaded_attn(query=x, key=x, value=x, need_weights=False)

        return attn_output
