In [20]:
# SelfAttention 
# learn which words are important 
# eg : the cat sat on the mat ,what should  "mat " pay ateention to 
# mainy t"the", "sat ", and maybe "cat"
# self attention learns which words matter for each word.

import torch 
import torch.nn as nn
import torch.nn.functional as F


class SelfAttention(nn.Module):
    def __init__(self, embedding_dim, head_dim):
        super().__init__()
        self.head_dim = head_dim

        # Linear projections 
        self.query = nn.Linear(embedding_dim,head_dim, bias= False)
        self.key = nn.Linear(embedding_dim,head_dim, bias= False)
        self.value = nn.Linear(embedding_dim,head_dim, bias= False)
        # casual mask (prevent looking at future tkens)
        # lowerr traingular matrix 

    def forward(self, x):
        batch, seq_len, emb_dim = x.shape 

        # each word is transformed into three versions:
        # Q = Query - What am I looking for ?
        #  K : Key  - "what information do I have?"
        # V : Value : " What my actual contetn?"

        # step 1: Project to Q, K, V
        Q = self.query(x) # (batch, seq_len, head_dim)
        K = self.query(x) # (batch, seq_len, head_dim)
        V = self.query(x) # (batch, seq_len, head_dim)




        # step2: Compute attention score 
        # for each word ask: "How similar am I to every other word?"
        # This is : Query @ Key^T 
        # eg with 3 words 
        # Word 1 query : [0.1, 0.5, 0.2]
        # word 2 key : [0.2, 0.4, 0.3]
        # word 3 key : [0.0, 0.8, 0.1]

        # Similarity  to word 1 : dot prodct with itself 
        # 0.1*0.1 + 0.5*0.5 + 0.2*0.2 
        # Similarity to word 2: 0.1*0.2 + 0.5*0.4 + 0.2*0.3 = 0.24
        # Similarity to word 3: 0.1*0.0 + 0.5*0.8 + 0.2*0.1 = 0.42



        scores = Q @ K.transpose(-2, -1)

        # step 3 : Scale 
        # Divide by sqrt(head_dim) so score don't explode
        # this prevents softmax from becoming too confident (all between 0 and  1)

        scores = scores / (self.head_dim ** 0.5)

        # apply causal mask
        # to avoid early tokens see future tokens 
        # - char1 can see only char  1 
        #  char2 can see char 1 and 2 
        # char 3 can see 1,2,3 and so on 
        # mask li=ooks like lower traingular matrix 
        # [1 0 0 0] can see position 0 only 
        # [1 1 0 0] can see0 & 1 
        # [1 1 1 0] can see 0,1,2
        # [1 1 1 1] can see all 
        mask = torch.tril(torch.ones(seq_len, seq_len))
        scores = scores.masked_fill(mask ==0, float('-inf'))

        # step 5: Softmax (get Probablitits)
        # should sum to 1 
        # make upper trangle -inf so when appled softamx it gives 0 
        # high score become high probablity 
        # eg  

        attention_weights = F.softmax(scores, dim=-1)


        # step 6;  Weighted sum of values 
        # for each word, combine all other words' values using attention weights
        # high weight = take more of that word's content 
        #  low weight = take less of that word's content  
        output = attention_weights @ V 

        return output, attention_weights

In [23]:
# test 
embedding_dim =  8 
head_dim = 9 
attention = SelfAttention(embedding_dim, head_dim)

# dummy data 
X = torch.randn(1,4, embedding_dim)

# forward pass 
output,weights = attention(X)

weights[0]


tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.3319, 0.6681, 0.0000, 0.0000],
        [0.0205, 0.1064, 0.8731, 0.0000],
        [0.0546, 0.1591, 0.4197, 0.3666]], grad_fn=<SelectBackward0>)