In [43]:
# Imports
import pandas as pd
import sentencepiece
import kagglehub
import os
import re
import math
import time
import unicodedata
import sentencepiece as spm
import random
from collections import Counter
from itertools import islice
import numpy as np
import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F
from torch.utils.data import random_split

pd.set_option('display.max_colwidth', None)

# Creating corpus

In [11]:
corpus_path = "corpus/"
english_files = ["commoncrawl.de-en.en", "europarl-v7.de-en.en"]
german_files = ["commoncrawl.de-en.de", "europarl-v7.de-en.de"]

In [12]:
english_sentences = []
german_sentences = []

for i in range(len(english_files)):
    with open(corpus_path + english_files[i], 'r', encoding='utf-8') as f:
        english_sentences.extend(f.read().splitlines())
    print(len(english_sentences))
    
    with open(corpus_path + german_files[i], 'r', encoding='utf-8') as f:
        german_sentences.extend(f.read().splitlines())
    print(len(german_sentences))

if len(english_sentences) != len(german_sentences):
    raise ValueError("Mismatch in number of sentences between English and German files.")


2399123
2399123
4319332
4319332


In [13]:
# News commentary are misaligned and therefore a different approach is taken
english_folder = corpus_path + "news_commentary/English"
german_folder = corpus_path + "news_commentary/German"

english_files = sorted(os.listdir(english_folder))
german_files = sorted(os.listdir(german_folder))

if len(english_files) != len(german_files):
    raise ValueError("Mismatch in number of files between English and German folders for news commentary.")

mismatched_files = []

for eng_file, ger_file in zip(english_files, german_files):
    eng_path = os.path.join(english_folder, eng_file)
    ger_path = os.path.join(german_folder, ger_file)

    # Ensuring same name for each file tuple
    if os.path.splitext(eng_file)[0] != os.path.splitext(ger_file)[0]:
        raise ValueError(f"File mismatch: {eng_file} and {ger_file} do not correspond.")

    with open(eng_path, 'r', encoding='utf-8') as f:
        english_section = [segment.strip() for segment in f.read().split('<P>') if segment.strip()]

    with open(ger_path, 'r', encoding='utf-8') as f:
        german_section = [segment.strip() for segment in f.read().split('<P>') if segment.strip()]

    if len(english_section) != len(german_section):
        mismatched_files.append((eng_file, ger_file, len(english_section), len(german_section)))
    else:
        # If line counts match, add sentences to the lists
        english_sentences.extend(english_section)
        german_sentences.extend(german_section)

print(f"Files with mistmatched number of lines: {mismatched_files}")

if len(english_sentences) != len(german_sentences):
    raise ValueError("Mismatch in number of sentences between English and German files.")

Files with mistmatched number of lines: []


In [14]:
print(f"All together there are {len(english_sentences)} english sentences in the corpus.")
print(f"All together there are {len(german_sentences)} german sentences in the corpus.")

All together there are 4394358 english sentences in the corpus.
All together there are 4394358 german sentences in the corpus.


In [15]:
data = {'en': english_sentences, 'de': german_sentences}
df = pd.DataFrame(data)
display(df)

Unnamed: 0,en,de
0,iron cement is a ready for use paste which is laid as a fillet by putty knife or finger in the mould edges (corners) of the steel ingot mould.,"iron cement ist eine gebrauchs-fertige Paste, die mit einem Spachtel oder den Fingern als Hohlkehle in die Formecken (Winkel) der Stahlguss -Kokille aufgetragen wird."
1,"iron cement protects the ingot against the hot, abrasive steel casting process.","Nach der Aushärtung schützt iron cement die Kokille gegen den heissen, abrasiven Stahlguss ."
2,"a fire restant repair cement for fire places, ovens, open fireplaces etc.","feuerfester Reparaturkitt für Feuerungsanlagen, Öfen, offene Feuerstellen etc."
3,Construction and repair of highways and...,Der Bau und die Reparatur der Autostraßen...
4,An announcement must be commercial character.,die Mitteilungen sollen den geschäftlichen kommerziellen Charakter tragen.
...,...,...
4394353,"The stakes for Africa are enormous.\nSouth Africa has the continent’s largest economy and, until the global financial crisis, posted 10 years of steady economic growth.\nIn an economic slowdown, the country’s severe crime problem might only worsen; so might unemployment, which already tops 20% in the formal economy.","Für Afrika steht Enormes auf dem Spiel.\nSüdafrika ist die größte Ökonomie des Kontinents und bis zur globalen Finanzkrise erlebte man 10 Jahre beständigen Wirtschaftswachstums.\nIn Zeiten des Abschwungs kann sich das immense Kriminalitätsproblem des Landes nur verschärfen. Das gilt auch für die Arbeitslosigkeit, die im Bereich der offiziellen Wirtschaft bereits über 20 Prozent liegt."
4394354,"Zuma senses the urgency of the situation.\nHe is, after all, 67 years old and likely to serve only a single term in office. “We can’t waste time,” he says.","Zuma weiß um die Dringlichkeit der Situation.\nImmerhin ist er 67 Jahre alt und wird wahrscheinlich nur eine Amtszeit dienen. „Wir können uns keine Zeitverschwendung leisten“, sagt er."
4394355,"Yet, according to the political economist Moeletsi Mbeki, at his core, “Zuma is a conservative.” In this sense, Zuma represents yesterday’s South Africa.\nHe is part of the proud generation that defeated apartheid – and then peacefully engineered a transition to durable black-majority rule.\nTheir achievement remains one of the greatest in recent history.","Dem politischen Ökonomen Moeletsi Mbeki zufolge, ist Zuma im Grunde seines Herzens „ein Konservativer“. In diesem Sinne vertritt Zuma das Südafrika von gestern.\nEr ist Mitglied einer stolzen Generation, die die Apartheid bezwang – und der anschließend ein friedlicher Übergang zu einer schwarzen Mehrheitsregierung gelang.\nDas bleibt eine der größten Errungenschaften in der jüngeren Geschichte."
4394356,"At the same time, Zuma’s revolutionary generation still seems uneasy leading South Africa in a post-apartheid era that is now 15 years old.\nIn a region that reveres the elderly, Zuma’s attachment to his rural traditions must be matched by an equal openness to the appetites of the country’s youth.","Gleichzeitig scheint sich Zumas revolutionäre Generation mit der Führung Südafrikas in der nun seit 15 Jahren dauernden Ära nach der Apartheid noch immer unwohl zu fühlen.\nIn einer Region, wo die älteren Menschen sehr verehrt werden, muss Zumas Bindung an landestypische Traditionen eine gleichwertige Offenheit gegenüber den Bedürfnissen der Jugend des Landes gegenüberstehen."


# Normalizing the corpus
- removing newlines large texts.

In [16]:
def normalize_text(text):
        # Converting to lowercase
        #text = text.lower()
    
        # Normalizing unicode chars
        #text = unicodedata.normalize('NFKC', text)
        
        # Remove leading/trailing and extra spaces
        #text = re.sub(r'\s+', ' ', text.strip())
    
        # Add spaces around punctuation
        #text = re.sub(r'([.,!?;:()\-])', r' \1 ', text)
    
        # Replace newline characters with a placeholder
        text = text.replace('\n', ' ')
    
        # Remove non-textual artifacts (allow only alphanumerics and selected punctuations)
        #text = re.sub(r'[^a-zA-Z0-9.,!?;:()\- ]', '', text)
        return text

In [17]:
df_normalized = df.copy(deep=True)
df_normalized['en'] = df_normalized['en'].apply(normalize_text)
df_normalized['de'] = df_normalized['de'].apply(normalize_text)
df_normalized.to_csv("df_normalized.csv", index=False)

In [18]:
display(df_normalized.iloc[4394353])

en                                                                          The stakes for Africa are enormous. South Africa has the continent’s largest economy and, until the global financial crisis, posted 10 years of steady economic growth. In an economic slowdown, the country’s severe crime problem might only worsen; so might unemployment, which already tops 20% in the formal economy.
de    Für Afrika steht Enormes auf dem Spiel. Südafrika ist die größte Ökonomie des Kontinents und bis zur globalen Finanzkrise erlebte man 10 Jahre beständigen Wirtschaftswachstums. In Zeiten des Abschwungs kann sich das immense Kriminalitätsproblem des Landes nur verschärfen. Das gilt auch für die Arbeitslosigkeit, die im Bereich der offiziellen Wirtschaft bereits über 20 Prozent liegt.
Name: 4394353, dtype: object

# Saving corpus for BPE model training

In [19]:
#df_normalized['en_length'] = df_normalized['en'].apply(lambda x: len(x.split()))
#df_normalized['de_length'] = df_normalized['de'].apply(lambda x: len(x.split()))
#df_normalized['total_length'] = df_normalized['en_length'] + df_normalized['de_length']
#df_normalized['en_tokenized'] = df_normalized['en'].apply(lambda x: list(x))
#df_normalized['de_tokenized'] = df_normalized['de'].apply(lambda x: list(x))

# display(df_normalized)

# To create corpus text file
# corpus_df = pd.concat([df_normalized['en'], df_normalized['de']])

# corpus_df.to_csv("corpus.txt", index=False, header=False)

In [20]:
# Train BPE tokenizer

In [68]:
df_normalized = pd.read_csv("df_normalized.csv")

In [None]:
sp = spm.SentencePieceProcessor()
sp.load('bpe_model.model')

vocab_size = sp.get_piece_size()
sb_vocab = [sp.id_to_piece(i) for i in range(vocab_size)]
print(sb_vocab[:10])
sb_vocab_dict = {sb_vocab[i]: i for i in range(vocab_size)}
print(sb_vocab_dict)

In [46]:
# Tokenize using SentencePiece
tokens = sp.encode_as_pieces("—Ich fahre manchmal nach meine Familie, sie sind virklich überrascht.")
print(tokens)
tokens = sp.encode("Ich fahre manchmal nach meine Familie, sie sind virklich überrascht.")
print(tokens)
tokens = sp.decode_pieces(['▁Ich', '▁f', 'ahre', '▁manchmal', '▁nach', '▁meine', '▁Familie', ',', '▁sie', '▁sind', '▁v', 'irk', 'lich', '▁überrascht', '.'])
print(tokens)
tokens = sp.decode([1361, 33, 1724, 10072, 459, 1938, 5720, 36953, 416, 321, 55, 11279, 116, 17860, 36954])
print(tokens)

['▁', '—', 'Ich', '▁f', 'ahre', '▁manchmal', '▁nach', '▁meine', '▁Familie', ',', '▁sie', '▁sind', '▁v', 'irk', 'lich', '▁überrascht', '.']
[1376, 33, 1742, 10181, 459, 1960, 5773, 36926, 416, 321, 55, 11399, 116, 18093, 36927]
Ich fahre manchmal nach meine Familie, sie sind virklich überrascht.
form f pers Fraktionen nach day Conventionö sie sind v liefernlichzweifL


In [42]:
print(sb_vocab_dict['—'])

KeyError: '—'

In [47]:
# Function to encode a sentence
def encode_sentence(sentence):
    return sp.encode_as_pieces(sentence)

# Apply the encoder to each column
df_encoded = df_normalized.applymap(encode_sentence).copy(deep=True)

# Save or inspect the new DataFrame
display(df_encoded)
print(type(df_encoded['en'][0]))
#df_encoded.to_csv("df_encoded.csv", index=False)

Unnamed: 0,en,de
0,"[▁iron, ▁c, ement, ▁is, ▁a, ▁ready, ▁for, ▁use, ▁paste, ▁which, ▁is, ▁laid, ▁as, ▁a, ▁fil, let, ▁by, ▁put, ty, ▁kn, ife, ▁or, ▁finger, ▁in, ▁the, ▁mould, ▁edges, ▁(, c, orn, ers, ), ▁of, ▁the, ▁steel, ▁ing, ot, ▁mould, .]","[▁iron, ▁c, ement, ▁ist, ▁eine, ▁gebrauch, s, -, fert, ige, ▁P, aste, ,, ▁die, ▁mit, ▁einem, ▁Sp, ach, tel, ▁oder, ▁den, ▁F, ing, ern, ▁als, ▁H, ohl, ke, hle, ▁in, ▁die, ▁For, me, cken, ▁(, W, inkel, ), ▁der, ▁Stahl, g, uss, ▁-, K, ok, ille, ▁aufge, tragen, ▁wird, .]"
1,"[▁iron, ▁c, ement, ▁protects, ▁the, ▁ing, ot, ▁against, ▁the, ▁hot, ,, ▁ab, ras, ive, ▁steel, ▁casting, ▁process, .]","[▁Nach, ▁der, ▁Aus, här, tung, ▁schützt, ▁iron, ▁c, ement, ▁die, ▁Kok, ille, ▁gegen, ▁den, ▁he, issen, ,, ▁ab, ras, iven, ▁Stahl, g, uss, ▁.]"
2,"[▁a, ▁fire, ▁rest, ant, ▁repair, ▁c, ement, ▁for, ▁fire, ▁places, ,, ▁o, vens, ,, ▁open, ▁fire, places, ▁etc, .]","[▁fe, uer, f, ester, ▁Reparatur, k, itt, ▁für, ▁Feuer, ungs, anlagen, ,, ▁Ö, fen, ,, ▁offene, ▁Fe, u, erstellen, ▁etc, .]"
3,"[▁Construction, ▁and, ▁repair, ▁of, ▁high, ways, ▁and, ...]","[▁Der, ▁Bau, ▁und, ▁die, ▁Reparatur, ▁der, ▁Aut, ost, ra, ßen, ...]"
4,"[▁An, ▁announcement, ▁must, ▁be, ▁commercial, ▁character, .]","[▁die, ▁Mitteilungen, ▁sollen, ▁den, ▁geschäftlichen, ▁kommerziellen, ▁Charakter, ▁tragen, .]"
...,...,...
95,"[▁Est, ablish, ed, ▁in, ▁1990,, ▁the, ▁office, ▁of, ▁Ha, ide, g, ger, ▁&, ▁Partner, ▁in, ▁Budapest, ▁has, ▁been, ▁providing, ▁a, ▁full, ▁range, ▁of, ▁legal, ▁services, ▁offering, ▁individual, ▁tailored, ▁advice, .]","[▁Die, ▁im, ▁Jahre, ▁1990, ▁gegründete, ▁K, anzlei, ▁Ha, ide, g, ger, ▁&, ▁Partner, ▁bietet, ▁eine, ▁alle, ▁Rechts, gebiete, ▁umfassende, ,, ▁auf, ▁den, ▁individuellen, ▁Bedarf, ▁zugeschnitten, e, ▁Rechts, beratung, .]"
96,"[▁Apart, ▁from, ▁being, ▁Hungary, ’, s, ▁principal, ▁political, ,, ▁commercial, ,, ▁industrial, ▁and, ▁transportation, ▁centre, ,, ▁the, ▁city, ▁of, ▁Budapest, ▁boasts, ▁sites, ,, ▁monuments, ▁and, ▁sp, as, ▁of, ▁worldwide, ▁ren, own, .]","[▁Budapest, ▁ist, ▁nicht, ▁nur, ▁das, ▁politische, ,, ▁wirtschaftliche, ,, ▁industrielle, ▁und, ▁verkehr, stechn, ische, ▁Herz, ▁Ungarn, s, ,, ▁sondern, ▁r, ühm, t, ▁sich, ▁auch, ▁weltweit, ▁bekannter, ▁Sehens, würdigkeiten, ,, ▁Denkm, äler, ▁und, ▁Bäder, .]"
97,"[▁This, ▁statistic, ▁is, ▁based, ▁on, ▁the, ▁68, 19, ▁using, ▁ecommerce, ▁sites, ▁(, esh, ops, ,, ▁distributors, ,, ▁comparison, ▁sites, ,, ▁ecommerce, ▁ASPs, ,, ▁purchase, ▁systems, ,, ▁etc, ), ▁downloading, ▁this, ▁ICEcat, ▁data, -, sheet, ▁since, ▁19, ▁Oct, ▁2007.]","[▁Diese, ▁Statistik, ▁basiert, ▁auf, ▁den, ▁teilnehmenden, ▁E, -, C, ommer, ces, eiten, ▁68, 19, ▁(, E, -, Shops, ,, ▁Distrib, ut, oren, ,, ▁Vergleich, ss, eiten, ,, ▁E, -, Commerce, ▁ASPs, ,, ▁Einkauf, ssysteme, ▁etc, ),, ▁welche, ▁dieses, ▁ICEcat, ▁Daten, blatt, ▁täglich, ▁seit, ▁dem, ▁21., ▁März, ▁2009, ▁herunterladen, ., ▁19, ▁Okt, ▁2007.]"
98,"[▁Only, ▁spons, oring, ▁brands, ▁are, ▁included, ▁in, ▁the, ▁free, ▁Open, ▁ICEcat, ▁content, ▁distribution, ▁as, ▁used, ▁by, ▁63, 28, ▁free, ▁Open, ▁ICEcat, ▁users, ▁.]","[▁Nur, ▁Spons, orm, ark, en, ▁sind, ▁in, ▁der, ▁kostenfreien, ▁Open, ▁ICEcat, ▁Content, ▁Verteilung, ▁vertreten, ▁63, 28, ▁und, ▁werden, ▁von, ▁Open, ▁ICEcat, ▁Nutzern, ▁genutzt, ..]"


<class 'list'>


In [48]:
class TranslationDataset(Dataset):
    def __init__(self, dataframe, vocab, start_token="<s>", end_token="</s>", pad_token="<mask>"):
        """
        Initialize the translation dataset.
        
        Args:
            dataframe: pandas DataFrame with 'en' and 'de' columns containing tokenized sentences
            vocab: shared vocabulary from BPE model
            start_token: token to mark the start of sentences
            end_token: token to mark the end of sentences
            pad_token: token used for padding
        """
        self.en_sentences = dataframe['en'].tolist()
        self.de_sentences = dataframe['de'].tolist()
        self.vocab = vocab
        self.start_token = start_token
        self.end_token = end_token
        self.pad_token = pad_token
        
    def __len__(self):
        return len(self.en_sentences)
    
    def __getitem__(self, idx):
        # Add special tokens to both source and target sentences
        en_tokens = [self.start_token] + self.en_sentences[idx] + [self.end_token]
        de_tokens = [self.start_token] + self.de_sentences[idx] + [self.end_token]
        
        return {
            'en': en_tokens,
            'de': de_tokens
        }

In [49]:
def collate_batch(batch, vocab):
    """
    Custom collate function to create batches with padding.
    
    Args:
        batch: list of dictionaries containing source and target tokens
        vocab: shared vocabulary from BPE model
    
    Returns:
        Dictionary containing padded and converted tensor sequences
    """
    en_sequences = []
    de_sequences = []
    
    for item in batch:
        en_indices = torch.tensor([vocab.get(token, vocab['<unk>']) for token in item['en']])
        de_indices = torch.tensor([vocab.get(token, vocab['<unk>']) for token in item['de']])
        
        en_sequences.append(en_indices)
        de_sequences.append(de_indices)
    
    # Pad sequences to the longest sequence in the batch
    en_padded = pad_sequence(en_sequences, batch_first=True, padding_value=vocab['<mask>'])
    de_padded = pad_sequence(de_sequences, batch_first=True, padding_value=vocab['<mask>'])
    
    return {
        'en': en_padded,
        'de': de_padded,
        'en_lengths': torch.tensor([len(seq) for seq in en_sequences]),
        'de_lengths': torch.tensor([len(seq) for seq in de_sequences])
    }

In [50]:
def create_dataloader(dataframe, vocab, batch_size=32, shuffle=True):
    """
    Create a DataLoader for the translation dataset.
    
    Args:
        dataframe: pandas DataFrame with tokenized sentences
        vocab: shared vocabulary from BPE model
        batch_size: size of batches
        shuffle: whether to shuffle the data
    
    Returns:
        DataLoader object
    """
    dataset = TranslationDataset(dataframe, vocab)
    
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=lambda b: collate_batch(b, vocab)
    )

In [51]:
dataloader = create_dataloader(
    df_encoded,
    sb_vocab_dict,
    batch_size=32
)

In [52]:
for batch in dataloader:
    print(batch)

{'en': tensor([[    1, 36049,    58,  ...,     5,     5,     5],
        [    1,  1898,   283,  ...,     5,     5,     5],
        [    1, 11579, 20834,  ...,     5,     5,     5],
        ...,
        [    1,   290, 33479,  ...,     5,     5,     5],
        [    1, 27870,  4955,  ...,     5,     5,     5],
        [    1,   290, 33479,  ...,     5,     5,     5]]), 'de': tensor([[    1,  1083,  4202,  ...,     5,     5,     5],
        [    1,  7404, 36940,  ...,     5,     5,     5],
        [    1,  7404, 36940,  ...,     5,     5,     5],
        ...,
        [    1,  2374, 13618,  ...,     5,     5,     5],
        [    1, 27870,  5202,  ...,  2281, 36927,     2],
        [    1, 21913,  2983,  ...,     5,     5,     5]]), 'en_lengths': tensor([10, 26, 50, 57, 33, 23, 16, 30, 22, 25, 29, 18, 49, 23, 29, 17, 31, 25,
        16, 11, 27, 25, 41, 39, 23, 17, 32, 25, 32, 54, 49, 23]), 'de_lengths': tensor([13, 22, 40, 47, 27, 29, 14, 44, 21, 24, 55, 20, 55, 24, 37, 22, 32, 23,
       

In [53]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        # Calculate positional encodings
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        
        # Embedding layers
        self.encoder_embedding = nn.Embedding(vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        
        # Transformer layers
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        
        # Output layer
        self.output_layer = nn.Linear(d_model, vocab_size)
        
        # Initialize parameters
        self._init_parameters()
        
        self.d_model = d_model
    
    def _init_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def create_mask(self, src, tgt):
        # Create source padding mask
        src_mask = src == 0  # Assuming 0 is the padding index
        
        # Create target padding mask
        tgt_mask = tgt == 0
        
        # Create target subsequent mask (for autoregressive property)
        seq_len = tgt.size(1)
        subsequent_mask = torch.triu(torch.ones((seq_len, seq_len)), diagonal=1).bool()
        subsequent_mask = subsequent_mask.to(tgt.device)
        
        return src_mask, tgt_mask, subsequent_mask
    
    def forward(self, src, tgt):
        # Create masks
        src_key_padding_mask, tgt_key_padding_mask, tgt_mask = self.create_mask(src, tgt)
        
        # Embed and add positional encoding
        src_embedded = self.positional_encoding(self.encoder_embedding(src) * math.sqrt(self.d_model))
        tgt_embedded = self.positional_encoding(self.decoder_embedding(tgt) * math.sqrt(self.d_model))
        
        # Transform
        output = self.transformer(
            src_embedded,
            tgt_embedded,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )
        
        # Project to vocabulary size
        return self.output_layer(output)

In [65]:
def evaluate(model, val_dataloader, criterion, vocab_size, device):
    """
    Evaluate the model on validation data.
    
    Args:
        model: Transformer model
        val_dataloader: DataLoader for validation data
        criterion: Loss function
        vocab_size: Size of vocabulary
        device: Device to evaluate on
    
    Returns:
        Average validation loss
    """
    model.eval()
    total_loss = 0
    total_tokens = 0
    
    with torch.no_grad():
        for batch in val_dataloader:
            src = batch['en'].to(device)
            tgt = batch['de'].to(device)
            
            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]
            
            output = model(src, tgt_input)
            
            # Calculate loss
            loss = criterion(output.contiguous().view(-1, vocab_size), 
                           tgt_output.contiguous().view(-1))
            
            # Count non-padding tokens
            non_pad_tokens = (tgt_output != 0).sum().item()
            
            total_loss += loss.item() * non_pad_tokens
            total_tokens += non_pad_tokens
    
    return total_loss / total_tokens

def train_transformer(model, train_dataloader, val_dataloader, vocab_size, num_epochs, 
                     save_path, save_interval, patience=5, device='cuda'):
    """
    Train the transformer model with validation.
    
    Args:
        model: Transformer model
        train_dataloader: DataLoader for training data
        val_dataloader: DataLoader for validation data
        vocab_size: Size of vocabulary
        num_epochs: Number of training epochs
        save_path: Path to save model checkpoints
        save_interval: Save model every N iterations
        patience: Number of epochs to wait for improvement before early stopping
        device: Device to train on
    """
    model = model.to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding index
    optimizer = torch.optim.Adam(model.parameters(), lr=0.044, betas=(0.9, 0.98), eps=1e-9)
    
    # Learning rate scheduler
    def lr_lambda(step):
        warmup_steps = 4000
        step = max(1, step)
        return min(step ** (-0.5), step * warmup_steps ** (-1.5))
    
    scheduler = LambdaLR(optimizer, lr_lambda)
    
    # Training tracking
    global_step = 0
    start_time = time.time()
    best_val_loss = float('inf')
    epochs_without_improvement = 0
    best_model_state = None
    
    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0
        num_train_batches = 0

    
        for batch_idx, batch in enumerate(train_dataloader):
            src = batch['en'].to(device)
            tgt = batch['de'].to(device)
            
            # Shift target for teacher forcing
            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]
            
            # Forward pass
            optimizer.zero_grad()
            output = model(src, tgt_input)
            
            # Calculate loss
            loss = criterion(output.contiguous().view(-1, vocab_size), 
                           tgt_output.contiguous().view(-1))
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            
            # Update metrics
            total_train_loss += loss.item()
            num_train_batches += 1
            global_step += 1
            
            # Save checkpoint
            # if global_step % save_interval == 0:
            #     checkpoint = {
            #         'epoch': epoch,
            #         'global_step': global_step,
            #         'model_state_dict': model.state_dict(),
            #         'optimizer_state_dict': optimizer.state_dict(),
            #         'scheduler_state_dict': scheduler.state_dict(),
            #         'loss': loss.item()
            #     }
            #     torch.save(checkpoint, f'{save_path}/checkpoint_{global_step}.pt')
            
            # Log progress
            if batch_idx % 100 == 0:
                elapsed = time.time() - start_time
                print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}, '
                      f'Time: {elapsed:.2f}s, Learning Rate: {scheduler.get_last_lr()[0]:.7f}')
        
        # Validation phase
        val_loss = evaluate(model, val_dataloader, criterion, vocab_size, device)
        avg_train_loss = total_train_loss / num_train_batches
        
        # Log epoch results
        print(f'Epoch {epoch} completed:')
        print(f'  Average Train Loss: {avg_train_loss:.4f}')
        print(f'  Validation Loss: {val_loss:.4f}')
        
        # Check for improvement
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_without_improvement = 0
            best_model_state = model.state_dict()
            
            # Save best model
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_loss': val_loss,
                'train_loss': avg_train_loss
            }, f'{save_path}/best_model.pt')
        else:
            epochs_without_improvement += 1
        
        # Early stopping
        if epochs_without_improvement >= patience:
            print(f'Early stopping triggered after {epoch + 1} epochs')
            model.load_state_dict(best_model_state)  # Restore best model
            break
    
    return model, best_val_loss

def create_train_val_dataloaders(dataset, batch_size, vocab, val_split=0.1, shuffle=True):
    val_length = int(len(dataset) * val_split)
    train_length = len(dataset) - val_length
    
    # Split dataset
    train_dataset, val_dataset = random_split(
        dataset, 
        [train_length, val_length],
        generator=torch.Generator().manual_seed(42)  # For reproducibility
    )
    
    # Create dataloaders
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=lambda b: collate_batch(b, vocab)
    )
    
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,  # No need to shuffle validation data
        collate_fn=lambda b: collate_batch(b, vocab)
    )

    print(len(train_dataloader))
    print(len(val_dataloader))
    
    return train_dataloader, val_dataloader

# Example usage
def create_and_train_transformer(dataset, vocab_size, vocab, save_path, batch_size=32, device='cuda'):
    """
    Create and train the transformer model with validation.
    
    Args:
        dataset: The full dataset
        vocab_size: Size of vocabulary
        save_path: Path to save model checkpoints
        batch_size: Batch size for training
        device: Device to train on
    """
    # Create train and validation dataloaders
    train_dataloader, val_dataloader = create_train_val_dataloaders(
        dataset,
        batch_size=batch_size,
        vocab=vocab,
        val_split=0.1
    )
    
    # Initialize model
    model = Transformer(
        vocab_size=vocab_size,
        d_model=512,
        nhead=8,
        num_encoder_layers=6,
        num_decoder_layers=6,
        dim_feedforward=2048,
        dropout=0.1
    )
    
    # Train model
    model, best_val_loss = train_transformer(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        vocab_size=vocab_size,
        num_epochs=100,  # Adjust as needed
        save_path=save_path,
        save_interval=1000,  # Save every 1000 steps
        patience=5,  # Early stopping patience
        device=device
    )
    
    return model, best_val_loss

In [66]:
save_path = 'checkpoints'  # Make sure this directory exists
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [67]:
dataset = TranslationDataset(df_encoded, sb_vocab)

model, best_val_loss = create_and_train_transformer(
    dataset=dataset,  # Your full dataset
    vocab_size=vocab_size,
    vocab=sb_vocab_dict,
    save_path=save_path,
    batch_size=1,
    device=device
)

90
10
Epoch: 0, Batch: 0, Loss: 10.4878, Time: 1.37s, Learning Rate: 0.0000002
Epoch 0 completed:
  Average Train Loss: 10.3666
  Validation Loss: 10.1594
Epoch: 1, Batch: 0, Loss: 10.0400, Time: 97.84s, Learning Rate: 0.0000158
Epoch 1 completed:
  Average Train Loss: 9.6079
  Validation Loss: 9.4577
Epoch: 2, Batch: 0, Loss: 8.9916, Time: 178.58s, Learning Rate: 0.0000315
Epoch 2 completed:
  Average Train Loss: 8.2168
  Validation Loss: 8.6608
Epoch: 3, Batch: 0, Loss: 7.2922, Time: 257.63s, Learning Rate: 0.0000471
Epoch 3 completed:
  Average Train Loss: 7.0242
  Validation Loss: 8.6421
Epoch: 4, Batch: 0, Loss: 6.8194, Time: 336.36s, Learning Rate: 0.0000628
Epoch 4 completed:
  Average Train Loss: 6.6230
  Validation Loss: 8.9947
Epoch: 5, Batch: 0, Loss: 6.5465, Time: 411.36s, Learning Rate: 0.0000784
Epoch 5 completed:
  Average Train Loss: 6.5254
  Validation Loss: 9.3673
Epoch: 6, Batch: 0, Loss: 6.0017, Time: 486.78s, Learning Rate: 0.0000941
Epoch 6 completed:
  Average Tr

In [None]:
sp = spm.SentencePieceProcessor()
sp.load('bpe_model.model')

# Create a batch of encoded sequences
encoded_en, encoded_de = create_sample_batch(df_normalized, sp)