# Datasets to test on
- WMT en-fr
- WMT en-gr
- WikiText-103

In [1]:
pip install -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math
from torch import nn
from torch.utils.data import DataLoader, Subset
from torch.nn.utils.rnn import pad_sequence
from tokenizers import Tokenizer
from sklearn.metrics import f1_score

In [3]:
from datasets import load_dataset

dataset = load_dataset('wmt14', 'fr-en')

train_data = dataset['train'].select(range(1000000))
test_data = dataset['test']
print(train_data[0])



  from .autonotebook import tqdm as notebook_tqdm


{'translation': {'en': 'Resumption of the session', 'fr': 'Reprise de la session'}}


In [4]:
def extract_text(dataset, src="en", tgt="fr"):
    for example in dataset:
        yield example["translation"][src]
        yield example["translation"][tgt]


In [5]:
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()

trainer = BpeTrainer(
    vocab_size=32000,
    special_tokens=["[PAD]", "[UNK]", "[BOS]", "[EOS]"]
)

tokenizer.train_from_iterator(
    extract_text(train_data),
    trainer=trainer
)







In [6]:
from sklearn.metrics import f1_score
from nltk.translate.bleu_score import corpus_bleu
import numpy as np

def decode_to_tokens(ids_list, tokenizer):
    """
    Converts a list of token IDs into a list of string tokens,
    skipping special tokens like [PAD], [BOS], [EOS].
    """
    # .decode() handles the skipping of special tokens automatically
    # .split() turns the sentence string into a list of tokens for BLEU
    decoded_text = tokenizer.decode(ids_list, skip_special_tokens=True)
    return decoded_text.split()

In [7]:
#tokenizer.save("wmt14_bpe.json")
PAD_ID = tokenizer.token_to_id("[PAD]")
BOS_ID = tokenizer.token_to_id("[BOS]")
EOS_ID = tokenizer.token_to_id("[EOS]")

def tokenize_sentence(sentence, add_special_tokens=True):
    encoding = tokenizer.encode(sentence)
    token_ids = encoding.ids

    if add_special_tokens:
        token_ids = [BOS_ID] + token_ids + [EOS_ID]

    return token_ids



In [8]:
from torch.nn.utils.rnn import pad_sequence
def collate_fn(batch):
    """
    batch: list of dicts from wmt14 
    Example item: {'translation': {'en': 'Hello', 'fr': 'Bonjour'}}
    """
    src_list = []
    tgt_list = []
    
    for item in batch:
        en_text = item['translation']['en']
        fr_text = item['translation']['fr']
        
        src_ids = tokenize_sentence(en_text)[:128] 
        tgt_ids = tokenize_sentence(fr_text)[:128]
        
        src_list.append(torch.tensor(src_ids, dtype=torch.long))
        tgt_list.append(torch.tensor(tgt_ids, dtype=torch.long))

    src_padded = pad_sequence(
        src_list,
        batch_first=True,
        padding_value=PAD_ID
    )

    tgt_padded = pad_sequence(
        tgt_list,
        batch_first=True,
        padding_value=PAD_ID
    )

    return src_padded, tgt_padded

In [9]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    train_data,
    batch_size=16,
    shuffle=True,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_data,
    batch_size=16,
    shuffle=True,
    collate_fn=collate_fn
)

In [41]:
#def generate_compact_dataset(num_samples=10000):
#    data_tokens = torch.arange(1, 11) 
#    
#    all_inputs = []
#    all_targets = []
#
#    for _ in range(num_samples):
#        perm = torch.randperm(10)
#        sample_data = data_tokens[perm[:6]]
#        
#        is_relational = torch.rand(1) > 0.5
#        
#        if is_relational:
#            cmd = torch.tensor([12])
#            # Pick a key from the first 5 (so there is a neighbor at +1)
#            key_idx = torch.randint(0, 5, (1,)).item()
#            query = sample_data[key_idx].view(1)
#            target = sample_data[key_idx + 1]
#        else:
#            # POSITIONAL: Input[7] is an Index (1-6); Target is data at that index
#            cmd = torch.tensor([11])
#            idx_to_pull = torch.randint(0, 6, (1,)).item()
#            query = torch.tensor([idx_to_pull + 1])
#            target = sample_data[idx_to_pull]
#
#        full_input = torch.cat([sample_data, cmd, query])
#        
#        all_inputs.append(full_input)
#        all_targets.append(target)
#
#    return torch.stack(all_inputs), torch.stack(all_targets)
#
## Generate the 10,000 samples
#inputs, targets = generate_compact_dataset(10000)
#
#print(f"Dataset Shape: {inputs.shape}") # [10000, 8]
#print(f"Sample 0 (Input): {inputs[0].tolist()} -> Target: {targets[0].item()}")

In [42]:
#indices = torch.randperm(len(inputs))
#
#train_size = int(0.5*len(inputs))
#
#train_idx = indices[:train_size]
#test_idx = indices[train_size:]
#
#train_inputs, train_targets = inputs[train_idx], targets[train_idx]
#test_inputs,  test_targets  = inputs[test_idx],  targets[test_idx]

In [10]:
def get_pe(seq_len, d_model):
    pe = torch.zeros(seq_len, d_model)
    position = torch.arange(0, seq_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(1000000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe.unsqueeze(0)

In [11]:
class embeddings(nn.Module):
    def __init__(self, d, vocab_size=32000, max_len=512):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d)
        self.register_buffer('pe', get_pe(max_len, d))

    def forward(self, x):
        te = self.token_emb(x)
        return te + self.pe[:, :x.size(1), :]

In [12]:
class relative_matrix(nn.Module):
    def __init__(self, e_dim: int):
        super().__init__()
        self.linear = nn.Linear(e_dim, e_dim)
        self.bias = nn.Parameter(torch.zeros(e_dim))

    def forward(self, x):
        x = self.linear(x)
        rel = x.unsqueeze(2) - x.unsqueeze(1) 

        out = rel.sum(dim=2) + self.bias
        return out

In [13]:
class attention_matrix(nn.Module):
    def __init__(self, e_dim):
        super().__init__()
        self.e_dim = e_dim

    def forward(self, x, y):
        return torch.softmax((x @ y.transpose(-2, -1))/math.sqrt(self.e_dim), -1)

In [14]:
class transformer(nn.Module):
    def __init__(self, e_dim: int, vocab_size=32000):
        super().__init__()
        self.embeddings = embeddings(e_dim, vocab_size)
        self.relative = relative_matrix(e_dim)
        self.attention = attention_matrix(e_dim)
        self.V = nn.Linear(e_dim, e_dim)
        self.norm1 = nn.LayerNorm(e_dim)
        self.norm2 = nn.LayerNorm(e_dim)
        self.dropout = nn.Dropout(0.1)
        self.mlp = nn.Sequential(
            nn.Linear(e_dim, e_dim * 4),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(e_dim * 4, e_dim)
        )
        self.e_dim = e_dim
        self.unembed = nn.Linear(e_dim, vocab_size, bias=False)
        self._init_weights()

    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, x):
        x_emb = self.embeddings(x) * math.sqrt(self.e_dim)
        rel_key = self.relative(x_emb)
        value = self.V(x_emb)
        attn_map = self.attention(x_emb, rel_key) 
        attn_out = attn_map @ value
        x = self.norm1(x_emb + self.dropout(attn_out))
        mlp_out = self.mlp(x)
        x = self.norm2(x + self.dropout(mlp_out))
        logits = self.unembed(x)
        return logits

# Test without scheduled LR

In [15]:
epochs = 5
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = transformer(512).to(DEVICE)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1, ignore_index=PAD_ID)

In [17]:
#def noam_lambda(step, d_model=512, warmup=4000):
#    # The scheduler passes the current_step (starting at 0)
#    # Adding 1 prevents division by zero
#    step += 1 
#    return (d_model ** -0.5) * min(step ** -0.5, step * (warmup ** -1.5))

# Use 1.0 as the base LR so the lambda controls the absolute value
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, betas=(0.9, 0.98), eps=1e-9, weight_decay = 0.01)
#scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=noam_lambda)

In [18]:
model.to(DEVICE)
global_step = 0
log_interval = 1000
accumulation_steps = 4 

for epoch in range(epochs):
    model.train()
    optimizer.zero_grad() 
    
    for i, (src, tgt) in enumerate(train_loader):
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)
        
        decoder_input  = tgt[:, :-1]  
        decoder_target = tgt[:, 1:] 
        logits = model(decoder_input) 

        # 1. Standard Loss Calculation
        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            decoder_target.reshape(-1)
        )
        
        # Scale loss so gradients are averaged across the full effective batch
        loss_scaled = loss / accumulation_steps 
        loss_scaled.backward()

        # Train Accuracy (for the current mini-batch)
        with torch.no_grad():
            pred_tokens = logits.argmax(-1)
            mask = decoder_target != PAD_ID
            train_acc = ((pred_tokens == decoder_target) & mask).float().sum() / mask.sum()

        # 2. Only update weights and LOG every 'accumulation_steps'
        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            # scheduler.step() # Uncomment if using Noam
            optimizer.zero_grad()
            
            global_step += 1
            
            # --- EVALUATION BLOCK: Now strictly inside the update check ---
            if global_step % log_interval == 0:
                model.eval()
                total_test_loss = 0
                all_preds_ids, all_trues_ids = [], []
                bleu_hypotheses, bleu_references = [], []
                
                with torch.no_grad():
                    for test_src, test_tgt in test_loader:
                        test_src, test_tgt = test_src.to(DEVICE), test_tgt.to(DEVICE)
                        dec_in_test, dec_tgt_test = test_tgt[:, :-1], test_tgt[:, 1:]
                        
                        test_logits = model(dec_in_test)
                        t_loss = criterion(test_logits.reshape(-1, test_logits.size(-1)), dec_tgt_test.reshape(-1))
                        total_test_loss += t_loss.item()

                        test_pred_ids = test_logits.argmax(-1)
                        test_mask = (dec_tgt_test != PAD_ID)
                        
                        all_preds_ids.extend(test_pred_ids[test_mask].cpu().numpy())
                        all_trues_ids.extend(dec_tgt_test[test_mask].cpu().numpy())

                        for b in range(test_pred_ids.size(0)):
                            hyp_tokens = decode_to_tokens(test_pred_ids[b].tolist(), tokenizer)
                            ref_tokens = decode_to_tokens(dec_tgt_test[b].tolist(), tokenizer)
                            bleu_hypotheses.append(hyp_tokens)
                            bleu_references.append([ref_tokens])

                # Metric Calculations
                avg_test_loss = total_test_loss / len(test_loader)
                test_acc = (np.array(all_preds_ids) == np.array(all_trues_ids)).mean()
                test_f1 = f1_score(all_trues_ids, all_preds_ids, average='weighted', zero_division=0)
                test_bleu = corpus_bleu(bleu_references, bleu_hypotheses)
                current_lr = optimizer.param_groups[0]['lr']
                
                output_str = (
                    f"| Step: {global_step:7d} | Epoch: {epoch:3d} | LR: {current_lr:.8f} |\n"
                    f"| Train Acc: {train_acc.item():.4f} | Train Loss: {loss.item():.4f} |\n"
                    f"| Test Acc:  {test_acc:.4f} | Test Loss:  {avg_test_loss:.4f} |\n"
                    f"| Test F1:   {test_f1:.4f} | Test BLEU:  {test_bleu:.4f} |\n"
                    f"{'-'*75}"
                )
                print(output_str)
                with open("step_results_logging.log", "a") as f:
                    f.write(output_str + "\n")
                
                model.train() # Switch back to training mode

The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


| Step:    1000 | Epoch:   0 | LR: 0.00030000 |
| Train Acc: 0.0928 | Train Loss: 6.9658 |
| Test Acc:  0.1076 | Test Loss:  7.6481 |
| Test F1:   0.0725 | Test BLEU:  0.0000 |
---------------------------------------------------------------------------


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


| Step:    2000 | Epoch:   0 | LR: 0.00030000 |
| Train Acc: 0.1362 | Train Loss: 6.8205 |
| Test Acc:  0.1158 | Test Loss:  7.6387 |
| Test F1:   0.0726 | Test BLEU:  0.0000 |
---------------------------------------------------------------------------


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


| Step:    3000 | Epoch:   0 | LR: 0.00030000 |
| Train Acc: 0.1163 | Train Loss: 6.7675 |
| Test Acc:  0.1150 | Test Loss:  7.5443 |
| Test F1:   0.0732 | Test BLEU:  0.0000 |
---------------------------------------------------------------------------


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


| Step:    4000 | Epoch:   0 | LR: 0.00030000 |
| Train Acc: 0.1149 | Train Loss: 6.7400 |
| Test Acc:  0.1185 | Test Loss:  7.4861 |
| Test F1:   0.0602 | Test BLEU:  0.0000 |
---------------------------------------------------------------------------
| Step:    5000 | Epoch:   0 | LR: 0.00030000 |
| Train Acc: 0.1304 | Train Loss: 6.4409 |
| Test Acc:  0.1232 | Test Loss:  7.4052 |
| Test F1:   0.0782 | Test BLEU:  0.0006 |
---------------------------------------------------------------------------
| Step:    6000 | Epoch:   0 | LR: 0.00030000 |
| Train Acc: 0.1142 | Train Loss: 6.4474 |
| Test Acc:  0.1245 | Test Loss:  7.3393 |
| Test F1:   0.0829 | Test BLEU:  0.0007 |
---------------------------------------------------------------------------
| Step:    7000 | Epoch:   0 | LR: 0.00030000 |
| Train Acc: 0.1621 | Train Loss: 6.1763 |
| Test Acc:  0.1355 | Test Loss:  7.2472 |
| Test F1:   0.0906 | Test BLEU:  0.0007 |
----------------------------------------------------------------

The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


| Step:   19000 | Epoch:   1 | LR: 0.00030000 |
| Train Acc: 0.1738 | Train Loss: 5.9245 |
| Test Acc:  0.1499 | Test Loss:  7.0895 |
| Test F1:   0.1036 | Test BLEU:  0.0000 |
---------------------------------------------------------------------------
| Step:   20000 | Epoch:   1 | LR: 0.00030000 |
| Train Acc: 0.1621 | Train Loss: 6.0374 |
| Test Acc:  0.1475 | Test Loss:  7.0787 |
| Test F1:   0.1029 | Test BLEU:  0.0014 |
---------------------------------------------------------------------------
| Step:   21000 | Epoch:   1 | LR: 0.00030000 |
| Train Acc: 0.1545 | Train Loss: 5.8794 |
| Test Acc:  0.1502 | Test Loss:  7.0587 |
| Test F1:   0.1015 | Test BLEU:  0.0011 |
---------------------------------------------------------------------------
| Step:   22000 | Epoch:   1 | LR: 0.00030000 |
| Train Acc: 0.1851 | Train Loss: 5.8509 |
| Test Acc:  0.1508 | Test Loss:  7.0950 |
| Test F1:   0.1010 | Test BLEU:  0.0010 |
----------------------------------------------------------------

The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


| Step:   27000 | Epoch:   1 | LR: 0.00030000 |
| Train Acc: 0.1805 | Train Loss: 6.0552 |
| Test Acc:  0.1527 | Test Loss:  7.0344 |
| Test F1:   0.1036 | Test BLEU:  0.0000 |
---------------------------------------------------------------------------
| Step:   28000 | Epoch:   1 | LR: 0.00030000 |
| Train Acc: 0.1850 | Train Loss: 5.9360 |
| Test Acc:  0.1536 | Test Loss:  7.0551 |
| Test F1:   0.1041 | Test BLEU:  0.0012 |
---------------------------------------------------------------------------
| Step:   29000 | Epoch:   1 | LR: 0.00030000 |
| Train Acc: 0.1760 | Train Loss: 5.9262 |
| Test Acc:  0.1530 | Test Loss:  7.0852 |
| Test F1:   0.1049 | Test BLEU:  0.0008 |
---------------------------------------------------------------------------
| Step:   30000 | Epoch:   1 | LR: 0.00030000 |
| Train Acc: 0.1898 | Train Loss: 5.8565 |
| Test Acc:  0.1532 | Test Loss:  7.0356 |
| Test F1:   0.1018 | Test BLEU:  0.0009 |
----------------------------------------------------------------

In [None]:
torch.save(model.state_dict(), 'model_weights.pth')