In [2]:
import numpy as np
import torch

## Multi-head Attention

In [3]:
def scaled_dot_product_attention(q, k, v, mask) :
    scaled_attention_logits = torch.bmm(q, k.transpose(1,2)) / np.sqrt(k.size(-1))

    if mask is not None :
        scaled_attention_logits.masked_fill_(mask, -1e9)

    attention_weights = torch.nn.functional.softmax(scaled_attention_logits, -1)
    output = torch.bmm(attention_weights, v)
    
    return output, attention_weights

In [4]:
class MultiHeadAttention(torch.nn.Module) :
    def __init__(self, d_model=512, num_heads=8) :
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads

        assert d_model % num_heads == 0

        self.depth = d_model // num_heads

        self.wq = torch.nn.Linear(d_model, d_model, bias=True)
        self.wk = torch.nn.Linear(d_model, d_model, bias=True)
        self.wv = torch.nn.Linear(d_model, d_model, bias=True)

        self.linear = torch.nn.Linear(d_model, d_model, bias=True) # ??

    def forward(self, q, k, v, mask=None) :
        batch_size = v.size(0)

        q = self.wq(q).view(batch_size, -1, self.num_heads, self.depth)
        k = self.wq(k).view(batch_size, -1, self.num_heads, self.depth)
        v = self.wq(v).view(batch_size, -1, self.num_heads, self.depth)

        # split heads
        q = q.permute(2,0,1,3).contiguous().view(batch_size * self.num_heads, -1, self.depth)
        k = k.permute(2,0,1,3).contiguous().view(batch_size * self.num_heads, -1, self.depth)
        v = v.permute(2,0,1,3).contiguous().view(batch_size * self.num_heads, -1, self.depth)

        if mask is not None :
            mask = mask.repeat(self.num_heads, 1, 1)

        scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)

        scaled_attention = scaled_attention.view(self.num_heads, batch_size, -1, self.depth)
        scaled_attention = scaled_attention.permute(1, 2, 0, 3).contiguous().view(batch_size, -1, self.d_model)
        output = self.linear(scaled_attention) # ??

        return output, attention_weights

In [5]:
temp_mha = MultiHeadAttention(d_model=512, num_heads=8)
y = torch.rand((1, 60, 512))  # (batch_size, encoder_sequence, d_model)
out, attn = temp_mha(y, y, y, mask=None)

display(out.shape, attn.shape)
display(y)
out

torch.Size([1, 60, 512])

torch.Size([8, 60, 60])

tensor([[[0.4955, 0.2698, 0.0228,  ..., 0.3468, 0.9099, 0.3282],
         [0.3973, 0.7943, 0.4576,  ..., 0.6460, 0.6003, 0.4776],
         [0.8922, 0.0058, 0.7936,  ..., 0.4749, 0.1825, 0.7920],
         ...,
         [0.4431, 0.0522, 0.9016,  ..., 0.9424, 0.2426, 0.5663],
         [0.5014, 0.5712, 0.6320,  ..., 0.9312, 0.5130, 0.4896],
         [0.4478, 0.1749, 0.9620,  ..., 0.9004, 0.7717, 0.8591]]])

tensor([[[-0.0054, -0.1599, -0.0259,  ..., -0.1342,  0.3016,  0.1085],
         [-0.0043, -0.1593, -0.0252,  ..., -0.1349,  0.3035,  0.1090],
         [-0.0065, -0.1619, -0.0244,  ..., -0.1344,  0.3040,  0.1080],
         ...,
         [-0.0054, -0.1629, -0.0252,  ..., -0.1342,  0.3032,  0.1081],
         [-0.0041, -0.1610, -0.0249,  ..., -0.1353,  0.3033,  0.1094],
         [-0.0052, -0.1614, -0.0252,  ..., -0.1342,  0.3028,  0.1093]]],
       grad_fn=<AddBackward0>)

## Decode

In [16]:
class DecoderLayer(torch.nn.Module) :
    def __init__(self, num_classes, d_model=1024, num_heads=4, dropout_p=0.3):
        super(DecoderLayer, self).__init__()
        self.d_model = d_model

        self.embedding = torch.nn.Embedding(num_classes, d_model)
        self.input_dropout = torch.nn.Dropout(dropout_p)

        self.uniDirLSTM = torch.nn.LSTM(input_size=d_model, hidden_size=d_model, num_layers=1, bias=True, batch_first=True, dropout=dropout_p, bidirectional=False)

        self.mha = MultiHeadAttention(d_model, num_heads)
        
        self.layernorm1 = torch.nn.LayerNorm(d_model, eps=1e-6)
        self.layernorm2 = torch.nn.LayerNorm(d_model, eps=1e-6)

        self.linear1 = torch.nn.Linear(d_model, d_model, bias=True)
        self.linear2 = torch.nn.Linear(d_model, num_classes, bias=False)

    def forward(self, input_var, hidden, enc_output, training) :
        # enc_output.shape == (batch_size, input_seq_len, d_model)
        batch_size, output_lengths = input_var.size(0), input_var.size(1)

        embedded = self.embedding(input_var)
        embedded = self.input_dropout(embedded)

        if training :
            self.uniDirLSTM.flatten_parameters()

        out1, hidden = self.uniDirLSTM(embedded, hidden)
        
        context, attn_weights_block = self.mha(out1, enc_output, enc_output) # (batch_size, target_seq_len, d_model)
        out2 = self.layernorm1(context + out1).view(-1, self.d_model) # (batch_size, target_seq_len, d_model)

        out_proj = self.linear1(out2)
        output = self.layernorm2(out_proj + out2).view(batch_size, -1, self.d_model) # (batch_size, target_seq_len, d_model)

        output = self.linear2(torch.tanh(output).contiguous().view(-1, self.d_model))

        output = torch.nn.functional.log_softmax(output, dim=1)
        output = output.view(batch_size, output_lenghts, -1).squeeze(1)

        return output, hidden, attn_weights_block

In [20]:
class Decoder(torch.nn.Module) :
    def __init__(self, num_classes, max_length=150, d_model=1024, num_heads=4, num_layers=2, dropout_p=0.3):
        super(Decoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        # self.dec_layers = [DecoderLayer(d_model, num_heads, num_classes, dropout_p) for _ in range(num_layers)]
        self.dec_layer = DecoderLayer(d_model, num_heads, num_classes, dropout_p)

    def forward(self, inputs, enc_outputs, training) :
        assert enc_outputs is not None or inputs is not None
        hidden = None
        result = list()

        max_lengths = inputs.size(1) - 1 # minus the start of sequence symbol

        input_var = inputs[:, 0].unsqueeze(1)

        for di in range(max_lengths):
            step_output, hidden = self.dec_layer(input_var, hidden, enc_outputs, training)
            result.append(step_output)
            input_var = result[-1].topk(1)[1]

        return result