In [4]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

class MyMultiHeadAttention(nn.Module):
    def __init__(self, d_model:int, n_heads:int, d_k:int, d_v:int, enable_gpu:bool):
        super(MyMultiHeadAttention, self).__init__()
        self.d_model = d_model  # 300
        self.n_heads = n_heads  # 20
        self.d_k = d_k  # 20
        self.d_v = d_v  # 20
        self.enable_gpu = enable_gpu
        self.W_Q = nn.Linear(d_model, d_k * n_heads)  # 300, 400
        self.W_K = nn.Linear(d_model, d_k * n_heads)  # 300, 400
        self.W_V = nn.Linear(d_model, d_v * n_heads)  # 300, 400

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=1)

    def forward(self, Q, K=None, V=None, mask=None):
        #       Q, K, V: [bz, seq_len, 300] -> W -> [bz, seq_len, 400]-> q_s: [bz, 20, seq_len, 20]
        if K is None:
            K = Q
        if V is None:
            V = Q
        batch_size, seq_len, _ = Q.shape

        q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads,
                               self.d_k).transpose(1, 2)
        k_s = self.W_K(K).view(batch_size, -1, self.n_heads,
                               self.d_k).transpose(1, 2)
        v_s = self.W_V(V).view(batch_size, -1, self.n_heads,
                               self.d_v).transpose(1, 2)

        if mask is not None:
            mask = mask.unsqueeze(1).expand(batch_size, seq_len, seq_len) #  [bz, seq_len, seq_len]
            mask = mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1) # attn_mask : [bz, 20, seq_len, seq_len]

        context = F.scaled_dot_product_attention(q_s, k_s, v_s, mask)  # [bz, 20, seq_len, 20]
        context = context.transpose(1, 2).contiguous().view(
            batch_size, -1, self.n_heads * self.d_v)  # [bz, seq_len, 400]
        #         output = self.fc(context)
        return context  #self.layer_norm(output + residual)

In [9]:
mh1 = MyMultiHeadAttention(768, 12, 64, 64, True)
input1 = torch.rand((20,40, 768))
output1 = mh1(input1)

In [10]:
output1.shape

torch.Size([20, 40, 768])

In [13]:
mh2 = nn.MultiheadAttention(768, 12)
output2 = mh2(input1, input1, input1)
print(output2[0].shape, output2[1].shape)

torch.Size([20, 40, 768]) torch.Size([40, 20, 20])
