In [1]:
import torch 
import numpy as np
import matplotlib.pyplot as plt

In [2]:
from torch import nn
from torch.nn import functional as F

In [40]:
# self attention
class self_attention(nn.Module):
    '''
    Module to apply self attention to an input sequence of vectors
    
    parameters:
    
    emb_dim = dimension of the embedding vector
    h = number of self attention heads
    
    '''
    def __init__(self, emb_dim, h):
        super().__init__()
        self.emb_dim = emb_dim
        self.h = h
        self.red_vec_size = emb_dim//h
        
        # Querry vector
        self.WQ = nn.Linear(emb_dim, self.red_vec_size, bias = False)
        self.WK = nn.Linear(emb_dim, self.red_vec_size, bias = False)
        self.WV = nn.Linear(emb_dim, self.red_vec_size, bias = False)
        
    def forward(self, x):
        # x has shape (batch_size, seq_len, emb_dim)
        batch_size = x.shape[0]
        seq_len = x.shape[1]
        querries = self.WQ(x)
        keys = self.WK(x)
        values = self.WV(x)
        att_scores = F.softmax((querries@keys.permute(0,2,1)).permute(0,2,1), dim = 2)
        ctx_vecs = att_scores @ values 
        assert ctx_vecs.shape == (batch_size, seq_len, self.red_vec_size ) 
        return querries, keys, values, ctx_vecs

In [26]:
batch_size = 5
seq_len = 3
emb_dim = 512
h = 8
x = torch.randn((batch_size, seq_len, emb_dim))
attn = self_attention(emb_dim, h)

In [27]:
attn

self_attention(
  (WQ): Linear(in_features=512, out_features=64, bias=False)
  (WK): Linear(in_features=512, out_features=64, bias=False)
  (WV): Linear(in_features=512, out_features=64, bias=False)
)

In [28]:
querries, keys, values, att_scores, ctx_vecs = attn(x)

In [29]:
querries.shape, keys.shape, values.shape, ctx_vecs.shape

(torch.Size([5, 3, 64]),
 torch.Size([5, 3, 64]),
 torch.Size([5, 3, 64]),
 torch.Size([5, 3, 64]))

In [30]:
att_scores

tensor([[[6.7062e-03, 8.0831e-01, 1.8499e-01],
         [5.6242e-02, 1.6197e-01, 7.8178e-01],
         [9.7505e-01, 4.1655e-03, 2.0783e-02]],

        [[3.1738e-02, 8.8465e-01, 8.3607e-02],
         [8.6128e-01, 6.6878e-02, 7.1839e-02],
         [8.9534e-01, 4.2649e-02, 6.2008e-02]],

        [[4.3472e-02, 8.5899e-01, 9.7543e-02],
         [1.5764e-03, 1.4066e-02, 9.8436e-01],
         [5.0195e-01, 1.5746e-02, 4.8231e-01]],

        [[9.9267e-01, 1.2288e-03, 6.1012e-03],
         [3.0698e-02, 5.0208e-03, 9.6428e-01],
         [5.6540e-01, 1.0565e-01, 3.2894e-01]],

        [[9.5420e-02, 8.9621e-01, 8.3741e-03],
         [2.1166e-02, 2.6959e-01, 7.0924e-01],
         [1.6179e-03, 9.9821e-01, 1.6967e-04]]], grad_fn=<SoftmaxBackward>)

In [37]:
att_scores[0,0]@values[0]

tensor([ 0.1312, -0.1994, -0.8424, -0.2728, -0.1895, -0.4478, -0.2879,  0.1850,
         0.0476, -0.2647, -0.1008,  1.0653, -0.7457, -0.9336,  0.0162,  0.3052,
         0.1973,  0.1181,  0.4215, -0.4157, -0.7890,  0.1855,  0.7934, -0.1824,
        -0.0045,  0.2831,  0.3987, -0.4449,  0.7939,  0.0970,  0.8292, -0.3705,
        -0.2376,  0.1030,  0.1147,  0.3346,  0.1519, -0.6783,  0.3524, -0.3001,
        -0.1445, -0.1007,  0.0783, -0.3000,  0.1789,  0.1023,  0.8283, -0.3681,
        -0.2386, -0.1363,  0.4252, -0.3139,  0.2909,  0.8021,  0.4274, -1.2823,
         0.4824, -0.0983, -0.6733, -0.4979, -0.9194, -0.4062,  0.4497, -0.0444],
       grad_fn=<SqueezeBackward3>)

In [39]:
att_scores[0,0].sum()

tensor(1., grad_fn=<SumBackward0>)