In [1]:
import torch
import torch.nn as nn

In [2]:
torch.manual_seed(40)

class SelfAttention(nn.Module):
    def __init__(self, d_model=2, row_dim=0, col_dim=1):
        super().__init__()
        self.W_q = nn.Linear(
            in_features=d_model, out_features=d_model, bias=False # bias false since the original paper had it set to false
            )
        self.W_k = nn.Linear(
            in_features=d_model, out_features=d_model, bias=False
        )
        self.W_v = nn.Linear(
            in_features=d_model, out_features=d_model, bias=False
        )
        self.row_dim = row_dim
        self.col_dim = col_dim
    
    def forward(self, token):
        # we need q, k, v for each token across the whole text
        q = self.W_q(token)
        k = self.W_k(token)
        v = self.W_v(token)

        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim)) # similarities. this is Q @ K^T
        scaled_sims = sims / torch.tensor(k.size(self.col_dim) ** 0.5) # this is (Q @ K^T) / sqrt(d_K)
        attn_percents = torch.softmax(scaled_sims, dim=self.col_dim) # applying softmax to scaled similarities
        self_attn_scores = torch.matmul(attn_percents, v)
        return self_attn_scores


In [3]:
torch.manual_seed(40)

self_Attn = SelfAttention()

In [4]:
torch.manual_seed(40)

encoding_matrix = torch.randn(5, 2)

self_Attn(encoding_matrix)

tensor([[-0.2011,  0.2057],
        [-0.2040,  0.1979],
        [-0.1110,  0.3437],
        [-0.1842,  0.2515],
        [-0.2032,  0.1883]], grad_fn=<MmBackward0>)

In [5]:
W_q = self_Attn.W_q.weight.transpose(0, 1) # query weights
W_k = self_Attn.W_k.weight.transpose(0, 1) # key weights
W_v = self_Attn.W_v.weight.transpose(0,1) # value weights

print(f"Query weights: {W_q}")
print(f"Key weights: {W_k}")
print(f"Value weights: {W_v}")

Q = encoding_matrix @ W_q # or Q = self_Attn.W_q(encoding_matrix)
K = encoding_matrix @ W_k # or K = self_Attn.W_k(encoding_matrix)
V = encoding_matrix @ W_v # or V = self_Attn.W_v(encoding_matrix)

print(f"\nQuery matrix: {Q}")
print(f"\nKey matrix: {K}")
print(f"\nValue matrix: {V}")

Query weights: tensor([[-0.1868, -0.4614],
        [ 0.5177,  0.3051]], grad_fn=<TransposeBackward0>)
Key weights: tensor([[ 0.5161,  0.0709],
        [-0.0173, -0.5207]], grad_fn=<TransposeBackward0>)
Value weights: tensor([[-0.2973,  0.4265],
        [-0.6072, -0.2483]], grad_fn=<TransposeBackward0>)

Query matrix: tensor([[-0.3542, -0.5357],
        [-0.4196, -0.5584],
        [ 0.5658,  0.4835],
        [ 0.0497, -0.3867],
        [-0.5662, -0.5234]], grad_fn=<MmBackward0>)

Key matrix: tensor([[ 0.4864,  0.2473],
        [ 0.4655,  0.3184],
        [-0.2368, -0.5190],
        [ 0.6020, -0.1886],
        [ 0.2944,  0.5062]], grad_fn=<MmBackward0>)

Value matrix: tensor([[-0.0653,  0.4834],
        [ 0.0348,  0.4995],
        [-0.4428, -0.4153],
        [-0.6698,  0.3750],
        [ 0.3850,  0.4535]], grad_fn=<MmBackward0>)
