In [1]:
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn
seaborn.set_context(context="talk")
%matplotlib inline

In [10]:
class PositionalEncoding(nn.Module):
    """This implementation is the same as in the Annotated transformer blog post
        See https://nlp.seas.harvard.edu/2018/04/03/attention.html for more detail.
    """
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        assert (d_model % 2) == 0, 'd_model should be an even number.'
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [11]:
class EncoderBlock(nn.Module):
    def __init__(self, n_features, n_heads, n_hidden = 64, dropout=0.1):
        super(EncoderBlock, self).__init__()
        self.attn = nn.MultiheadAttention(n_features, n_heads)
        self.feed_forward = nn.Sequential(
            nn.Linear(n_features, n_hidden),
            nn.Dropout(dropout),
            nn.ReLU(),
            nn.Linear(n_hidden, n_features)
        )
        self.norm1 = nn.LayerNorm(n_features)
        self.dropout1 = nn.Dropout(dropout)

        self.norm2 = nn.LayerNorm(n_features)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x, mask):
        attn_output, _ = self.attn(x, x, x, attn_mask=mask.T)
        x = x + attn_output
        x = self.norm1(x)
        x = self.dropout1(x)
        feed_forward_output = self.feed_forward(x)
        x = x + feed_forward_output
        x = self.norm2(x)
        x = self.dropout2(x)
        return x

In [None]:

def clones(module, N):
    "Produces N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

class Encoder(nn.Module):
    def __init__(self, src_vocab_size, n_blocks, n_features, n_heads, n_hidden=64, dropout=0.1, max_length = 5000):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(src_vocab_size, n_features)
        self.position = PositionalEncoding(n_features, dropout, max_length)
        self.blocks = nn.ModuleList([EncoderBlock(n_features, n_heads, n_hidden, dropout) for _ in range(n_blocks)])
        self.norm = nn.LayerNorm(n_features)
        
    def forward(self, x, mask):
        x = self.embedding(x)
        x = self.position.forward(x)
        for block in self.blocks:
            x = block(x, mask)
        return self.norm(x)

In [None]:
#Classifier on top of the encoder
class Classifier(nn.Module):
    '''
    A classifier which builds on top of the transformer encoder.
    A simple feedforward network with dropout and ReLU activation is used
    with the number of layers being a hyperparameter.
    '''
    def __init__(self, encoder, embed_size, num_classes, hidden_size, num_layers, dropout=0.2):
        super(Classifier, self).__init__()
        self.encoder = encoder
        self.fc = nn.Linear(embed_size, hidden_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        #We can choose the number of layers
        self.layers = clones(nn.Linear(hidden_size, hidden_size), num_layers-1)
        self.out = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        #Pass through encoder
        x = self.encoder(x)

        #Pass through the classifier
        x = self.fc(x)
        x = self.relu(x)
        x = self.dropout(x)
        for layer in self.layers:
            x = layer(x)
            x = self.relu(x)
            x = self.dropout(x)
        x = self.out(x)
        return x
    