In [1]:
import torch
import torch.nn.functional as F

class SelfAttention(torch.nn.Module):
    
    def __init__(self, embedding_size:int = 4, dim_out: int = 4):
        super().__init__()
        self.wq = torch.nn.Linear(embedding_size, dim_out, bias=False)
        self.wk = torch.nn.Linear(embedding_size, dim_out, bias=False)
        self.wv = torch.nn.Linear(embedding_size, dim_out, bias=False)
        self.dim_out = dim_out

    def forward(self, embeddings):

        query = self.wq(embeddings)
        key = self.wk(embeddings)
        value = self.wv(embeddings)

        sims = torch.matmul(query, key.T)
        scaled_sim = torch.softmax(sims / (self.dim_out ** 0.5), dim=1)

        return torch.matmul(scaled_sim, value)


In [2]:
torch.manual_seed(1)
embeddings = torch.nn.Embedding(1000, embedding_dim=4)
attention = SelfAttention(embedding_size=4, dim_out=4)

token_ids = torch.tensor([1, 2, 3])
encodings = embeddings(token_ids)

encodings

tensor([[-0.1002, -0.6092, -0.9798, -1.6091],
        [-0.7121,  0.3037, -0.7773, -0.2515],
        [-0.2223,  1.6871,  0.2284,  0.4676]], grad_fn=<EmbeddingBackward0>)

In [3]:
attention(encodings)

tensor([[ 0.3959, -0.0747, -0.1569, -0.0704],
        [ 0.3421, -0.0247, -0.1320, -0.1292],
        [ 0.2429,  0.0757, -0.0831, -0.2359]], grad_fn=<MmBackward0>)

In [4]:
torch.manual_seed(42)

attention = SelfAttention(embedding_size=2, dim_out=2)
encodings = torch.tensor([[1.16, 0.23],
                          [0.57, 1.36],
                          [4.41, -2.16]])

attention(encodings)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

In [34]:
attention.wq.weight.T

tensor([[ 0.5406, -0.1657],
        [ 0.5869,  0.6496]], grad_fn=<PermuteBackward0>)

In [35]:
attention.wk.weight.T

tensor([[-0.1549, -0.3443],
        [ 0.1427,  0.4153]], grad_fn=<PermuteBackward0>)

In [36]:
attention.wv.weight.T

tensor([[ 0.6233,  0.6146],
        [-0.5188,  0.1323]], grad_fn=<PermuteBackward0>)

In [41]:
q = attention.wq(encodings)
q

tensor([[ 0.7621, -0.0428],
        [ 1.1063,  0.7890],
        [ 1.1164, -2.1336]], grad_fn=<MmBackward0>)

In [42]:
k = attention.wk(encodings)
k

tensor([[-0.1469, -0.3038],
        [ 0.1057,  0.3685],
        [-0.9914, -2.4152]], grad_fn=<MmBackward0>)

In [44]:
sims = torch.matmul(q, k.T)
sims

tensor([[-0.0990,  0.0648, -0.6523],
        [-0.4022,  0.4078, -3.0024],
        [ 0.4842, -0.6683,  4.0461]], grad_fn=<MmBackward0>)

In [49]:
scaled_sim = F.softmax(sims / (torch.tensor(2) ** 0.5), dim=1)
scaled_sim

tensor([[0.3573, 0.4011, 0.2416],
        [0.3410, 0.6047, 0.0542],
        [0.0722, 0.0320, 0.8959]], grad_fn=<SoftmaxBackward0>)

In [50]:
torch.matmul(scaled_sim, attention.wv(encodings))

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

## MultiHead Attention

In [4]:
class MultiHeadAttention(torch.nn.Module):

    def __init__(self, embedding_size:int = 4, dim_out:int = 4, num_heads: int = 4):
        super().__init__()
        self.heads = torch.nn.ModuleList([SelfAttention(embedding_size, dim_out) for _ in range(num_heads) 
                                          ])
    def forward(self, embeddings):
        return torch.cat([head(embeddings) for head in self.heads], dim=1)

In [5]:
torch.manual_seed(42)
attention = MultiHeadAttention(embedding_size=2, dim_out=2, num_heads=2)
encodings = torch.tensor([[1.16, 0.23],
                          [0.57, 1.36],
                          [4.41, -2.16]])

attention(encodings)

tensor([[ 1.0100,  1.0641, -0.7081, -0.8268],
        [ 0.2040,  0.7057, -0.7417, -0.9193],
        [ 3.4989,  2.2427, -0.7190, -0.8447]], grad_fn=<CatBackward0>)