In [1]:
import torch 
import numpy as np
import matplotlib.pyplot as plt

In [2]:
from torch import nn
from torch.nn import functional as F

In [3]:
# self attention
class self_attention(nn.Module):
    '''
    Module to apply self attention to an input sequence of vectors
    
    parameters:
    
    emb_dim = dimension of the embedding vector
    h = number of self attention heads
    
    '''
    def __init__(self, emb_dim, h):
        super().__init__()
        self.emb_dim = emb_dim
        self.h = h
        self.red_vec_size = emb_dim//h
        
        # Querry vector
        self.WQ = nn.Linear(emb_dim, self.red_vec_size, bias = False)
        self.WK = nn.Linear(emb_dim, self.red_vec_size, bias = False)
        self.WV = nn.Linear(emb_dim, self.red_vec_size, bias = False)
        
    def forward(self, x):
        # x has shape (batch_size, seq_len, emb_dim)
        batch_size = x.shape[0]
        seq_len = x.shape[1]
        querries = self.WQ(x)
        keys = self.WK(x)
        values = self.WV(x)
        att_scores = F.softmax((querries@keys.permute(0,2,1)).permute(0,2,1)\
                               /np.sqrt(self.red_vec_size), dim = 2)
        ctx_vecs = att_scores @ values 
        assert ctx_vecs.shape == (batch_size, seq_len, self.red_vec_size ) 
        return querries, keys, values, ctx_vecs

In [4]:
batch_size = 5
seq_len = 3
emb_dim = 4
h = 1
x = torch.randn((batch_size, seq_len, emb_dim))
attn = self_attention(emb_dim, h)

In [5]:
attn

self_attention(
  (WQ): Linear(in_features=4, out_features=4, bias=False)
  (WK): Linear(in_features=4, out_features=4, bias=False)
  (WV): Linear(in_features=4, out_features=4, bias=False)
)

In [6]:
querries, keys, values, ctx_vecs = attn(x)

In [7]:
querries.shape, keys.shape, values.shape, ctx_vecs.shape

(torch.Size([5, 3, 4]),
 torch.Size([5, 3, 4]),
 torch.Size([5, 3, 4]),
 torch.Size([5, 3, 4]))

In [8]:
ctx_vecs

tensor([[[ 0.2804,  0.5018, -0.7626,  0.0181],
         [ 0.4935, -0.0438,  0.1872, -0.3386],
         [ 0.4075,  0.2587, -0.2906, -0.1660]],

        [[-0.2464, -0.2251,  0.3710, -0.0988],
         [-0.1341, -0.1451,  0.2347, -0.1433],
         [-0.1541, -0.1500,  0.2305, -0.1329]],

        [[ 0.0843, -0.2047,  0.1024,  0.1585],
         [ 0.0067, -0.3096,  0.1880,  0.1146],
         [ 0.1305, -0.1400,  0.0396,  0.2153]],

        [[-0.4525, -0.3964,  0.3201, -0.1685],
         [-0.4658, -0.4039,  0.3224, -0.1578],
         [-0.5057, -0.4148,  0.3209, -0.1241]],

        [[ 0.0791,  0.1095, -0.2467,  0.0517],
         [ 0.1067,  0.0871, -0.2460,  0.0682],
         [ 0.0070,  0.3619, -0.5443,  0.2537]]], grad_fn=<UnsafeViewBackward>)

In [9]:
attn(x)[3]

tensor([[[ 0.2804,  0.5018, -0.7626,  0.0181],
         [ 0.4935, -0.0438,  0.1872, -0.3386],
         [ 0.4075,  0.2587, -0.2906, -0.1660]],

        [[-0.2464, -0.2251,  0.3710, -0.0988],
         [-0.1341, -0.1451,  0.2347, -0.1433],
         [-0.1541, -0.1500,  0.2305, -0.1329]],

        [[ 0.0843, -0.2047,  0.1024,  0.1585],
         [ 0.0067, -0.3096,  0.1880,  0.1146],
         [ 0.1305, -0.1400,  0.0396,  0.2153]],

        [[-0.4525, -0.3964,  0.3201, -0.1685],
         [-0.4658, -0.4039,  0.3224, -0.1578],
         [-0.5057, -0.4148,  0.3209, -0.1241]],

        [[ 0.0791,  0.1095, -0.2467,  0.0517],
         [ 0.1067,  0.0871, -0.2460,  0.0682],
         [ 0.0070,  0.3619, -0.5443,  0.2537]]], grad_fn=<UnsafeViewBackward>)

In [10]:
class multi_head_attn(nn.Module):
    '''
    Module to create multiple attention heads
    
    parameters:
    
    emb_dim = dimension of the embedding vectors
    h = number of attention heads
    parallelize = parallelize the computations for differnt heads 
    
    '''
    def __init__(self, emb_dim, h, parallelize = 'False'):
        super().__init__()
        self.emb_dim = emb_dim
        self.h = h
        self.red_vec_size = emb_dim // h 
        
        self.heads = [self_attention(emb_dim, h) for i in range(h)]
        
        # transform the contatenated context vectors to have same size as emb_sim
        # this is to be able to enable implement a skip-connection between the input and output
        self.Wo = nn.Linear(self.red_vec_size*h, emb_dim, bias = False) 
        
        # layer norm
        # should we apply 
        self.LNorm = nn.LayerNorm(emb_dim)
        
    def forward(self, x):
        ctx_vecs = torch.cat([head(x)[3] for head in self.heads], dim = 2)
        transformed = self.Wo(ctx_vecs)
        
        return self.LNorm(x + transformed)

In [11]:
batch_size = 5
seq_len = 3
emb_dim = 6
h = 2
x = torch.randn((batch_size, seq_len, emb_dim))
multihead = multi_head_attn(emb_dim, h)

In [12]:
ctx = multihead(x)

In [13]:
ctx.shape

torch.Size([5, 3, 6])

In [14]:
ctx

tensor([[[-1.3967, -0.4712, -0.7870,  0.4672,  1.6398,  0.5479],
         [-0.9126,  0.7557, -1.0113,  1.1812, -1.0501,  1.0370],
         [-0.3855, -0.5054,  2.0535,  0.3557, -0.5357, -0.9826]],

        [[-1.6022,  0.9865,  1.3913,  0.0374, -0.0966, -0.7164],
         [-0.0444,  1.2961,  0.6694,  0.3283, -0.3399, -1.9095],
         [-0.6499, -0.5765,  0.4959, -0.8299,  2.0238, -0.4634]],

        [[-0.0690, -1.8813,  1.2001, -0.3236,  0.9458,  0.1280],
         [ 0.7996, -0.6632,  0.8283, -0.2290,  1.0307, -1.7663],
         [-1.1234,  1.3501, -1.3677,  0.9640,  0.3114, -0.1345]],

        [[-1.7464, -0.4007,  0.6703, -0.3558,  0.3994,  1.4332],
         [ 0.3436, -1.4343,  1.6375, -0.9994,  0.3713,  0.0814],
         [ 0.7205, -2.0331,  0.0393,  0.9465,  0.6087, -0.2819]],

        [[ 1.1007,  0.5273, -0.2811,  1.1312, -1.0375, -1.4406],
         [-0.6268, -0.2386,  0.2656,  0.6721, -1.6213,  1.5490],
         [-0.0276, -0.2385, -1.3998,  1.7999, -0.6727,  0.5388]]],
       grad_fn=

In [15]:
multihead

multi_head_attn(
  (Wo): Linear(in_features=6, out_features=6, bias=False)
  (LNorm): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
)

In [16]:
list(multihead.LNorm.parameters())

[Parameter containing:
 tensor([1., 1., 1., 1., 1., 1.], requires_grad=True), Parameter containing:
 tensor([0., 0., 0., 0., 0., 0.], requires_grad=True)]

In [17]:
ctx.mean(dim = 2)

tensor([[-4.9671e-08,  1.9868e-08,  3.9736e-08],
        [ 3.9736e-08, -3.9736e-08,  1.9868e-08],
        [-3.2286e-08,  0.0000e+00, -2.9802e-08],
        [-1.9868e-08,  8.6923e-09,  1.9868e-08],
        [ 1.9868e-08,  0.0000e+00, -9.9341e-09]], grad_fn=<MeanBackward1>)

In [18]:
ctx.std(dim = 2)

tensor([[1.0954, 1.0954, 1.0954],
        [1.0954, 1.0954, 1.0954],
        [1.0954, 1.0954, 1.0954],
        [1.0954, 1.0954, 1.0954],
        [1.0954, 1.0954, 1.0954]], grad_fn=<StdBackward1>)