In [10]:
import math
from shutil import make_archive

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [11]:
class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        self.model_type = 'Transformer'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding()

    def forward(self, x):
        pass

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.pe = torch.zeros(max_len, d_model)

    def forward(self, x):
        x = x + self.pe

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, d_k, d_v, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_k
        self.d_v = d_v
        self.dropout = nn.Dropout(p=dropout)

        self.Q_matrix = nn.Linear(d_model, d_k * num_heads)
        self.K_matrix = nn.Linear(d_model, d_k * num_heads)
        self.V_matrix = nn.Linear(d_model, d_v * num_heads)
        self.out_matrix = nn.Linear(d_v * num_heads, d_model)

        nn.init.xavier_uniform_(self.Q.weight)
        nn.init.xavier_uniform_(self.K.weight)
        nn.init.xavier_uniform_(self.V.weight)
        nn.init.xavier_uniform_(self.out.weight)

    def forward(self, q, k, v, mask=None):
        n = q.size(0)
        q_len, k_len = q.size(1), k.size(1)
        d_k, d_v = self.d_k, self.d_v
        num_heads = self.num_heads
        q = self.Q_matrix(q).view(n, -1, num_heads, d_k).transpose(1, 2)
        k = self.K_matrix(k).view(n, -1, num_heads, d_k).transpose(1, 2)
        v = self.V_matrix(v).view(n, -1, num_heads, d_v).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-1, -2)) / np.sqrt(d_k)
        if mask is not None:
            assert mask.size() == (n, q_len, k_len)
            mask = mask.unsqueeze(1).repeat(1, num_heads, 1, 1)
            mask = mask.bool()
            scores.masked_fill_(mask, -float('inf'))
        attentions = F.softmax(scores, dim=-1)
        attentions = self.dropout(attentions)
        output = torch.matmul(attentions, v)
        output = output.transpose(1, 2).contiguous().reshape(n, -1, d_v * num_heads)
        output = self.out_matrix(output)
        return output