In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np

In [2]:
class ScaledDotProdAttention(nn.Module):
    def __init__(self, dropout = 0.1):
        super(ScaledDotProdAttention, self).__init__()

        self.dropout = dropout

    def forward(self, query, key, value, mask = None):

        attention_scores = torch.matmul(query, key.transpose(-2, -1)) / np.sqrt(query.size(-1))

        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask == 0, -1e10)

        attention = F.softmax(attention_scores, dim = -1)
        attention = self.dropout(attention)

        return torch.matmul(attention, value), attention

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, nhead, dropout = 0.1):
        super(MultiHeadAttention, self).__init__()

        self.d_model = d_model
        self.nhead = nhead
        self.d_k = d_model // nhead
        self.d_v = d_model // nhead

        self.linear_q = nn.Linear(d_model, d_model)
        self.linear_k = nn.Linear(d_model, d_model)
        self.linear_v = nn.Linear(d_model, d_model)

        self.scaled_dot_prod_attention = ScaledDotProdAttention(dropout)

        self.linear_layer = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)


    def forward(self, query, key, value, mask = None, key_padding = None):

        batch_size =  query.size(0)

        query = self.linear_q(query).view(batch_size, -1, self.nhead, self.d_k).transpose(1,2)
        key = self.linear_k(key).view(batch_size, -1, self.nhead, self.d_k).transpose(1,2)
        value = self.linear_v(value).view(batch_size, -1, self.nhead, self.d_v).transpose(1,2)

        output, attention_scores = self.scaled_dot_prod_attention(query, key, value)

        output_concat = output.transpose(1,2).contiguous().view(batch_size, -1, self.d_model)

        output_concat = self.linear_layer(output_concat)

        return self.dropout(output_concat)


In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout = 0.1, max_len = 100):
        super(PositionalEncoding, self).__init__()

        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype = torch.float ).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0,1)
        self.register_buffer('pe', pe) # to not change these values while training

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)
        