In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class InputEmbeddings(nn.Module):

    def __init__(self, vocab_size, embdim):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, embdim)

    def forward(self, x):
        return self.embeddings(x)

In [3]:
class PositionEmbeddings(nn.Module):

    def __init__(self, seq_len, embdim):
        super().__init__()
        # self.pe is the lookup matrix where each row represents seq[i]
        self.pe = torch.zeros(seq_len, embdim, dtype=torch.float32)
        # positions is just the sequences of seq_len
        positions = torch.arange(0, seq_len, dtype=torch.float32).unsqueeze(1)
        # skipdim, skips by 2 starting 0
        emb_skip_dim = torch.arange(0, embdim, step=2, dtype=torch.float32)

        z = positions / (10000 ** (emb_skip_dim / embdim))
        self.pe[:, 0::2] = torch.sin(z)
        self.pe[:, 1::2] = torch.cos(z)
        self.pe = self.pe.unsqueeze(0)
        self.pe = nn.Parameter(self.pe, requires_grad=False)

    def forward(self, x):
        B, T, C = x.shape

        x = x + self.pe[:, :T, :]
        return x

In [36]:
class MultiHeadAttention(nn.Module):

    def __init__(self, embdim, num_heads):
        super().__init__()
        self.query = nn.Linear(embdim, embdim)
        self.key = nn.Linear(embdim, embdim)
        self.value = nn.Linear(embdim, embdim)
        self.num_heads = num_heads
        assert embdim % num_heads == 0, "Make sure embdim is divisible by num_heads"
        self.head_dim = embdim // num_heads

    def forward(self, x):
        B, T, C = x.shape
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        print(q.shape, k.shape, v.shape)

In [37]:
B, T, C = 4, 2, 8
x = torch.randn(B, T, C)

In [38]:
x.shape

torch.Size([4, 2, 8])

In [39]:
model = MultiHeadAttention(C, 2)

In [40]:
model(x)

torch.Size([4, 2, 2, 4]) torch.Size([4, 2, 2, 4]) torch.Size([4, 2, 2, 4])


In [41]:
x[0]

tensor([[-0.2698, -1.8405, -0.3761, -0.2407, -2.5391, -0.7688,  0.6951, -1.1937],
        [ 0.6633, -1.1746, -1.2367,  0.2515,  0.9736, -0.5950,  1.4202, -0.4203]])

In [None]:
mask = [False ]