# Config & Imports

In [None]:
!pip install -U datasets huggingface_hub fsspec

Collecting fsspec
  Using cached fsspec-2025.7.0-py3-none-any.whl.metadata (12 kB)


In [None]:
!pip install torchinfo



In [None]:
import torch
from torch.utils.data import DataLoader, Dataset, random_split
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
import torchinfo

from tqdm import tqdm
import math
from datasets import load_dataset
from pathlib import Path
import pandas as pd
import numpy as np
import altair as alt

from tokenizers.models import WordPiece
from tokenizers import Tokenizer
from tokenizers.trainers import WordPieceTrainer
from tokenizers.pre_tokenizers import Whitespace

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def get_config():
    return {
        "batch_size": 8,
        "num_epochs": 20,
        "lr": 2e-4,
        "seq_len": 250,
        "d_model": 768,
        "lang_src": "en",
        "lang_trgt": "ru",
        'model_folder': 'weights',
        'model_basename': 'tmodel_',
        "preload": None,
        "tokenizer_file": "tokenizer_{0}.json",
        "experiment_name": "runs/tmodel",
        "checkpoint_every": 50
    }

In [None]:
def get_weights_file_path(config, epoch):
    model_folder = config['model_folder']
    model_basename = config['model_basename']
    model_filename = f"{model_basename}_{epoch}.pt"
    return Path('.') / model_folder / model_filename

# Dataset & Tokenization & Dataloaders

In [None]:
def get_all_sentences(ds, lang):
    for item in ds:
        yield item['translation'][lang]     # item = {'translation': {'en': 'Hello', 'de': 'Hallo'}}, if we have lang = 'en' we will get 'Hello'

In [None]:
def get_or_build_tokenizer(config, ds, lang):
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if not Path.exists(tokenizer_path):
        tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordPieceTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2) # trainer of tokenizer with special tokens
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer) # trains tokenizer with trainer using initialized language
        tokenizer.save(str(tokenizer_path)) # saves tokenizer
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path)) # gets tokenizer if it exists
    return tokenizer

In [None]:
class BilingualDataset(Dataset):

    def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
        super().__init__()
        self.seq_len = seq_len

        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang

        self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        src_target_pair = self.ds[idx]
        src_text = src_target_pair['translation'][self.src_lang]
        tgt_text = src_target_pair['translation'][self.tgt_lang]

        # Transform the text into tokens
        enc_input_tokens = self.tokenizer_src.encode(src_text).ids
        dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids

        # Add sos, eos and padding to each sentence
        enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2  # We will add <s> and </s>
        # We will only add <s>, and </s> only on the label
        dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1

        # Make sure the number of padding tokens is not negative. If it is, the sentence is too long
        if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
            raise ValueError("Sentence is too long")

        # Add <s> and </s> token
        encoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(enc_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Add only <s> token
        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Add only </s> token
        label = torch.cat(
            [
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Double check the size of the tensors to make sure they are all seq_len long
        assert encoder_input.size(0) == self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len

        return {
            "encoder_input": encoder_input,  # (seq_len)
            "decoder_input": decoder_input,  # (seq_len)
            "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
            "label": label,  # (seq_len)
            "src_text": src_text,
            "tgt_text": tgt_text,
        }

def causal_mask(size):
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0

In [None]:
def get_ds(config):
    ds_raw = load_dataset('opus_books', f'{config["lang_src"]}-{config["lang_trgt"]}', split='train') # loading dataset

    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src']) # tokenizer for src_lang
    tokenizer_trgt = get_or_build_tokenizer(config, ds_raw, config['lang_trgt']) #ds tokenizer for trgt_lang

    # ========== splitting, creating dataset ===========

    train_ds_size = int(0.9 * len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])

    train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_trgt, config['lang_src'], config['lang_trgt'], config['seq_len'])
    val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_trgt, config['lang_src'], config['lang_trgt'], config['seq_len'])

    # ==================================================

    max_len_src = 0
    max_len_trgt = 0

    for item in ds_raw: # loop throught all the samples of the dataset to find max length of sequence and initialize seq_len correctly
        src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids # tokenizing the sentence and getting the idx of the tokens in src_lang
        trgt_ids = tokenizer_trgt.encode(item['translation'][config['lang_trgt']]).ids # tokenizing the sentence and getting the idx of the tokens in trgt_lang
        max_len_src = max(max_len_src, len(src_ids))
        max_len_trgt = max(max_len_trgt, len(trgt_ids))

    print(max_len_src)
    print(max_len_trgt)

    train_loader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True) # creating Dataloader
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=True) # creating Dataloader

    return train_loader, val_loader, tokenizer_src, tokenizer_trgt

# Building a model(Transformer)

In [None]:
class InputEmbed(nn.Module):
    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.vocab = vocab_size
        self.d_model = d_model
        self.embed = torch.nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        return self.embed(x) * math.sqrt(self.d_model) # (batch_size, seq_len, d_model)

In [None]:
class PosEncod(nn.Module):
    def __init__(self, d_model, seq_len, drop_rate):
        super().__init__()

        self.d_model = d_model
        self.seq_len = seq_len
        self.drop_rate = drop_rate
        self.dropout = nn.Dropout(p=self.drop_rate)

        pos_enc = torch.zeros(seq_len, d_model)
        pos = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pos_enc[:, 0::2] = torch.sin(pos * div_term)
        pos_enc[:, 1::2] = torch.cos(pos * div_term)
        pe = pos_enc.unsqueeze(0)

        self.register_buffer('pos_enc', pe) # doesn't requires any gradient

    def forward(self, x):
        x = x + (self.pos_enc[:, :x.size(1), :]).requires_grad_(False)
        return self.dropout(x)

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-12):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

In [None]:
class FFBlock(nn.Module):
    def __init__(self, d_model, d_ff, drop_rate):
        super().__init__()

        self.d_model = d_model
        self.d_ff = d_ff
        self.drop_rate = drop_rate

        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(p=self.drop_rate)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x

In [None]:
class MultiheadAttention(nn.Module):
    def __init__(self, d_model, h, drop_rate):
        super().__init__()

        self.d_model = d_model
        self.h = h
        self.d_k = d_model // self.h
        self.attn_weights = None

        self.dropout = nn.Dropout(p=drop_rate)

        # Linear projection for Q, K, V

        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)

        # # Linear projection for output after heads connection

        self.w_o = nn.Linear(d_model, d_model, bias=False)

    @staticmethod
    def attention(query, key, value, mask=None, dropout=None): # Scalad Dot-Product Attention
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        scores = scores.softmax(dim=-1) # prediction --> probabilities

        if dropout is not None:
            scores = dropout(scores)

        # print(scores.shape)
        # print(value.shape)

        return (scores @ value), scores

    def forward(self, q, k, v, mask):
        batch_size = q.size(0)

        query = self.w_q(q) # (Batch, seq_len, d_model)
        key = self.w_k(k) # (Batch, seq_len, d_model)
        value = self.w_v(v) # (Batch, seq_len, d_model)

        # ========== Splitting into heads ===========

        Q = query.view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
        K = key.view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
        V = value.view(batch_size, -1, self.h, self.d_k).transpose(1, 2)

        # ===========================================

        x, self.attention_scores = MultiheadAttention.attention(query, key, value, mask, self.dropout)

        output, attn_weights = self.attention(Q, K, V, mask, self.dropout)
        self.attn_weights = attn_weights

        out = output.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k) # (B, h, seq_len, d_k) --> (Batch, seq_len, d_model)
        out = self.w_o(out)
        # print('Attention has passed')
        return out

In [None]:
class ResidualConnection(nn.Module):
    def __init__(self, drop_rate, d_model):
        super().__init__()

        self.drop_rate = drop_rate
        self.norm = LayerNorm(d_model)
        self.dropout = nn.Dropout(p=self.drop_rate)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, self_attention_block: MultiheadAttention, feed_forward: FFBlock, drop_rate, d_model):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward = feed_forward
        self.residual_connections = nn.ModuleList([ResidualConnection(drop_rate, d_model) for _ in range(2)])

    def forward(self, x, src_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connections[1](x, self.feed_forward)
        return x

In [None]:
class Encoder(nn.Module):
    def __init__(self, layers: nn.ModuleList, d_model):
        super().__init__()
        self.layers = layers
        self.norm = LayerNorm(d_model)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, self_attention_block: MultiheadAttention, cross_attention_block: MultiheadAttention, feed_forward: FFBlock, drop_rate, d_model):
        super().__init__()

        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward = feed_forward
        self.residual_connections = nn.ModuleList([ResidualConnection(drop_rate, d_model) for _ in range(3)])

    def forward(self, x, encoder_output, src_mask, trgt_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, trgt_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connections[2](x, self.feed_forward)
        return x

In [None]:
class Decoder(nn.Module):
    def __init__(self, layers: nn.ModuleList, d_model):
        super().__init__()

        self.layers = layers
        self.norm = LayerNorm(d_model)

    def forward(self, x, encoder_output, src_mask, trgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, trgt_mask)
        return self.norm(x)

In [None]:
class Projection(nn.Module):
    def __init__(self, d_model, vocab_size):
        super().__init__()

        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        return self.proj(x)

# Putting everything together

In [None]:
class Transformer(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbed, trgt_embed: InputEmbed, src_pos: PosEncod, trgt_pos: PosEncod, proj: Projection, d_model):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.trgt_embed = trgt_embed
        self.src_pos = src_pos
        self.trgt_pos = trgt_pos
        self.proj = proj
        self.attention_weights_map = []
        self.d_model = d_model

    def encode(self, src, src_mask):
        src = self.src_embed(src)
        src = self.src_pos(src)
        out = self.encoder(src, src_mask)
        # print('encoding has passed')
        return out

    def decode(self, encoder_output, src_mask, trgt, trgt_mask):
        trgt = self.trgt_embed(trgt)
        trgt = self.trgt_pos(trgt)
        out = self.decoder(trgt, encoder_output, src_mask, trgt_mask)
        # print('Decoding has passed')
        return out

    def projection(self, x):
        return self.proj(x)

In [None]:
def build_transformer(src_size, trgt_size, src_seq_len, trgt_seq_len, d_model=768, N=6, h=8, drop_rate=0.3, d_ff=1024):
    src_embed = InputEmbed(d_model, src_size) # creating embeddings for src_lang
    trgt_embed = InputEmbed(d_model, trgt_size) # creating embedding for trgt_lang

    src_pos = PosEncod(d_model, src_seq_len, drop_rate) # adding positional encoding
    trgt_pos = PosEncod(d_model, trgt_seq_len, drop_rate) # adding positional encoding

    encoder_blocks = []
    for _ in range(N):
        encoder_self_attention_block = MultiheadAttention(d_model, h, drop_rate)
        feed_forward_block = FFBlock(d_model, d_ff, drop_rate)
        encoder_block = EncoderBlock(encoder_self_attention_block, feed_forward_block, drop_rate, d_model)
        encoder_blocks.append(encoder_block)

    decoder_blocks = []
    for _ in range(N):
        decoder_self_attention_block = MultiheadAttention(d_model, h, drop_rate)
        decoder_cross_attention_block = MultiheadAttention(d_model, h, drop_rate)
        feed_forward_block = FFBlock(d_model, d_ff, drop_rate)
        decoder_block = DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, drop_rate, d_model)
        decoder_blocks.append(decoder_block)

    encoder = Encoder(nn.ModuleList(encoder_blocks), d_model)
    decoder = Decoder(nn.ModuleList(decoder_blocks), d_model)

    proj_layer = Projection(d_model, trgt_size)
    transformer = Transformer(encoder, decoder, src_embed, trgt_embed, src_pos, trgt_pos, proj_layer, d_model)

    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return transformer

# Initializing, training and validation

In [None]:
def get_model(config, vocab_src_len, vocab_trgt_len):
    model = build_transformer(vocab_src_len, vocab_trgt_len, config['seq_len'], config['seq_len'], config['d_model'])
    return model

In [None]:
def greedy_decode(model, source, source_mask, tokenizer_src,  tokenizer_trgt, max_len, device):
    sos_idx = tokenizer_trgt.token_to_id('[SOS]')
    eos_idx = tokenizer_trgt.token_to_id('[EOS]')

    encoder_output = model.encode(source, source_mask)
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)

    while True:
        if decoder_input.size(1) == max_len:
            break

        decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)

        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

        prob = model.proj(out[:, -1])
        _, next_word = torch.max(prob, dim=1)

        decoder_input = torch.cat([
            decoder_input,
            torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)
        ], dim=1)

        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)

In [None]:
def run_validation(model, val_ds, tokenizer_src, tokenizer_trgt, max_len, device, print_msg, global_state, writer, num_examples=2):
    model.eval()
    count = 0

    src_texts = []
    expected = []
    predicted = []

    console_width = 80

    with torch.inference_mode():
        for batch in val_ds:
            count += 1
            encoder_input = batch['encoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)

            assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"

            model_output = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_trgt, max_len, device)

            source_text = batch['src_text'][0]
            target_text = batch['trgt_text'][0]
            model_output_text = tokenizer_trgt.decode(model_output.detach().cpu().numpy())

            src_texts.append(source_text)
            expected.append(target_text)
            predicted.append(model_output_text)

            print_msg('-'*console_width)
            print_msg(f'SOURCE: {source_text}')
            print_msg(f'TARGET: {target_text}')
            print_msg(f'PREDICTED: {model_output_text}')

            if count > num_examples:
                break

In [None]:
def train_model(device, config):
    Path(config['model_folder']).mkdir(parents=True, exist_ok=True)

    train_loader, val_loader, tokenizer_src, tokenizer_trgt = get_ds(config)
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_trgt.get_vocab_size()).to(device)

    writer = SummaryWriter(config['experiment_name'])

    optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'], eps=1e-9, betas=(0.9, 0.999), weight_decay=0.01)

    initial_epoch = 0
    global_step = 0
    if config['preload']:
        model_filename = get_weights_file_path(config, config["preload"])
        print(f'Preloading model {model_filename}')
        state = torch.load(model_filename)
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']

    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_trgt.token_to_id('[PAD]'), label_smoothing=0.1).to(device)

    for epoch in range(initial_epoch, config['num_epochs']):
        model.train()
        batch_iterator = tqdm(train_loader, desc=f'Processing epoch {epoch:02d}')
        for batch_idx, batch in enumerate(batch_iterator):
            model.train()
            encoder_input = batch['encoder_input'].to(device)
            decoder_input = batch['decoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)
            decoder_mask = batch['decoder_mask'].to(device)

            encoder_output = model.encode(encoder_input, encoder_mask)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
            proj_output = model.proj(decoder_output)

            label = batch['label'].to(device)
            loss = loss_fn(proj_output.view(-1, tokenizer_trgt.get_vocab_size()), label.view(-1))
            batch_iterator.set_postfix({'loss': f'{loss.item():6.3f}'})
            writer.add_scalar('train loss', loss.item(), global_step)
            writer.flush()

            loss.backward()

            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            # if batch_idx % config['checkpoint_every'] == 0:
            #     run_validation(model, val_loader, tokenizer_src, tokenizer_trgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer)

            global_step += 1

        val_loss = run_validation(model, val_loader, tokenizer_src, tokenizer_trgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer)
        print(val_loss)

        model_filename = get_weights_file_path(config, f'{epoch:02d}')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step
        }, model_filename)

# Running

In [None]:
config = get_config()
train_model(device, config)

233
232


Processing epoch 00:  11%|█         | 221/1969 [01:37<12:48,  2.27it/s, loss=7.221]


KeyboardInterrupt: 