### Define the model
In this tutorial, we train `nn.TransformerEncoder` model on a language modeling task. The language modeling task is to assign a probability for the likelihood of a given word (or a sequence of words) to follow a sequence of words. 

A sequence of tokens are passed to the embedding layer first, followed by a positional encoding layer to account for the order of the word (see the next paragraph for more details). 

The nn.TransformerEncoder consists of multiple layers of `nn.TransformerEncoderLayer`. Along with the input sequence, a square attention mask is required because the self-attention layers in `nn.TransformerEncoder` are only allowed to attend the earlier positions in the sequence. For the language modeling task, any tokens on the future positions should be masked. 

To have the actual words, the output of `nn.TransformerEncoder` model is sent to the final Linear layer, which is followed by a log-Softmax function.

In [1]:
import torch 
print(torch.__version__)

1.3.0.post2


In [2]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class TransformerModel(nn.Module):

    def __init__(self, vocab_size, embed_dim, n_heads, hidden_dim, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        
        # Set up fields
        self.embed_dim = embed_dim
        self.model_type = 'Transformer'
        
        # Init mask as null
        self.src_mask = None
        
        # Init positional encoder
        self.pos_encoder = PositionalEncoding(embed_dim, dropout)
        
        # Init embedding encoder
        self.encoder = nn.Embedding(vocab_size, embed_dim)
        
        # Init encoder
        encoder_layers = TransformerEncoderLayer(embed_dim, n_heads, hidden_dim, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        
        # Init decoder
        self.decoder = nn.Linear(embed_dim, vocab_size)
        
        # Init weights for encoder + decoder
        self.init_weights()

    def _generate_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src):
        # Init mask (if null)
        if self.src_mask is None or self.src_mask.size(0) != len(src):
            self.src_mask = self._generate_mask(len(src)).to(src.device)
        
        # Embed sentence
        src = self.encoder(src) * math.sqrt(self.embed_dim)
        src = self.pos_encoder(src)
        
        # Run through encoder
        output = self.transformer_encoder(src, self.src_mask)
        
        # Run through decoder
        output = self.decoder(output)
        return output

### Positional Encoding
`PositionalEncoding` module injects some information about the relative or absolute position of the tokens in the sequence. The positional encodings have the same dimension as the embeddings so that the two can be summed. Here, we use sine and cosine functions of different frequencies.

In [3]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Init zero vector
        pe = torch.zeros(max_len, d_model)
        
        # Init position vector as [0, 1, 2, ..., max_len]
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # Calculate positional embedding
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        
        # Register as buffer (same as self.pe = pe)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

### Load and batch data
The training process uses Wikitext-2 dataset from `torchtext`. The vocab object is built based on the train dataset and is used to numericalize tokens into tensors. 

Starting from sequential data, the `batchify()` function arranges the dataset into columns, trimming off any tokens remaining after the data has been divided into batches of size `batch_size`. 

For instance, with the alphabet as the sequence (total length of 26) and a batch size of 4, we would divide the alphabet into 4 sequences of length 6.

These columns are treated as independent by the model, which means that the dependence of G and F can not be learned, but allows more efficient batch processing.

In [4]:
import torchtext
from torchtext.data.utils import get_tokenizer
TEXT = torchtext.data.Field(tokenize=get_tokenizer("basic_english"),
                            init_token='<sos>',
                            eos_token='<eos>',
                            lower=True)
train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(TEXT)
TEXT.build_vocab(train_txt)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def batchify(data, bsz):
    data = TEXT.numericalize([data.examples[0].text])
    # Divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)

batch_size = 20
eval_batch_size = 10
train_data = batchify(train_txt, batch_size)
val_data = batchify(val_txt, eval_batch_size)
test_data = batchify(test_txt, eval_batch_size)

### Functions to generate input and target sequence
The `get_batch()` function generates the input and target sequence for the transformer model. It subdivides the source data into chunks of length `bptt`. For the language modeling task, the model needs the following words as Target. For example, with a `bptt` value of 2, we’d get the following two Variables for i = 0.

It should be noted that the chunks are along dimension 0, consistent with the S dimension in the Transformer model. The batch dimension N is along dimension 1.

In [5]:
bptt = 35
def get_batch(source, i):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].view(-1)
    return data, target

### Initiate an instance
The model is set up with the hyperparameter below. The vocab size is equal to the length of the vocab object.

In [6]:
ntokens = len(TEXT.vocab.stoi) # the size of vocabulary
emsize = 200 # embedding dimension
nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2 # the number of heads in the multiheadattention models
dropout = 0.2 # the dropout value

# Model
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)

In [9]:
criterion = nn.CrossEntropyLoss()
lr = 5.0 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

In [10]:
# Test the forward pass
src, tar = get_batch(train_data, 1)
print(src.shape, tar.shape)

out = model.forward(src)
out.shape

torch.Size([35, 20]) torch.Size([700])


torch.Size([35, 20, 28785])

In [11]:
# Test the backprop
loss = criterion(out.view(-1, ntokens), tar)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()

### Run the model
`CrossEntropyLoss` is applied to track the loss and SGD implements stochastic gradient descent method as the optimizer. The initial learning rate is set to 5.0. StepLR is applied to adjust the learn rate through epochs. During the training, we use nn.utils.clip_grad_norm_ function to scale all the gradient together to prevent exploding.

In [15]:
def log_progress():
    cur_loss = total_loss / log_interval
    elapsed = time.time() - start_time
    
    print('| epoch {:3d} | {:5d}/{:5d} batches | '
            'lr {:02.2f} | ms/batch {:5.2f} | '
            'loss {:5.2f} | ppl {:8.2f}'
            .format(
                epoch, batch, 
                len(train_data) // bptt, scheduler.get_lr()[0],
                elapsed * 1000 / log_interval,
                cur_loss, 
                math.exp(cur_loss)
            )
         )

In [18]:
criterion = nn.CrossEntropyLoss()
lr = 5.0 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)
log_interval = 50

import time
def train_one_epoch():
    '''
    Trains the model for one epoch.
    '''
    
    # Turn on the train mode
    model.train() 
    
    # Reset loss
    total_loss = 0.
    
    # Init time
    start_time = time.time()
    
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        # Get batch data
        data, targets = get_batch(train_data, i)
        
        # Reset gradient
        optimizer.zero_grad()
        
        # Forward pass
        output = model(data)
        
        # Backprop
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        
        # Keep track of loss. Only used for logging, not training itself
        total_loss += loss.item()

        if batch % log_interval == 0 and batch > 0:
            log_progress()
            total_loss = 0
            start_time = time.time()

In [25]:
def evaluate(eval_model, val_set):
    '''
    Evaluates a model on the validation set.
    '''
    # Turn on the evaluation mode
    eval_model.eval() 
    total_loss = 0.
    
    # Stop calculatin gradients
    with torch.no_grad():
        for i in range(0, val_set.size(0) - 1, bptt):
            # Get current batch data
            data, targets = get_batch(val_set, i)
            
            # Forward pass
            output = eval_model(data)
            output_flat = output.view(-1, ntokens)
            
            # Calculate cumulative loss  
            total_loss += len(data) * criterion(output_flat, targets).item()
    
    return total_loss / (len(val_set) - 1)

In [22]:
def log_epoch_progress():
    elapsed_time = time.time() - epoch_start_time
    print('-' * 100)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
      'valid ppl {:8.2f}'.format(epoch, 
                                 elapsed_time,
                                 val_loss, 
                                 math.exp(val_loss)))
    print('-' * 100)

def train_n_epochs(n_epochs=3):
    '''
    Trains for a specified number of epochs.
    After each epoch the model is validated on the validation set.
    '''
    best_model = None
    print("Training for %i epochs..." % n_epochs)
    for epoch in range(1, n_epochs + 1):
        epoch_start_time = time.time()
        train_one_epoch()
        val_loss = evaluate(model, val_data)
        log_epoch_progress()

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = model

        scheduler.step()
    
    return best_model

In [23]:
best_model = train_n_epochs(3)

Training for 3 epochs...


KeyboardInterrupt: 

In [None]:
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
    test_loss, math.exp(test_loss)))
print('=' * 89)