For me to better understand the Transformer architecture and how to train them, I designed this simple sequence2sequence example which takes in a string and sorts in alphabetical order. 

It learns to do the job well within 4 epochs. Here is the output on unseen data after 4 epochs:

Epoch: 4, Train loss: 0.10, Validation loss: 0.05

Input string: sortingtest - Sorted string: eginorssttt

Input string: abcdef - Sorted string: abcdef

Input string: fedcba - Sorted string: abcdef

Input string: prithvi - Sorted string: hiiprtv

In [92]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import string
import itertools
from torch.utils.data import DataLoader
import math

# Constants
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_LENGTH = 20
BATCH_SIZE = 64
TRAIN_SAMPLES = 50000
VALIDATION_SAMPLES = 500
EPOCHS = 5
LEARNING_RATE = 0.001

NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3
HIDDEN_SIZE = 256
NHEAD = 4
DROPOUT = 0.1

# Data Preparation
def generate_sample():
    sequence_length = random.randint(1, MAX_LENGTH)
    sequence = "".join(random.choices(string.ascii_lowercase, k=sequence_length))
    sorted_sequence = "".join(sorted(sequence))
    return sequence, sorted_sequence

def pad_sequence(seq, max_length):
    return seq + [0] * (max_length - len(seq))

def tensor_from_string_padded(s, char_to_idx, max_length):
    idxs = [char_to_idx[c] for c in s]
    padded_idxs = pad_sequence(idxs, max_length)
    return torch.tensor(padded_idxs, dtype=torch.long, device=DEVICE)

def tensors_from_pair_padded(input_str, output_str, char_to_idx, max_length):
    input_tensor = tensor_from_string_padded(input_str, char_to_idx, max_length)
    output_tensor = tensor_from_string_padded(output_str, char_to_idx, max_length)
    return input_tensor, output_tensor

def generate_data_padded(num_samples, char_to_idx, max_length):
    data = [generate_sample() for _ in range(num_samples)]
    return [tensors_from_pair_padded(input_str, output_str, char_to_idx, max_length) for input_str, output_str in data]

def custom_collate_fn(batch):
    inputs, targets = zip(*batch)
    inputs = torch.stack(inputs, dim=1)
    targets = torch.stack(targets, dim=1)
    return inputs, targets

chars = string.ascii_lowercase
char_to_idx = {c: i + 1 for i, c in enumerate(chars)}
idx_to_char = {i + 1: c for i, c in enumerate(chars)}

train_data = generate_data_padded(TRAIN_SAMPLES, char_to_idx, MAX_LENGTH)
validation_data = generate_data_padded(VALIDATION_SAMPLES, char_to_idx, MAX_LENGTH)

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=custom_collate_fn)
validation_loader = DataLoader(validation_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=custom_collate_fn)

In [93]:
# Model
class TransformerModel(nn.Module):
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        self.src_mask = None
        self.ninp = ninp
        self.encoder = nn.Embedding(ntoken, ninp)
        self.pos_encoder = PositionalEncoding(ninp)
        encoder_layers = nn.TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, nlayers)
        self.decoder = nn.Linear(ninp, ntoken)
        self.init_weights()

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, apply_softmax=False):
        src_mask = (src != 0).permute(1,0)
        src_mask = src_mask.to(torch.float32)
        src_mask = src_mask.masked_fill(src_mask == 0, float('-inf'))
        src_mask = src_mask.masked_fill(src_mask == 1, 0)

        src = self.encoder(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_key_padding_mask = src_mask)
        output = self.decoder(output)

        if apply_softmax:
          output = torch.softmax(output, dim=-1)
        return output

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=MAX_LENGTH):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]


# Training
def train_step(model, train_loader, optimizer, criterion, epoch):
    model.train()
    total_loss = 0.
    for i, (input, target) in enumerate(train_loader):
        input, target = input.to(DEVICE), target.to(DEVICE)
        optimizer.zero_grad()
        output = model(input)
        loss = criterion(output.view(-1, len(chars) + 1), target.reshape(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def evaluate(model, data_loader, criterion):
    model.eval()
    total_loss = 0.
    with torch.no_grad():
        for input, target in data_loader:
            input, target = input.to(DEVICE), target.to(DEVICE)
            output = model(input)
            loss = criterion(output.view(-1, len(chars) + 1), target.reshape(-1))
            total_loss += loss.item()
    return total_loss / len(data_loader)

def sort_string(input_str, model, char_to_idx, idx_to_char, max_length):
    model.eval()
    with torch.no_grad():
        input_tensor = tensor_from_string_padded(input_str, char_to_idx, max_length).unsqueeze(1)
        output_tensor = model(input_tensor, apply_softmax=True).squeeze(1)  

        sorted_indices = torch.argmax(output_tensor, dim=-1).squeeze()

        sorted_str = "".join([idx_to_char[idx.item()] for idx in sorted_indices if idx.item() != 0]) 

    return sorted_str

model = TransformerModel(len(chars) + 1, HIDDEN_SIZE, NHEAD, HIDDEN_SIZE, NUM_ENCODER_LAYERS, DROPOUT).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

for epoch in range(1, EPOCHS + 1):
    train_loss = train_step(model, train_loader, optimizer, criterion, epoch)
    validation_loss = evaluate(model, validation_loader, criterion)
    print('******************')
    print(f'Epoch: {epoch}, Train loss: {train_loss:.2f}, Validation loss: {validation_loss:.2f}')
    for input_str in ['sortingtest', 'abcdef', 'fedcba', 'prithvi']:
      sorted_str = sort_string(input_str, model, char_to_idx, idx_to_char, MAX_LENGTH)
      print(f"Input string: {input_str} - Sorted string: {sorted_str}")

# Save the model
torch.save(model.state_dict(), 'sequence_transformer.pth')

******************
Epoch: 1, Train loss: 0.36, Validation loss: 0.10
Input string: sortingtest - Sorted string: eginorssttt
Input string: abcdef - Sorted string: abcdef
Input string: fedcba - Sorted string: abcdef
Input string: prithvi - Sorted string: hiiprtv
******************
Epoch: 2, Train loss: 0.16, Validation loss: 0.11
Input string: sortingtest - Sorted string: eginrrssttt
Input string: abcdef - Sorted string: abcdef
Input string: fedcba - Sorted string: abceef
Input string: prithvi - Sorted string: hiiprtv
******************
Epoch: 3, Train loss: 0.12, Validation loss: 0.07
Input string: sortingtest - Sorted string: eginorssttt
Input string: abcdef - Sorted string: abcdef
Input string: fedcba - Sorted string: abcdef
Input string: prithvi - Sorted string: hiiprtv
******************
Epoch: 4, Train loss: 0.10, Validation loss: 0.05
Input string: sortingtest - Sorted string: eginorssttt
Input string: abcdef - Sorted string: abcdef
Input string: fedcba - Sorted string: abcdef
Inp