In [3]:
'''
batch size is 64; k words; 300 dim for each words
k:[64, 10 ,300];
v:[64, 10 ,300]
q:[64, 12 ,300]
'''
import torch
from torch import nn


class MultiHeadAttention(nn.Module):
    def __init__(self, hid_dim,n_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        self.hid_dim = hid_dim
        self.n_heads = n_heads

        #force head dim max full divide by heads
        assert hid_dim % n_heads == 0
        # define weight Q
        self.w_q = nn.Linear(hid_dim, hid_dim)
        self.w_k = nn.Linear(hid_dim, hid_dim)
        self.w_v = nn.Linear(hid_dim, hid_dim)

        self.fc = nn.Linear(hid_dim, hid_dim)
        self.do = nn.Dropout(dropout)

        self.scale = torch.sqrt(torch.FloatTensor([hid_dim// n_heads]))

    def forward(self, q, k, v, mask=None):
        bsz = q.shape[0]
        Q = self.w_q(q)
        K = self.w_k(k)
        V = self.w_v(v)

        #divide the dim into n_heads attention groups
        #K :[64, 10, 300] - > [64, 10, 6,50]  divide into 6 attention groups. each group will have 50 dim
        # and then permutate the attention group 6 to the from, 10 words 50 dim to the back to calculate convenience.
        Q = Q.view(bsz, -1, self.n_heads, self.hid_dim//self.n_heads).permute(0, 2, 1, 3)
        K = K.view(bsz, -1, self.n_heads, self.hid_dim//self.n_heads).permute(0, 2, 1, 3)
        V = V.view(bsz, -1, self.n_heads, self.hid_dim//self.n_heads).permute(0, 2, 1, 3)

        #step 1. Q*K'/ scale : [64, 6,12,50] * [64, 6, 50,10] = [64,6, 12,10]
        #attention = [64,6, 12,10]
        attention = torch.matmul(Q, K.permute(0, 1, 3, 2))/self.scale  #K' is transfer the last two dim as matrix

        # if mask is not None; then mark the attention on the positiion at mask = 0 to -1e10 (means that can not been attentioned like padding unk)
        if mask is not None:
            attention = attention.masked_fill(mask==0, -1e10)

        # step2; softmax and dropout
        attention = self.do(torch.softmax(attention, dim=-1))

        #step3: multiply attention with value get results of attention:
        #[64,6,12,10] * [64,6,10,50] = [64,6,12,50]
        x = torch.matmul(attention, V)

        # we have 12 words for query, we put 12 inthe front and put 50,6 at the end for calculation convenience:
        # x [64,6,12,50] -> [64,12,6,50]
        x = x.permute(0, 2, 1, 3).contiguous()
        # concate multihead results:
        # x = [64,12,6,50] -> [64,12,300]
        x = x.view(bsz, -1, self.n_heads*(self.hid_dim//self.n_heads))
        x = self.fc(x)
        return x

query = torch.rand(64,12,300)
key = torch.rand(64,10,300)
value = torch.rand(64,10,300)

attention = MultiHeadAttention(300,6,0.1)
out = attention(query, key, value)
print(out.shape)

torch.Size([64, 12, 300])
