In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [16]:
class AttentionHead(nn.Module):
    def __init__(self, in_dim, key_dim, ctx_size) -> None:
        super(AttentionHead, self).__init__()

        self.scaler = key_dim**-0.5
        self.register_buffer('mask', torch.tril(torch.ones(ctx_size, ctx_size))==0)

        self.q = nn.Linear(in_dim, key_dim, bias=False)
        self.k = nn.Linear(in_dim, key_dim, bias=False)
        self.v = nn.Linear(in_dim, in_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        Q, K, V = self.q(x), self.k(x), self.v(x)
        A = (Q @ K.transpose(1,2)) * self.scaler
        A = F.softmax(A.masked_fill(self.mask, -torch.inf), dim=2)
        return A @ V

In [17]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, in_dim, key_dim, ctx_size) -> None:
        super(MultiHeadAttention, self).__init__()

        self.heads = nn.ModuleList([AttentionHead(in_dim, key_dim//n_heads, ctx_size) for _ in range(n_heads)])
        self.proj = nn.Linear(n_heads*in_dim, in_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        t = torch.cat([h(x) for h in self.heads], dim=2)
        return self.proj(t)

In [18]:
class Block(nn.Module):
    def __init__(self, n_heads, in_dim, key_dim, h_dim, ctx_size) -> None:
        super(Block, self).__init__()

        self.attention = MultiHeadAttention(n_heads, in_dim, key_dim, ctx_size)
        self.mlp = nn.Sequential(nn.Linear(in_dim, h_dim), nn.ReLU(), nn.Linear(h_dim, in_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attention(x)
        return x + self.mlp(x)

In [19]:
ctx_size = 10
in_dim = 100

n_blocks = 4
n_heads = 4
key_dim = 32
h_dim = 2*in_dim

In [20]:
args = [n_heads, in_dim, key_dim, h_dim, ctx_size]
modules = [Block(*args) for _ in range(n_blocks)]
model = nn.Sequential(*modules)

print('model size: ', sum([p.numel() for p in model.parameters()]))

model size:  507200


In [21]:
batch_size = 32

x = torch.rand(batch_size, ctx_size, in_dim)
print(f'{x.shape=}')

x = model(x)
print(f'{x.shape=}')

x.shape=torch.Size([32, 10, 100])
x.shape=torch.Size([32, 10, 100])
