In [1]:
import os
import math
import time
import random
import pickle
import datetime

from datasets import load_dataset
from tokenizers import Tokenizer, ByteLevelBPETokenizer
from tokenizers.processors import TemplateProcessing

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, Sampler

import sacrebleu

In [2]:
# Train on the GPU if possible
DEVICE = "cuda"

VOCAB_SIZE = 37_000
# For memory control (More than 99% of WMT14 EN-DE sequences have seq_len <= 500)
MAX_SEQ_LEN = 500
MAX_TOKENS_PER_BATCH = 10_000

D_MODEL = 512
# During warm-up 100 M tokens are consumed
WARMUP_STEPS = 100_000_000 // MAX_TOKENS_PER_BATCH
LABEL_SMOOTHING = 0.1

# Base model is trained with 2.5 billion tokens
# MAX_STEPS = 2_500_000_000 // MAX_TOKENS_PER_BATCH
MAX_STEPS = 75_000
EVAL_STEPS = 1_000

## 1. Tokenizer

In [3]:
# Loading WMT4
raw_dataset = load_dataset("wmt14", "de-en")

# Create BPE data for training
if not os.path.exists("bpe_data/train.en") or not os.path.exists("bpe_data/train.de"):
    os.makedirs("bpe_data", exist_ok=True)
    with open("bpe_data/train.en", "w", encoding="utf-8") as f_en, open("bpe_data/train.de", "w", encoding="utf-8") as f_de:
        for example in raw_dataset["train"]:
            f_en.write(example["translation"]["en"] + "\n")
            f_de.write(example["translation"]["de"] + "\n")

# Train tokenizer with shared source and target sequences
if not os.path.exists("bpe_data/tokenizer.json"):    
    tokenizer = ByteLevelBPETokenizer()
    tokenizer.train(["bpe_data/train.en", "bpe_data/train.de"],
                    vocab_size=VOCAB_SIZE,
                    show_progress=True,
                    special_tokens=["<PAD>", "<START>", "<END>", "<UNK>"])
    
    tokenizer._tokenizer.post_processor = TemplateProcessing(
        single=f"<START>:0 $A:0 <END>:0",
        pair=f"<START>:0 $A:0 <END>:0 <START>:1 $B:1 <END>:1",
        special_tokens=[
            ("<START>", tokenizer.token_to_id("<START>")),
            ("<END>", tokenizer.token_to_id("<END>")),
        ],
    )    
    tokenizer.save("bpe_data/tokenizer.json")

tokenizer = Tokenizer.from_file("bpe_data/tokenizer.json")
tokenizer.enable_truncation(max_length=MAX_SEQ_LEN)

PAD_ID = tokenizer.token_to_id("<PAD>")
START_ID = tokenizer.token_to_id("<START>")
END_ID = tokenizer.token_to_id("<END>")
UNK_ID = tokenizer.token_to_id("<UNK>")

### 2. Data Loader

In [4]:
class WMT14Dataset(Dataset):
    def __init__(self, split="train", tokenizer=None):
        super().__init__()
        self.data = raw_dataset[split].flatten()
        self.tokenizer = tokenizer
        self.data = self.data.map(self._tokenize_batch, batched=True)
        self.length = self.data["length"]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        src_ids = self.data[idx]["input_ids"]
        tgt_ids = self.data[idx]["labels"]
        return torch.tensor(src_ids), torch.tensor(tgt_ids)

    def _tokenize_batch(self, example):
        src_enc = tokenizer.encode_batch(example["translation.de"])
        src_ids, src_len = [], []
        for e in src_enc:
            src_ids.append(e.ids)
            src_len.append(len(e.ids))
        
        tgt_enc = tokenizer.encode_batch(example["translation.en"])
        tgt_ids, tgt_len = [], []
        for e in tgt_enc:
            tgt_ids.append(e.ids)
            tgt_len.append(len(e.ids))

        seq_len = [max(len(a), len(b)) for a, b in zip(src_ids, tgt_ids)]
        return {"input_ids": src_ids, "labels": tgt_ids, "length": seq_len}

In [5]:
class MaxTokensBucketSampler(Sampler):
    def __init__(self, length, max_tokens=25000, shuffle=True):
        self.length = length
        self.max_tokens = max_tokens
        self.shuffle = shuffle
        self.bucket_batches = []
        self._build_buckets()

    def _build_buckets(self):
        sorted_len_indices = sorted(range(len(self.length)), key=lambda i: self.length[i])
        batch = []
        tokens = 0
        for idx in sorted_len_indices:
            length = self.length[idx]
            tokens += length
            batch.append(idx)
            if tokens >= self.max_tokens:
                self.bucket_batches.append(batch)
                batch = []
                tokens = 0
        if batch:
            self.bucket_batches.append(batch)

    def __iter__(self):
        bucket_indices = list(range(len(self.bucket_batches)))
        if self.shuffle:
            random.shuffle(bucket_indices)
        for idx in bucket_indices:
            yield self.bucket_batches[idx]

    def __len__(self):
        return len(self.length)

In [6]:
def wmt_collate(batch, pad_token_id, batch_first=False):
    src_batch, tgt_batch = zip(*batch)
    src_padded = pad_sequence(src_batch, padding_value=pad_token_id, batch_first=batch_first)
    tgt_padded = pad_sequence(tgt_batch, padding_value=pad_token_id, batch_first=batch_first)
    return src_padded, tgt_padded

In [7]:
MAX_SEQ_LEN, MAX_TOKENS_PER_BATCH

(500, 10000)

In [8]:
train_dataset = WMT14Dataset(split="train", tokenizer=tokenizer)
train_sampler = MaxTokensBucketSampler(train_dataset.length, max_tokens=MAX_TOKENS_PER_BATCH, shuffle=True)
train_loader = DataLoader(train_dataset, batch_sampler=train_sampler, collate_fn=lambda batch: wmt_collate(batch, PAD_ID))

In [9]:
eval_dataset = WMT14Dataset(split="validation", tokenizer=tokenizer)
eval_loader = DataLoader(eval_dataset, batch_size=32, shuffle=False, collate_fn=lambda batch: wmt_collate(batch, PAD_ID))

### 3. Modeling

In [10]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        den = torch.exp(- torch.arange(0, d_model, 2) * math.log(10000.0) / d_model)
        pos = torch.arange(0, max_len).reshape(max_len, 1)
        pos_embedding = torch.zeros((max_len, d_model))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)
        
        self.dropout = nn.Dropout(dropout)        
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, x):
        return self.dropout(x + self.pos_embedding[:x.size(0), :])

In [11]:
class TransformerModel(nn.Module):
    def __init__(
        self,
        vocab_size,
        pad_token_id,
        d_model=512,
        num_heads=8,
        num_layers=6,
        dim_ff=2048,
        dropout=0.1
    ):
        super().__init__()
        self.pad_token_id = pad_token_id
        self.d_model = d_model
        self.shared_embed = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=dim_ff,
            dropout=dropout
        )
        self.fc_out = nn.Linear(d_model, vocab_size)
        self._init_weights()

    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(
        self,
        src,
        tgt,
        src_mask=None,
        tgt_mask=None,
        memory_mask=None,
        src_key_padding_mask=None,
        tgt_key_padding_mask=None,
        memory_key_padding_mask=None
    ):
        src_mask, tgt_mask, src_key_padding_mask, tgt_key_padding_mask = self._create_masks(src, tgt, self.pad_token_id, src.device)
        
        src_emb = self.pos_encoder(self.shared_embed(src) * math.sqrt(self.d_model))
        tgt_emb = self.pos_encoder(self.shared_embed(tgt) * math.sqrt(self.d_model))
        output = self.transformer(src_emb,
                                  tgt_emb,
                                  src_mask=src_mask,
                                  tgt_mask=tgt_mask,
                                  memory_mask=memory_mask,
                                  src_key_padding_mask=src_key_padding_mask,
                                  tgt_key_padding_mask=tgt_key_padding_mask,
                                  memory_key_padding_mask=memory_key_padding_mask)
        return self.fc_out(output)

    def _create_masks(self, src, tgt, pad_id, device):
        src_seq_len = src.shape[0]
        tgt_seq_len = tgt.shape[0]
    
        src_mask = torch.zeros((src_seq_len, src_seq_len)).type(torch.bool).to(device)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_seq_len).type(torch.bool).to(device)
        
        src_key_padding_mask = (src == pad_id).transpose(0, 1).to(device)
        tgt_key_padding_mask = (tgt == pad_id).transpose(0, 1).to(device)

        return src_mask, tgt_mask, src_key_padding_mask, tgt_key_padding_mask
    
    def encode(self, src, src_mask=None, src_key_padding_mask=None):
        return self.transformer.encoder(self.pos_encoder(self.shared_embed(src) * math.sqrt(self.d_model)),
                                        mask=src_mask,
                                        src_key_padding_mask=src_key_padding_mask)

    def decode(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        return self.transformer.decoder(self.pos_encoder(self.shared_embed(tgt) * math.sqrt(self.d_model)),
                                        memory,
                                        tgt_mask=tgt_mask,
                                        memory_mask=memory_mask,
                                        tgt_key_padding_mask=tgt_key_padding_mask,
                                        memory_key_padding_mask=memory_key_padding_mask)

In [12]:
class InverseSqrtScheduler:
    def __init__(self, model_size, tokens_per_step, warmup, optimizer):
        self.model_size = model_size
        self.tokens_per_step = tokens_per_step
        self.warmup = warmup
        self.optimizer = optimizer
        self._step = 0
        # Reference paper uses 25_000 tokens per step and 4_000 warm-up steps
        self.factor = (self.tokens_per_step / 25_000) * (self.warmup / 4_000) ** 0.5

    def step(self, scaler=None):
        self._step += 1
        lr = self.factor * (self.model_size ** -0.5) * min(self._step ** -0.5, self._step * (self.warmup ** -1.5))
        for p in self.optimizer.param_groups:
            p['lr'] = lr
        if scaler:
            scaler.step(self.optimizer)
        else:
            self.optimizer.step()

    def lr_at_step(self, step):
        lr = self.factor * (self.model_size ** -0.5) * min(step ** -0.5, step * (self.warmup ** -1.5))
        return lr

### 4. Training

In [13]:
VOCAB_SIZE, D_MODEL, WARMUP_STEPS, LABEL_SMOOTHING, PAD_ID

(37000, 512, 10000, 0.1, 0)

In [14]:
# Preparing
scaler = torch.amp.GradScaler(DEVICE)
model = TransformerModel(VOCAB_SIZE, PAD_ID).to(DEVICE)
opt = InverseSqrtScheduler(D_MODEL, MAX_TOKENS_PER_BATCH, WARMUP_STEPS, optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTHING, ignore_index=PAD_ID)



In [15]:
EVAL_STEPS, MAX_STEPS

(1000, 75000)

In [17]:
# Step control
step = 0
peak_memory_overall = 0
val_losses, step_checkpoints = [], []
best_val_loss = 1e6

# # tell CUDA to start recording memory allocations
# torch.cuda.memory._record_memory_history(enabled='all')

print("Training started:", datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))

model.train()
torch.cuda.reset_peak_memory_stats()
start_time = time.time()

while step < MAX_STEPS:
    for src, tgt in train_loader:
        # print("-- beginning")
        # torch.cuda.reset_peak_memory_stats()
        # before = torch.cuda.memory_allocated()
        
        # Beggining allocation
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)
        
        # after = torch.cuda.memory_allocated()
        # allocated = (after - before) / (1024**2)
        # peak_memory = torch.cuda.max_memory_allocated() / (1024**2)
        # print(f"allocated: {allocated:.0f}M, peak: {peak_memory:.0f}M")
        # print(f"src shape: {src.shape} and tgt shape: {tgt.shape}")
    
        # Auto casting to float16 for memory reduction
        with torch.amp.autocast(DEVICE, torch.float16):
            # print("-- forward")
            # torch.cuda.reset_peak_memory_stats()
            # before = torch.cuda.memory_allocated()
            
            # Forward pass
            output = model(src, tgt[:-1, :])
            
            # after = torch.cuda.memory_allocated()
            # allocated = (after - before) / (1024**2)
            # peak_memory = torch.cuda.max_memory_allocated() / (1024**2)
            # print(f"allocated: {allocated:.0f}M, peak: {peak_memory:.0f}M")
            
            # print("-- zero grad")
            # torch.cuda.reset_peak_memory_stats()
            # before = torch.cuda.memory_allocated()
            
            # Zeroing the gradient
            opt.optimizer.zero_grad()
            
            # after = torch.cuda.memory_allocated()
            # allocated = (after - before) / (1024**2)
            # peak_memory = torch.cuda.max_memory_allocated() / (1024**2)
            # print(f"allocated: {allocated:.0f}M, peak: {peak_memory:.0f}M")
            
            # print("-- compute loss")
            # torch.cuda.reset_peak_memory_stats()
            # before = torch.cuda.memory_allocated()
            
            # Loss computation
            logits_flat = output.reshape(-1, output.shape[-1])
            tgt_flat = tgt[1:, :].reshape(-1)
            loss = criterion(logits_flat, tgt_flat)
            
            # after = torch.cuda.memory_allocated()
            # allocated = (after - before) / (1024**2)
            # peak_memory = torch.cuda.max_memory_allocated() / (1024**2)
            # print(f"allocated: {allocated:.0f}M, peak: {peak_memory:.0f}M")
        
        # print("-- backward")
        # torch.cuda.reset_peak_memory_stats()
        # before = torch.cuda.memory_allocated()
    
        # Loss backward
        scaler.scale(loss).backward()
        
        # after = torch.cuda.memory_allocated()
        # allocated = (after - before) / (1024**2)
        # peak_memory = torch.cuda.max_memory_allocated() / (1024**2)
        # print(f"allocated: {allocated:.0f}M, peak: {peak_memory:.0f}M")
        
        # print("-- optimizer step")
        # torch.cuda.reset_peak_memory_stats()
        # before = torch.cuda.memory_allocated()
        
        # Optimizer step
        opt.step(scaler)
        
        # after = torch.cuda.memory_allocated()
        # allocated = (after - before) / (1024**2)
        # peak_memory = torch.cuda.max_memory_allocated() / (1024**2)
        # print(f"allocated: {allocated:.0f}M, peak: {peak_memory:.0f}M")
        
        # print("-- scaler update")
        # torch.cuda.reset_peak_memory_stats()
        # before = torch.cuda.memory_allocated()
        
        # Scaler update
        scaler.update()
        
        # after = torch.cuda.memory_allocated()
        # allocated = (after - before) / (1024**2)
        # peak_memory = torch.cuda.max_memory_allocated() / (1024**2)
        # print(f"allocated: {allocated:.0f}M, peak: {peak_memory:.0f}M")
        
        # print(f"***************** finish train step: {step}")
        # print()
        
        torch.cuda.empty_cache()
        step += 1
        
        if step % EVAL_STEPS == 0:
            # Saving some stats
            delta_time = (time.time() - start_time) / EVAL_STEPS
            peak_memory = torch.cuda.max_memory_allocated() / (1024**3)
            if peak_memory > peak_memory_overall:
                peak_memory_overall = peak_memory
            step_checkpoints.append(step)
            
            # Evaluation
            model.eval()
            val_loss = 0
            with torch.no_grad():
                for src, tgt in eval_loader:
                    src, tgt = src.to(DEVICE), tgt.to(DEVICE)
                    output = model(src, tgt[:-1, :])
                    logits_flat = output.reshape(-1, output.shape[-1])
                    tgt_flat = tgt[1:, :].reshape(-1)
                    loss = criterion(logits_flat, tgt_flat)
                    val_loss += loss.item()
                    del output, loss
                    torch.cuda.empty_cache()
            avg_val_loss = val_loss / len(eval_loader)
            val_losses.append(avg_val_loss)
    
            # Report every EVAL STEPS
            print(f"Step {step}: val loss = {val_losses[-1]:.4f}, time per step = {delta_time:.4f} s, peak memory = {peak_memory:.2f} GB")
            # Once training is done, we want to save out the model
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                torch.save(model.state_dict(), f"models/wmt14_maxseqlen{MAX_SEQ_LEN}_maxtokens{MAX_TOKENS_PER_BATCH}_maxsteps{MAX_STEPS}_best.pt")
            
            # Back to training
            model.train()
            torch.cuda.reset_peak_memory_stats()
            start_time = time.time()

        # Stop batch iteration
        if step >= MAX_STEPS:
            break
        
# # Save a snapshot of the memory allocations
# snapshot = torch.cuda.memory._snapshot()
# with open("snapshot.pickle", "wb") as f_out:
#     pickle.dump(snapshot, f_out)
# # Tell CUDA to stop recording memory allocations now
# torch.cuda.memory._record_memory_history(enabled=None)

print("Training finished:", datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
print(f"Overall peak memory was = {peak_memory_overall:.2f} GB.")
torch.cuda.empty_cache()

torch.save(model.state_dict(), f"models/wmt14_maxseqlen{MAX_SEQ_LEN}_maxtokens{MAX_TOKENS_PER_BATCH}_maxsteps{MAX_STEPS}_last.pt")

Training started: 2025-08-04 05:45:59
Step 1000: val loss = 7.5100, time per step = 0.3723 s, peak memory = 10.30 GB
Step 2000: val loss = 6.8585, time per step = 0.3554 s, peak memory = 9.81 GB
Step 3000: val loss = 6.5398, time per step = 0.3505 s, peak memory = 9.88 GB
Step 4000: val loss = 6.3124, time per step = 0.3514 s, peak memory = 9.88 GB
Step 5000: val loss = 6.0964, time per step = 0.3495 s, peak memory = 9.76 GB
Step 6000: val loss = 5.7165, time per step = 0.3513 s, peak memory = 9.88 GB
Step 7000: val loss = 5.1879, time per step = 0.3517 s, peak memory = 10.04 GB
Step 8000: val loss = 4.7147, time per step = 0.3526 s, peak memory = 9.88 GB
Step 9000: val loss = 4.2205, time per step = 0.3514 s, peak memory = 9.88 GB
Step 10000: val loss = 4.2175, time per step = 0.3531 s, peak memory = 10.03 GB
Step 11000: val loss = 4.0104, time per step = 0.3478 s, peak memory = 10.04 GB
Step 12000: val loss = 3.8327, time per step = 0.3530 s, peak memory = 10.19 GB
Step 13000: val lo

### 5. Translating

In [18]:
MAX_SEQ_LEN, MAX_TOKENS_PER_BATCH, MAX_STEPS

(500, 10000, 75000)

In [19]:
state_dict = torch.load(f"models/wmt14_maxseqlen{MAX_SEQ_LEN}_maxtokens{MAX_TOKENS_PER_BATCH}_maxsteps{MAX_STEPS}_best.pt", map_location=torch.device(DEVICE), weights_only=True)

In [20]:
model.load_state_dict(state_dict)

<All keys matched successfully>

In [21]:
# Most probable next token
def greedy_decode(model, src_ids, bos_token_id, eos_token_id, device='cuda'):
    model.eval()

    with torch.no_grad():
        # Tokenizing the input
        src_tensor = torch.tensor(src_ids).view(-1, 1).to(device)
        num_tokens = src_tensor.shape[0]
        max_len = num_tokens + 50
        src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool).to(device)
        
        # Encoding
        memory = model.encode(src_tensor, src_mask=src_mask)
        # Starting with <START>
        generated = torch.ones(1, 1).fill_(bos_token_id).type(torch.long).to(device)
    
        for cur_len in range(1, max_len):
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(cur_len).type(torch.bool).to(device)
            output = model.decode(generated, memory, tgt_mask=tgt_mask).transpose(0, 1)        
            logits = model.fc_out(output[:, -1])
            _, next_token = torch.max(logits, dim=1)
            next_token = next_token.item()
            
            generated = torch.cat([generated, torch.ones(1, 1).type(torch.long).fill_(next_token).to(device)], dim=0)
            if next_token == eos_token_id:
                break

        torch.cuda.empty_cache()
    
    return generated.view(1, -1)[0]

In [22]:
src_ids = tokenizer.encode("Das ist ein Beispieltext.").ids # This is an example text.
preds = greedy_decode(model, src_ids, START_ID, END_ID, device=DEVICE)
tokenizer.decode(preds.cpu().numpy(), skip_special_tokens=True)

'This is an example text.'

In [23]:
src_ids = tokenizer.encode("Die Natur ist wunderschön.").ids # Nature is beautiful.
preds = greedy_decode(model, src_ids, START_ID, END_ID, device=DEVICE)
tokenizer.decode(preds.cpu().numpy(), skip_special_tokens=True)

'Nature is beautiful.'

In [24]:
# Adapted from https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
def generate_beam(model, src_ids, pad_token_id, start_token_id, end_token_id, beam_size=4, alpha=0.6, early_stopping=True, device="cuda"):
    # check inputs
    assert beam_size >= 1

    model.eval()
    
    with torch.no_grad():    
        src_ids = src_ids.to(device) # shape: (src_len, batch_size)
        
        # batch size
        src_len, batch_size = src_ids.shape
        max_len = src_len + 50
        src_mask = torch.zeros((src_len, src_len), dtype=torch.bool, device=device) # shape: (src_len, src_len)
        src_key_padding_mask = (src_ids == pad_token_id).transpose(0, 1).to(device) # shape: (batch_size, src_len)
    
        # encode source
        memory = model.encode(src_ids, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask) # shape: (src_len, batch_size, d_model)
        
        # repeating beam size times
        memory = memory.repeat_interleave(beam_size, dim=1)                             # shape: (src_len, beam_size * batch_size, d_model)
        src_key_padding_mask = src_key_padding_mask.repeat_interleave(beam_size, dim=0) # shape: (beam_size * batch_size, src_len)
        
        # generated sentences (batch with beam current hypotheses)
        generated = torch.full((max_len, batch_size * beam_size), pad_token_id, dtype=torch.long, device=device)  # upcoming output
        generated[0].fill_(start_token_id)
        
        # generated hypotheses
        generated_hyps = [BeamHypotheses(beam_size, max_len, alpha, early_stopping) for _ in range(batch_size)]
        
        # scores for each sentence in the beam
        beam_scores = torch.zeros((batch_size, beam_size), device=device)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view(-1)
        
        # current position
        cur_len = 1
        
        # done sentences
        done = [False for _ in range(batch_size)]
        
        while cur_len < max_len:
            tgt_input = generated[:cur_len, :]
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(cur_len).type(torch.bool).to(device)
            tgt_key_padding_mask = (tgt_input == pad_token_id).transpose(0, 1).to(device)
            output = model.decode(tgt_input, memory,
                                  tgt_mask=tgt_mask,
                                  memory_mask=None,
                                  tgt_key_padding_mask=tgt_key_padding_mask,
                                  memory_key_padding_mask=src_key_padding_mask)
            logits = model.fc_out(output[-1])
            scores = F.log_softmax(logits, dim=-1) # shape: (batch_size * beam_size, vocab_size)
            vocab_size = scores.shape[-1]
    
            # select next words with scores
            _scores = scores + beam_scores[:, None].expand_as(scores)             # shape: (batch_size * beam_size, vocab_size)
            _scores = _scores.view(batch_size, beam_size * vocab_size)            # shape: (batch_size, beam_size * vocab_size)
    
            next_scores, next_words = torch.topk(_scores, 2 * beam_size, dim=1, largest=True, sorted=True)
    
            # next batch beam content
            # list of (batch_size * beam_size) tuple(next hypothesis score, next word, current position in the batch)
            next_batch_beam = []
    
            # for each sentence
            for sent_id in range(batch_size):
    
                # if we are done with this sentence
                done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item())
                if done[sent_id]:
                    next_batch_beam.extend([(0, pad_token_id, 0)] * beam_size)  # pad the batch
                    continue
    
                # next sentence beam content
                next_sent_beam = []
    
                # next words for this sentence
                for idx, value in zip(next_words[sent_id], next_scores[sent_id]):
    
                    # get beam and word IDs
                    beam_id = idx // vocab_size
                    word_id = idx % vocab_size
    
                    # end of sentence, or next word
                    if word_id == end_token_id or cur_len + 1 == max_len:
                        generated_hyps[sent_id].add(generated[:cur_len, sent_id * beam_size + beam_id].clone(), value.item())
                    else:
                        next_sent_beam.append((value, word_id, sent_id * beam_size + beam_id))
    
                    # the beam for next step is full
                    if len(next_sent_beam) == beam_size:
                        break
    
                # update next beam content
                if len(next_sent_beam) == 0:
                    next_sent_beam = [(0, pad_token_id, 0)] * beam_size  # pad the batch
                next_batch_beam.extend(next_sent_beam)

            # sanity check / prepare next batch
            beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
            beam_words = generated.new([x[1] for x in next_batch_beam])
            beam_idx = torch.Tensor([x[2] for x in next_batch_beam]).long()
    
            # re-order batch and internal states
            generated = generated[:, beam_idx]
            generated[cur_len] = beam_words
    
            # update current length
            cur_len = cur_len + 1
    
            # stop when we are done with each sentence
            if all(done):
                break
        
        torch.cuda.empty_cache()
    
    # select the best hypotheses
    tgt_len = torch.zeros(batch_size).long()
    best = []

    for i, hypotheses in enumerate(generated_hyps):
        best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
        tgt_len[i] = len(best_hyp) + 1  # +1 for the <EOS> symbol
        best.append(best_hyp)

    # generate target batch
    decoded = torch.zeros(tgt_len.max().item(), batch_size).fill_(pad_token_id).long()
    for i, hypo in enumerate(best):
        decoded[:tgt_len[i] - 1, i] = hypo
        decoded[tgt_len[i] - 1, i] = end_token_id
    
    return decoded.T

class BeamHypotheses(object):

    def __init__(self, n_hyp, max_len, length_penalty, early_stopping):
        """
        Initialize n-best list of hypotheses.
        """
        self.max_len = max_len - 1  # ignoring <BOS>
        self.length_penalty = length_penalty
        self.early_stopping = early_stopping
        self.n_hyp = n_hyp
        self.hyp = []
        self.worst_score = 1e9

    def __len__(self):
        """
        Number of hypotheses in the list.
        """
        return len(self.hyp)

    def add(self, hyp, sum_logprobs):
        """
        Add a new hypothesis to the list.
        """
        score = sum_logprobs / len(hyp) ** self.length_penalty
        if len(self) < self.n_hyp or score > self.worst_score:
            self.hyp.append((score, hyp))
            if len(self) > self.n_hyp:
                sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
                del self.hyp[sorted_scores[0][1]]
                self.worst_score = sorted_scores[1][0]
            else:
                self.worst_score = min(score, self.worst_score)

    def is_done(self, best_sum_logprobs):
        """
        If there are enough hypotheses and that none of the hypotheses being generated
        can become better than the worst one in the heap, then we are done with this sentence.
        """
        if len(self) < self.n_hyp:
            return False
        elif self.early_stopping:
            return True
        else:
            return self.worst_score >= best_sum_logprobs / self.max_len ** self.length_penalty

### 6. BLEU Score

In [25]:
test_dataset = WMT14Dataset(split="test", tokenizer=tokenizer)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, collate_fn=lambda batch: wmt_collate(batch, PAD_ID))

### 6.1 Greedy decoding

In [26]:
# # Testing
# greedy_preds = []
# for src in test_dataset.data["translation.de"]:
#     src_ids = tokenizer.encode(src).ids
#     greedy_preds.append(greedy_decode(model, src_ids, START_ID, END_ID, device=DEVICE))
#     if len(greedy_preds) == 2:
#         break

In [27]:
start_time = time.time()

greedy_preds = []
for src in test_dataset.data["translation.de"]:
    src_ids = tokenizer.encode(src).ids
    preds = greedy_decode(model, src_ids, START_ID, END_ID, device=DEVICE)
    decoded = tokenizer.decode(preds.cpu().numpy(), skip_special_tokens=True)
    greedy_preds.append(decoded)

delta_time = (time.time() - start_time)
print(f"Greedy decoding finished in {delta_time:.2f} s")

Greedy decoding finished in 368.23 s


In [28]:
bleu = sacrebleu.corpus_bleu(greedy_preds, [test_dataset.data["translation.en"]])
print(f"BLEU score: {bleu.score:.2f}")

BLEU score: 26.80


### 6.2 Beam Search decoding

In [29]:
# # Testing
# beam_preds = []
# for src, _ in test_loader:
#     decoded = generate_beam(model, src, PAD_ID, START_ID, END_ID, beam_size=4, alpha=0.6, device=DEVICE)
#     beam_preds += tokenizer.decode_batch(decoded.cpu().numpy(), skip_special_tokens=True)
#     break

In [30]:
start_time = time.time()

beam_preds = []
for src, _ in test_loader:
    preds = generate_beam(model, src, PAD_ID, START_ID, END_ID, beam_size=4, alpha=0.6, device=DEVICE)
    decoded = tokenizer.decode_batch(preds.cpu().numpy(), skip_special_tokens=True)
    beam_preds += decoded

delta_time = (time.time() - start_time)    
print(f"Beam search decoding finished in {delta_time:.2f} s")

Beam search decoding finished in 205.47 s


In [31]:
bleu = sacrebleu.corpus_bleu(beam_preds, [test_dataset.data["translation.en"]])
print(f"BLEU score: {bleu.score:.2f}")

BLEU score: 27.50
