In [9]:
import os
import pandas as pd
from os.path import exists
import torch
import torch.nn as nn
import copy
import math

In [2]:
def CloneModule(module, n):
    '''Create n identical copies of a module'''
    
    return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])

In [3]:
# TODO: Change the param from layer size to something meaningful.
class LayerNorm(nn.Module):
    def __init__(self, layer_size, eps=1e-5):
        super(LayerNorm, self).__init__()
        
        # Add these to the parameter list of the layer norm module
        # and can be accessed using the 'parameters' iterator.
        # Parameters have require_grad set by default.
        self.gamma = nn.Parameter(torch.ones(layer_size))
        self.beta = nn.Parameter(torch.zeros(layer_size))

        self.eps = eps

    def forward(self, X):
        mean = X.mean(-1, keepdim=True)
        std = X.std(-1, keepdim=True)
        return self.gamma * (X - mean) / (std + self.eps) + self.beta

In [4]:
class ResidualandNormalSublayer(nn.Module):
    '''
    A residual connection followed by a layer norm.
    The output of each sub-layer is LayerNorm(x + Sublayer(x)) where Sublayer(x) is the function implemented by the sub-layer itself. 
    We apply dropout to the output of each sub-layer, before it is added to the sub-layer input and normalized.
    For code simplicity the norm is first as opposed to last.
    '''

    def __init__(self, size, dropout):
        super(ResidualandNormalSublayer, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, X, sublayer):
        '''Apply residual connection to any sublayer with the same size'''
        return X + self.dropout(sublayer(self.norm(X)))

In [5]:
class EncoderLayer(nn.Module):
    '''
    Each layer has two sub-layers. The first is a multi-head self-attention mechanism,
    and the second is a simple, position-wise fully connected feed-forward network.

    Encoder is made up of self-attn and feed forward.
    '''

    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = self.feed_forward
        # These are the add and norm layers after multihead attention/feed forward blocks.
        self.sublayer = CloneModule(ResidualandNormalSublayer(size, dropout), 2)
        self.size = size
    
    def forward(self, X, mask):
        X = self.sublayer[0](X, self.self_attn(X, X, X, mask))
        return self.sublayer[1](X, self.feed_forward)


In [6]:
class Encoder(nn.Module):
    '''Encoder which a stack of n EncoderLayers'''

    def __init__(self, layer, n):
        super(Encoder, self).__init__()
        self.layers = CloneModule(layer, n)
        self.norm = LayerNorm(layer.size)

    def forward(self, X, mask):
        '''Pass the input (and mask) through each layer in turn'''
        for layer in self.layers:
            X = layer(X, mask)
        return self.norm(X)

In [7]:
class DecoderLayer(nn.Module):
    '''
    The decoder inserts a third sub-layer, which performs multi-head attention over the output of the encoder stack. 
    Similar to the encoder, we employ residual connections around each of the sub-layers, followed by layer normalization.
    '''
    
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = CloneModule(ResidualandNormalSublayer(size, dropout),3)

    def forward(self, X, memory, src_mask, tgt_mask):
        m = memory
        X = self.sublayer[0](X, self.self_attn(X, X, X, tgt_mask))
        X = self.sublayer[1](X, self.src_attn(X, m, m, src_mask))
        return self.sublayer[2](X, self.feed_forward)

In [8]:
class Decoder(nn.Module):
    '''Decoder which a stack of n DecoderLayers'''

    def __init__(self, layer, n):
        super(Decoder, self).__init__()
        self.layers = CloneModule(layer, n)
        self.norm = LayerNorm(layer.size)
    
    def forward(self, x, memory, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)

In [10]:
def Attention(query, key, value, mask=None, dropout=None):
    ''' Scaled dot product attention '''
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = scores.softmax(dim= -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

In [None]:
class MultiHeadAttention(nn.Module):

    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linears = CloneModule(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p = dropout)
    
    def forward(self, query, key, value, mask = None):
        if mask is not None:
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)

        # Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) for lin, x in zip(self.linears, (query, key, value))]

        # Apply attention on all the projected vectors  in batch
        x, self.attn = Attention(query, key, value, mask = mask, dropout = self.dropout)

        # Concat using a view and apply a final linear
        x = (x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k))

        del query
        del key
        del value
        return self.linears[-1](x)