<a href="https://colab.research.google.com/github/ra1ph2/Vision-Transformer/blob/main/VisionTransformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [58]:
import torch
from torch import nn
from torch import functional as F

class Attention(nn.Module):
    def __init__(self, heads=8, embed_dim=512, activation='relu', dropout=0.1):
        super(Attention, self).__init__()
        self.heads = heads
        self.embed_dim = embed_dim
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.softmax = nn.Softmax(dim=-1)
        if activation == 'relu':
            self.activation = nn.ReLU()
        else:
            self.activation = nn.Identity()
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, inp):
        # inp: (batch_size, seq_len, embed_dim)
        assert inp.size(2) == self.embed_dim

        query = self.activation(self.query(inp))
        key   = self.activation(self.key(inp))
        value = self.activation(self.value(inp))

        # output of _reshape_heads(): (batch_size * heads, seq_len, reduced_dim) | reduced_dim = embed_dim // heads
        query = self._reshape_heads(query)
        key   = self._reshape_heads(key)
        value = self._reshape_heads(value)

        # attention_scores: (batch_size * heads, seq_len, seq_len) | Softmaxed along the last dimension
        attention_scores = self.softmax(torch.matmul(query, key.transpose(1, 2)))

        # out: (batch_size * heads, seq_len, reduced_dim)
        out = torch.matmul(self.dropout(attention_scores), value)
        
        # output of _reshape_heads_back(): (batch_size, seq_len, embed_size)
        out = self._reshape_heads_back(out)

        return out, attention_scores

    def _reshape_heads(self, inp):
        # inp: (batch_size, seq_len, embed_dim)
        batch_size, seq_len, embed_dim = inp.size()

        reduced_dim = self.embed_dim // self.heads
        assert reduced_dim * self.heads == self.embed_dim
        out = inp.reshape(batch_size, seq_len, self.heads, reduced_dim)
        out = out.permute(0, 2, 1, 3)
        out = out.reshape(-1, seq_len, reduced_dim)

        # out: (batch_size * heads, seq_len, reduced_dim)
        return out

    def _reshape_heads_back(self, inp):
        # inp: (batch_size * heads, seq_len, reduced_dim) | reduced_dim = embed_dim // heads
        batch_size_mul_heads, seq_len, reduced_dim = inp.size()
        batch_size = batch_size_mul_heads // self.heads

        out = inp.reshape(batch_size, self.heads, seq_len, reduced_dim)
        out = out.permute(0, 2, 1, 3)
        out = out.reshape(batch_size, seq_len, self.embed_dim)

        # out: (batch_size, seq_len, embed_dim)
        return out

In [59]:
attention = Attention(embed_dim=4, heads=2, activation=None)
attention_lib = nn.MultiheadAttention(embed_dim=4, num_heads=2)
inp = torch.ones((1, 2, 4))
print(inp)
out, wts = attention(inp)
print(out)
print(wts)
out_lib, out_wts = attention_lib(inp.permute(1, 0, 2), inp.permute(1, 0, 2), inp.permute(1, 0, 2))
print(out_lib.permute(1, 0, 2))
print(out_wts)

tensor([[[1., 1., 1., 1.],
         [1., 1., 1., 1.]]])
tensor([[[-2.1490,  0.8508,  0.5317, -0.9154],
         [-2.1490,  0.8508,  0.5317, -0.9154]]], grad_fn=<UnsafeViewBackward>)
tensor([[[0.5000, 0.5000],
         [0.5000, 0.5000]],

        [[0.5000, 0.5000],
         [0.5000, 0.5000]]], grad_fn=<SoftmaxBackward>)
tensor([[[-0.6255,  0.0154,  0.3488,  0.4291],
         [-0.6255,  0.0154,  0.3488,  0.4291]]], grad_fn=<PermuteBackward>)
tensor([[[0.5000, 0.5000],
         [0.5000, 0.5000]]], grad_fn=<DivBackward0>)
