In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import math

# Model Architecture

## Multi Head Self Attention 

In [3]:
class MultiHeadAttention(nn.Module):
    def __ini__(self, h, d_model):
        super().__init__()

        # check in hidden dimension of model divisible by number of attention heads
        assert d_model%h == 0, 'd_model must be divisible by h'

        # initialize hyperparams of multihead attention
        self.h = h
        self.d_model = d_model
        self.d_k = d_model//h

        # initialize weights of key, query and values
        self.w_Q = nn.Linear(d_model, d_model)
        self.w_K = nn.Linear(d_model, d_model)
        self.w_V = nn.Linear(d_model, d_model)

        # initialize a linear layer to concat the attention vector
        self.w_O = nn.Linear(d_model, d_model)


    def self_attention(self, Q, K, V, mask=None, dropout=None):
        # calculated self attention scores
        attn_score =  torch.matmul(Q, K.transpose(-2, -1))/math.sqrt(self.d_k)

        # apply masking to self attention scores
        if mask is not None:
            attn_score = attn_score.masked_fill(mask==0, 1e-9)

        # apply dropout to self attention scores
        if dropout:
            attn_score = nn.Dropout(dropout)(attn_score)

        return torch.matmul(nn.Softmax(attn_score), V), attn_score
        
    
    def forward(self, input, mask=None, dropout=None):
        # calulate query, key and values
        Q, K, V = self.w_Q(input), self.w_K(input), self.w_V(input)

        # input shape
        batch_size, seq_lenth, d_model = input.size()

        # linear projections of Q,k,V into h heads
        Q = Q.view(batch_size, seq_lenth, self.h, self.d_k).transpose(1,2)
        K = K.view(batch_size, seq_lenth, self.h, self.d_k).transpose(1,2)
        V = V.view(batch_size, seq_lenth, self.h, self.d_k).transpose(1,2)

        # calculate the self attention scores
        x, attn_score = self.self_attention(Q, K, V, mask, dropout)

        # concat the linear projections of attention scores and apply one linear layer
        x = x.transpose(1,2).contiguous().view(batch_size, seq_lenth, self.h, self.d_k)
        return self.w_O(x)
