In [2]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

In [3]:
#hyperparameter 
sequence_length = 512
embedding_dimension = 300
head_dimension = 32
number_heads = 8
batch_size = 16
scaling_factor = head_dimension**-0.5

In [4]:
input_data = torch.rand((batch_size,sequence_length,embedding_dimension))
print(input_data.shape)

torch.Size([16, 512, 300])


In [15]:
#KQV transformation
query_matrix = nn.Linear(embedding_dimension,head_dimension*number_heads)
key_matrix = nn.Linear(embedding_dimension,head_dimension*number_heads)
value_matrix = nn.Linear(embedding_dimension,head_dimension*number_heads)

In [25]:
query = query_matrix(input_data)
key = key_matrix(input_data)
value = value_matrix(input_data)

q = query.reshape(batch_size,sequence_length,number_heads,head_dimension)
k = key.reshape(batch_size,sequence_length,number_heads,head_dimension)
v = value.reshape(batch_size,sequence_length,number_heads,head_dimension)

q = q.transpose(1,2)
k = k.transpose(1,2)
v = v.transpose(1,2)

print(q.shape)
qk = q@k.transpose(-1,-2)* scaling_factor
print(qk.shape)

mask = torch.ones((qk.shape[-1],qk.shape[-1]),dtype=torch.bool).triu(1)
qk_masked = qk.masked_fill(mask,float('-inf'))
qk_sofmax  = F.softmax(qk_masked,dim=-1)
print(qk_sofmax.shape)
new_value = qk_sofmax@v
new_value = new_value.transpose(1,2)
new_value = new_value.reshape(batch_size,sequence_length,-1)
print(new_value.shape)

torch.Size([16, 8, 512, 32])
torch.Size([16, 8, 512, 512])
torch.Size([16, 8, 512, 512])
torch.Size([16, 512, 256])


In [46]:
class MultiHeadsAttention(nn.Module):
    def __init__(self,embedding_dimension,heads=8,head_dimension=32):
        super().__init__()
        self.heads = heads
        self.head_dimension = head_dimension
        self.embedding_dimension = embedding_dimension
        self.scaling_factor = self.head_dimension**-0.5

        self.query_proj = nn.Linear(self.embedding_dimension,self.heads*self.head_dimension)
        self.key_proj = nn.Linear(self.embedding_dimension,self.heads*self.head_dimension)
        self.value_proj = nn.Linear(self.embedding_dimension,self.heads*self.head_dimension)
        self.output_proj = nn.Linear(self.heads*self.head_dimension, self.embedding_dimension)

    def forward(self,input):
        batch_size,sq_length, _ = input.shape
        query = self.query_proj(input)
        key = self.key_proj(input)
        value = self.value_proj(input)

        q = query.reshape(batch_size,sequence_length,self.heads,self.head_dimension).transpose(2,1)
        k = key.reshape(batch_size,sequence_length,self.heads,self.head_dimension).transpose(2,1)
        v = value.reshape(batch_size,sequence_length,self.heads,self.head_dimension).transpose(2,1)

        qk = q@k.transpose(-1,-2)*self.scaling_factor

        ##### position embedding here
        # qk = qk + relative_position_embedding
        
        mask = torch.ones((qk.shape[-1],qk.shape[-1]), dtype=torch.bool).triu(1)
        qk_masked = qk.masked_fill(mask,float('-inf'))
        qk_softmax = F.softmax(qk_masked,dim=-1)

        print(qk_softmax.shape, v.shape)
        new_value = qk_softmax@v
        new_value  = new_value.transpose(1,2).reshape(batch_size,sequence_length,-1)

        output = self.output_proj(new_value)

        return output

In [5]:
mha = MultiHeadsAttention(300)
input = torch.rand(8,512,300)
print(mha(input).shape)

NameError: name 'MultiHeadsAttention' is not defined