In [12]:
import torch 
import numpy as np
import matplotlib.pyplot as plt

In [13]:
from torch import nn
from torch.nn import functional as F

In [14]:
# self attention
class self_attention(nn.Module):
    '''
    Module to apply self attention to an input sequence of vectors
    
    parameters:
    
    emb_dim = dimension of the embedding vector
    h = number of self attention heads
    
    '''
    def __init__(self, emb_dim, h):
        super().__init__()
        self.emb_dim = emb_dim
        self.h = h
        self.red_vec_size = emb_dim//h
        
        # Querry vector
        self.WQ = nn.Linear(emb_dim, self.red_vec_size, bias = False)
        self.WK = nn.Linear(emb_dim, self.red_vec_size, bias = False)
        self.WV = nn.Linear(emb_dim, self.red_vec_size, bias = False)
        
    def forward(self, x):
        # x has shape (batch_size, seq_len, emb_dim)
        batch_size = x.shape[0]
        seq_len = x.shape[1]
        querries = self.WQ(x)
        keys = self.WK(x)
        values = self.WV(x)
        att_scores = F.softmax((querries@keys.permute(0,2,1)).permute(0,2,1)\
                               /torch.sqrt(self.red_vec_size), dim = 2)
        ctx_vecs = att_scores @ values 
        assert ctx_vecs.shape == (batch_size, seq_len, self.red_vec_size ) 
        return querries, keys, values, ctx_vecs

In [15]:
batch_size = 5
seq_len = 3
emb_dim = 512
h = 8
x = torch.randn((batch_size, seq_len, emb_dim))
attn = self_attention(emb_dim, h)

In [16]:
attn

self_attention(
  (WQ): Linear(in_features=512, out_features=64, bias=False)
  (WK): Linear(in_features=512, out_features=64, bias=False)
  (WV): Linear(in_features=512, out_features=64, bias=False)
)

In [17]:
querries, keys, values, ctx_vecs = attn(x)

TypeError: sqrt(): argument 'input' (position 1) must be Tensor, not int

In [8]:
querries.shape, keys.shape, values.shape, ctx_vecs.shape

(torch.Size([5, 3, 64]),
 torch.Size([5, 3, 64]),
 torch.Size([5, 3, 64]),
 torch.Size([5, 3, 64]))

In [11]:
sqrt(2)

NameError: name 'sqrt' is not defined