In [5]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F 

In [6]:
class SelfAttention(nn.Module):
    def __init__(self, d_model = 2,row_dim=0, col_dim=1):
        super().__init__()
    ## d_model = the number of embedding values per token. 
    ## row_dim, col_dim = the indices we should use to access rows or columns 
    
        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
    
        self.row_dim = row_dim 
        self.col_dim = col_dim 
    def forward(self, token_encodings):
        ## Create the query, key and values using the encoding numbers associated with each token 
        q = self.W_q(token_encodings)
        k = self.W_k(token_encodings)
        v = self.W_v(token_encodings)

        ## Compute similarities scores 
        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        ## Scale the similarities by dividing by sqrt(k.col_dim)
        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

        ## Apply the softmax to determine what percent of each tokens' value to use 
        ## in final attention values. 
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

In [7]:
## Calculate Self-Attention 

## Matrix of token encodings
encodings_matrix = torch.tensor([[1.16, 0.23], 
                                [0.57, 1.36], 
                                [4.41, -2.16]])

torch.manual_seed(42)

selfAttention = SelfAttention(d_model=2, row_dim=0, col_dim=1)

selfAttention(encodings_matrix)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

In [8]:
## Printing out weights and verifying calculations 
## Weight matrix to create the values
selfAttention.W_q.weight.transpose(0, 1)

tensor([[ 0.5406, -0.1657],
        [ 0.5869,  0.6496]], grad_fn=<TransposeBackward0>)

In [9]:
## Weight matrix to create the keys 
selfAttention.W_k.weight.transpose(0, 1)

tensor([[-0.1549, -0.3443],
        [ 0.1427,  0.4153]], grad_fn=<TransposeBackward0>)

In [12]:
## Weight matrix to create the values
selfAttention.W_v.weight.transpose(0, 1)

tensor([[ 0.6233,  0.6146],
        [-0.5188,  0.1323]], grad_fn=<TransposeBackward0>)

In [13]:
# Calculate the queries 
selfAttention.W_q(encodings_matrix)

tensor([[ 0.7621, -0.0428],
        [ 1.1063,  0.7890],
        [ 1.1164, -2.1336]], grad_fn=<MmBackward0>)

In [14]:
# Calculate the keys 
selfAttention.W_k(encodings_matrix)

tensor([[-0.1469, -0.3038],
        [ 0.1057,  0.3685],
        [-0.9914, -2.4152]], grad_fn=<MmBackward0>)

In [15]:
# Calculate the values 
selfAttention.W_v(encodings_matrix)

tensor([[ 0.6038,  0.7434],
        [-0.3502,  0.5303],
        [ 3.8695,  2.4246]], grad_fn=<MmBackward0>)

In [16]:
q = selfAttention.W_q(encodings_matrix)
q

tensor([[ 0.7621, -0.0428],
        [ 1.1063,  0.7890],
        [ 1.1164, -2.1336]], grad_fn=<MmBackward0>)

In [17]:
k = selfAttention.W_k(encodings_matrix)
k

tensor([[-0.1469, -0.3038],
        [ 0.1057,  0.3685],
        [-0.9914, -2.4152]], grad_fn=<MmBackward0>)

In [18]:
sims = torch.matmul(q, k.transpose(dim0 = 0, dim1 = 1))
sims

tensor([[-0.0990,  0.0648, -0.6523],
        [-0.4022,  0.4078, -3.0024],
        [ 0.4842, -0.6683,  4.0461]], grad_fn=<MmBackward0>)

In [20]:
scaled_sims = sims / (torch.tensor(2)**0.5)
scaled_sims

tensor([[-0.0700,  0.0458, -0.4612],
        [-0.2844,  0.2883, -2.1230],
        [ 0.3424, -0.4725,  2.8610]], grad_fn=<DivBackward0>)

In [21]:
attention_percents = F.softmax(scaled_sims, dim=1)
attention_percents

tensor([[0.3573, 0.4011, 0.2416],
        [0.3410, 0.6047, 0.0542],
        [0.0722, 0.0320, 0.8959]], grad_fn=<SoftmaxBackward0>)

In [22]:
torch.matmul(attention_percents, selfAttention.W_v(encodings_matrix))

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)