In [2]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn

import torch.nn.functional as F

In [29]:
class Projector(nn.Module):
    """
    Making projection matrix(Q, K, V) for each attention head
    When you call this class, it returns projection matrix of each attention head
    For example, if you call this class with 8 heads, it returns 8 set of projection matrices (Q, K, V)
    Args:
        num_heads: number of heads in MHA, default 8
        dim_head: dimension of each attention head, default 64
    """
    def __init__(self, num_heads: int = 8, dim_head: int = 64) -> None:
        super(Projector, self).__init__()
        self.dim_model = num_heads * dim_head
        self.num_heads = num_heads
        self.dim_head = dim_head

    def __call__(self):
        fc_q = nn.Linear(self.dim_model, self.dim_head)
        fc_k = nn.Linear(self.dim_model, self.dim_head)
        fc_v = nn.Linear(self.dim_model, self.dim_head)
        return fc_q, fc_k, fc_v


class MultiHeadAttention(nn.Module):
    """
    Class for multi-head attention (MHA) module in vanilla transformer
    We apply linear transformation to input vector by each attention head's projection matrix (8, 512, 64)
    Other approaches are possible, such as using one projection matrix for all attention heads (1, 512, 512)
    and then split into each attention heads (8. 512, 64)
    Args:
        dim_model: dimension of model's latent vector space, default 512 from official paper
        num_heads: number of heads in MHA, default 8 from official paper
        dropout: dropout rate, default 0.1
    Math:
        MHA(Q, K, V) = Concat(Head1, Head2, ... Head8) * W_concat
    Reference:
        https://arxiv.org/abs/1706.03762
    """
    def __init__(self, dim_model: int = 512, num_heads: int = 8, dropout: float = 0.1) -> None:
        super(MultiHeadAttention, self).__init__()
        self.dim = dim_model
        self.num_heads = num_heads
        self.dropout = dropout
        self.dim_head = int(self.dim / self.num_heads)  # dimension of each attention head
        self.dot_scale = torch.sqrt(torch.tensor(self.dim_head))  # scale factor for Q•K^T Result

        # linear combination: projection matrix(Q_1, K_1, V_1, ... Q_n, K_n, V_n) for each attention head
        self.projector = Projector(self.num_heads, self.dim_head)  # init instance
        self.projector_list = [list(self.projector()) for _ in range(self.num_heads)]  # call instance
        self.fc_concat = nn.Linear(self.dim, self.dim)  # for concatenation of each attention head

    def forward(self, x: torch.Tensor, mask: bool = None) -> torch.Tensor:
        """
        1) make Q, K, V matrix for each attention head: [BS, HEAD, SEQ_LEN, DIM_HEAD], ex) [10, 8, 512, 64]
        2) Do self-attention in each attention head
            - Matmul (Q, K^T) with scale factor (sqrt(DIM_HEAD))
            - Mask for padding token (Option for Decoder)
            - Softmax
            - Matmul (Softmax, V)
        3) Concatenate each attention head & linear transformation (512, 512)
        """
        # 1) make Q, K, V matrix for each attention head
        Q, K, V = [], [], []

        for i in range(self.num_heads):
            Q.append(self.projector_list[i][0](x))
            K.append(self.projector_list[i][1](x))
            V.append(self.projector_list[i][2](x))

        Q = torch.stack(Q, dim=1)
        K = torch.stack(K, dim=1)
        V = torch.stack(V, dim=1)
        # 2) Do self-attention in each attention head
        attention_score = torch.matmul(Q, K.transpose(-1, -2)) / self.dot_scale
        if mask is not None:  # for padding token
            attention_score[mask] = float('-inf')
        attention_dist = F.softmax(attention_score, dim=-1)  # [BS, HEAD, SEQ_LEN, SEQ_LEN]
        attention_matrix = torch.matmul(attention_dist, V).transpose(1, 2).reshape(x.shape[0], x.shape[1], self.dim)  # [BS, SEQ_LEN, DIM]

        # 3) Concatenate each attention head & linear transformation (512, 512)
        x = self.fc_concat(attention_matrix)
        return x

In [30]:
""" Debug for MultiHeadAttention """

x = torch.randn(10, 512, 512)
test_head = MultiHeadAttention()
test_result = test_head(x)
test_result, test_result.shape

(tensor([[[-0.0150,  0.0380, -0.0050,  ..., -0.0405, -0.0748, -0.0368],
          [-0.0525,  0.0485, -0.0082,  ..., -0.0480, -0.0552, -0.0323],
          [-0.0359,  0.0470, -0.0003,  ..., -0.0431, -0.0615, -0.0242],
          ...,
          [-0.0433,  0.0437, -0.0037,  ..., -0.0463, -0.0659, -0.0287],
          [-0.0339,  0.0363, -0.0071,  ..., -0.0500, -0.0526, -0.0325],
          [-0.0377,  0.0323, -0.0083,  ..., -0.0482, -0.0634, -0.0343]],
 
         [[-0.0337,  0.0676, -0.0072,  ..., -0.0437, -0.0553, -0.0332],
          [-0.0287,  0.0731, -0.0007,  ..., -0.0512, -0.0519, -0.0123],
          [-0.0415,  0.0732, -0.0023,  ..., -0.0372, -0.0561, -0.0157],
          ...,
          [-0.0395,  0.0665, -0.0115,  ..., -0.0480, -0.0525, -0.0363],
          [-0.0365,  0.0787, -0.0015,  ..., -0.0318, -0.0453, -0.0211],
          [-0.0487,  0.0730, -0.0136,  ..., -0.0303, -0.0521, -0.0293]],
 
         [[-0.0191,  0.0416, -0.0245,  ..., -0.0369, -0.0279, -0.0012],
          [-0.0197,  0.0395,