## Basic Transformer

In [2]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

In this exercise, we give a network an input sequence of characters (e.g., **aabbccdd**), the network is trained to produce the same sequence in reverse order (**ddccbbaa**).

### Data Preparation

In [3]:
# create a dataset of 500 examples using a small vocabulary
# we will try training on sequences of length 10 and testing on sequences of length 15
# this setup tests whether the model has actually learned an algorithm to reverse its input
vocab = {'a': 0, 'b': 1, 'c':2, 'd':3, 'e':4, '<sos>':5, '<eos>':6}
idx_to_w = dict((v, k) for (k,v) in vocab.items())
train_seq_len = 10
num_train_examples = 5

In [4]:
# generate toy data
train_inputs = torch.LongTensor(num_train_examples, train_seq_len).random_(0
, len(vocab)-2) # random sequences
inv_idx = torch.arange(train_seq_len-1, -1, -1).long()
train_outputs = train_inputs[:, inv_idx] # outputs are just the reverse of the input
sos_vec = torch.LongTensor(num_train_examples, 1)
sos_vec[:] = vocab['<sos>']
eos_vec = torch.LongTensor(num_train_examples, 1)
eos_vec[:] = vocab['<eos>']
train_encoder_input = torch.cat((train_inputs, eos_vec), 1)
train_decoder_input = torch.cat((sos_vec, train_outputs), 1)
train_targets = torch.cat((train_outputs, eos_vec), 1)

In [5]:
print('encoder input :', ' '.join([idx_to_w[w] for w in train_encoder_input[0].numpy()]))
print('decoder input:', ' '.join([idx_to_w[w] for w in train_decoder_input[0].numpy()]))
print('decoder target:', ' '.join([idx_to_w[w] for w in train_targets[0].numpy()]))

encoder input : e c a a e e b e a b <eos>
decoder input: <sos> b a e b e e a a c e
decoder target: b a e b e e a a c e <eos>


### Build Model

In [6]:
torch.set_printoptions(precision=2, sci_mode=False)

In [7]:
# class for our vanilla seq2seq
class Seq2Seq(nn.Module):
    def __init__(self, char_dim, hidden_size, vocab_size):
        super().__init__()
        
        self.char_dim = char_dim
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size

        # character embeddings
        self.char_embeds = nn.Embedding(vocab_size, char_dim)
        
        # position embeddings
        self.pos_embeds = nn.Embedding(15, char_dim) # add these to enc/dec
        
        # decoder attention
        self.query = nn.Linear(char_dim, hidden_size)
        self.key = nn.Linear(char_dim, hidden_size)
        self.value = nn.Linear(char_dim, hidden_size)
        
        # output layer (softmax will be applied after this)
        self.cls = nn.Linear(hidden_size, vocab_size)
    
    # a vectorized way of computing self attention for all queries efficiently
    def smart_unmasked_attn(self, qs, ks, vs):
        # here, queries are decoder states, keys and values are encoder representations
        scores = qs @ ks.t() # get all dot products at once, N x N
        scores = F.softmax(scores, dim=1)
        return scores @ vs # N x hidden_size  
    
    # a vectorized way of computing **target-side self-attention**
    # we need to implement some masking to avoid cheating!
    def smart_masked_attn(self, qs, ks, vs):
        max_len = qs.size(0)
        mask = torch.tril(torch.ones(max_len, max_len))
        scores = qs @ ks.t() # get all UNMASKED dot products at once, max_len X max_len
        scores = scores.masked_fill(mask == 0, -1e9)
        scores = F.softmax(scores, dim=1)
        return scores @ vs

    
    def forward(self, inputs, decoder_inputs):
        
        batch_size, max_len = inputs.size()

        positions = torch.arange(0, inputs.size(1))
        pos_embeds = self.pos_embeds(positions)
        
        # we'll just consider this the output of our encoder
        # of course in a real transformer this would be computed
        # through multiple self attention blocks
        e_embeds = self.char_embeds(inputs).squeeze(0)
        e_embeds = e_embeds + pos_embeds
        e_keys = self.key(e_embeds)
        e_values = self.value(e_embeds)
        
        # we'll use the same weights to project decoder embeddings to q,k,v
        d_embeds = self.char_embeds(decoder_inputs).squeeze(0)
        d_embeds = d_embeds + pos_embeds
        d_queries = self.query(d_embeds)
        d_keys = self.key(d_embeds)
        d_values = self.value(d_embeds)

        # compute target side self attention
        fast_decoder_states = self.smart_masked_attn(d_queries, d_keys, d_values)

        # source attention, queries come from decoder, keys/values from encoder
        source_attn = self.smart_unmasked_attn(fast_decoder_states, e_keys, e_values)
        
        # combine decoder self attention w/ source attention
        source_attn = source_attn + fast_decoder_states

        # now do prediction over decoder states (reshape to 2d first)
        source_attn = source_attn.transpose(0, 1).contiguous().view(-1, self.hidden_size)
        decoder_preds = self.cls(source_attn)
        decoder_preds = F.log_softmax(decoder_preds, dim=1)

        return decoder_preds

### Train the model

In [8]:
def training_loop(net):

    # set some hyperparameters for training the network
    idx_to_w = dict((v,k) for (k,v) in vocab.items())
    loss_fn = nn.NLLLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
    num_epochs = 10
    
    # okay, let's train the network!
    for ep in range(num_epochs):
        ep_loss = 0.

        for start in range(0, len(train_inputs)):
            e_in_batch = train_encoder_input[start].unsqueeze(0)
            d_in_batch = train_decoder_input[start].unsqueeze(0)
            d_targ_batch = train_targets[start].unsqueeze(0)
            
            preds = net(e_in_batch, d_in_batch)
            batch_loss = loss_fn(preds, d_targ_batch.view(-1))
            ep_loss += batch_loss

            # compute gradients
            optimizer.zero_grad() # reset the gradients from the last batch
            batch_loss.backward() # does backprop!!!
            optimizer.step() # updates parameters using gradients

        print('epoch %d, loss %f\n' % (ep, ep_loss))

In [9]:
# build the network
net = Seq2Seq(32, 64, len(vocab))
training_loop(net)

epoch 0, loss 9.770307

epoch 1, loss 5.881337

epoch 2, loss 4.329915

epoch 3, loss 3.894290

epoch 4, loss 2.730859

epoch 5, loss 1.853093

epoch 6, loss 1.146887

epoch 7, loss 0.653309

epoch 8, loss 0.368380

epoch 9, loss 0.222095

