In this tutorial, we train a `nn.TransformerEncoder` model on a causal language modeling task.
Please note that this tutorial does not cover the training of `nn.TransformerDecoder`, as depicted in the right half of the diagram above.

The language modeling task is to assign a probability for the likelihood of a given word (or a sequence of words) to follow a sequence of words.
A sequence of tokens are passed to the embedding layer first, followed by a positional encoding layer to account for the order of the word (see the next paragraph for more details).
The `nn.TransformerEncoder` consists of multiple layers of `nn.TransformerEncoderLayer`.

Along with the input sequence, a square attention mask is required because the self-attention layers in `nn.TransformerDecoder` are only allowed to attend the earlier positions in the sequence.
For the language modeling task, any tokens on the future positions should be masked.
This masking, combined with fact that the output embeddings are offset with later positions ensures that the predictions for position i can depend only on the known outputs at positions less than i.

To produce a probability distribution over output words, the output of the `nn.TransformerEncoder` model is passed through a linear layer to output unnormalized logits.
The log-softmax function isn’t applied here due to the later use of `CrossEntropyLoss`, which requires the inputs to be unnormalized logits.

In [1]:
import math
import os
from tempfile import TemporaryDirectory
from typing import Tuple

import torch
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

In [2]:
class TransformerModel(nn.Module):

    def __init__(self,
                 ntoken: int,
                 d_model: int,
                 nhead: int, d_hid: int,
                 nlayers: int,
                 dropout: float = 0.5
                ) -> None:
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.embedding = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.linear = nn.Linear(d_model, ntoken)
        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
        """
        Arguments:
            src: Tensor, shape ``[seq_len, batch_size]``
            src_mask: Tensor, shape ``[seq_len, seq_len]``

        Returns:
            output Tensor of shape ``[seq_len, batch_size, ntoken]``
        """
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        if src_mask is None:
            """Generate a square causal mask for the sequence. The masked positions are filled with float('-inf').
            Unmasked positions are filled with float(0.0).
            """
            src_mask = nn.Transformer.generate_square_subsequent_mask(len(src)).to(device)
        output = self.transformer_encoder(src, src_mask)
        output = self.linear(output)
        return output

The vocab object is built based on the train dataset and is used to numericalize tokens into tensors. Wikitext-2 represents rare tokens as <unk>.

Given a 1-D vector of sequential data, batchify() arranges the data into batch_size columns. If the data does not divide evenly into batch_size columns, then the data is trimmed to fit. For instance, with the alphabet as the data (total length of 26) and batch_size=4, we would divide the alphabet into sequences of length 6, resulting in 4 of such sequences.

Batching enables more parallelizable processing. However, batching means that the model treats each column independently; for example, the dependence of G and F can not be learned in the example above.

In [3]:
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

In [4]:
train_iter = WikiText2(split='train')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

HTTPError: 404 Client Error: Not Found for url: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip
This exception is thrown by __iter__ of HTTPReaderIterDataPipe(skip_on_error=False, source_datapipe=OnDiskCacheHolderIterDataPipe, timeout=None)

In [None]:
def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:
    """Converts raw text into a flat Tensor."""
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

In [None]:
# ``train_iter`` was "consumed" by the process of building the vocab,
# so we have to create it again
train_iter, val_iter, test_iter = WikiText2('/home/tesla/dataset/wikitext-2.zip')
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def batchify(data: Tensor, bsz: int) -> Tensor:
    """Divides the data into ``bsz`` separate sequences, removing extra elements
    that wouldn't cleanly fit.

    Arguments:
        data: Tensor, shape ``[N]``
        bsz: int, batch size

    Returns:
        Tensor of shape ``[N // bsz, bsz]``
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)

In [None]:
batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size)  # shape ``[seq_len, batch_size]``
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)