# Language modeling using a transformer encoder

We will build a model which relies on the transformer encoder for the task of language modeling. 

In [2]:
!pip install torchdata
import math
import copy
import time
from typing import Tuple

import torch
from torch import nn, Tensor
import torch.nn.functional as F 
from torch.nn import TransformerEncoder, TransformerEncoderLayer 
from torch.utils.data import dataset

import torchtext
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchdata
  Downloading torchdata-0.3.0-py3-none-any.whl (47 kB)
[K     |████████████████████████████████| 47 kB 2.9 MB/s 
Collecting urllib3>=1.25
  Downloading urllib3-1.26.9-py2.py3-none-any.whl (138 kB)
[K     |████████████████████████████████| 138 kB 8.8 MB/s 
  Downloading urllib3-1.25.11-py2.py3-none-any.whl (127 kB)
[K     |████████████████████████████████| 127 kB 58.4 MB/s 
[?25hInstalling collected packages: urllib3, torchdata
  Attempting uninstall: urllib3
    Found existing installation: urllib3 1.24.3
    Uninstalling urllib3-1.24.3:
      Successfully uninstalled urllib3-1.24.3
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.[0m
Successful

# Transformer (encoder) model
Structure of the `TransformerModel`:

* A source equence of tokens are passed to the embedding layer first, 
* followed by a positional encoding layer to account for the order of the words,
* both the positionally-encoded source sequence and a corresponding a square source mask is passed to the transformer encoder.

Along with the source sequence, a square attention mask is required because the self-attention layers in `nn.TransformerEncoder` are only allowed to attend the earlier positions in the sequence. For the language modeling task, any tokens on the future positions should be masked.

In [3]:
class PositionalEncoding(nn.Module):
    """
    Computes positional encoding.

    Args:
        d_model: embedding dimension. The positional encodings have the same 
            dimension as the embeddings so that the two can be summed
    """

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)


class TransformerModel(nn.Module):
    """
    Transformer encoder model for language modeling. The language modeling task 
    is to assign a probability for the likelihood of a given word 
    (or a sequence of words) to follow a sequence of words.

    Args:
        ntoken: vocabulary size (number of tokens).
        d_model: embedding dimension.
        nhead:  the number of heads in the multiheadattention models.
        d_hid: dimension of the feedforward network model in the encoder 
            layers (nn.TransformerEncoderLayer).
        nlayers: number of encoder layers in the encoder (nn.TransformerEncoder).
    """
    def __init__(self, 
                 ntoken: int, 
                 d_model: int, 
                 nhead: int, 
                 d_hid: int,
                 nlayers: int, 
                 dropout: float = 0.1
    ) -> None:

        super().__init__()
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.decoder = nn.Linear(d_model, ntoken)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:
        """
        Args:
            src: Tensor, shape [seq_len, batch_size]
            src_mask: Tensor, shape [seq_len, seq_len]

        Returns:
            output Tensor of shape [seq_len, batch_size, ntoken]
        """
        src = self.encoder(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask) # size = [seq_len, batch_size, d_model]
        output = self.decoder(output) # size = [seq_len, batch_size, ntoken]
        return output


def generate_square_subsequent_mask(sz: int) -> Tensor:
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

# Data processing

#### Build vocabulary using `torchtext.vocab.Vocab` object

In [4]:
# build vocab
train_iter = WikiText2(split='train')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:
    """Converts raw text into a flat Tensor."""
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

def batchify(data: Tensor, bsz: int) -> Tensor:
    """Divides the data into bsz separate sequences, removing extra elements
    that wouldn't cleanly fit.

    Args:
        data: Tensor, shape [N]
        bsz: int, batch size

    Returns:
        Tensor of shape [N // bsz, bsz]
    """
    seq_len = data.size(0) // bsz
    # cut off those elemements which don't fit in a batch
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)

# get train, validation, and test data
train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)

# build dataloaders
batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size)  # shape [seq_len, batch_size]
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)

# Training and evaluating functions

In [54]:
# model 
ntokens = len(vocab)  # size of vocabulary
emsize = 200  # embedding dimension
d_hid = 200  # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2  # number of heads in nn.MultiheadAttention
dropout = 0.2  # dropout probability
model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)

criterion = nn.CrossEntropyLoss()
lr = 5.0  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

In [55]:
BPTT = 35 # max "backprop through time", i.e. sequence length

def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:
    """
    Gets batch with max sequence length BPTT.

    Args:
        source: Tensor, shape [full_seq_len, batch_size]
        i: index of batch to return

    Returns:
        tuple (data, target), where data has shape [seq_len, batch_size] and
        target has shape [seq_len * batch_size]
    """
    seq_len = min(BPTT, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

def train(model: nn.Module) -> None:
    model.train()  # turn on train mode
    total_loss = 0.
    log_interval = 200
    start_time = time.time()
    src_mask = generate_square_subsequent_mask(BPTT).to(device)

    num_batches = len(train_data) // BPTT
    for batch, i in enumerate(range(0, train_data.size(0) - 1, BPTT)):
        data, targets = get_batch(train_data, i)
        batch_size = data.size(0)
        if batch_size != BPTT:  # only on last batch
            src_mask = src_mask[:batch_size, :batch_size]
        output = model(data, src_mask)
        loss = criterion(output.view(-1, ntokens), targets)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            ppl = math.exp(cur_loss)
            print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '
                  f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '
                  f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')
            total_loss = 0
            start_time = time.time()

def evaluate(model: TransformerModel, eval_data: Tensor) -> float:
    """
    Evaluates given transformer encoder model on eval_data.
    """
    model.eval()  # turn on evaluation mode
    total_loss = 0.
    src_mask = generate_square_subsequent_mask(BPTT).to(device)
    with torch.no_grad():
        for i in range(0, eval_data.size(0) - 1, BPTT):
            data, targets = get_batch(eval_data, i)
            batch_size = data.size(0)
            if batch_size != BPTT:
                src_mask = src_mask[:batch_size, :batch_size]
            output = model(data, src_mask)
            output_flat = output.view(-1, ntokens)
            total_loss += batch_size * criterion(output_flat, targets).item()
    return total_loss / (len(eval_data) - 1)

#### Example of outputs of some of the functions above:

In [15]:
sample_data = ['today is sunny', 
               'my car works', 
               'trees are green', 
               'book worms fall', 
               'moon chair show',
               'plane boat bone']

# map sample_data to flat tensor of indeces
idx_data = data_process(sample_data)
idx_data

tensor([  802,    23, 13787,   447,  1380,   499,  1101,    34,  1213,   244,
        11441,  1343,  2374,  4923,   247,  3502,  1681,  3242])

In [16]:
# Map indexed data to 4 batches. Note that batch dim is the second one.
batched_data = batchify(idx_data, bsz = 4)
batched_data

tensor([[  802,  1380,  1213,  2374],
        [   23,   499,   244,  4923],
        [13787,  1101, 11441,   247],
        [  447,    34,  1343,  3502]], device='cuda:0')

In [20]:
# Get the first batch
data, target = get_batch(batched_data, 0)
print('data', data)
print('target', target)

data tensor([[  802,  1380,  1213,  2374],
        [   23,   499,   244,  4923],
        [13787,  1101, 11441,   247]], device='cuda:0')
target tensor([   23,   499,   244,  4923, 13787,  1101, 11441,   247,   447,    34,
         1343,  3502], device='cuda:0')


In [21]:
# Easier to read target when not flattened. Note target is obtained by shifting
# by one element the sequence in batched_data with respect to data, as well as
# looking one element into the future.
target.view(-1, 4)

tensor([[   23,   499,   244,  4923],
        [13787,  1101, 11441,   247],
        [  447,    34,  1343,  3502]], device='cuda:0')

# Training loop

In [56]:
# training loop
best_val_loss = float('inf')
epochs = 3
# to store best model
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(model)
    val_loss = evaluate(model, val_data)
    val_ppl = math.exp(val_loss)
    elapsed = time.time() - epoch_start_time
    print('-' * 89)
    print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
          f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)

    scheduler.step()

| epoch   1 |   200/ 2928 batches | lr 5.00 | ms/batch 15.79 | loss  8.07 | ppl  3211.27
| epoch   1 |   400/ 2928 batches | lr 5.00 | ms/batch 15.44 | loss  6.85 | ppl   942.32
| epoch   1 |   600/ 2928 batches | lr 5.00 | ms/batch 15.46 | loss  6.44 | ppl   624.99
| epoch   1 |   800/ 2928 batches | lr 5.00 | ms/batch 15.53 | loss  6.30 | ppl   542.15
| epoch   1 |  1000/ 2928 batches | lr 5.00 | ms/batch 15.58 | loss  6.19 | ppl   490.21
| epoch   1 |  1200/ 2928 batches | lr 5.00 | ms/batch 15.62 | loss  6.16 | ppl   473.61
| epoch   1 |  1400/ 2928 batches | lr 5.00 | ms/batch 15.64 | loss  6.12 | ppl   453.82
| epoch   1 |  1600/ 2928 batches | lr 5.00 | ms/batch 15.69 | loss  6.11 | ppl   451.28
| epoch   1 |  1800/ 2928 batches | lr 5.00 | ms/batch 15.70 | loss  6.03 | ppl   414.16
| epoch   1 |  2000/ 2928 batches | lr 5.00 | ms/batch 15.73 | loss  6.02 | ppl   409.76
| epoch   1 |  2200/ 2928 batches | lr 5.00 | ms/batch 15.69 | loss  5.90 | ppl   363.49
| epoch   1 |  2400/ 

### Eval on hold-out test set

In [57]:
test_loss = evaluate(best_model, test_data)
test_ppl = math.exp(test_loss)
print(f'| After training:  test loss {test_loss:5.2f} | '
      f'test ppl {test_ppl:8.2f}')

| After training:  test loss  5.51 | test ppl   246.75


# Inference

In [30]:
def infer(text: str, 
          model: TransformerModel, 
          vocab_object: torchtext.vocab.Vocab,
          num_pred: int = 1,):
    """
    Method to predict num = i words following an input sequence.

    Args:
        text: string to use as input sequence.
        model: trained transformer model.
        num_pred: number of tokens to predict.
    """
    # extend by one token so we don't lose the original last token when 
    # calling get_batch
    dummy_iter = [text + ' <unk>']
    idx_data = data_process(dummy_iter)
    # put whole thing into batch
    batched_data = batchify(idx_data, 1)

    model.eval()
    with torch.no_grad():
        data, _ = get_batch(batched_data, 0)
        while num_pred > 0:
            src_mask = generate_square_subsequent_mask(data.size(0)).to(device)
            output = model(data, src_mask) # size = [seq_len, batch size, vocab size]
            # use last prediction in the sequence
            output = output[-1:]
            # choose highest logit
            pred_idx = output.argmax(axis=-1) # size = [1, batch size]
            # append prediction to input sequence
            new_seq = torch.concat((data, pred_idx) , dim=0)
            data = new_seq
            num_pred -= 1

            new_seq = data

        new_text = ' '.join(vocab_object.lookup_tokens(new_seq.squeeze().tolist()))
        return new_text

Some inference examples using random input phrases. We will predict the following three words.

In [46]:
s = "the construction of ships offered ice"
infer(s, best_model, vocab, 3)

'the construction of ships offered ice age of the'

In [47]:
s = "our fathers have eaten"
infer(s, best_model, vocab, 3)

'our fathers have eaten in the first'

In [48]:
s = "the engine inside"
infer(s, best_model, vocab, 3)

'the engine inside the first quarter'

In [49]:
s = "Those tall trees gave us"
infer(s, best_model, vocab, 3)

'those tall trees gave us $ 1 million'

In [50]:
s = "The men were highly"
infer(s, best_model, vocab, 3)

'the men were highly successful in the'

In [53]:
s = "We ate"
infer(s, best_model, vocab, 3)

'we ate . . .'

https://towardsdatascience.com/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1