In [None]:
import torch
import torch.nn as nn
import math
import numpy as np
import numpy

### Use this versions to run the following code

torch==2.0.1

torchvision==0.15.2

torchaudio==2.0.2

torchtext==0.15.2

datasets==2.15.0

tokenizers==0.13.3

torchmetrics==1.0.3

tensorboard==2.13.0

altair==5.1.1

wandb==0.15.9

numpy==1.24.2

In [None]:
# class TokenEmbedding(nn.Module):

#     def __init__(self):
#         super(TokenEmbedding, self).__init__()
#         self.d_model =512
#         self.vocab_size =1000
#         self.embedding = nn.Embedding(num_embeddings = self.vocab_size,
#                                      embedding_dim=self.d_model)

#     def forward(self,x):
#         token_embedding = self.embedding(x)

#         return token_embedding

In [None]:
# a = TokenEmbedding()

In [None]:
# input_tensor = torch.LongTensor([1, 2, 3, 4])  # Example input tensor
# output = a(input_tensor)
# print(output)

In [None]:
# class PositionalEncoding(nn.Module):
#     def __init__(self):
#         super(PositionalEncoding,self).__init__()
#         self.d_model = 512
#         self.seq_length = 512
#         self.dropout_p = 0.2

#         self.dropout = nn.Dropout(self.dropout_p)
#         position_encoding = torch.zeros(self.seq_length,self.d_model)
#         position = torch.arange(0,self.seq_length,dtype= torch.float).unsqueeze(1)
#         even_odd_position = torch.arange(0,self.d_model,2).float()
#         div_term = torch.pow(10000,even_odd_position/self.d_model)
#         position_encoding[:,0::2] = torch.sin(position*div_term)
#         position_encoding[:,1::2] = torch.cos(position*div_term)
#         position_encoding = position_encoding.unsqueeze(0)
#         self.register_buffer('position_encoding', position_encoding)

#     def forward(self,x):
#         x =x +(self.positional_encoding[:,:x.shape[1],:]).requires_grad_(False)
#         x =self.dropout(x)
#         return x

In [None]:
# b = PositionalEncoding()

In [None]:
# class InputEmbedding(nn.Module):
#     def __init__(self,positional_encoding,token_embeddings):
#         super(InputEmbedding,self).__init__()
#         self.d_model = 512
#         self.vocab_size = 1000
#         self.seq_len = 512
#         self.dropout_p = 0.2
#         self.token_embedding = token_embeddings
#         self.positional_encoding = positional_encoding

#     def forward(self,x):
#         x = self.token_embedding(x)
#         x = self.positional_encoding(x)

#         return x

In [None]:
# token_embeddings = TokenEmbedding()
# positional_encoding = PositionalEncoding()
# InputEmbedding(positional_encoding,token_embeddings)

### Input

In [None]:
class InputEmbeddings(nn.Module):

    def __init__(self, d_model: int, vocab_size: int) -> None:
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size  #seq_len hyperparameter = max(sentence_len)
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        # (batch, seq_len) --> (batch, seq_len, d_model)
        # Multiply by sqrt(d_model) to scale the embeddings according to the paper
        return self.embedding(x) * math.sqrt(self.d_model)

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        # Create a matrix of shape (seq_len, d_model)
        pe = torch.zeros(seq_len, d_model)
        # Create a vector of shape (seq_len)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1)
        # Create a vector of shape (d_model)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # (d_model / 2)
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term) # sin(position * (10000 ** (2i / d_model))
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term) # cos(position * (10000 ** (2i / d_model))
        # Add a batch dimension to the positional encoding
        pe = pe.unsqueeze(0) # (1, seq_len, d_model)
        # Register the positional encoding as a buffer
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model)
        return self.dropout(x)


In [None]:
# class LayerNormalization(nn.Module):
#     def __init__(self,features,eps=10**-6):
#         super().__init__()
#         self.d_model = 512
#         self.eps = 10**-6
#         self.alph = nn.Parameter(torch.ones(self.d_model))
#         self.beta = nn.Parameter(torch.ones(self.d_model))

#     def forward(self,x):
#         mean = x.mean(dim = -1,keepdim = True)
#         std = x.std(dim = -1 , keepdim = True)
#         norm = (x-mean)/(std + self.eps)
#         layer_norm = self.alpha * norm + self.beta
#         return layer_norm

### Layer Normalization

In [None]:
class LayerNormalization(nn.Module):

    def __init__(self, features: int, eps:float=10**-6) -> None:
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(features)) # alpha is a learnable parameter
        self.bias = nn.Parameter(torch.zeros(features)) # bias is a learnable parameter

    def forward(self, x):
        # x: (batch, seq_len, hidden_size)
         # Keep the dimension for broadcasting
        mean = x.mean(dim = -1, keepdim = True) # (batch, seq_len, 1)
        # Keep the dimension for broadcasting
        std = x.std(dim = -1, keepdim = True) # (batch, seq_len, 1)
        # eps is to prevent dividing by zero or when std is very small
        return self.alpha * (x - mean) / (std + self.eps) + self.bias

### Feed Forward

In [None]:
class FeedForwardBlock(nn.Module):
    def __init__(self,d_model,d_ff):
        super().__init__(self)
        self.d_model = d_model
        self.d_ff = d_ff

        self.first_layer = nn.Linear(d_model,d_ff)
        self.second_layer = nn.Linear(d_ff,d_model)
        self.dropout = nn.Dropout(0.2)


    def forward(self,x):
        x = self.first_layer(x)
        x = nn.ReLU(x)
        x = self.dropout(x)
        x = self.second_layer(x)

        return x


In [None]:
class FeedForwardBlock(nn.Module):

    def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff) # w1 and b1
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model) # w2 and b2

    def forward(self, x):
        # (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))


### Multihead Attention

In [None]:
class MultiHeadAttentionBlock(nn.Module):
    def __init__(self,d_model,head):
        super().__init__(self)
        self.d_model = d_model
        self.head = head
        self.hread_dim = d_model // head
        self.wq = nn.Linear(d_model,d_model , bias = True)
        self.wk = nn.Linear(d_model,d_model , bias = True)
        self.wv = nn.Linear(d_model,d_model , bias = True)
        self.wo = nn.Linear(d_model,d_model , bias = True)
        self.dropout = nn.Dropout(0.2)


    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        d_k =query.shape[-1]

        attention_scores = (query @ key.transpose(-2,-1))/math.squrt(d_k)

        if mask is not None:
            attention_scores.masked_fill_(mask == 0,-1e9)

        attention_scores = attention_scores.softmax(dim=-1)
        if dropout is not None:
            attention_scores = dropout(attention_scores)

        return (attention_scores @ value), attention_scores


    def forward(self,q,k,v,mask):
        query = self.wq(q)
        key = self.wk(k)
        value = self.wv(v)


        query = query.view(query.shape[0],query.shape[1],self.h,self.d_k).transpose(1,2)
        value = value.view(value.shape[0],value.shape[1],self.h,self.d_k).transpose(1,2)
        key   = key.view(key.shape[0],key.shape[1],self.h,self.d_k).transpose(1,2)

        x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)

        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)

        return self.wo(x)


In [None]:
class MultiHeadAttentionBlock(nn.Module):

    def __init__(self, d_model: int, h: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model # Embedding vector size
        self.h = h # Number of heads
        # Make sure d_model is divisible by h
        assert d_model % h == 0, "d_model is not divisible by h"

        self.d_k = d_model // h # Dimension of vector seen by each head
        self.w_q = nn.Linear(d_model, d_model, bias=False) # Wq
        self.w_k = nn.Linear(d_model, d_model, bias=False) # Wk
        self.w_v = nn.Linear(d_model, d_model, bias=False) # Wv
        self.w_o = nn.Linear(d_model, d_model, bias=False) # Wo
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        d_k = query.shape[-1]
        # Just apply the formula from the paper
        # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            # Write a very low value (indicating -inf) to the positions where mask == 0
            attention_scores.masked_fill_(mask == 0, -1e9)
        attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len) # Apply softmax
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        # (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k)
        # return attention scores which can be used for visualization
        return (attention_scores @ value), attention_scores

    def forward(self, q, k, v, mask):
        query = self.w_q(q) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        key = self.w_k(k) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        value = self.w_v(v) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)

        # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)

        # Calculate attention
        x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)

        # Combine all the heads together
        # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)

        # Multiply by Wo
        # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        return self.w_o(x)

### Res block

In [None]:
class ResidualConnection(nn.Module):

        def __init__(self, features, dropout):
            super().__init__()
            self.dropout = nn.Dropout(dropout)
            self.norm = LayerNormalization(features)

        def forward(self, x, sublayer):
            return x + self.dropout(sublayer(self.norm(x)))

### Encoder block

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self,features, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock,dropout):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])

    def forward(self, x, src_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x


In [None]:
class Encoder(nn.Module):
    def __init__(self,features,layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)


### Decoder

In [None]:
class DecoderBlock(nn.Module):

    def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x


In [None]:
class Decoder(nn.Module):

    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)


### Linear Layer

In [None]:
class ProjectionLayer(nn.Module):

    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        return self.proj(x)

### Combine all

In [None]:
class Transformer(nn.Module):

    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer

    def encode(self, src, src_mask):
        # (batch, seq_len, d_model)
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)

    def decode(self, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
        # (batch, seq_len, d_model)
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)

    def project(self, x):
        # (batch, seq_len, vocab_size)
        return self.projection_layer(x)

### Original Implemenetation with all the hyperparameters from the paper


# def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int, d_model: int=512, N: int=6, h: int=8, dropout: float=0.1, d_ff: int=2048) -> Transformer:
#     # Create the embedding layers
#     src_embed = InputEmbeddings(d_model, src_vocab_size)
#     tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)

#     # Create the positional encoding layers
#     src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
#     tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)

#     # Create the encoder blocks
#     encoder_blocks = []
#     for _ in range(N):
#         encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
#         feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
#         encoder_block = EncoderBlock(d_model, encoder_self_attention_block, feed_forward_block, dropout)
#         encoder_blocks.append(encoder_block)

#     # Create the decoder blocks
#     decoder_blocks = []
#     for _ in range(N):
#         decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
#         decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
#         feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
#         decoder_block = DecoderBlock(d_model, decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
#         decoder_blocks.append(decoder_block)

#     # Create the encoder and decoder
#     encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
#     decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))

#     # Create the projection layer
#     projection_layer = ProjectionLayer(d_model, tgt_vocab_size)

#     # Create the transformer
#     transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)

#     # Initialize the parameters
#     for p in transformer.parameters():
#         if p.dim() > 1:
#             nn.init.xavier_uniform_(p)

#     return transformer

In [None]:
### Hyperparameters changed so that training is lighter and faster

def build_transformer(src_vocab_size, tgt_vocab_size, src_seq_len, tgt_seq_len, d_model=512, N=3, h=4, dropout=0.1, d_ff=1024) :
    # Create the embedding layers
    src_embed = InputEmbeddings(d_model, src_vocab_size)
    tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)

    # Create the positional encoding layers
    src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
    tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)

    # Create the encoder blocks
    encoder_blocks = []
    for _ in range(N):
        encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        encoder_block = EncoderBlock(d_model, encoder_self_attention_block, feed_forward_block, dropout)
        encoder_blocks.append(encoder_block)

    # Create the decoder blocks
    decoder_blocks = []
    for _ in range(N):
        decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        decoder_block = DecoderBlock(d_model, decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
        decoder_blocks.append(decoder_block)

    # Create the encoder and decoder
    encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
    decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))

    # Create the projection layer
    projection_layer = ProjectionLayer(d_model, tgt_vocab_size)

    # Create the transformer
    transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)

    # Initialize the parameters
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return transformer

### Config

In [None]:
from pathlib import Path

def get_config():
    return {
        "batch_size": 8,
        "num_epochs": 15,
        "lr": 10**-4,
        "seq_len": 350,
        "d_model": 512,
        "datasource": 'opus_books',
        "lang_src": "en",
        "lang_tgt": "mr",
        "model_folder": "weights",
        "model_basename": "tmodel_",
        "preload": "latest",
        "tokenizer_file": "tokenizer_{0}.json",
        "experiment_name": "runs/tmodel"
    }

def get_weights_file_path(config, epoch: str):
    model_folder = f"{config['datasource']}_{config['model_folder']}"
    model_filename = f"{config['model_basename']}{epoch}.pt"
    return str(Path('.') / model_folder / model_filename)

# Find the latest weights file in the weights folder
def latest_weights_file_path(config):
    model_folder = f"{config['datasource']}_{config['model_folder']}"
    model_filename = f"{config['model_basename']}*"
    weights_files = list(Path(model_folder).glob(model_filename))
    if len(weights_files) == 0:
        return None
    weights_files.sort()
    return str(weights_files[-1])

### Building Dataset

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset

class BilingualDataset(Dataset):

    def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
        super().__init__()
        self.seq_len = seq_len

        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang

        self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        src_target_pair = self.ds[idx]
        src_text = src_target_pair['translation'][self.src_lang]
        tgt_text = src_target_pair['translation'][self.tgt_lang]

        # Transform the text into tokens
        enc_input_tokens = self.tokenizer_src.encode(src_text).ids
        dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids

        # Add sos, eos and padding to each sentence
        enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2  # We will add <s> and </s>
        # We will only add <s>, and </s> only on the label
        dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1

        # Make sure the number of padding tokens is not negative. If it is, the sentence is too long
        if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
            raise ValueError("Sentence is too long")

        # Add <s> and </s> token
        encoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(enc_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Add only <s> token
        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Add only </s> token
        label = torch.cat(
            [
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Double check the size of the tensors to make sure they are all seq_len long
        assert encoder_input.size(0) == self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len

        return {
            "encoder_input": encoder_input,  # (seq_len)
            "decoder_input": decoder_input,  # (seq_len)
            "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
            "label": label,  # (seq_len)
            "src_text": src_text,
            "tgt_text": tgt_text,
        }

def causal_mask(size):
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0

### Inference

In [None]:
import torchtext.datasets as datasets
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import LambdaLR

import warnings
from tqdm import tqdm
import os
from pathlib import Path

# Huggingface datasets and tokenizers
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

import torchmetrics
from torch.utils.tensorboard import SummaryWriter

def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')

    # Precompute the encoder output and reuse it for every step
    encoder_output = model.encode(source, source_mask)
    # Initialize the decoder input with the sos token
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
    while True:
        if decoder_input.size(1) == max_len:
            break

        # build mask for target
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)

        # calculate output
        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

        # get next token
        prob = model.project(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat(
            [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
        )

        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)


def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer, num_examples=2):
    model.eval()
    count = 0

    source_texts = []
    expected = []
    predicted = []

    try:
        # get the console window width
        with os.popen('stty size', 'r') as console:
            _, console_width = console.read().split()
            console_width = int(console_width)
    except:
        # If we can't get the console width, use 80 as default
        console_width = 80

    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
            encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)

            # check that the batch size is 1
            assert encoder_input.size(
                0) == 1, "Batch size must be 1 for validation"

            model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)

            source_text = batch["src_text"][0]
            target_text = batch["tgt_text"][0]
            model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())

            source_texts.append(source_text)
            expected.append(target_text)
            predicted.append(model_out_text)

            # Print the source, target and model output
            print_msg('-'*console_width)
            print_msg(f"{f'SOURCE: ':>12}{source_text}")
            print_msg(f"{f'TARGET: ':>12}{target_text}")
            print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")

            if count == num_examples:
                print_msg('-'*console_width)
                break

    if writer:
        # Evaluate the character error rate
        # Compute the char error rate
        metric = torchmetrics.CharErrorRate()
        cer = metric(predicted, expected)
        writer.add_scalar('validation cer', cer, global_step)
        writer.flush()

        # Compute the word error rate
        metric = torchmetrics.WordErrorRate()
        wer = metric(predicted, expected)
        writer.add_scalar('validation wer', wer, global_step)
        writer.flush()

        # Compute the BLEU metric
        metric = torchmetrics.BLEUScore()
        bleu = metric(predicted, expected)
        writer.add_scalar('validation BLEU', bleu, global_step)
        writer.flush()

  from .autonotebook import tqdm as notebook_tqdm


### Tokenizer

In [None]:
def get_all_sentences(ds, lang):
    for item in ds:
        yield item['translation'][lang]

def get_or_build_tokenizer(config, ds, lang):
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if not Path.exists(tokenizer_path):
        # Most code taken from: https://huggingface.co/docs/tokenizers/quicktour
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

def get_ds(config):
    # It only has the train split, so we divide it overselves
#     ds_raw = load_dataset(f"{config['datasource']}", f"{config['lang_src']}-{config['lang_tgt']}", split='train')
    ds_raw = load_dataset("opus100", "en-mr", split='train')
    print("Dataset fetched.")
    # Build tokenizers
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])

    # Keep 90% for training, 10% for validation
    train_ds_size = int(0.9 * len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])

    train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
    val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])

    # Find the maximum length of each sentence in the source and target sentence
    max_len_src = 0
    max_len_tgt = 0

    for item in ds_raw:
        src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
        tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))

    print(f'Max length of source sentence: {max_len_src}')
    print(f'Max length of target sentence: {max_len_tgt}')


    train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt

### Training loop

In [None]:
def get_model(config, vocab_src_len, vocab_tgt_len):
    model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config['seq_len'], d_model=config['d_model'])
    return model

def train_model(config):
    # Define the device
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
    print("Using device:", device)
    if (device == 'cuda'):
        print(f"Device name: {torch.cuda.get_device_name(device.index)}")
        print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
    elif (device == 'mps'):
        print(f"Device name: <mps>")
    else:
        print("NOTE: If you have a GPU, consider using it for training.")
        print("      On a Windows machine with NVidia GPU, check this video: https://www.youtube.com/watch?v=GMSjDTU8Zlc")
        print("      On a Mac machine, run: pip3 install --pre torch torchvision torchaudio torchtext --index-url https://download.pytorch.org/whl/nightly/cpu")
    device = torch.device(device)

    # Make sure the weights folder exists
    Path(f"{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True)

    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
    # Tensorboard
    writer = SummaryWriter(config['experiment_name'])

    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)

    # If the user specified a model to preload before training, load it
    initial_epoch = 0
    global_step = 0
    preload = config['preload']
    model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None
    if model_filename:
        print(f'Preloading model {model_filename}')
        state = torch.load(model_filename)
        model.load_state_dict(state['model_state_dict'])
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']
    else:
        print('No model to preload, starting from scratch')

    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)

    for epoch in range(initial_epoch, config['num_epochs']):
        torch.cuda.empty_cache()
        model.train()
        batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
        for batch in batch_iterator:

            encoder_input = batch['encoder_input'].to(device) # (b, seq_len)
            decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
            encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
            decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)

            # Run the tensors through the encoder, decoder and the projection layer
            encoder_output = model.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)
            proj_output = model.project(decoder_output) # (B, seq_len, vocab_size)

            # Compare the output with the label
            label = batch['label'].to(device) # (B, seq_len)

            # Compute the loss using a simple cross entropy
            # (B,seq_len,tgt_vocab_size) ---> (B*seq_length, tgt_vocab_size)
            loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
            batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})

            # Log the loss
            writer.add_scalar('train loss', loss.item(), global_step)
            writer.flush()

            # Backpropagate the loss
            loss.backward()

            # Update the weights
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            global_step += 1

        # Run validation at the end of every epoch
        run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer)

        # Save the model at the end of every epoch
        if (epoch + 1) % 10 == 0:
                model_filename = get_weights_file_path(config, f"{epoch:02d}")
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'global_step': global_step
                }, model_filename)
                print(f"Model saved at epoch {epoch + 1}")


if __name__ == '__main__':
    warnings.filterwarnings("ignore")
    config = get_config()
    train_model(config)

Using device: cuda
Device name: NVIDIA GeForce RTX 4050 Laptop GPU
Device memory: 5.99658203125 GB
Dataset fetched.
Max length of source sentence: 289
Max length of target sentence: 261
No model to preload, starting from scratch


Processing Epoch 00: 100%|██████████| 3039/3039 [05:46<00:00,  8.77it/s, loss=4.429]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: Abort.
    TARGET: वििच्छेद करा
 PREDICTED: ते .
--------------------------------------------------------------------------------
    SOURCE: If he should be busy, help him.
    TARGET: जर तो व्यग्र असला तर त्याची मदत करा.
 PREDICTED: तो , तो .
--------------------------------------------------------------------------------


Processing Epoch 01: 100%|██████████| 3039/3039 [05:45<00:00,  8.80it/s, loss=5.769]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: Konqueror
    TARGET: कॉन्करर
 PREDICTED: @ info
--------------------------------------------------------------------------------
    SOURCE: Crash report
    TARGET: क्रॅश अहवाल
 PREDICTED: @ info
--------------------------------------------------------------------------------


Processing Epoch 02: 100%|██████████| 3039/3039 [05:45<00:00,  8.79it/s, loss=3.953]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: In India, the cow is a sacred animal.
    TARGET: भारतात गाय ही एक पवित्र पशु आहे.
 PREDICTED: , .
--------------------------------------------------------------------------------
    SOURCE: X DVI
    TARGET: X DVIGenericName
 PREDICTED: एक्स
--------------------------------------------------------------------------------


Processing Epoch 03: 100%|██████████| 3039/3039 [05:45<00:00,  8.80it/s, loss=3.588]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: Forget Tom.
    TARGET: टॉमला विसरून जा.
 PREDICTED: टॉमला बघ .
--------------------------------------------------------------------------------
    SOURCE: Patrick Dowler
    TARGET: Patrick Dowler
 PREDICTED: 
--------------------------------------------------------------------------------


Processing Epoch 04: 100%|██████████| 3039/3039 [05:45<00:00,  8.80it/s, loss=2.982]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: Why don't you take a taxi?
    TARGET: तुम्ही टॅक्सी का नाही करत?
 PREDICTED: तू टॅक्सी का नाही ?
--------------------------------------------------------------------------------
    SOURCE: Everybody's getting rich but me.
    TARGET: मला सोडल्यास सर्वच श्रीमंत होत आहेत.
 PREDICTED: तो श्रीमंत .
--------------------------------------------------------------------------------


Processing Epoch 05: 100%|██████████| 3039/3039 [05:45<00:00,  8.80it/s, loss=1.975]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: Button Text
    TARGET: बटण पाठ्य
 PREDICTED: पाठ्य पाठ्य
--------------------------------------------------------------------------------
    SOURCE: No service matching the requirements was found.
    TARGET: आवश्यकतानुरूप सेवा जुळवणी नंतर आढळले नाही.
 PREDICTED: सेवा % 1 आढळले नाही .
--------------------------------------------------------------------------------


Processing Epoch 06: 100%|██████████| 3039/3039 [05:45<00:00,  8.80it/s, loss=2.115]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: Geert Jansen
    TARGET: गीर्त जानसन
 PREDICTED: गीर्त जानसन
--------------------------------------------------------------------------------
    SOURCE: I am 19 years old.
    TARGET: मी १९ वर्षांचा आहे.
 PREDICTED: मी वर्ष आहे .
--------------------------------------------------------------------------------


Processing Epoch 07: 100%|██████████| 3039/3039 [05:45<00:00,  8.80it/s, loss=2.134]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: Which is more important, me or your job?
    TARGET: जास्त महत्त्वाचं काय आहे, मी की तुझी नोकरी?
 PREDICTED: तुझं आहे , आता की तुझं ?
--------------------------------------------------------------------------------
    SOURCE: We don't have any sugar.
    TARGET: आपल्याकडे साखर अजिबात नाहीये.
 PREDICTED: आपल्याकडे साखर नाही आहे .
--------------------------------------------------------------------------------


Processing Epoch 08: 100%|██████████| 3039/3039 [05:45<00:00,  8.80it/s, loss=2.058]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: Signaled target abort
    TARGET: लक्ष्य नष्ट करण्याकरीता संकेत पाठविले
 PREDICTED: लक्ष्य नष्ट करण्याकरीता संकेत पाठविले
--------------------------------------------------------------------------------
    SOURCE: Multi-Channel
    TARGET: बहु- मार्ग
 PREDICTED: -
--------------------------------------------------------------------------------


Processing Epoch 09: 100%|██████████| 3039/3039 [05:45<00:00,  8.79it/s, loss=1.683]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: Build Index
    TARGET: इन्डेक्स बिल्ड करा
 PREDICTED: इनडेक्स बिल्ड करा
--------------------------------------------------------------------------------
    SOURCE: We are his sons.
    TARGET: आम्ही त्याची मुलं आहोत.
 PREDICTED: आपण त्याची मुलं आहोत .
--------------------------------------------------------------------------------
Model saved at epoch 10


Processing Epoch 10: 100%|██████████| 3039/3039 [05:45<00:00,  8.79it/s, loss=1.748]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: Where do you watch television?
    TARGET: तू टीव्ही कुठे बघतोस?
 PREDICTED: तू टीव्ही कुठे ?
--------------------------------------------------------------------------------
    SOURCE: She looks like her aunt.
    TARGET: ती तिच्या मावशीसारखी दिसते.
 PREDICTED: ती तिच्या दिसते .
--------------------------------------------------------------------------------


Processing Epoch 11: 100%|██████████| 3039/3039 [05:45<00:00,  8.79it/s, loss=2.315]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: URI
    TARGET: URI
 PREDICTED: URI URI @ item :: intable
--------------------------------------------------------------------------------
    SOURCE: I am downloading books.
    TARGET: मी पुस्तकं डाउनलोड करतेय.
 PREDICTED: मी पुस्तकं डाउनलोड करतोय .
--------------------------------------------------------------------------------


Processing Epoch 12: 100%|██████████| 3039/3039 [05:45<00:00,  8.79it/s, loss=2.602]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: Button
    TARGET: बटण
 PREDICTED: बटण
--------------------------------------------------------------------------------
    SOURCE: Product: %1 (%2)
    TARGET: @ info bug report label and value
 PREDICTED: :% 1 (% 2 )
--------------------------------------------------------------------------------


Processing Epoch 13: 100%|██████████| 3039/3039 [05:45<00:00,  8.80it/s, loss=2.024]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: Show progress information (even if'silent 'mode is on)
    TARGET: प्रगतीविषयक माहिती दर्शवा (जरी ते 'शांत' पध्दतीत कार्यक्षम आहे)
 PREDICTED: प्रगती माहिती दर्शवा ( , दाखलन पासून प्रगती दर्शवा )
--------------------------------------------------------------------------------
    SOURCE: Ricky, this is my friend Suzuki.
    TARGET: रिकी, ही आहे माझी मैत्रिण, सुजुकी.
 PREDICTED: रिकी , हा आहे माझा मित्र , सुजुकी .
--------------------------------------------------------------------------------


Processing Epoch 14: 100%|██████████| 3039/3039 [05:45<00:00,  8.79it/s, loss=1.725]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: That's Tom's girl.
    TARGET: ती टॉमची मुलगी आहे.
 PREDICTED: ती टॉमची मुलगी आहे .
--------------------------------------------------------------------------------
    SOURCE: UADescription (IE 6.0 on Win XP)
    TARGET: UAडिस्क्रिप्शन (IE 6. 0 Win XP वर) Name
 PREDICTED: UAडिस्क्रिप्शन ( IE 6 . 0 Win XP वर ) Name
--------------------------------------------------------------------------------
