<a href="https://colab.research.google.com/github/yashpapa6969/data-venki/blob/main/transformer_SMOE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets torchmetrics wandb



In [None]:
import torch
from torch import einsum
import torch.nn as nn
import math
from dataclasses import dataclass
from typing import Optional
from typing import Dict,Tuple
from typing import List, Union
from collections.abc import MutableMapping
import torch.nn.functional as F

In [None]:
768/2

384.0

In [None]:
@dataclass
class ModelConfig:
  d_model :int = 768
  n_heads :int = 12
  ffn_dim_multiplier: Optional[float] = None
  multiple_of: int = 256
  dropout: float = 0.1
  pad_token_id: int = 1
  hidden_size: int = 768
  num_attention_heads:int = 12
  max_sequence_length: int = 768

  num_experts: int = 4
  num_experts_per_tok: int = 2
  capacity_factor: float = 1.25


In [None]:
# class LayerNormalization(nn.Module):

#     def __init__(self, features: int, eps:float=10**-6) -> None:
#         super().__init__()
#         self.eps = eps
#         # Here features are the d_model
#         self.alpha = nn.Parameter(torch.ones(features)) # alpha is a learnable parameter
#         self.bias = nn.Parameter(torch.zeros(features)) # bias is a learnable parameter

#     def forward(self, x):
#         # x: (batch, seq_len, hidden_size)
#          # Keep the dimension for broadcasting
#         mean = x.mean(dim = -1, keepdim = True) # (batch, seq_len, 1)
#         # Keep the dimension for broadcasting
#         std = x.std(dim = -1, keepdim = True) # (batch, seq_len, 1)
#         # eps is to prevent dividing by zero or when std is very small
#         return self.alpha * (x - mean) / (std + self.eps) + self.bias

In [None]:
class BufferCache(dict, MutableMapping[str, torch.Tensor]):
  """ """

In [None]:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        # The gamma parameter
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x: torch.Tensor):
        # (B, Seq_Len, Dim) * (B, Seq_Len, 1) = (B, Seq_Len, Dim)
        # rsqrt: 1 / sqrt(x)
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor):
        # (Dim) * (B, Seq_Len, Dim) = (B, Seq_Len, Dim)
        return self.weight * self._norm(x.float()).type_as(x)

In [None]:
# class FeedForwardBlock(nn.Module):

#     def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
#         super().__init__()
#         self.linear_1 = nn.Linear(d_model, d_ff) # w1 and b1
#         self.dropout = nn.Dropout(dropout)
#         self.linear_2 = nn.Linear(d_ff, d_model) # w2 and b2

#     def forward(self, x):
#         # (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
#         return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

In [None]:
class FeedForwardBlock(nn.Module):
    def __init__(
        self,
        args: ModelConfig
    ):
        super().__init__()

        hidden_dim = 4 * args.d_model
        hidden_dim = int(2 * hidden_dim / 3)
        if args.ffn_dim_multiplier is not None:
            hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
        # Round the hidden_dim to the nearest multiple of the multiple_of parameter
        hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)

        self.w1 = nn.Linear(args.d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, args.d_model, bias=False)
        self.w3 = nn.Linear(args.d_model, hidden_dim, bias=False)

    def forward(self, x: torch.Tensor):
        # (B, Seq_Len, Dim) --> (B, Seq_Len, Hidden_Dim)
        swish = F.silu(self.w1(x))
        # (B, Seq_Len, Dim) --> (B, Seq_Len, Hidden_Dim)
        x_V = self.w3(x)
        # (B, Seq_Len, Hidden_Dim) * (B, Seq_Len, Hidden_Dim) --> (B, Seq_Len, Hidden_Dim)
        x = swish * x_V
        # (B, Seq_Len, Hidden_Dim) --> (B, Seq_Len, Dim)
        x = self.w2(x)
        return x

In [None]:
class InputEmbeddings(nn.Module):

    def __init__(self, d_model: int, vocab_size: int) -> None:
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        # (batch, seq_len) --> (batch, seq_len, d_model)
        # Multiply by sqrt(d_model) to scale the embeddings according to the paper
        return self.embedding(x) * math.sqrt(self.d_model)

In [None]:
# class PositionalEncoding(nn.Module):

#     def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
#         super().__init__()
#         self.d_model = d_model
#         self.seq_len = seq_len
#         self.dropout = nn.Dropout(dropout)
#         # Create a matrix of shape (seq_len, d_model)
#         pe = torch.zeros(seq_len, d_model)
#         # Create a vector of shape (seq_len)
#         position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1)
#         # Create a vector of shape (d_model)
#         div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # (d_model / 2)
#         # Apply sine to even indices
#         pe[:, 0::2] = torch.sin(position * div_term) # sin(position * (10000 ** (2i / d_model))
#         # Apply cosine to odd indices
#         pe[:, 1::2] = torch.cos(position * div_term) # cos(position * (10000 ** (2i / d_model))
#         # Add a batch dimension to the positional encoding
#         pe = pe.unsqueeze(0) # (1, seq_len, d_model)
#         # Register the positional encoding as a buffer
#         self.register_buffer('pe', pe)

#     def forward(self, x):
#         x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model)
#         return self.dropout(x)

In [None]:
class RotaryEmbedding(nn.Module):


    def __init__(self, config: ModelConfig, cache: BufferCache, device):
        super().__init__()
        self.config = config
        self.__cache = cache

        # Warm up cache.
        self.get_rotary_embedding(config.max_sequence_length, device)

    def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
        if (
            (pos_sin := self.__cache.get("rope_pos_sin")) is not None
            and (pos_cos := self.__cache.get("rope_pos_cos")) is not None
            and pos_sin.shape[-2] >= seq_len
            and pos_cos.shape[-2] >= seq_len
        ):
            if pos_sin.device != device:
                pos_sin = pos_sin.to(device)
                self.__cache["rope_pos_sin"] = pos_sin
            if pos_cos.device != device:
                pos_cos = pos_cos.to(device)
                self.__cache["rope_pos_cos"] = pos_cos
            return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]

        dim = self.config.d_model // self.config.n_heads
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
        seq = torch.arange(seq_len, device=device, dtype=torch.float)
        freqs = einsum("i , j -> i j", seq, inv_freq)
        positions = torch.cat((freqs, freqs), dim=-1)
        pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]
        self.__cache["rope_pos_sin"] = pos_sin
        self.__cache["rope_pos_cos"] = pos_cos
        return pos_sin, pos_cos

    def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
        B, nh, T, hs = x.size()
        x = x.view(B, nh, T, 2, hs // 2)
        x1, x2 = x.unbind(dim=-2)
        return torch.cat((-x2, x1), dim=-1)

    def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)

    def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        q_, k_ = q, k
        query_len, key_len = q_.shape[-2], k_.shape[-2]  # could be different if layer_past not None
        pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
        pos_sin = pos_sin.type_as(q_)
        pos_cos = pos_cos.type_as(q_)
        q_ = self.apply_rotary_pos_emb(
            pos_sin[:, :, key_len - query_len : key_len, :],
            pos_cos[:, :, key_len - query_len : key_len, :],
            q_,
        )
        k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
        return q_.type_as(q), k_.type_as(k)

In [None]:
class ResidualConnection(nn.Module):

        def __init__(self, features: int, dropout: float) -> None:
            super().__init__()
            self.dropout = nn.Dropout(dropout)
            self.norm = RMSNorm(features)

        def forward(self, x, sublayer):
            return x + self.dropout(sublayer(self.norm(x)))

In [None]:
class MoeLayer(nn.Module):
    def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: ModelConfig):
        super().__init__()
        assert len(experts) > 0
        self.experts = nn.ModuleList(experts)
        self.gate = gate
        self.args = moe_args
        self.balance_loss = None

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, d_model = inputs.shape
        flat_inputs = inputs.view(-1, d_model)

        # Generate gate logits
        gate_logits = self.gate(flat_inputs)

        # Select top-k experts
        weights, selected_experts = torch.topk(
            gate_logits, self.args.num_experts_per_tok, dim=-1
        )
        weights = F.softmax(weights, dim=-1, dtype=torch.float).to(inputs.dtype)

        # Calculate expert capacity
        num_tokens = batch_size * seq_len
        capacity = int(self.args.capacity_factor * num_tokens // self.args.num_experts)

        # Compute load balancing
        expert_counts = torch.zeros(self.args.num_experts, device=inputs.device)
        for i in range(self.args.num_experts_per_tok):
            expert_counts += torch.bincount(selected_experts[:, i], minlength=self.args.num_experts)

        # Compute balance loss
        mean_count = num_tokens / self.args.num_experts
        self.balance_loss = torch.sum((expert_counts - mean_count)**2) / (self.args.num_experts * mean_count**2)

        results = torch.zeros_like(flat_inputs)
        for i, expert in enumerate(self.experts):
            # Find tokens routed to this expert
            expert_mask = (selected_experts == i)
            token_indices = expert_mask.nonzero(as_tuple=True)[0]
            expert_weights = weights[expert_mask]

            # Apply capacity constraint
            if token_indices.shape[0] > capacity:
                perm = torch.randperm(token_indices.shape[0], device=token_indices.device)
                token_indices = token_indices[perm[:capacity]]
                expert_weights = expert_weights[perm[:capacity]]

            if token_indices.shape[0] == 0:
                continue

            # Compute expert outputs
            expert_inputs = flat_inputs[token_indices]
            expert_outputs = expert(expert_inputs)

            # Combine expert outputs weighted by gate probabilities
            results.index_add_(0, token_indices, expert_weights.unsqueeze(-1) * expert_outputs)

        return results.view(batch_size, seq_len, d_model)

    def get_balance_loss(self):
        return self.balance_loss

In [None]:

class SelfAttentionBlock(nn.Module):

    def __init__(self, d_model: int, h: int, dropout: float, config: ModelConfig, cache: BufferCache, device) -> None:
        super().__init__()
        self.d_model = d_model # Embedding vector size
        self.h = h # Number of heads
        # Make sure d_model is divisible by h
        assert d_model % h == 0, "d_model is not divisible by h"

        self.d_k = d_model // h # Dimension of vector seen by each head
        self.w_q = nn.Linear(d_model, d_model, bias=False) # Wq
        self.w_k = nn.Linear(d_model, d_model, bias=False) # Wk
        self.w_v = nn.Linear(d_model, d_model, bias=False) # Wv
        self.w_o = nn.Linear(d_model, d_model, bias=False) # Wo
        self.dropout = nn.Dropout(dropout)

        self.__cache = cache
        self.rotary_emb = RotaryEmbedding(config, self.__cache, device)

    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        d_k = query.shape[-1]
        # Just apply the formula from the paper
        # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            # Write a very low value (indicating -inf) to the positions where mask == 0
            attention_scores.masked_fill_(mask == 0, torch.finfo(attention_scores.dtype).min)
        attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len) # Apply softmax
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        # (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k)
        # return attention scores which can be used for visualization
        return (attention_scores @ value), attention_scores

    def forward(self, q, k, v, mask):
        query = self.w_q(q) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        key = self.w_k(k) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        value = self.w_v(v) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)

        # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)

        query, key = self.rotary_emb(query, key)   #Rotary positional Embedding

        # Calculate attention
        x, self.attention_scores = SelfAttentionBlock.attention(query, key, value, mask, self.dropout)

        # Combine all the heads together
        # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)

        # Multiply by Wo
        # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        return self.w_o(x)


In [None]:
class CrossAttentionBlock(nn.Module):

    def __init__(self, d_model: int, h: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model # Embedding vector size
        self.h = h # Number of heads
        # Make sure d_model is divisible by h
        assert d_model % h == 0, "d_model is not divisible by h"

        self.d_k = d_model // h # Dimension of vector seen by each head
        self.w_q = nn.Linear(d_model, d_model, bias=False) # Wq
        self.w_k = nn.Linear(d_model, d_model, bias=False) # Wk
        self.w_v = nn.Linear(d_model, d_model, bias=False) # Wv
        self.w_o = nn.Linear(d_model, d_model, bias=False) # Wo
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        d_k = query.shape[-1]
        # Just apply the formula from the paper
        # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            # Write a very low value (indicating -inf) to the positions where mask == 0
            attention_scores.masked_fill_(mask == 0, torch.finfo(attention_scores.dtype).min)
        attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len) # Apply softmax
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        # (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k)
        # return attention scores which can be used for visualization
        return (attention_scores @ value), attention_scores

    def forward(self, q, k, v, mask):
        query = self.w_q(q) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        key = self.w_k(k) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        value = self.w_v(v) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)

        # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)

        # Calculate attention
        x, self.attention_scores = CrossAttentionBlock.attention(query, key, value, mask, self.dropout)

        # Combine all the heads together
        # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)

        # Multiply by Wo
        # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        return self.w_o(x)

In [None]:
class EncoderBlock(nn.Module):

    def __init__(self, features: int, self_attention_block: SelfAttentionBlock, feed_forward_block: MoeLayer, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])

    def forward(self, x, src_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x

In [None]:
class Encoder(nn.Module):

    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = RMSNorm(features)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [None]:
class DecoderBlock(nn.Module):

    def __init__(self, features: int, self_attention_block: SelfAttentionBlock, cross_attention_block: CrossAttentionBlock, feed_forward_block: MoeLayer, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x

In [None]:
class Decoder(nn.Module):

    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = RMSNorm(features)

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)

In [None]:
class ProjectionLayer(nn.Module):

    def __init__(self, d_model, vocab_size) -> None:
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x) -> None:
        # (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
        return self.proj(x)

In [None]:
class Transformer(nn.Module):

    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings,  projection_layer: ProjectionLayer) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.projection_layer = projection_layer

    def encode(self, src, src_mask):
        # (batch, seq_len, d_model)
        src = self.src_embed(src)
        return self.encoder(src, src_mask)

    def decode(self, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
        # (batch, seq_len, d_model)
        tgt = self.tgt_embed(tgt)
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)

    def project(self, x):
        # (batch, seq_len, vocab_size)
        return self.projection_layer(x)

In [None]:
def build_transformer(src_vocab_size: int, tgt_vocab_size: int,  d_model: int=768, N: int=6, h: int=12, dropout: float=0.1) -> Transformer:
    # Create the embedding layers
    src_embed = InputEmbeddings(d_model, src_vocab_size)
    tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)


    device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
    # Create the encoder blocks
    Cache = BufferCache()
    encoder_blocks = []
    for _ in range(N):
        encoder_self_attention_block = SelfAttentionBlock(d_model, h, dropout, ModelConfig, Cache,device)
        feed_forward_block = MoeLayer(
                experts=[FeedForwardBlock(ModelConfig) for _ in range(ModelConfig.num_experts)],
                gate=nn.Linear(ModelConfig.d_model, ModelConfig.num_experts, bias=False),
                moe_args=ModelConfig,
            )
        encoder_block = EncoderBlock(d_model, encoder_self_attention_block, feed_forward_block, dropout)
        encoder_blocks.append(encoder_block)

    # Create the decoder blocks
    decoder_blocks = []
    for _ in range(N):
        decoder_self_attention_block = SelfAttentionBlock(d_model, h, dropout, ModelConfig, Cache,device)
        decoder_cross_attention_block = CrossAttentionBlock(d_model, h, dropout)
        feed_forward_block = MoeLayer(
                experts=[FeedForwardBlock(ModelConfig) for _ in range(ModelConfig.num_experts)],
                gate=nn.Linear(ModelConfig.d_model, ModelConfig.num_experts, bias=False),
                moe_args=ModelConfig,
            )
        decoder_block = DecoderBlock(d_model, decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
        decoder_blocks.append(decoder_block)

    # Create the encoder and decoder
    encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
    decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))

    # Create the projection layer
    projection_layer = ProjectionLayer(d_model, tgt_vocab_size)

    # Create the transformer
    transformer = Transformer(encoder, decoder, src_embed, tgt_embed, projection_layer)

    # Initialize the parameters
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return transformer

In [None]:
from pathlib import Path

def get_config():
    return {
        "batch_size": 24,
        "num_epochs": 3,
        "lr": 2e-4,
        "seq_len": 320,
        "d_model": 768,
        "datasource": "ai4bharat/IndicParaphrase",
        "lang_src": "en",
        "lang_tgt": "hi",
        "model_folder": "weights",
        "model_basename": "tmodel_",
        "preload": "latest",
        "tokenizer_file": "tokenizer_{0}.json",
        "experiment_name": "runs/tmodel"
    }

def get_weights_file_path(config, epoch: str):
    model_folder = f"{config['datasource']}_{config['model_folder']}"
    model_filename = f"{config['model_basename']}{epoch}.pt"
    return str(Path('.') / model_folder / model_filename)

# Find the latest weights file in the weights folder
def latest_weights_file_path(config):
    model_folder = f"{config['datasource']}_{config['model_folder']}"
    model_filename = f"{config['model_basename']}*"
    weights_files = list(Path(model_folder).glob(model_filename))
    if len(weights_files) == 0:
        return None
    weights_files.sort()
    return str(weights_files[-1])

In [None]:
from datasets import load_dataset, concatenate_datasets, Dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

In [None]:
from tokenizers import Tokenizer
from tokenizers import SentencePieceBPETokenizer
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

In [None]:
def format_dataset(example):
     return {"en": example["src"],
             "hi": example["tgt"]}

In [None]:
def preprocess_dataset(dataset):
  dataset = load_dataset(dataset)
  all_dataset = concatenate_datasets([dataset[split] for split in dataset.keys()])
  df = all_dataset.to_pandas()
  df["translation"] = df[["src", "tgt"]].apply(lambda x: format_dataset(x), axis=1)
  df = df.drop(columns=["src", "tgt"], axis=1)
  hf_dataset = Dataset.from_pandas(df)
  return hf_dataset

In [None]:
dataset = "Venkatesh4342/samanantar-hi-seq-len-300"

In [None]:
ff_dataset = preprocess_dataset(dataset)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
ff_dataset["translation"][500009]

{'en': 'There is also considerable potential for Russia and India to reinforce each other, in executing energy and rail transportation projects in third countries, like Afghanistan and Vietnam.',
 'hi': 'भारत-रूस के बीच तीसरे देशों जैसे कि अफगानिस्तान और वियतनाम में ऊर्जा और रेल परिवहन परियोजनाओं में एक-दूसरे का सहयोग करते हुए कमाई करने की काफी संभावना है।'}

In [None]:
def get_all_sentences(ds, lang):
    for item in ds:
        yield item['translation'][lang]

In [None]:
def get_or_build_tokenizer(config, ds, lang):
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if not Path.exists(tokenizer_path):
        if lang == "en":
            vocab_size = 32_000
        elif lang == "hi":
            vocab_size = 64_000

        # Most code taken from: https://huggingface.co/docs/tokenizers/quicktour
        #tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer = SentencePieceBPETokenizer()
        #trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        special_tokens = ["<unk>", "<pad>", "<sos>", "<eos>"]
        tokenizer.train_from_iterator(get_all_sentences(ds, lang),
        vocab_size=vocab_size,
        min_frequency=5,
        special_tokens=special_tokens,
        show_progress=True,)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

In [None]:
config  = get_config()

In [None]:
Path.exists(Path(config['tokenizer_file'].format("en")))

True

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import LambdaLR

import warnings
from tqdm import tqdm
import os
from pathlib import Path

# Huggingface datasets and tokenizers
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

import torchmetrics

In [None]:
def get_ds(config):
    # It only has the train split, so we divide it overselves
    ds_raw = ff_dataset

    # Build tokenizers
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])

    # Keep 90% for training, 10% for validation
    train_ds_size = int(0.001 * len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])

    train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
    val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])

    # Find the maximum length of each sentence in the source and target sentence
    max_len_src = 0
    max_len_tgt = 0
    src_list = []
    tgt_list = []
    src_str = []
    tgt_str = []

    # for item in ds_raw:
    #     src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
    #     tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
    #     src_list.append(len(src_ids))
    #     tgt_list.append(len(tgt_ids))
    #     src_str.append(item['translation'][config['lang_src']])
    #     tgt_str.append(item['translation'][config['lang_tgt']])
    #     max_len_src = max(max_len_src, len(src_ids))
    #     max_len_tgt = max(max_len_tgt, len(tgt_ids))

    print(f'Max length of source sentence: {max_len_src}')
    print(f'Max length of target sentence: {max_len_tgt}')


    train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt, {"src_len": src_list, "tgt_len":tgt_list, "src":src_str, "tgt":tgt_str}


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset

class BilingualDataset(Dataset):

    def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
        super().__init__()
        self.seq_len = seq_len

        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang

        self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("<sos>")], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("<eos>")], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("<pad>")], dtype=torch.int64)

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        src_target_pair = self.ds[idx]
        src_text = src_target_pair['translation'][self.src_lang]
        tgt_text = src_target_pair['translation'][self.tgt_lang]

        # Transform the text into tokens
        enc_input_tokens = self.tokenizer_src.encode(src_text).ids
        dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids

        # Add sos, eos and padding to each sentence
        enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2  # We will add <s> and </s>
        # We will only add <s>, and </s> only on the label
        dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1

        # Make sure the number of padding tokens is not negative. If it is, the sentence is too long
        if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
            raise ValueError("Sentence is too long")

        # Add <s> and </s> token
        encoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(enc_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Add only <s> token
        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Add only </s> token
        label = torch.cat(
            [
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Double check the size of the tensors to make sure they are all seq_len long
        assert encoder_input.size(0) == self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len

        return {
            "encoder_input": encoder_input ,  # (seq_len)
            "decoder_input": decoder_input,  # (seq_len)
            "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
            "label": label,  # (seq_len)
            "src_text": src_text,
            "tgt_text": tgt_text,
        }

def causal_mask(size):
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0

In [None]:
import copy
import warnings
from abc import abstractmethod
from inspect import signature
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast

import torch
import wandb


In [None]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mvenkatesh3132003[0m ([33mnlp_explorer[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
StateType = Dict[str, torch.Tensor]
StepFunctionTypeWithTimestep = Callable[[torch.Tensor, StateType, int], Tuple[torch.Tensor, StateType]]
StepFunctionTypeNoTimestep = Callable[[torch.Tensor, StateType], Tuple[torch.Tensor, StateType]]

StepFunctionType = TypeVar("StepFunctionType", StepFunctionTypeWithTimestep, StepFunctionTypeNoTimestep)
ConstraintStateType = List[List[Dict[str, Any]]]

In [None]:
class Sampler:

    def init_state(
        self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int
    ) -> StateType:
        del start_class_log_probabilities, batch_size, num_classes
        return {}

    @abstractmethod
    def sample_nodes(
        self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
    ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
        raise NotImplementedError

    def sample_beams(
        self, log_probs: torch.Tensor, beam_size: int, state: StateType
    ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
        del state
        selected_log_probs, selected_indices = torch.topk(log_probs, beam_size, dim=-1)
        return selected_log_probs, selected_indices, {}

In [None]:
class TopKSampler(Sampler):

    def __init__(
        self,
        k: int = 1,
        temperature: float = 1.0,
        with_replacement: bool = False,
    ):
        self.k = k
        self.temperature = temperature or 1.0
        self.with_replacement = with_replacement

    def sample_nodes(
        self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
    ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
        if not per_node_beam_size <= self.k <= log_probs.size()[1]:
            raise ValueError(
                "k must be a postive integer no less than per_node_beam_size and no greater than vocabulary size"
            )

        # shape (both): (batch_size, k)
        top_k_log_probs, top_k_indices = log_probs.topk(self.k, dim=-1)

        # Apply temperature if necessary.
        # shape: (batch_size, k)
        if self.temperature != 1.0:
            top_k_log_probs = top_k_log_probs / self.temperature

        # Re-normalize the subset.
        # shape: (batch_size, k)
        normalized_top_k_probs = torch.nn.functional.softmax(top_k_log_probs, dim=-1)

        # Sample from the re-normalized subset.
        # NOTE: These indices are not indices into `log_probs`, they are indices into `top_k_log_probs`.
        # shape: (batch_size, per_node_beam_size)
        sampled_indices = torch.multinomial(
            normalized_top_k_probs, per_node_beam_size, replacement=self.with_replacement
        )

        # Convert `sampled_indices` back to indices in the original `log_probs` tensor.
        # shape: (batch_size, per_node_beam_size)
        indices = top_k_indices.gather(-1, sampled_indices)

        return log_probs.gather(1, indices), indices, state


In [None]:
class TopPSampler(Sampler):

    def __init__(
        self,
        p: float = 0.9,
        temperature: float = 1.0,
        with_replacement: bool = False,
    ):
        if p < 0.0 or p > 1.0:
            raise ValueError("p must be a positive float no greater than 1.0")
        self.p = p
        self.temperature = temperature or 1.0
        self.with_replacement = with_replacement

    def sample_nodes(
        self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
    ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
        if not per_node_beam_size <= log_probs.size()[1]:
            raise ValueError("per_node_beam_size cannot be greater than vocabulary size")

        # First apply temperature coefficient:
        if self.temperature != 1.0:
            _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1)
        else:
            _log_probs = log_probs

        # Sort the probabilities in descending order to then find cumulative sum
        log_probs_descending, sorting_indices = torch.sort(_log_probs, descending=True)

        # shape: (batch_size, num_classes)
        probabilities_descending = log_probs_descending.exp()
        probabilities_summed = torch.cumsum(probabilities_descending, dim=-1)

        # Create a mask for filtering out probabilities that don't make the top `p`.
        # shape: (batch_size, num_classes)
        exclusion_mask = probabilities_summed >= self.p

        # We want to include the first index where probabilities_summed >= p, so we shift over one.
        exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone()
        exclusion_mask[..., 0] = False

        # Make sure there's at least `per_node_beam_size` options to be selected.
        if not self.with_replacement:
            exclusion_mask[..., :per_node_beam_size] = False

        log_probs_descending[exclusion_mask] = torch.finfo(log_probs.dtype).min

        # Now re-normalized the included log probs.
        # shape: (batch_size, num_classes)
        filtered_probabilities = torch.nn.functional.softmax(log_probs_descending, dim=-1)

        # Sample from the re-normalized subset.
        # NOTE: These indices are not indices into `log_probs`, they are indices into `log_probs_descending`.
        # shape: (batch_size, per_node_beam_size)
        sampled_indices = torch.multinomial(
            filtered_probabilities, per_node_beam_size, replacement=self.with_replacement
        )

        # Convert `sampled_indices` back to indices in the original `log_probs` tensor.
        # shape: (batch_size, per_node_beam_size)
        selected_indices = sorting_indices.gather(-1, sampled_indices)

        # Return (selected log probabilities, selected classes)
        # shape: (len(log_probs),1) , (len(log_probs), 1)
        return torch.gather(log_probs, 1, selected_indices), selected_indices, state

In [None]:
class FinalSequenceScorer:

    @abstractmethod
    def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
        raise NotImplementedError

In [None]:
class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer):

    def __init__(self, length_penalty: float = 1.0):
        super().__init__()
        self.length_penalty = length_penalty

    def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
        # shape: (batch_size, beam_size)
        lengths = (predictions != end_index).long().sum(dim=2)

        # If the sequence ended during beam search, the `log_probabilities` will include
        # the transition to the end token. Therefore, in such situations, `lengths` is
        # actually off by 1. This corrects for that.
        # shape: (batch_size, beam_size)
        is_end_token = predictions[:, :, -1] == end_index
        lengths += is_end_token.long()

        # shape: (batch_size, beam_size)
        average_log_probs = log_probabilities / (lengths**self.length_penalty)
        return average_log_probs


In [None]:
class Constraint:


    @abstractmethod
    def init_state(
        self,
        batch_size: int,
    ) -> ConstraintStateType:
        raise NotImplementedError

    @abstractmethod
    def apply(
        self,
        state: ConstraintStateType,
        class_log_probabilities: torch.Tensor,
    ) -> torch.Tensor:
        raise NotImplementedError

    @staticmethod
    def _copy_state(
        state: ConstraintStateType,
        batch_size: int,
        beam_size: int,
        last_backpointer: Optional[torch.Tensor] = None,
    ) -> ConstraintStateType:

        new_state = []
        for i in range(batch_size):
            batch_state = []
            for j in range(beam_size):
                if last_backpointer is None:
                    # This is the first prediction, so the backpointer is 0
                    backpointer = 0
                else:
                    backpointer = last_backpointer[i, j].item()
                batch_state.append(copy.deepcopy(state[i][backpointer]))  # type: ignore
            new_state.append(batch_state)
        return new_state

    def update_state(
        self,
        state: ConstraintStateType,
        last_prediction: torch.Tensor,
        last_backpointer: Optional[torch.Tensor] = None,
    ) -> ConstraintStateType:
        batch_size, beam_size = last_prediction.size()
        new_state = self._copy_state(state, batch_size, beam_size, last_backpointer)
        return self._update_state(new_state, last_prediction)

    @abstractmethod
    def _update_state(
        self,
        state: ConstraintStateType,
        last_prediction: torch.Tensor,
    ) -> ConstraintStateType:
        raise NotImplementedError

In [None]:
class RepeatedNGramBlockingConstraint(Constraint):
    def __init__(self, ngram_size: int, **kwargs) -> None:
        super().__init__(**kwargs)
        self.ngram_size = ngram_size

    def init_state(
        self,
        batch_size: int,
    ) -> ConstraintStateType:
        return [[{"seen_ngrams": {}, "current_prefix": []}] for _ in range(batch_size)]

    def apply(
        self,
        state: ConstraintStateType,
        class_log_probabilities: torch.Tensor,
    ) -> torch.Tensor:
        for i, batch in enumerate(state):
            for j, beam in enumerate(batch):
                current_prefix = tuple(beam["current_prefix"])
                seen_ngrams = beam["seen_ngrams"]
                try:
                    disallowed_indices = seen_ngrams[current_prefix]
                    class_log_probabilities[i, j, disallowed_indices] = torch.finfo(
                        class_log_probabilities.dtype
                    ).min
                except KeyError:
                    # We have not seen this prefix before, so there is no index
                    # that needs to be blocked
                    pass
        return class_log_probabilities

    def _update_state(
        self,
        state: ConstraintStateType,
        last_prediction: torch.Tensor,
    ) -> ConstraintStateType:
        for i, batch in enumerate(state):
            for j, beam in enumerate(batch):
                prediction = last_prediction[i, j].item()
                prefix = beam["current_prefix"]
                seen_ngrams = beam["seen_ngrams"]

                if len(prefix) == self.ngram_size - 1:
                    # This is a new ngram that we have to remember
                    if tuple(prefix) not in seen_ngrams:
                        seen_ngrams[tuple(prefix)] = []
                    seen_ngrams[tuple(prefix)].append(prediction)

                # Create the new prefix, removing the oldest index if the prefix
                # is too long
                prefix.append(prediction)
                if len(prefix) == self.ngram_size:
                    prefix.pop(0)
        return state

In [None]:
class BeamSearch:


    def __init__(
        self,
        end_index: int,
        *,
        max_steps: int = 50,
        beam_size: int = 10,
        per_node_beam_size: Optional[int] = None,
        sampler: Optional[Sampler] = None,
        min_steps: Optional[int] = None,
        final_sequence_scorer: Optional[FinalSequenceScorer] = None,
        constraints: Optional[List[Constraint]] = None,
    ) -> None:
        if not max_steps > 0:
            raise ValueError("max_steps must be positive")
        if not beam_size > 0:
            raise ValueError("beam_size must be positive")
        if per_node_beam_size is not None and not per_node_beam_size > 0:
            raise ValueError("per_node_beam_size must be positive")
        if min_steps is not None:
            if not min_steps >= 0:
                raise ValueError("min_steps must be non-negative")
            if not min_steps <= max_steps:
                raise ValueError("min_steps must be less than or equal to max_steps")

        self._end_index = end_index
        self.max_steps = max_steps
        self.beam_size = beam_size
        self.per_node_beam_size = per_node_beam_size or beam_size
        self.sampler = sampler
        self.min_steps = min_steps or 0
        self.final_sequence_scorer = final_sequence_scorer
        self.constraints = [constraints] or []

    @staticmethod
    def _reconstruct_sequences(predictions, backpointers):
        # Reconstruct the sequences.
        # shape: [(batch_size, beam_size, 1)]
        reconstructed_predictions = [predictions[-1].unsqueeze(2)]

        if not backpointers:
            return reconstructed_predictions

        # shape: (batch_size, beam_size)
        cur_backpointers = backpointers[-1]

        for timestep in range(len(predictions) - 2, 0, -1):
            # shape: (batch_size, beam_size, 1)
            cur_preds = predictions[timestep].gather(1, cur_backpointers).unsqueeze(2)

            reconstructed_predictions.append(cur_preds)

            # shape: (batch_size, beam_size)
            cur_backpointers = backpointers[timestep - 1].gather(1, cur_backpointers)

        # shape: (batch_size, beam_size, 1)
        final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2)

        reconstructed_predictions.append(final_preds)

        return reconstructed_predictions

    def search(
        self,
        start_predictions: torch.Tensor,
        start_state: StateType,
        step: StepFunctionType,
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        step_signature = signature(step)
        if len(step_signature.parameters) < 3:
            # If the step function we're given does not take the time step argument, wrap it
            # in one that does.
            old_step = cast(StepFunctionTypeNoTimestep, step)

            def new_step(last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], time_step: int):
                del time_step
                return old_step(last_predictions, state)

            return self._search(start_predictions, start_state, new_step)
        else:
            return self._search(start_predictions, start_state, cast(StepFunctionTypeWithTimestep, step))

    def _search(
        self,
        start_predictions: torch.Tensor,
        start_state: StateType,
        step: StepFunctionTypeWithTimestep,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = start_predictions.size()[0]

        # List of (batch_size, beam_size) tensors. One for each time step. Does not
        # include the start symbols, which are implicit.
        predictions: List[torch.Tensor] = []

        # List of (batch_size, beam_size) tensors. One for each time step. None for
        # the first.  Stores the index n for the parent prediction, i.e.
        # predictions[t-1][i][n], that it came from.
        backpointers: List[torch.Tensor] = []

        constraint_states = [constraint.init_state(batch_size) for constraint in self.constraints]

        # Calculate the first timestep. This is done outside the main loop
        # because we are going from a single decoder input (the output from the
        # encoder) to the top `beam_size` decoder outputs. On the other hand,
        # within the main loop we are going from the `beam_size` elements of the
        # beam to `beam_size`^2 candidates from which we will select the top
        # `beam_size` elements for the next iteration.
        # shape: (batch_size, num_classes)
        start_class_log_probabilities, state = step(start_predictions, start_state, 0)

        num_classes = start_class_log_probabilities.size()[1]

        # Make sure `per_node_beam_size` is not larger than `num_classes`.
        if self.per_node_beam_size > num_classes:
            raise ValueError(
                f"Vocab size ({num_classes:d}) too small "
                f"relative to per_node_beam_size ({self.per_node_beam_size:d}).\n"
                f"Please decrease beam_size or per_node_beam_size."
            )

        sampler_state = self.sampler.init_state(start_class_log_probabilities, batch_size, num_classes)

        # Apply all constraints.
        if self.constraints:
            # shape: (batch_size, 1, num_classes)
            expanded_start_class_log_probabilities = start_class_log_probabilities.unsqueeze(1)
            for constraint, constraint_state in zip(self.constraints, constraint_states):
                expanded_start_class_log_probabilities = constraint.apply(
                    constraint_state, expanded_start_class_log_probabilities
                )
            start_class_log_probabilities = expanded_start_class_log_probabilities.squeeze(1)

        # Prevent selecting the end symbol if there is any min_steps constraint
        if self.min_steps >= 1:
            start_class_log_probabilities[:, self._end_index] = torch.finfo(
                start_class_log_probabilities.dtype
            ).min

        # Get the initial predicted classed and their log probabilities.
        # shape: (batch_size, beam_size), (batch_size, beam_size)
        (
            start_top_log_probabilities,
            start_predicted_classes,
            sampler_state,
        ) = self.sampler.sample_beams(start_class_log_probabilities, self.beam_size, sampler_state)

        if self.beam_size == 1 and (start_predicted_classes == self._end_index).all():
            warnings.warn(
                "Empty sequences predicted. You may want to increase the beam size or ensure "
                "your step function is working properly.",
                RuntimeWarning,
            )
            return start_predicted_classes.unsqueeze(-1), start_top_log_probabilities

        # The log probabilities for the last time step.
        # shape: (batch_size, beam_size)
        last_log_probabilities = start_top_log_probabilities

        # shape: [(batch_size, beam_size)]
        predictions.append(start_predicted_classes)

        # Log probability tensor that mandates that the end token is selected.
        # shape: (batch_size * beam_size, num_classes)
        log_probs_after_end = start_class_log_probabilities.new_full(
            (batch_size * self.beam_size, num_classes),
            torch.finfo(start_class_log_probabilities.dtype).min,
        )
        log_probs_after_end[:, self._end_index] = 0.0

        # Set the same state for each element in the beam.
        self._update_initial_state(state, batch_size)

        for i, constraint in enumerate(self.constraints):
            constraint_states[i] = constraint.update_state(constraint_states[i], start_predicted_classes)

        for timestep in range(self.max_steps - 1):
            # shape: (batch_size * beam_size,)
            last_predictions = predictions[-1].reshape(batch_size * self.beam_size)

            # If every predicted token from the last step is `self._end_index`,
            # then we can stop early.
            if (last_predictions == self._end_index).all():
                break
            # Take a step. This get the predicted log probs of the next classes
            # and updates the state.
            # shape: (batch_size * beam_size, num_classes)
            class_log_probabilities, state = step(last_predictions, state, timestep + 1)

            # Apply all constraints.
            if self.constraints:
                # shape: (batch_size, beam_size, num_classes)
                reshaped_class_log_probabilities = class_log_probabilities.view(batch_size, self.beam_size, -1)
                for constraint, constraint_state in zip(self.constraints, constraint_states):
                    reshaped_class_log_probabilities = constraint.apply(
                        constraint_state, reshaped_class_log_probabilities
                    )
                # shape: (batch_size * beam_size, num_classes)
                class_log_probabilities = reshaped_class_log_probabilities.view(batch_size * self.beam_size, -1)

            # The `timestep`-th iteration of the for loop is generating the `timestep + 2`-th token
            # of the sequence (because `timestep` is 0-indexed and we generated the first token
            # before the for loop). Here we block the end index if the search is not allowed to
            # terminate on this iteration.
            if timestep + 2 <= self.min_steps:
                class_log_probabilities[:, self._end_index] = torch.finfo(class_log_probabilities.dtype).min

            # shape: (batch_size * beam_size, num_classes)
            last_predictions_expanded = last_predictions.unsqueeze(-1).expand(
                batch_size * self.beam_size, num_classes
            )

            # Here we are finding any beams where we predicted the end token in
            # the previous timestep and replacing the distribution with a
            # one-hot distribution, forcing the beam to predict the end token
            # this timestep as well.if padding_len > 0:
            # shape: (batch_size * beam_size, num_classes)
            cleaned_log_probabilities = torch.where(
                last_predictions_expanded == self._end_index,
                log_probs_after_end,
                class_log_probabilities,
            )

            # shape (both): (batch_size * beam_size, per_node_beam_size)
            top_log_probabilities, predicted_classes, sampler_state = self.sampler.sample_nodes(
                cleaned_log_probabilities, self.per_node_beam_size, sampler_state
            )

            # Here we expand the last log probabilities to (batch_size * beam_size, per_node_beam_size)
            # so that we can add them to the current log probs for this timestep.
            # This lets us maintain the log probability of each element on the beam.
            # shape: (batch_size * beam_size, per_node_beam_size)
            expanded_last_log_probabilities = (
                last_log_probabilities.unsqueeze(2)
                .expand(batch_size, self.beam_size, self.per_node_beam_size)
                .reshape(batch_size * self.beam_size, self.per_node_beam_size)
            )

            # shape: (batch_size * beam_size, per_node_beam_size)
            summed_top_log_probabilities = top_log_probabilities + expanded_last_log_probabilities

            # shape: (batch_size, beam_size * per_node_beam_size)
            reshaped_summed = summed_top_log_probabilities.reshape(
                batch_size, self.beam_size * self.per_node_beam_size
            )

            # shape: (batch_size, beam_size * per_node_beam_size)
            reshaped_predicted_classes = predicted_classes.reshape(
                batch_size, self.beam_size * self.per_node_beam_size
            )

            # Keep only the top `beam_size` beam indices.
            # shape (both): (batch_size, beam_size)
            (
                restricted_beam_log_probs,
                restricted_beam_indices,
                sampler_state,
            ) = self.sampler.sample_beams(reshaped_summed, self.beam_size, sampler_state)

            # Use the beam indices to extract the corresponding classes.
            # shape: (batch_size, beam_size)
            restricted_predicted_classes = reshaped_predicted_classes.gather(1, restricted_beam_indices)

            predictions.append(restricted_predicted_classes)

            # shape: (batch_size, beam_size)
            last_log_probabilities = restricted_beam_log_probs

            # The beam indices come from a `beam_size * per_node_beam_size` dimension where the
            # indices with a common ancestor are grouped together. Hence
            # dividing by per_node_beam_size gives the ancestor. (Note that this is integer
            # division as the tensor is a LongTensor.)
            # shape: (batch_size, beam_size)
            backpointer = torch.divide(restricted_beam_indices, self.per_node_beam_size, rounding_mode="trunc")
            backpointers.append(backpointer)

            # Keep only the pieces of the state tensors corresponding to the
            # ancestors created this iteration.
            self._update_state(state, backpointer)

            for i, constraint in enumerate(self.constraints):
                constraint_states[i] = constraint.update_state(
                    constraint_states[i], restricted_predicted_classes, last_backpointer=backpointer
                )

        # Warn about "-inf" log probabilities if not using any constraints (negligible
        # log probabilities are expected when using constraints).
        if not self.constraints and (
            not torch.isfinite(last_log_probabilities).all()
            or (last_log_probabilities == torch.finfo(last_log_probabilities.dtype).min).any()
        ):
            warnings.warn(
                "Negligible log probabilities encountered ('-inf' or equivalent). "
                "Some final sequences may not make sense. "
                "This can happen when the beam size is larger than the number of valid (non-zero "
                "probability) transitions that the step function produces.",
                RuntimeWarning,
            )

        reconstructed_predictions = self._reconstruct_sequences(predictions, backpointers)

        # shape: (batch_size, beam_size, max_steps)
        all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2)

        # Calculate the final sequence scores
        # shape: (batch_size, beam_size)
        final_scores = self.final_sequence_scorer.score(all_predictions, last_log_probabilities, self._end_index)

        # Sort the sequences based on the final scores so the best scoring
        # sequence is at index 0
        sorted_final_scores, sorted_indices = torch.sort(final_scores, dim=1, descending=True)
        sorted_all_predictions = torch.gather(
            all_predictions, 1, sorted_indices.unsqueeze(-1).expand_as(all_predictions)
        )

        return sorted_all_predictions, sorted_final_scores

    def _update_initial_state(self, state: StateType, batch_size: int):
        """
        Expand tensors in a state dictionary from `(batch_size, *)` to `(batch_size * beam_size, *)`.
        """
        for key, state_tensor in state.items():
            if state_tensor is None:
                continue
            # shape: (batch_size * beam_size, *)
            _, *last_dims = state_tensor.size()
            state[key] = (
                state_tensor.unsqueeze(1)
                .expand(batch_size, self.beam_size, *last_dims)
                .reshape(batch_size * self.beam_size, *last_dims)
            )

    def _update_state(self, state: StateType, backpointer: torch.Tensor):
        batch_size = backpointer.size()[0]

        for key, state_tensor in state.items():
            if state_tensor is None:
                continue
            _, *last_dims = state_tensor.size()
            # shape: (batch_size, beam_size, *)
            expanded_backpointer = backpointer.view(batch_size, self.beam_size, *([1] * len(last_dims))).expand(
                batch_size, self.beam_size, *last_dims
            )
            # shape: (batch_size * beam_size, *)
            state[key] = (
                state_tensor.reshape(batch_size, self.beam_size, *last_dims)
                .gather(1, expanded_backpointer)
                .reshape(batch_size * self.beam_size, *last_dims)
            )

In [None]:
# def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
#     sos_idx = tokenizer_tgt.token_to_id('[SOS]')
#     eos_idx = tokenizer_tgt.token_to_id('[EOS]')

#     # Precompute the encoder output and reuse it for every step
#     encoder_output = model.encode(source, source_mask)
#     # Initialize the decoder input with the sos token
#     decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
#     while True:
#         if decoder_input.size(1) == max_len:
#             break

#         # build mask for target
#         decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)

#         # calculate output
#         out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

#         # get next token
#         prob = model.project(out[:, -1])
#         _, next_word = torch.max(prob, dim=1)
#         decoder_input = torch.cat(
#             [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
#         )

#         if next_word == eos_idx:
#             break

#     return decoder_input.squeeze(0)

In [None]:
torch.set_float32_matmul_precision('high')

In [None]:

def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, num_examples=2):
    model.eval()
    count = 0
    state = {}

    source_texts = []
    expected = []
    predicted = []

    sos_idx = tokenizer_tgt.token_to_id('<sos>')
    eos_idx = tokenizer_tgt.token_to_id('<eos>')

    try:
        # get the console window width
        with os.popen('stty size', 'r') as console:
            _, console_width = console.read().split()
            console_width = int(console_width)
    except:
        # If we can't get the console width, use 80 as default
        console_width = 80

    beam_search = BeamSearch(
            eos_idx,
            max_steps=610,
            beam_size=3,
            per_node_beam_size=3,
            sampler=TopPSampler(),
            min_steps=5,
            final_sequence_scorer=LengthNormalizedSequenceLogProbabilityScorer(),
            constraints=RepeatedNGramBlockingConstraint(3),
        )


    def step(
            last_predictions: torch.Tensor, state: dict[str, torch.Tensor]
        ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
            nonlocal tokens_generated

            attention_mask = state.get("attention_mask")
            if tokens_generated > 0:
                input_ids = last_predictions.unsqueeze(1)
                if attention_mask is not None:
                    group_size = input_ids.shape[0]
                    attention_mask = torch.cat((attention_mask, attention_mask.new_ones((group_size, 1))), dim=-1)
            else:
                past_key_values = None
                input_ids = last_predictions

            tokens_generated += 1

            # Run forward pass of model to get logits, then normalize to get log probs.
            output = model.decode(encoder_output, encoder_mask, input_ids, attention_mask)
            output = model.project(output)
            log_probs = F.log_softmax(output[:, -1, :], dim=-1)

            # Create new state.
            if attention_mask is not None:
                state["attention_mask"] = attention_mask

            return log_probs, state

    with torch.no_grad():
        for batch in validation_ds:
            tokens_generated = 0
            count += 1
            encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
            encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)

            # check that the batch size is 1
            assert encoder_input.size(
                0) == 1, "Batch size must be 1 for validation"

            # Precompute the encoder output and reuse it for every step
            encoder_output = model.encode(encoder_input,encoder_mask)
            # Initialize the decoder input with the sos token

            initial_preds = torch.empty(1,1).fill_(sos_idx).type_as(encoder_input).to(device)
            decoder_mask = causal_mask(initial_preds.size(1)).type_as(encoder_mask ).to(device)
            state["attention_mask"] = decoder_mask

            state: dict[str, torch.Tensor] = {"input_ids": initial_preds}


            model_out, scores = beam_search.search(initial_preds, state, step)

            source_text = batch["src_text"][0]
            target_text = batch["tgt_text"][0]


            preds = [tokenizer_tgt.decode(preds[0]) for preds in model_out.detach().cpu().numpy()]
            probs = [round(float(p[0]), 4) for p in scores.exp()]
            #model_out_text = tokenizer_tgt.decode_batch(model_out.detach().cpu().numpy())

            source_texts.append(source_text)
            expected.append(target_text)
            #predicted.append(model_out_text)

            # Print the source, target and model output
            print_msg('-'*console_width)
            print_msg(f"{f'SOURCE: ':>12}{source_text}")
            print_msg(f"{f'TARGET: ':>12}{target_text}")
            print_msg(f"{f'PREDICTED: ':>12}{preds}")
            print_msg(f"{f'SCORE: ':>12}{probs}")

            if count == num_examples:
                print_msg('-'*console_width)
                break

    # Evaluate the character error rate
    # Compute the char error rate
    metric = torchmetrics.CharErrorRate()
    cer = metric(predicted, expected)
    wandb.log({'validation/cer': cer, 'global_step': global_step})

    # Compute the word error rate
    metric = torchmetrics.WordErrorRate()
    wer = metric(predicted, expected)
    wandb.log({'validation/wer': wer, 'global_step': global_step})

    # Compute the BLEU metric
    metric = torchmetrics.BLEUScore()
    bleu = metric(predicted, expected)
    wandb.log({'validation/BLEU': bleu, 'global_step': global_step})



In [None]:
import torch
import math

class CustomLRAdamOptimizer:
    def __init__(self, optimizer, model_dimension, num_of_warmup_steps):
        self.optimizer = optimizer
        self.model_size = model_dimension
        self.num_of_warmup_steps = num_of_warmup_steps
        self.current_step_number = 0

    def step(self, closure=None):
        self.current_step_number += 1
        current_learning_rate = self.get_current_learning_rate()

        for p in self.optimizer.param_groups:
            p['lr'] = current_learning_rate

        return self.optimizer.step(closure)

    def state_dict(self):
        return {
            'custom_state': {key: value for key, value in self.__dict__.items() if key != 'optimizer'},
            'optimizer_state_dict': self.optimizer.state_dict()
        }

    def load_state_dict(self, state_dict):
        self.__dict__.update(state_dict['custom_state'])
        self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])

    def get_current_learning_rate(self):
        step = self.current_step_number
        warmup = self.num_of_warmup_steps

        # Warmup phase
        if step <= warmup:
            return (self.model_size ** (-0.5)) * (step * (warmup ** (-1.5)))

        # Decay phase (inverse square root)
        return (self.model_size ** (-0.5)) * (warmup ** 0.5) * (step ** (-0.5))

    @property
    def param_groups(self):
        return self.optimizer.param_groups

    def zero_grad(self, set_to_none=False):
        self.optimizer.zero_grad(set_to_none=set_to_none)

    # Pass through methods to the underlying optimizer
    def __getattr__(self, name):
        return getattr(self.optimizer, name)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.cuda.amp import GradScaler, autocast
from pathlib import Path
from tqdm import tqdm
import wandb
import math

def get_model(config, vocab_src_len, vocab_tgt_len):
    model = build_transformer(vocab_src_len, vocab_tgt_len, d_model=config['d_model'])
    return model

def train_model(config):
    # Define the device
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
    print("Using device:", device)
    if (device == 'cuda'):
        print(f"Device name: {torch.cuda.get_device_name(device.index)}")
        print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
    elif (device == 'mps'):
        print(f"Device name: <mps>")
    else:
        print("NOTE: If you have a GPU, consider using it for training.")
    device = torch.device(device)

    # Make sure the weights folder exists
    Path(f"{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True)

    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt, dict_plot = get_ds(config)
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
    model = torch.compile(model)  # Using torch.compile for potential speedup

    MODEL_DIMENSION = 768
    num_warmup_steps = 4000
    trg_vocab_size = tokenizer_tgt.get_vocab_size()
    pad_token_id = torch.tensor([tokenizer_tgt.token_to_id("<pad>")], dtype=torch.int64)

    optimizer = AdamW(model.parameters(), lr=0.0, betas=(0.9, 0.98), eps=1e-9, weight_decay=0.01)
    custom_lr_optimizer = CustomLRAdamOptimizer(optimizer, MODEL_DIMENSION, num_warmup_steps)

    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id("<pad>"), label_smoothing=0.1).to(device)

    initial_epoch = 0
    global_step = 0
    scaler = GradScaler()

    # Preload model if specified
    preload = config['preload']
    model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None
    if model_filename:
        print(f'Preloading model {model_filename}')
        state = torch.load(model_filename)
        model.load_state_dict(state['model_state_dict'])
        initial_epoch = state['epoch'] + 1
        custom_lr_optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']
    else:
        print('No model to preload, starting from scratch')

    # define our custom x axis metric
    wandb.define_metric("global_step")
    # define which metrics will be plotted against it
    wandb.define_metric("validation/*", step_metric="global_step")
    wandb.define_metric("train/*", step_metric="global_step")

    # Add this before the training loop
    balance_loss_weight = 0.005  # Adjust this value as needed

    for epoch in range(initial_epoch, config['num_epochs']):
        model.train()
        batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")

        for batch in batch_iterator:
            encoder_input = batch['encoder_input'].to(device)
            decoder_input = batch['decoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)
            decoder_mask = batch['decoder_mask'].to(device)
            label = batch['label'].to(device)

            with autocast():
                encoder_output = model.encode(encoder_input, encoder_mask)
                decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
                proj_output = model.project(decoder_output)

                # Compute the main loss
                main_loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))

                # Collect balance losses from all MoE layers
                balance_losses = []
                for module in model.modules():
                    if hasattr(module, 'get_balance_loss'):
                        balance_losses.append(module.get_balance_loss())

                # Compute total balance loss
                total_balance_loss = sum(balance_losses) if balance_losses else torch.tensor(0.0).to(device)

                # Combine main loss and balance loss
                loss = main_loss + balance_loss_weight * total_balance_loss

            scaler.scale(loss).backward()
            scaler.unscale_(custom_lr_optimizer.optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(custom_lr_optimizer)
            scaler.update()
            custom_lr_optimizer.zero_grad()

            batch_iterator.set_postfix({
                "loss": f"{loss.item():6.3f}",
                "main_loss": f"{main_loss.item():6.3f}",
                "balance_loss": f"{total_balance_loss.item():6.3f}"
            })
            wandb.log({
                'train/loss': loss.item(),
                'train/main_loss': main_loss.item(),
                'train/balance_loss': total_balance_loss.item(),
                'train/learning_rate':  custom_lr_optimizer.get_current_learning_rate(),
                'global_step': global_step
            })

            global_step += 1

        run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step)

        model_filename = get_weights_file_path(config, f"{epoch:02d}")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': custom_lr_optimizer.state_dict(),
            'global_step': global_step
        }, model_filename)

    wandb.finish()

In [None]:
warnings.filterwarnings("ignore")
config = get_config()
wandb.init(
        # set the wandb project where this run will be logged
        project="pytorch-transformer-smoe",

        # track hyperparameters and run metadata
        config=config
    )
dict_plot = train_model(config)

In [None]:
torch.cuda.empty_cache()

In [None]:
import matplotlib.pyplot as plt

In [None]:
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2,
                           figsize=(14, 7))

ax1.plot(dict_plot["src_len"])
ax1.axhline(y = 300, color = 'r', linestyle = '--')
ax1.set_title('Scatter: $x$ versus $y$')
ax1.set_xlabel('$x$')
ax1.set_ylabel('$y$')

ax2.plot(dict_plot["tgt_len"])
ax2.axhline(y = 300, color = 'r', linestyle = '--')
ax2.legend(loc=(0.65, 0.8))
ax2.set_title('Frequencies of $x$ and $y$')
ax2.yaxis.tick_right()

In [None]:
import pandas as pd

In [None]:
df = pd.DataFrame.from_dict(dict_plot)

In [None]:
df.sample(8)

In [None]:
def mapper(example):
  if (example["src_len"] < 300 and example["tgt_len"]<300):
    return 1
  else:
    return 0


In [None]:
df["filtered"] = df.apply(lambda x: mapper(x), axis=1)

In [None]:
filtered_rows = df[df['filtered'] == 1]

In [None]:
filtered_rows.head()

In [None]:
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2,
                           figsize=(14, 7))

ax1.plot(filtered_rows["src_len"])
ax1.axhline(y = 300, color = 'r', linestyle = '--')
ax1.set_title('Scatter: $x$ versus $y$')
ax1.set_xlabel('$x$')
ax1.set_ylabel('$y$')

ax2.plot(filtered_rows["tgt_len"])
ax2.axhline(y = 300, color = 'r', linestyle = '--')
ax2.legend(loc=(0.65, 0.8))
ax2.set_title('Frequencies of $x$ and $y$')
ax2.yaxis.tick_right()

In [None]:
len(filtered_rows)

In [None]:
filtered_rows["filtered"].value_counts()

In [None]:
filtered_rows.head()

In [None]:
filtered_rows = filtered_rows.drop(columns=["src_len", "tgt_len", "filtered"], axis=1)

In [None]:
filtered_rows.head()

In [None]:
filtered_rows.to_csv('translation_dataset.csv', index=False)

In [None]:
print(filtered_rows.iloc[4,:]["src"])
print(filtered_rows.iloc[4,:]["tgt"])