In [5]:
import torch
import torch.nn as nn
import einops

In [6]:
test_dim = 40
test_heads = 2
test_seq = 15
test_batches = 5

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dim, heads):
        super().__init__()

        self.dim = dim # size of embeddings
        self.heads = heads

        self.head_len = int(self.dim / self.heads)
        self.norm_scale = self.head_len ** -0.5 # divide by sqrt of query size during normalization

        # pytorch automatically handles the remaining dimensions by expansion. It is equivalent of the neural network receiving one word embedding as input at a time, but vectorized.
        self.q_linear = nn.Linear(self.dim, self.dim, bias = False)
        self.k_linear = nn.Linear(self.dim , self.dim, bias = False)
        self.v_linear = nn.Linear(self.dim, self.dim, bias = False)

        # dim is seq (because will be applied after q_dot_k)
        self.softmax = nn.Softmax(dim=-1)

        # input and output is the same. Could be different if we increased the head_len as a parameter instead of being dim/num_heads.
        self.dense = nn.Linear(self.dim, self.dim)


    def forward(self, key, query, value):
        assert key.shape == (test_batches, test_seq, self.dim)
        assert query.shape == (test_batches, test_seq, self.dim)
        assert value.shape == (test_batches, test_seq, self.dim)

        # pass through linear layers
        q = self.q_linear(query)
        k = self.k_linear(key)
        v = self.v_linear(value)

        assert key.shape == (test_batches, test_seq, self.dim)
        assert query.shape == (test_batches, test_seq, self.dim)
        assert value.shape == (test_batches, test_seq, self.dim)

        # split heads and reshape. seq, head_len is a 2d matrix that is going to be multiplied. The other two dimensions are static.
        q = einops.rearrange(q, 'b seq (head head_len) -> b head seq head_len', head = self.heads)
        k = einops.rearrange(k, 'b seq (head head_len) -> b head seq head_len', head = self.heads)
        v = einops.rearrange(v, 'b seq (head head_len) -> b head seq head_len', head = self.heads)

        assert q.shape == (test_batches, self.heads, test_seq, self.head_len)
        assert k.shape == (test_batches, self.heads, test_seq, self.head_len)
        assert v.shape == (test_batches, self.heads, test_seq, self.head_len)

        # transpose k for matmul. Just change the last two dimensions around because we need seq,head_len * head_len,seq
        k = einops.rearrange(k, 'b head seq head_len -> b head head_len seq')

        assert k.shape == (test_batches, self.heads, self.head_len, test_seq)

        # matmul between q and k
        q_dot_k = torch.matmul(q,k)

        assert q_dot_k.shape == (test_batches, self.heads, test_seq, test_seq)

        # divide result by sqrt of head len
        q_dot_k *= self.norm_scale

        # softmax of q_dot_k
        attention_scores = self.softmax(q_dot_k)

        assert attention_scores.shape == (test_batches, self.heads, test_seq, test_seq)

        # matmul by value to obtain final result
        result = torch.matmul(attention_scores, v)

        assert result.shape == (test_batches, self.heads, test_seq, self.head_len)

        # concatenate all heads. We get back the entire dim size
        result_concat = einops.rearrange(result, 'b head seq head_len -> b seq (head head_len)')

        assert result_concat.shape == (test_batches, test_seq, self.dim)

        # pass through final dense layer
        out = self.dense(result_concat)

        assert result_concat.shape == (test_batches, test_seq, self.dim)

        return out


In [8]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

attention = MultiHeadAttention(dim=test_dim, heads=test_heads).to(device)

test_tensor = torch.rand(test_batches,test_seq,test_dim).to(device)

attention(test_tensor, test_tensor, test_tensor)

tensor([[[ 0.0998,  0.1155,  0.0078,  ..., -0.2253, -0.1911, -0.2133],
         [ 0.1004,  0.1154,  0.0071,  ..., -0.2250, -0.1884, -0.2126],
         [ 0.1003,  0.1152,  0.0070,  ..., -0.2252, -0.1884, -0.2132],
         ...,
         [ 0.1006,  0.1136,  0.0072,  ..., -0.2255, -0.1888, -0.2133],
         [ 0.1005,  0.1148,  0.0080,  ..., -0.2251, -0.1895, -0.2133],
         [ 0.0998,  0.1158,  0.0078,  ..., -0.2248, -0.1903, -0.2126]],

        [[ 0.0967,  0.0855, -0.0097,  ..., -0.1902, -0.1999, -0.2322],
         [ 0.0966,  0.0837, -0.0106,  ..., -0.1884, -0.1979, -0.2319],
         [ 0.0971,  0.0836, -0.0099,  ..., -0.1891, -0.2002, -0.2329],
         ...,
         [ 0.0981,  0.0837, -0.0114,  ..., -0.1892, -0.1996, -0.2331],
         [ 0.0969,  0.0863, -0.0116,  ..., -0.1901, -0.1998, -0.2318],
         [ 0.0973,  0.0845, -0.0101,  ..., -0.1900, -0.2011, -0.2329]],

        [[ 0.0937,  0.1357, -0.0062,  ..., -0.2130, -0.1774, -0.1880],
         [ 0.0928,  0.1348, -0.0054,  ..., -0