In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/wiki-articles-2/cleaned_file.csv
/kaggle/input/tokenizer/tokenizer.json
/kaggle/input/tokenizer/tokenizer_config.json
/kaggle/input/tokenizer/special_tokens_map.json


In [2]:
# At the top of your script
import warnings
# warnings.filterwarnings('ignore', category=RuntimeWarning, message='.*os.fork.*')
warnings.filterwarnings('ignore', category=FutureWarning, message='.*autocast.*')

In [3]:
import torch
print(f"Available GPUs: {torch.cuda.device_count()}")

Available GPUs: 1


In [4]:
!pip install sacrebleu

Collecting sacrebleu
  Downloading sacrebleu-2.4.3-py3-none-any.whl.metadata (51 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting portalocker (from sacrebleu)
  Downloading portalocker-2.10.1-py3-none-any.whl.metadata (8.5 kB)
Downloading sacrebleu-2.4.3-py3-none-any.whl (103 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.0/104.0 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading portalocker-2.10.1-py3-none-any.whl (18 kB)
Installing collected packages: portalocker, sacrebleu
Successfully installed portalocker-2.10.1 sacrebleu-2.4.3


In [5]:
np.random.seed(123)
torch.manual_seed(123)

<torch._C.Generator at 0x799082842e30>

## Env Variables

## Imprtant Imports

In [7]:
import os
import time
import math
import wandb
import torch
import logging

import torch.nn as nn
import torch.optim as optim


from pathlib import Path
from torch.optim.lr_scheduler import OneCycleLR
from torch.cuda.amp import GradScaler, autocast
from typing import Dict, Any, Optional

from tensorboardX import SummaryWriter
# from .metrics import TranslationMetrics
# from .lr_finder import LRFinder
# from .distributed import DistributedManager

## Tokenizer

In [8]:
from transformers import PreTrainedTokenizerFast

# Load the tokenizer from the saved directory
tokenizer = PreTrainedTokenizerFast.from_pretrained("/kaggle/input/tokenizer")

## Configurations

### Model Configurations

In [9]:
# Define your parameters
model_params = {
    'vocab_size': 30000,
    'd_model': 512,
    'num_heads': 8,
    'num_layers': 6,
    'd_ff': 1024,
    'max_seq_length': 256,
    'dropout': 0.1,
    'log_attention_weights': False,
}


### Data Configurations

In [None]:
# Basic configuration
data_dict = {
    'file_path': "/kaggle/input/wiki-articles-2/cleaned_file.csv",
    'batch_size': 16,
    'test_size': 0.1,
    'val_size': 0.1,
    'max_seq_length': 256,
    'num_workers': 2,  # Set to 0 for debugging
    'pin_memory': False  # No need for CPU training
}

# Create pipeline
pipeline = FlexibleDataPipeline(config=data_dict, tokenizer=tokenizer)
train_loader, val_loader, test_loader = pipeline.create_dataloaders()



# For Single GPU
data_dict = {
    'file_path': "/kaggle/input/wiki-articles-2/cleaned_file.csv",
    'batch_size': 16,
    'test_size': 0.1,
    'val_size': 0.1,
    'max_seq_length': 256,
    'num_workers': 2,
    'pin_memory': True
}

pipeline = FlexibleDataPipeline(config=data_dict, tokenizer=tokenizer)
train_loader, val_loader, test_loader = pipeline.create_dataloaders()


# For Multi GPU
import os
os.environ['WORLD_SIZE'] = '2'  # Number of GPUs
os.environ['RANK'] = '0'  # Local rank of this process
os.environ['LOCAL_RANK'] = '0'  # Local rank of this process
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'

# Setup distributed training
is_distributed = setup_distributed()

data_dict = {
    'file_path': "/kaggle/input/wiki-articles-2/cleaned_file.csv",
    'batch_size': 32,  # Will be divided by number of GPUs
    'test_size': 0.1,
    'val_size': 0.1,
    'max_seq_length': 256,
    'num_workers': 2,
    'pin_memory': True
}

pipeline = FlexibleDataPipeline(config=data_dict, tokenizer=tokenizer)
train_loader, val_loader, test_loader = pipeline.create_dataloaders(distributed=is_distributed)

In [13]:
dataset = pd.read_csv("/kaggle/input/wiki-articles-2/cleaned_file.csv")
subset = dataset.head(16)
subset.to_csv("./subset.csv", index=False)

## Dataset Curator

In [14]:
import torch
import pandas as pd
from typing import Tuple, Dict
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

class TranslationDataset(Dataset):
    def __init__(self, src_sentences, tgt_sentences, tokenizer, max_length=256):
        self.src_sentences = src_sentences
        self.tgt_sentences = tgt_sentences
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.src_sentences)

    def __getitem__(self, idx):
        src = self.src_sentences[idx]
        tgt = self.tgt_sentences[idx]
        
        src_tokenized = self.tokenizer(
            src, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt"
        )['input_ids'].squeeze(0)
        
        tgt_tokenized = self.tokenizer(
            tgt, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt"
        )['input_ids'].squeeze(0)
        
        return {'src': src_tokenized, 'tgt': tgt_tokenized}

    @staticmethod
    def collate_fn(batch):
        src_batch = torch.stack([item['src'] for item in batch])
        tgt_batch = torch.stack([item['tgt'] for item in batch])
        return {'src': src_batch, 'tgt': tgt_batch}

class FlexibleDataPipeline:
    def __init__(self, file_path: str, tokenizer, **kwargs):
        """
        Initialize the data pipeline with automatic environment detection
        Args:
            file_path: Path to the CSV file
            tokenizer: Tokenizer instance for processing text
            **kwargs: Optional configuration overrides
        """
        # Detect computing environment
        self.device_info = self._detect_environment()
        
        # Set default configuration based on environment
        self.config = self._get_default_config()
        
        # Override defaults with any provided kwargs
        self.config.update(kwargs)
        
        # Store essential components
        self.file_path = file_path
        self.tokenizer = tokenizer
        
        # Load data
        self.src_sentences, self.tgt_sentences = self._load_sentences()
        
        # Print configuration for transparency
        self._print_config()

    def _detect_environment(self) -> Dict:
        """Detect the computing environment and return relevant information"""
        if torch.cuda.is_available():
            num_gpus = torch.cuda.device_count()
            gpu_names = [torch.cuda.get_device_name(i) for i in range(num_gpus)]
            return {
                'device': 'cuda',
                'num_gpus': num_gpus,
                'gpu_names': gpu_names
            }
        return {
            'device': 'cpu',
            'num_gpus': 0,
            'gpu_names': []
        }

    def _get_default_config(self) -> Dict:
        """Get default configuration based on detected environment"""
        is_cpu = self.device_info['device'] == 'cpu'
        num_gpus = self.device_info['num_gpus']
        
        config = {
            'batch_size': 8 if is_cpu else (16 * num_gpus if num_gpus > 0 else 16),
            'test_size': 0.1,
            'val_size': 0.1,
            'max_seq_length': 256,
            'num_workers': 0,   # No workers for CPU training
            'pin_memory': not is_cpu,  # Only use pin_memory for GPU
            'device': self.device_info['device']
        }
        return config

    def _print_config(self):
        """Print the current configuration and environment details"""
        print("\n=== Environment Detection ===")
        print(f"Device: {self.device_info['device'].upper()}")
        if self.device_info['num_gpus'] > 0:
            print(f"Number of GPUs: {self.device_info['num_gpus']}")
            for i, gpu in enumerate(self.device_info['gpu_names']):
                print(f"GPU {i}: {gpu}")
        
        print("\n=== Pipeline Configuration ===")
        print(f"Batch Size: {self.config['batch_size']}")
        print(f"Number of Workers: {self.config['num_workers']}")
        print(f"Pin Memory: {self.config['pin_memory']}")
        print(f"Max Sequence Length: {self.config['max_seq_length']}")
        print("===========================\n")

    def _load_sentences(self) -> Tuple[list, list]:
        """Load source and target sentences from CSV file"""
        df = pd.read_csv(self.file_path)
        return df.iloc[:, 0].tolist(), df.iloc[:, 1].tolist()

    def create_dataloaders(self) -> Tuple[DataLoader, DataLoader, DataLoader]:
        """
        Create data loaders optimized for the current environment
        Returns:
            Tuple of (train_loader, val_loader, test_loader)
        """
        # Split data
        src_train, src_temp, tgt_train, tgt_temp = train_test_split(
            self.src_sentences, self.tgt_sentences, 
            test_size=self.config['test_size'] + self.config['val_size'], 
            random_state=42
        )
        src_val, src_test, tgt_val, tgt_test = train_test_split(
            src_temp, tgt_temp, 
            test_size=self.config['test_size'] / (self.config['test_size'] + self.config['val_size']), 
            random_state=42
        )

        # Create datasets
        train_dataset = TranslationDataset(
            src_train, tgt_train, self.tokenizer, self.config['max_seq_length']
        )
        val_dataset = TranslationDataset(
            src_val, tgt_val, self.tokenizer, self.config['max_seq_length']
        )
        test_dataset = TranslationDataset(
            src_test, tgt_test, self.tokenizer, self.config['max_seq_length']
        )

        # Create dataloaders with environment-optimized settings
        train_loader = DataLoader(
            dataset=train_dataset,
            batch_size=self.config['batch_size'],
            shuffle=True,
            num_workers=self.config['num_workers'],
            pin_memory=self.config['pin_memory'],
            collate_fn=TranslationDataset.collate_fn
        )

        # Use larger batch size for evaluation when possible
        eval_batch_size = self.config['batch_size'] * 2
        val_loader = DataLoader(
            dataset=val_dataset,
            batch_size=eval_batch_size,
            shuffle=False,
            num_workers=self.config['num_workers'],
            pin_memory=self.config['pin_memory'],
            collate_fn=TranslationDataset.collate_fn
        )

        test_loader = DataLoader(
            dataset=test_dataset,
            batch_size=eval_batch_size,
            shuffle=False,
            num_workers=self.config['num_workers'],
            pin_memory=self.config['pin_memory'],
            collate_fn=TranslationDataset.collate_fn
        )

        return train_loader, val_loader, test_loader

In [15]:
# Create pipeline with minimal configuration
pipeline = FlexibleDataPipeline(
#     file_path="/kaggle/input/wiki-articles-2/cleaned_file.csv",
    file_path = "/kaggle/working/subset.csv",
    tokenizer=tokenizer
)

# Get dataloaders
train_loader, val_loader, test_loader = pipeline.create_dataloaders()

# Prepare model for the detected environment
# model = YourModel()
# model = pipeline.prepare_model(model)


=== Environment Detection ===
Device: CUDA
Number of GPUs: 1
GPU 0: Tesla P100-PCIE-16GB

=== Pipeline Configuration ===
Batch Size: 16
Number of Workers: 0
Pin Memory: True
Max Sequence Length: 256



In [16]:
for batch in train_loader:
    src_data = batch['src']
    tgt_data = batch['tgt']
    
    break

In [17]:
print(tokenizer.decode(src_data[0], skip_special_tokens=True))
print(tokenizer.decode(tgt_data[0], skip_special_tokens=True))

Brienne Sidonie Dessaulniers، جو پیشہ ورانہ طور پر Brie Larson کے نام سے مشہور ہیں، ایک امریکی اداکارہ اور فلم ساز ہیں۔
Brienne Sidonie Dessaulniers, known professionally as Brie Larson, is an American actress and filmmaker.


## Model Architecture

### Model related contstants

In [None]:
# Let's create some constants to make the code a bit cleaner

# Architecture related constants taken from the paper
BASELINE_MODEL_NUMBER_OF_LAYERS = 6
BASELINE_MODEL_DIMENSION = 512
BASELINE_MODEL_NUMBER_OF_HEADS = 8
BASELINE_MODEL_DROPOUT_PROB = 0.1
BASELINE_MODEL_LABEL_SMOOTHING_VALUE = 0.1

CHECKPOINTS_PATH = os.path.join(os.getcwd(), 'models', 'checkpoints') # semi-trained models during training will be dumped here
BINARIES_PATH = os.path.join(os.getcwd(), 'models', 'binaries') # location where trained models are located
DATA_DIR_PATH = os.path.join(os.getcwd(), 'data') # training data will be stored here

os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
os.makedirs(BINARIES_PATH, exist_ok=True)
os.makedirs(DATA_DIR_PATH, exist_ok=True)

# Special token symbols used later in the data section
BOS_TOKEN = '<s>'
EOS_TOKEN = '</s>'
PAD_TOKEN = "<pad>"

### Multi Head Attention Layer

In [18]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, log_attention_weights=False):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.log_attention_weights = log_attention_weights
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e4)
        
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        
        if self.log_attention_weights:
            return output, attn_probs
        return output

    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        if self.log_attention_weights:
            attn_output, attn_probs = self.scaled_dot_product_attention(Q, K, V, mask)
            output = self.W_o(self.combine_heads(attn_output))
            return output, attn_probs
        else:
            attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
            output = self.W_o(self.combine_heads(attn_output))
            return output

### Position-wise Feed-Forward Networks

In [19]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

### Positional Encoding

In [20]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

### Encoder Layer

In [21]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout, log_attention_weights=False):
        super(EncoderLayer, self).__init__()
        self.log_attention_weights = log_attention_weights
        self.self_attn = MultiHeadAttention(d_model, num_heads, log_attention_weights)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        if self.log_attention_weights:
            attn_output, attn_weights = self.self_attn(x, x, x, mask)
            x = self.norm1(x + self.dropout(attn_output))
            ff_output = self.feed_forward(x)
            x = self.norm2(x + self.dropout(ff_output))
            return x, attn_weights
        else:
            attn_output = self.self_attn(x, x, x, mask)
            x = self.norm1(x + self.dropout(attn_output))
            ff_output = self.feed_forward(x)
            x = self.norm2(x + self.dropout(ff_output))
            return x


### Decoder Layer

In [22]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout, log_attention_weights):
        super(DecoderLayer, self).__init__()
        self.log_attention_weights = log_attention_weights
        self.self_attn = MultiHeadAttention(d_model, num_heads, log_attention_weights)
        self.cross_attn = MultiHeadAttention(d_model, num_heads, log_attention_weights)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask):
        if self.log_attention_weights:
            attn_output, self_attn_weights = self.self_attn(x, x, x, tgt_mask)
            x = self.norm1(x + self.dropout(attn_output))
            attn_output, cross_attn_weights = self.cross_attn(x, enc_output, enc_output, src_mask)
            x = self.norm2(x + self.dropout(attn_output))
            ff_output = self.feed_forward(x)
            x = self.norm3(x + self.dropout(ff_output))
            return x, self_attn_weights, cross_attn_weights
        else:
            attn_output = self.self_attn(x, x, x, tgt_mask)
            x = self.norm1(x + self.dropout(attn_output))
            attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
            x = self.norm2(x + self.dropout(attn_output))
            ff_output = self.feed_forward(x)
            x = self.norm3(x + self.dropout(ff_output))
            return x

### Mask Generator

In [23]:
def generate_mask(src, tgt):
    device = src.device
    src_mask = (src != 1).unsqueeze(1).unsqueeze(2)
    tgt_mask = (tgt != 1).unsqueeze(1).unsqueeze(3)
    seq_length = tgt.size(1)
    nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length, device=device), diagonal=1)).bool()
    tgt_mask = tgt_mask & nopeak_mask
    return src_mask, tgt_mask

### Transformer Module

In [24]:
import torch
import torch.nn as nn

# Use model_params dictionary in your Transformer model
class Transformer(nn.Module): 
    def __init__(self, config: dict):
        super(Transformer, self).__init__()
        self.vocab_size = config['vocab_size']
        self.d_model = config['d_model']
        self.num_heads = config['num_heads']
        self.num_layers = config['num_layers']
        self.d_ff = config['d_ff']
        self.max_seq_length = config['max_seq_length']
        self.dropout = config['dropout']
        self.log_attention_weights = config['log_attention_weights']

        self.embedding = nn.Embedding(self.vocab_size, self.d_model)
        self.positional_encoding = PositionalEncoding(self.d_model, self.max_seq_length)
        
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(self.d_model, self.num_heads, self.d_ff, self.dropout, self.log_attention_weights) 
            for _ in range(self.num_layers)
        ])
        
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(self.d_model, self.num_heads, self.d_ff, self.dropout, self.log_attention_weights) 
            for _ in range(self.num_layers)
        ])
        
        self.fc = nn.Linear(self.d_model, self.vocab_size)
        self.dropout_layer = nn.Dropout(self.dropout)

    def generate_mask(self, src, tgt):
        device = src.device
        src_mask = (src != 1).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 1).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length, device=device), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_embedded = self.dropout_layer(self.positional_encoding(self.embedding(src)))
        tgt_embedded = self.dropout_layer(self.positional_encoding(self.embedding(tgt)))
    
        # Initialize lists to store attention weights if logging is enabled
        enc_attentions = [] if self.log_attention_weights else None
        dec_self_attentions = [] if self.log_attention_weights else None
        dec_cross_attentions = [] if self.log_attention_weights else None
    
        # Encoder
        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            if self.log_attention_weights:
                enc_output, enc_attn = enc_layer(enc_output, src_mask)
                enc_attentions.append(enc_attn)
            else:
                enc_output = enc_layer(enc_output, src_mask)
    
        # Decoder
        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            if self.log_attention_weights:
                dec_output, dec_self_attn, dec_cross_attn = dec_layer(
                    dec_output, enc_output, src_mask, tgt_mask
                )
                dec_self_attentions.append(dec_self_attn)
                dec_cross_attentions.append(dec_cross_attn)
            else:
                dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)
    
        output = self.fc(dec_output)
    
        if self.log_attention_weights:
            return output, (enc_attentions, dec_self_attentions, dec_cross_attentions)
        return output


In [25]:
# Define the configuration for the model
device_info = {
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'num_gpus': torch.cuda.device_count() if torch.cuda.is_available() else 0
}


In [26]:
# The prepare_model function
def prepare_model(device_info: dict, model: torch.nn.Module) -> torch.nn.Module:
    """
    Prepare model for the current computing environment
    Args:
        device_info: Dictionary containing device information
        model: PyTorch model to prepare
    Returns:
        Prepared model
    """
    if device_info['device'] == 'cuda':
        if device_info['num_gpus'] > 1:
            model = torch.nn.DataParallel(model)
        model = model.cuda()
    return model

# Initialize and prepare the model
model = Transformer(config=model_params)
model = prepare_model(device_info, model)

In [27]:
model

Transformer(
  (embedding): Embedding(30000, 512)
  (positional_encoding): PositionalEncoding()
  (encoder_layers): ModuleList(
    (0-5): 6 x EncoderLayer(
      (self_attn): MultiHeadAttention(
        (W_q): Linear(in_features=512, out_features=512, bias=True)
        (W_k): Linear(in_features=512, out_features=512, bias=True)
        (W_v): Linear(in_features=512, out_features=512, bias=True)
        (W_o): Linear(in_features=512, out_features=512, bias=True)
      )
      (feed_forward): PositionWiseFeedForward(
        (fc1): Linear(in_features=512, out_features=1024, bias=True)
        (fc2): Linear(in_features=1024, out_features=512, bias=True)
        (relu): ReLU()
      )
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (decoder_layers): ModuleList(
    (0-5): 6 x DecoderLayer(
      (self_attn): MultiHeadAttention(
        (W_q)

## Loop 1

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from torch.utils.tensorboard import SummaryWriter
from nltk.translate.bleu_score import sentence_bleu
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
from torch.amp import GradScaler, autocast
from rich.console import Console
from rich.table import Table
from rich.live import Live
from datetime import datetime
import multiprocessing  # Import multiprocessing for setting start method



class Trainer:
    def __init__(self, model, train_loader, val_loader, test_loader, training_params, tokenizer):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.num_epochs = training_params['num_epochs']
        self.device = training_params['device']
        self.vocab_size = tokenizer.vocab_size
        self.criterion = nn.CrossEntropyLoss(ignore_index=training_params['ignore_index'])
        self.optimizer = optim.Adam(self.model.parameters(), lr=training_params['learning_rate'])
        self.scaler = GradScaler()
        self.clip_grad_norm = training_params.get('clip_grad_norm', None)
        self.device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.console = Console()
        
        # Initialize wandb
        if training_params.get('use_wandb', False):
            wandb.init(
                project=training_params.get('project_name', 'transformer-training'),
                config=training_params
            )
            self.use_wandb = True
        else:
            self.use_wandb = False
        
        # Initialize TensorBoard
        current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
        self.log_dir = os.path.join('runs', f'training_{current_time}')
        os.makedirs(self.log_dir, exist_ok=True)
        self.tb_writer = SummaryWriter(log_dir=self.log_dir)
        self.console.print(f"[bold blue]TensorBoard logs will be saved to: {self.log_dir}[/bold blue]")
        
    def create_metrics_table(self, metrics):
        table = Table(show_header=True, header_style="bold magenta")
        table.add_column("Metric", style="cyan")
        table.add_column("Value", style="green")
        for name, value in metrics.items():
            table.add_row(name, f"{value:.4f}" if isinstance(value, float) else str(value))
        return table

    def train(self):
        self.model.train()
        best_loss = float('inf')
        global_step = 0
        
        for epoch in range(self.num_epochs):
            total_loss = 0
            progress_bar = tqdm(
                enumerate(self.train_loader),
                total=len(self.train_loader),
                desc=f"[Epoch {epoch + 1}/{self.num_epochs}]",
                bar_format="{desc}: {percentage:3.0f}%|{bar:30}{r_bar}",
                ncols=100
            )
            
            for batch_idx, batch in progress_bar:
                src_data = batch['src'].to(self.device)
                tgt_data = batch['tgt'].to(self.device)

                self.optimizer.zero_grad()
                # Use autocast with the new syntax
                with autocast(device_type=self.device_type):
                    output = self.model(src_data, tgt_data[:, :-1])
                    loss = self.criterion(output.contiguous().view(-1, self.vocab_size), 
                                       tgt_data[:, 1:].contiguous().view(-1))

                self.scaler.scale(loss).backward()
                if self.clip_grad_norm:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip_grad_norm)

                self.scaler.step(self.optimizer)
                self.scaler.update()

                total_loss += loss.item()
                avg_loss = total_loss / (batch_idx + 1)
                global_step += 1

                # Update progress bar with colored metrics
                progress_bar.set_postfix_str(
                    f"Loss: {avg_loss:.4f} | LR: {self.optimizer.param_groups[0]['lr']:.6f}"
                )

                # Log to TensorBoard
                self.tb_writer.add_scalar('Loss/train', avg_loss, global_step)
                self.tb_writer.add_scalar('Learning_rate', self.optimizer.param_groups[0]['lr'], global_step)
                
                # Log to WandB if enabled
                if self.use_wandb:
                    wandb.log({
                        'Loss/train': avg_loss,
                        'learning_rate': self.optimizer.param_groups[0]['lr']
                    }, step=global_step)
            
            # Validate after each epoch
            val_loss = self.validate(global_step)
            
            # Create and display metrics table
            metrics = {
                "Epoch": f"{epoch + 1}/{self.num_epochs}",
                "Training Loss": avg_loss,
                "Validation Loss": val_loss,
                "Best Loss": min(best_loss, val_loss),
                "Learning Rate": self.optimizer.param_groups[0]['lr']
            }
            self.console.print(self.create_metrics_table(metrics))
            
            # Log epoch metrics to TensorBoard
            self.tb_writer.add_scalars('Loss/epoch', {
                'train': avg_loss,
                'val': val_loss
            }, epoch)
            
            best_loss = min(best_loss, val_loss)

    def validate(self, global_step):
        self.model.eval()
        total_loss = 0
        progress_bar = tqdm(
            self.val_loader,
            desc="Validating",
            bar_format="{desc}: {percentage:3.0f}%|{bar:30}{r_bar}",
            ncols=100
        )
        
        with torch.no_grad():
            for batch in progress_bar:
                src_data = batch['src'].to(self.device)
                tgt_data = batch['tgt'].to(self.device)

                with autocast(device_type=self.device_type):
                    output = self.model(src_data, tgt_data[:, :-1])
                    loss = self.criterion(output.contiguous().view(-1, self.vocab_size), 
                                       tgt_data[:, 1:].contiguous().view(-1))

                total_loss += loss.item()
                avg_loss = total_loss / len(self.val_loader)
                progress_bar.set_postfix_str(f"Loss: {avg_loss:.4f}")

            # Log to TensorBoard
            self.tb_writer.add_scalar('Loss/val', avg_loss, global_step)
            
            # Log to WandB if enabled
            if self.use_wandb:
                wandb.log({'Loss/val': avg_loss}, step=global_step)
            
            return avg_loss

    def test(self):
        self.model.eval()
        total_loss = 0
        progress_bar = tqdm(
            self.test_loader,
            desc="Testing",
            bar_format="{desc}: {percentage:3.0f}%|{bar:30}{r_bar}",
            ncols=100
        )
        
        with torch.no_grad():
            for batch in progress_bar:
                src_data = batch['src'].to(self.device)
                tgt_data = batch['tgt'].to(self.device)

                with autocast(device_type=self.device_type):
                    output = self.model(src_data, tgt_data[:, :-1])
                    loss = self.criterion(output.contiguous().view(-1, self.vocab_size), 
                                       tgt_data[:, 1:].contiguous().view(-1))

                total_loss += loss.item()
                avg_loss = total_loss / len(self.test_loader)
                progress_bar.set_postfix_str(f"Loss: {avg_loss:.4f}")
                
    

            # Log to TensorBoard
            self.tb_writer.add_scalar('Loss/test', avg_loss, 0)
            
            # Log to WandB if enabled
            if self.use_wandb:
                wandb.log({'Loss/test': avg_loss})
            
            # Display final test results
            metrics = {
                "Final Test Loss": avg_loss
            }
            self.console.print("\n[bold green]Test Results:[/bold green]")
            self.console.print(self.create_metrics_table(metrics))
            
        # Close TensorBoard writer
        self.tb_writer.close()

In [None]:
# Sample training_params definition
training_params = {
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'ignore_index': tokenizer.pad_token_id,  # Replace with the pad token ID used in your tokenizer
    'vocab_size': tokenizer.vocab_size,      # Vocabulary size based on your tokenizer
    'learning_rate': 0.001,                   # Learning rate
    'num_epochs': 20,                         # Total epochs
    'clip_grad_norm': 1.0,                    # Gradient clipping norm
    'save_dir': './checkpoints',              # Directory to save checkpoints
    'log_interval': 100,                      # Interval for logging train loss
    'eval_interval': 1,                       # Epochs between each evaluation
    'warmup_steps': 500,                      # Number of warmup steps for scheduler
    'project_name': 'your_wandb_project_name', # Replace with actual project name for W&B
    'use_wandb': True                         # Set to True if using W&B for logging
}

# Initialize and train
trainer = Trainer(model, train_loader, val_loader, test_loader, training_params, tokenizer)
trainer.train()

## Loop 2

In [33]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from torch.utils.tensorboard import SummaryWriter
from nltk.translate.bleu_score import corpus_bleu, sentence_bleu
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
from torch.amp import GradScaler, autocast
from rich.console import Console
from rich.table import Table
from rich.live import Live
from datetime import datetime
import numpy as np
from pathlib import Path

class EnhancedTrainer:
    def __init__(self, model, train_loader, val_loader, test_loader, training_params, tokenizer):
        """
        Initialize the trainer with model and training parameters.
        """
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.training_params = training_params
        self.tokenizer = tokenizer
        
        # Basic training setup
        self.num_epochs = training_params['num_epochs']
        self.device = training_params['device']
        self.vocab_size = tokenizer.vocab_size
        self.criterion = nn.CrossEntropyLoss(ignore_index=training_params['ignore_index'])
        self.optimizer = optim.Adam(self.model.parameters(), lr=training_params['learning_rate'])
        self.scaler = GradScaler()
        self.clip_grad_norm = training_params.get('clip_grad_norm', None)
        self.device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        # Rich console setup
        self.console = Console()
        
        # Training tracking
        self.best_loss = float('inf')
        self.best_bleu = 0.0
        self.best_model_path = None
        self.global_step = 0
        
        # Initialize learning rate scheduler
        num_training_steps = len(self.train_loader) * self.num_epochs
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=training_params['warmup_steps'],
            num_training_steps=num_training_steps
        )
        
        # Setup logging directories
        self.setup_logging()
        
        # Initialize metrics tracking
        self.reset_metrics()

    def setup_logging(self):
        """Setup logging directories and initialize loggers."""
        # Create checkpoint directory
        self.save_dir = Path(self.training_params['save_dir'])
        self.save_dir.mkdir(parents=True, exist_ok=True)
        
        # Initialize wandb
        if self.training_params.get('use_wandb', False):
            wandb.init(
                project=self.training_params.get('project_name', 'transformer-training'),
                config=self.training_params
            )
            self.use_wandb = True
        else:
            self.use_wandb = False
        
        # Initialize TensorBoard
        current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
        self.log_dir = Path('runs') / f'training_{current_time}'
        self.log_dir.mkdir(parents=True, exist_ok=True)
        self.tb_writer = SummaryWriter(log_dir=str(self.log_dir))
        self.console.print(f"[bold blue]TensorBoard logs will be saved to: {self.log_dir}[/bold blue]")

    def reset_metrics(self):
        """Reset metrics for new epoch."""
        self.metrics = {
            'train_loss': 0.0,
            'train_acc': 0.0,
            'train_ppl': 0.0,
            'val_loss': 0.0,
            'val_acc': 0.0,
            'val_ppl': 0.0,
            'bleu': 0.0
        }

    def calculate_metrics(self, output, target):
        """Calculate accuracy and perplexity for the batch."""
        with torch.no_grad():
            # Calculate accuracy
            pred = output.argmax(dim=-1)
            correct = (pred == target).float()
            mask = (target != self.training_params['ignore_index']).float()
            accuracy = (correct * mask).sum() / mask.sum()
            
            # Calculate perplexity
            loss = self.criterion(output.view(-1, self.vocab_size), target.view(-1))
            perplexity = torch.exp(loss)
            
            return {
                'accuracy': accuracy.item(),
                'perplexity': perplexity.item(),
                'loss': loss.item()
            }

    def save_checkpoint(self, epoch, metrics, is_best=False):
        """Save model checkpoint."""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'scaler_state_dict': self.scaler.state_dict(),
            'metrics': metrics,
            'best_loss': self.best_loss,
            'best_bleu': self.best_bleu
        }
        
        # Save regular checkpoint
        checkpoint_path = self.save_dir / f'checkpoint_epoch_{epoch}.pt'
        torch.save(checkpoint, checkpoint_path)
        
        # Save best model if applicable
        if is_best:
            best_path = self.save_dir / 'best_model.pt'
            torch.save(checkpoint, best_path)
            self.best_model_path = best_path

    def create_metrics_table(self, metrics):
        """Create a rich table for displaying metrics."""
        table = Table(show_header=True, header_style="bold magenta")
        table.add_column("Metric", style="cyan")
        table.add_column("Value", style="green")
        
        for name, value in metrics.items():
            table.add_row(
                name,
                f"{value:.4f}" if isinstance(value, float) else str(value)
            )
        return table

    def log_metrics(self, metrics, step):
        """Log metrics to both TensorBoard and W&B."""
        # Log to TensorBoard
        for name, value in metrics.items():
            self.tb_writer.add_scalar(name, value, step)
        
        # Log to WandB if enabled
        if self.use_wandb:
            wandb.log(metrics, step=step)

    def train(self):
        """Main training loop."""
        self.model.train()
        
        for epoch in range(self.num_epochs):
            self.reset_metrics()
            epoch_start_time = datetime.now()
            
            # Training loop
            progress_bar = tqdm(
                enumerate(self.train_loader),
                total=len(self.train_loader),
                desc=f"[Epoch {epoch + 1}/{self.num_epochs}]",
                bar_format="{desc}: {percentage:3.0f}%|{bar:30}{r_bar}",
                ncols=100
            )
            
            for batch_idx, batch in progress_bar:
                src_data = batch['src'].to(self.device)
                tgt_data = batch['tgt'].to(self.device)

                self.optimizer.zero_grad()
                
                with autocast(device_type=self.device_type):
                    output = self.model(src_data, tgt_data[:, :-1])
                    loss = self.criterion(
                        output.contiguous().view(-1, self.vocab_size),
                        tgt_data[:, 1:].contiguous().view(-1)
                    )

                # Calculate batch metrics
                batch_metrics = self.calculate_metrics(
                    output.contiguous().view(-1, self.vocab_size),
                    tgt_data[:, 1:].contiguous().view(-1)
                )
                
                # Update running metrics
                self.metrics['train_loss'] += batch_metrics['loss']
                self.metrics['train_acc'] += batch_metrics['accuracy']
                self.metrics['train_ppl'] += batch_metrics['perplexity']

                # Backward pass with gradient scaling
                self.scaler.scale(loss).backward()
                
                # Gradient clipping if enabled
                if self.clip_grad_norm:
                    self.scaler.unscale_(self.optimizer)
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(),
                        max_norm=self.clip_grad_norm
                    )

                # Optimizer and scheduler steps
                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.scheduler.step()
                
                # Update progress bar
                avg_loss = self.metrics['train_loss'] / (batch_idx + 1)
                avg_acc = self.metrics['train_acc'] / (batch_idx + 1)
                avg_ppl = self.metrics['train_ppl'] / (batch_idx + 1)
                lr = self.optimizer.param_groups[0]['lr']
                
                progress_bar.set_postfix_str(
                    f"Loss: {avg_loss:.4f} | Acc: {avg_acc:.4f} | "
                    f"PPL: {avg_ppl:.4f} | LR: {lr:.6f}"
                )
                
                # Log step metrics
                if self.global_step % self.training_params['log_interval'] == 0:
                    step_metrics = {
                        'Loss/train': avg_loss,
                        'Accuracy/train': avg_acc,
                        'Perplexity/train': avg_ppl,
                        'Learning_rate': lr
                    }
                    if self.clip_grad_norm:
                        step_metrics['Gradient_norm'] = grad_norm.item()
                    
                    self.log_metrics(step_metrics, self.global_step)
                
                self.global_step += 1
                
        
            
            # Calculate epoch averages
            num_batches = len(self.train_loader)
            self.metrics['train_loss'] /= num_batches
            self.metrics['train_acc'] /= num_batches
            self.metrics['train_ppl'] /= num_batches
            
            # Validation
            if (epoch + 1) % self.training_params['eval_interval'] == 0:
                val_metrics = self.validate()
                self.metrics.update(val_metrics)
                
                # Check for best model
                is_best = val_metrics['val_loss'] < self.best_loss
                if is_best:
                    self.best_loss = val_metrics['val_loss']
                    self.best_bleu = val_metrics['bleu']
                
                # Save checkpoint
                self.save_checkpoint(epoch + 1, self.metrics, is_best)
                
                # Display epoch metrics
                epoch_time = datetime.now() - epoch_start_time
                display_metrics = {
                    "Epoch": f"{epoch + 1}/{self.num_epochs}",
                    "Time": str(epoch_time).split('.')[0],
                    "Train Loss": self.metrics['train_loss'],
                    "Val Loss": self.metrics['val_loss'],
                    "Train Acc": self.metrics['train_acc'],
                    "Val Acc": self.metrics['val_acc'],
                    "Train PPL": self.metrics['train_ppl'],
                    "Val PPL": self.metrics['val_ppl'],
                    "BLEU": self.metrics['bleu'],
                    "Learning Rate": lr
                }
                self.console.print(self.create_metrics_table(display_metrics))
        
        # Final test evaluation
        self.test()

    def validate(self):
        """Validation loop."""
        self.model.eval()
        val_metrics = {
            'val_loss': 0.0,
            'val_acc': 0.0,
            'val_ppl': 0.0,
            'bleu': 0.0
        }
        all_references = []
        all_hypotheses = []
        
        progress_bar = tqdm(
            self.val_loader,
            desc="Validating",
            bar_format="{desc}: {percentage:3.0f}%|{bar:30}{r_bar}",
            ncols=100
        )
        
        with torch.no_grad():
            for batch in progress_bar:
                src_data = batch['src'].to(self.device)
                tgt_data = batch['tgt'].to(self.device)

                with autocast(device_type=self.device_type):
                    output = self.model(src_data, tgt_data[:, :-1])
                    
                    # Calculate metrics
                    batch_metrics = self.calculate_metrics(
                        output.contiguous().view(-1, self.vocab_size),
                        tgt_data[:, 1:].contiguous().view(-1)
                    )
                    
                    # Update running metrics
                    val_metrics['val_loss'] += batch_metrics['loss']
                    val_metrics['val_acc'] += batch_metrics['accuracy']
                    val_metrics['val_ppl'] += batch_metrics['perplexity']
                    
                    # Collect translations for BLEU score
                    predictions = output.argmax(dim=-1)
                    for pred, target in zip(predictions, tgt_data[:, 1:]):
                        pred_tokens = [token.item() for token in pred if token.item() not in [self.tokenizer.pad_token_id, self.tokenizer.eos_token_id]]
                        target_tokens = [token.item() for token in target if token.item() not in [self.tokenizer.pad_token_id, self.tokenizer.eos_token_id]]
                        all_hypotheses.append(pred_tokens)
                        all_references.append([target_tokens])
        
        # Calculate averages
        num_batches = len(self.val_loader)
        val_metrics['val_loss'] /= num_batches
        val_metrics['val_acc'] /= num_batches
        val_metrics['val_ppl'] /= num_batches
        
        # Calculate BLEU score
        val_metrics['bleu'] = corpus_bleu(all_references, all_hypotheses) * 100
        
        return val_metrics

    def test(self):
        """Test loop."""
        self.model.eval()
        test_metrics = {
            'test_loss': 0.0,
            'test_acc': 0.0,
            'test_ppl': 0.0,
            'test_bleu': 0.0
        }
        
        progress_bar = tqdm(
            self.test_loader,
            desc="Testing",
            bar_format="{desc}: {percentage:3.0f}%|{bar:30}{r_bar}",
            ncols=100
        )
        
        with torch.no_grad():
            for batch in progress_bar:
                src_data = batch['src'].to(self.device)
                tgt_data = batch['tgt'].to(self.device)

                with autocast(device_type=self.device_type):
                    output = self.model(src_data, tgt_data[:, :-1])
                    batch_metrics = self.calculate_metrics(
                        output.contiguous().view(-1, self.vocab_size),
                        tgt_data[:, 1:].contiguous().view(-1)
                    )
                    
                    test_metrics['test_loss'] += batch_metrics['loss']
                    test_metrics['test_acc'] += batch_metrics['accuracy']
                    test_metrics['test_ppl'] += batch_metrics['perplexity']

                    # Collect translations for BLEU score
                    predictions = output.argmax(dim=-1)
                    references = []
                    hypotheses = []
                    for pred, target in zip(predictions, tgt_data[:, 1:]):
                        pred_tokens = [token.item() for token in pred if token.item() not in [self.tokenizer.pad_token_id, self.tokenizer.eos_token_id]]
                        target_tokens = [token.item() for token in target if token.item() not in [self.tokenizer.pad_token_id, self.tokenizer.eos_token_id]]
                        hypotheses.append(pred_tokens)
                        references.append([target_tokens])

        # Calculate averages
        num_batches = len(self.test_loader)
        for key in test_metrics:
            test_metrics[key] /= num_batches
        
        # Calculate BLEU score
        test_metrics['test_bleu'] = corpus_bleu(references, hypotheses) * 100

        # Log final test metrics
        self.log_metrics({
            'Loss/test': test_metrics['test_loss'],
            'Accuracy/test': test_metrics['test_acc'],
            'Perplexity/test': test_metrics['test_ppl'],
            'BLEU/test': test_metrics['test_bleu']
        }, self.global_step)

        # Display final test results
        final_metrics = {
            "Final Test Loss": test_metrics['test_loss'],
            "Final Test Accuracy": test_metrics['test_acc'],
            "Final Test Perplexity": test_metrics['test_ppl'],
            "Final Test BLEU": test_metrics['test_bleu']
        }
        
        self.console.print("\n[bold green]Test Results:[/bold green]")
        self.console.print(self.create_metrics_table(final_metrics))
        
        # Close TensorBoard writer
        self.tb_writer.close()

    def generate_translation(self, src_text, max_length=100):
        """
        Generate translation for a given source text.
        """
        self.model.eval()
        
        # Tokenize input text
        src_tokens = self.tokenizer(src_text, return_tensors="pt", padding=True)
        src_tokens = src_tokens['input_ids'].to(self.device)
        
        with torch.no_grad():
            with autocast(device_type=self.device_type):
                # Initialize target sequence with start token
                tgt = torch.tensor([[self.tokenizer.bos_token_id]]).to(self.device)
                
                for _ in range(max_length):
                    # Generate next token
                    output = self.model(src_tokens, tgt)
                    next_token = output[:, -1:].argmax(dim=-1)
                    tgt = torch.cat([tgt, next_token], dim=1)
                    
                    # Stop if end token is generated
                    if next_token.item() == self.tokenizer.eos_token_id:
                        break
        
        # Convert tokens to text
        generated_tokens = tgt[0].tolist()
        translated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
        
        return translated_text

    def load_checkpoint(self, checkpoint_path):
        """
        Load a saved checkpoint.
        """
        self.console.print(f"[bold blue]Loading checkpoint from {checkpoint_path}[/bold blue]")
        
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        
        # Load model and training states
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
        
        # Load best metrics
        self.best_loss = checkpoint.get('best_loss', float('inf'))
        self.best_bleu = checkpoint.get('best_bleu', 0.0)
        
        return checkpoint['epoch']

    def cleanup(self):
        """
        Cleanup resources and save final artifacts.
        """
        # Close TensorBoard writer if it exists
        if hasattr(self, 'tb_writer'):
            self.tb_writer.close()
        
        # Finish wandb run if it was used
        if self.use_wandb:
            wandb.finish()
        
        # Save final model state
        final_checkpoint = {
            'epoch': self.num_epochs,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'scaler_state_dict': self.scaler.state_dict(),
            'best_loss': self.best_loss,
            'best_bleu': self.best_bleu,
            'final_metrics': self.metrics
        }
        
        final_path = self.save_dir / 'final_model.pt'
        torch.save(final_checkpoint, final_path)
        self.console.print(f"[bold green]Final model saved to {final_path}[/bold green]")

In [35]:
# Sample training_params definition
training_params = {
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'ignore_index': tokenizer.pad_token_id,  # Replace with the pad token ID used in your tokenizer
    'vocab_size': tokenizer.vocab_size,      # Vocabulary size based on your tokenizer
    'learning_rate': 0.001,                   # Learning rate
    'num_epochs': 20,                         # Total epochs
    'clip_grad_norm': 1.0,                    # Gradient clipping norm
    'save_dir': './checkpoints',              # Directory to save checkpoints
    'log_interval': 100,                      # Interval for logging train loss
    'eval_interval': 1,                       # Epochs between each evaluation
    'warmup_steps': 500,                      # Number of warmup steps for scheduler
    'project_name': 'your_wandb_project_name', # Replace with actual project name for W&B
    'use_wandb': True                         # Set to True if using W&B for logging
}

# Initialize trainer
trainer = EnhancedTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    training_params=training_params,
    tokenizer=tokenizer
)

# Train the model
trainer.train()

# Generate translations
# translation = trainer.generate_translation("Your source text here")

# Load a checkpoint
# trainer.load_checkpoint("path/to/checkpoint.pt")

# Cleanup when done
trainer.cleanup()

VBox(children=(Label(value='0.020 MB of 0.020 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Accuracy/train,▁
Gradient_norm,▁
Learning_rate,▁
Loss/train,▁
Perplexity/train,▁

0,1
Accuracy/train,0.04142
Gradient_norm,2.47832
Learning_rate,0.0
Loss/train,8.70312
Perplexity/train,6020.0


[Epoch 1/20]: 100%|██████████████████████████████| 1/1 [00:00<00:00,  4.20it/s, Loss: 8.3672 | Acc: 
Validating: 100%|██████████████████████████████| 1/1 [00:00<00:00, 25.69it/s]


[Epoch 2/20]: 100%|██████████████████████████████| 1/1 [00:00<00:00,  4.28it/s, Loss: 8.2422 | Acc: 
Validating: 100%|██████████████████████████████| 1/1 [00:00<00:00, 25.04it/s]


[Epoch 3/20]: 100%|██████████████████████████████| 1/1 [00:00<00:00,  4.54it/s, Loss: 8.2344 | Acc: 
Validating: 100%|██████████████████████████████| 1/1 [00:00<00:00, 24.41it/s]


[Epoch 4/20]: 100%|██████████████████████████████| 1/1 [00:00<00:00,  4.54it/s, Loss: 8.2188 | Acc: 
Validating: 100%|██████████████████████████████| 1/1 [00:00<00:00, 25.02it/s]


KeyboardInterrupt: 

In [None]:
import os

checkpoint_path = '/kaggle/working/checkpoints/final_model.pt'
if os.path.exists(checkpoint_path):
    print("The checkpoint file exists.")
else:
    print("The checkpoint file does not exist.")


In [None]:
import os
import torch

# Assuming your model is already defined
# model = YourModelClass(...)  # Define your model here

checkpoint_path = '/kaggle/working/checkpoints/final_model.pt'

# Check if the checkpoint file exists
if os.path.exists(checkpoint_path):
    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path, weights_only=True)  # Use weights_only=True for safety

    # Load the model state
    model.load_state_dict(checkpoint['model_state_dict'])

    # Set the model to evaluation mode
    model.eval()

    print(f"Model loaded successfully for inference from {checkpoint_path}.")
else:
    print(f"Checkpoint file does not exist at {checkpoint_path}.")


In [None]:
# Load TensorBoard
%load_ext tensorboard

# Specify the log directory
log_dir = '/kaggle/working/runs/training_20241030_122948/events.out.tfevents.1730291388.e073326608c4.30.0'  # This points to your log directory

# Start TensorBoard
%tensorboard --logdir $log_dir


## Visulizations

In [None]:
!pip install arabic-reshaper python-bidi

In [None]:
import matplotlib
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['font.sans-serif'] = ['Noto Nastaliq Urdu']

In [None]:
import matplotlib.font_manager as fm
fm.fontManager.ttflist = []  # Clear the list
for font in fm.findSystemFonts():
    try:
        fm.fontManager.addfont(font)
    except:
        pass

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Tuple, Optional
import matplotlib.gridspec as gridspec
import matplotlib.font_manager as fm
from bidi.algorithm import get_display
import arabic_reshaper

class AttentionVisualizer:
    def __init__(self, tokenizer, font_path=None):
        """
        Initialize the visualizer with optional font path for Urdu text
        If font_path is None, will try to use Noto Naskh Arabic or similar fonts
        """
        self.tokenizer = tokenizer
        self.layer_names = {
            'encoder': 'Encoder Layer',
            'decoder_self': 'Decoder Self-Attention Layer',
            'decoder_cross': 'Decoder Cross-Attention Layer'
        }
        
        # Set up font for Urdu text
        if font_path:
            self.font_prop = fm.FontProperties(fname=font_path)
        else:
            # Try to find a suitable Arabic/Urdu font
            font_names = ['Noto Naskh Arabic', 'Arabic Typesetting', 'Traditional Arabic']
            for font_name in font_names:
                if any(f.name == font_name for f in fm.fontManager.ttflist):
                    self.font_prop = fm.FontProperties(family=font_name)
                    break
            else:
                print("Warning: No suitable Arabic/Urdu font found. Text may not display correctly.")
                self.font_prop = None

    def _process_urdu_text(self, text: str) -> str:
        """Process Urdu text for proper display"""
        # Reshape Arabic/Urdu text
        reshaped_text = arabic_reshaper.reshape(text)
        # Handle right-to-left text
        bidi_text = get_display(reshaped_text)
        return bidi_text

    def _get_tokens_from_ids(self, token_ids: torch.Tensor, batch_idx: int = 0) -> List[str]:
        """Convert token IDs to tokens, handling batched input"""
        # Remove padding tokens
        mask = token_ids[batch_idx] != self.tokenizer.pad_token_id
        valid_ids = token_ids[batch_idx][mask]
        # Decode individual tokens and process for Urdu display
        tokens = [self._process_urdu_text(self.tokenizer.decode([id.item()], skip_special_tokens=False))
                 for id in valid_ids]
        return tokens

    def plot_attention_weights(self, 
                             attention_weights: Tuple[List[torch.Tensor], ...],
                             src_ids: torch.Tensor,
                             tgt_ids: torch.Tensor,
                             batch_idx: int = 0,
                             layer_idx: int = 0,
                             head_idx: int = 0,
                             save_path: Optional[str] = None,
                             figsize=(20, 15)):
        """Plot attention weights with proper Urdu text support"""
        enc_attentions, dec_self_attentions, dec_cross_attentions = attention_weights
        
        # Get tokens from IDs for the specified batch
        src_tokens = self._get_tokens_from_ids(src_ids, batch_idx)
        tgt_tokens = self._get_tokens_from_ids(tgt_ids, batch_idx)

        # Create figure with larger size to accommodate Urdu text
        plt.rcParams['figure.figsize'] = figsize
        fig = plt.figure()
        gs = gridspec.GridSpec(2, 2, figure=fig)
        
        # 1. Encoder Self-Attention
        ax1 = fig.add_subplot(gs[0, 0])
        self._plot_attention_map(
            attention_weights=enc_attentions[layer_idx][batch_idx, head_idx, :len(src_tokens), :len(src_tokens)].detach().cpu(),
            x_labels=src_tokens,
            y_labels=src_tokens,
            title=f"{self.layer_names['encoder']} {layer_idx+1}\nHead {head_idx+1}",
            ax=ax1
        )

        # 2. Decoder Self-Attention
        ax2 = fig.add_subplot(gs[0, 1])
        self._plot_attention_map(
            attention_weights=dec_self_attentions[layer_idx][batch_idx, head_idx, :len(tgt_tokens), :len(tgt_tokens)].detach().cpu(),
            x_labels=tgt_tokens,
            y_labels=tgt_tokens,
            title=f"{self.layer_names['decoder_self']} {layer_idx+1}\nHead {head_idx+1}",
            ax=ax2
        )

        # 3. Decoder Cross-Attention
        ax3 = fig.add_subplot(gs[1, :])
        self._plot_attention_map(
            attention_weights=dec_cross_attentions[layer_idx][batch_idx, head_idx, :len(tgt_tokens), :len(src_tokens)].detach().cpu(),
            x_labels=src_tokens,
            y_labels=tgt_tokens,
            title=f"{self.layer_names['decoder_cross']} {layer_idx+1}\nHead {head_idx+1}",
            ax=ax3
        )

        plt.tight_layout(pad=3.0)  # Added extra padding for Urdu text
        if save_path:
            plt.savefig(save_path, bbox_inches='tight', dpi=300)
        plt.show()

    def _plot_attention_map(self,
                           attention_weights: torch.Tensor,
                           x_labels: List[str],
                           y_labels: List[str],
                           title: str,
                           ax: plt.Axes):
        """Plot a single attention map with Urdu text support"""
        sns.heatmap(
            attention_weights,
            xticklabels=x_labels,
            yticklabels=y_labels,
            cmap='YlOrRd',
            ax=ax,
            cbar_kws={'label': 'Attention Weight'}
        )
        ax.set_title(title)
        ax.set_xlabel('Key Tokens')
        ax.set_ylabel('Query Tokens')
        
        # Adjust label properties for better Urdu text display
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right', fontproperties=self.font_prop)
        plt.setp(ax.get_yticklabels(), rotation=0, fontproperties=self.font_prop)
        
        # Increase spacing for tick labels
        ax.tick_params(axis='x', pad=10)
        ax.tick_params(axis='y', pad=10)

In [None]:
import matplotlib
matplotlib.rcParams['font.family'] = ['Noto Nastaliq Urdu', 'Arial']

In [None]:
visualizer = AttentionVisualizer(tokenizer)

visualizer.plot_attention_weights(
    attention_weights=attention_weights,
    src_ids=src_data,
    tgt_ids=tgt_data,
    batch_idx=0,
    layer_idx=0,  # Change to view different layers
    head_idx=0    # Change to view different heads
)

In [None]:
import torch
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
from typing import List, Tuple, Optional

# Ensure proper display in the notebook
pio.renderers.default = 'notebook_connected'

class AttentionVisualizerPlotly:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.layer_names = {
            'encoder': 'Encoder Layer',
            'decoder_self': 'Decoder Self-Attention Layer',
            'decoder_cross': 'Decoder Cross-Attention Layer'
        }

    def _get_tokens_from_ids(self, token_ids: torch.Tensor, batch_idx: int = 0) -> List[str]:
        """Convert token IDs to tokens, handling batched input"""
        # Remove padding tokens
        mask = token_ids[batch_idx] != self.tokenizer.pad_token_id
        valid_ids = token_ids[batch_idx][mask]
        # Decode individual tokens
        tokens = [self.tokenizer.decode([id.item()], skip_special_tokens=False) 
                  for id in valid_ids]
        return tokens

    def plot_attention_weights(self, 
                               attention_weights: Tuple[List[torch.Tensor], ...],
                               src_ids: torch.Tensor,
                               tgt_ids: torch.Tensor,
                               batch_idx: int = 0,
                               layer_idx: int = 0,
                               head_idx: int = 0,
                               save_path: Optional[str] = None):
        """
        Plot attention weights using token IDs directly with Plotly.
        
        Args:
            attention_weights: Tuple of (encoder_attentions, decoder_self_attentions, decoder_cross_attentions)
            src_ids: Source token IDs [batch_size, src_len]
            tgt_ids: Target token IDs [batch_size, tgt_len]
            batch_idx: Which batch item to visualize
            layer_idx: Index of the layer to visualize
            head_idx: Index of the attention head to visualize
            save_path: Optional path to save the plot
        """
        enc_attentions, dec_self_attentions, dec_cross_attentions = attention_weights
        
        # Get tokens from IDs for the specified batch
        src_tokens = self._get_tokens_from_ids(src_ids, batch_idx)
        tgt_tokens = self._get_tokens_from_ids(tgt_ids, batch_idx)

        # 1. Encoder Self-Attention
        fig = self._plot_attention_map(
            attention_weights=enc_attentions[layer_idx][batch_idx, head_idx, :len(src_tokens), :len(src_tokens)].detach().cpu(),
            x_labels=src_tokens,
            y_labels=src_tokens,
            title=f"{self.layer_names['encoder']} {layer_idx+1}\nHead {head_idx+1}",
            fig=go.Figure()
        )

        # 2. Decoder Self-Attention
        fig = self._plot_attention_map(
            attention_weights=dec_self_attentions[layer_idx][batch_idx, head_idx, :len(tgt_tokens), :len(tgt_tokens)].detach().cpu(),
            x_labels=tgt_tokens,
            y_labels=tgt_tokens,
            title=f"{self.layer_names['decoder_self']} {layer_idx+1}\nHead {head_idx+1}",
            fig=fig
        )

        # 3. Decoder Cross-Attention
        fig = self._plot_attention_map(
            attention_weights=dec_cross_attentions[layer_idx][batch_idx, head_idx, :len(tgt_tokens), :len(src_tokens)].detach().cpu(),
            x_labels=src_tokens,
            y_labels=tgt_tokens,
            title=f"{self.layer_names['decoder_cross']} {layer_idx+1}\nHead {head_idx+1}",
            fig=fig
        )

        if save_path:
            fig.write_image(save_path)
        fig.show()

    def _plot_attention_map(self,
                            attention_weights: torch.Tensor,
                            x_labels: List[str],
                            y_labels: List[str],
                            title: str,
                            fig: go.Figure) -> go.Figure:
        """Plot a single attention map with Plotly."""
        heatmap = go.Heatmap(
            z=attention_weights.numpy(),
            x=x_labels,
            y=y_labels,
            colorscale='YlOrRd',
            colorbar=dict(title='Attention Weight')
        )

        fig.add_trace(heatmap)
        fig.update_layout(
            title=title,
            xaxis_title='Key Tokens',
            yaxis_title='Query Tokens',
            xaxis_tickangle=45
        )
        return fig

# Example usage:
# Assuming you have the necessary attention weights, source and target token IDs, and a tokenizer.
visualizer = AttentionVisualizerPlotly(tokenizer)
visualizer.plot_attention_weights(attention_weights, src_data, tgt_data)


In [None]:
visualizer = AttentionVisualizer(tokenizer)

In [None]:
# # After your training loop where you have attention_weights
# visualizer.plot_attention_weights(
#     attention_weights=attention_weights,
#     src_text= src_data,
#     tgt_text=tgt_data,
#     layer_idx=0,  # Choose layer to visualize
#     head_idx=0,   # Choose head to visualize
#     save_path="attention_plot.png"  # Optional
# )

visualizer.plot_attention_weights(
    attention_weights=attention_weights,
    src_ids=src_data,
    tgt_ids=tgt_data,
    batch_idx=0,
    layer_idx=0,  # Change to view different layers
    head_idx=0    # Change to view different heads
)

## Training Loop

### Training Loop Configurations

In [None]:
# Define the configuration class for training hyperparameters
class TrainConfig:
    def __init__(self, num_epochs=10, lr=1e-4, batch_size=32, device=None, criterion=None, optimizer_cls=None):
        self.num_epochs = num_epochs
        self.lr = lr
        self.batch_size = batch_size
        self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')
        self.criterion = criterion if criterion else nn.CrossEntropyLoss(ignore_index=1)
        self.optimizer_cls = optimizer_cls if optimizer_cls else optim.Adam

In [None]:
import os
import time
import logging
from pathlib import Path
from datetime import datetime
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.tensorboard import SummaryWriter
import wandb
from sacrebleu.metrics import BLEU
import numpy as np
from typing import Dict, Any
from dataclasses import dataclass

@dataclass
class TrainingConfig:
    base_dir: str = "mt_training"
    model_name: str = "urdu_english_transformer"
    num_epochs: int = 100
    batch_size: int = 32
    grad_accum_steps: int = 4
    max_grad_norm: float = 1.0
    warmup_steps: int = 4000
    eval_steps: int = 1000
    save_steps: int = 2000
    early_stopping_patience: int = 5
    max_seq_length: int = 128
    learning_rate: float = 3e-4
    min_learning_rate: float = 1e-5
    weight_decay: float = 0.01
    log_interval: int = 100

In [None]:
class MTTrainer:
    def __init__(self, config: TrainingConfig, model, tokenizer):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.setup_directories()
        self.setup_logging()
        self.setup_distributed()
        self.setup_model(model, tokenizer)
        self.setup_tracking()
        
    def setup_directories(self):
        """Create necessary directories for logs, checkpoints, etc."""
        self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.run_dir = Path(self.config.base_dir) / f"run_{self.timestamp}"
        
        self.dirs = {
            'checkpoints': self.run_dir / 'checkpoints',
            'logs': self.run_dir / 'logs',
            'tensorboard': self.run_dir / 'tensorboard',
            'predictions': self.run_dir / 'predictions'
        }
        
        for dir_path in self.dirs.values():
            dir_path.mkdir(parents=True, exist_ok=True)

    def setup_logging(self):
        """Configure logging to file and console"""
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(self.dirs['logs'] / 'training.log'),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)

    def setup_distributed(self):
        """Initialize distributed training"""
        self.local_rank = int(os.environ.get('LOCAL_RANK', 0))
        self.world_size = int(os.environ.get('WORLD_SIZE', 1))
        
        if self.world_size > 1:
            torch.cuda.set_device(self.local_rank)
            dist.init_process_group(backend='nccl')
            self.logger.info(f"Initialized distributed training with world size {self.world_size}")

    def setup_model(self, model, tokenizer):
        """Setup model, optimizer, and scheduler"""
        self.model = model.to(self.device)
        self.tokenizer = tokenizer
        
        if self.world_size > 1:
            self.model = DDP(self.model, device_ids=[self.local_rank])
        
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=self.config.learning_rate,
            weight_decay=self.config.weight_decay
        )
        
        # Calculate steps per epoch for scheduling
        self.steps_per_epoch = None  # Set later in train()

        self.scheduler = None  # Set in train() after steps_per_epoch is known
        self.scaler = torch.cuda.amp.GradScaler() if self.device.type == 'cuda' else None
        
    def setup_tracking(self):
        """Initialize wandb and tensorboard"""
        if self.local_rank == 0:
            self.writer = SummaryWriter(self.dirs['tensorboard'])
            wandb.init(
                project="urdu_english_mt",
                name=f"run_{self.timestamp}",
                config=self.config.__dict__,
                dir=str(self.run_dir)
            )
            
            # Log model architecture
            self.writer.add_text(
                'Model/Architecture', 
                str(self.model), 
                0
            )

    def get_scheduler(self):
        """Create learning rate scheduler with warmup"""
        return torch.optim.lr_scheduler.OneCycleLR(
            self.optimizer,
            max_lr=self.config.learning_rate,
            total_steps=self.config.num_epochs * self.steps_per_epoch,
            pct_start=0.1,
            anneal_strategy='cos',
            final_div_factor=self.config.learning_rate / self.config.min_learning_rate
        )

    def train_epoch(self, train_loader, epoch: int):
        """Train for one epoch"""
        self.model.train()
        epoch_loss = 0
        epoch_metrics = {
            'attention_entropy': [],
            'grad_norm': [],
            'learning_rate': []
        }
        
        for step, batch in enumerate(train_loader):
            # Move batch to the correct device
            for key in batch:
                batch[key] = batch[key].to(self.device)
            
            with torch.cuda.amp.autocast(enabled=self.device.type == 'cuda'):
                loss, metrics = self.training_step(batch)
                
            # Gradient accumulation
            scaled_loss = self.scaler.scale(loss / self.config.grad_accum_steps) if self.scaler else loss / self.config.grad_accum_steps
            scaled_loss.backward() if self.scaler else loss.backward()
            
            if (step + 1) % self.config.grad_accum_steps == 0:
                if self.scaler:
                    self.scaler.unscale_(self.optimizer)
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), 
                        self.config.max_grad_norm
                    )
                    
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.scheduler.step()  # Scheduler step after optimizer
                else:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), 
                        self.config.max_grad_norm
                    )
                    self.optimizer.step()
                    self.scheduler.step()  # Scheduler step after optimizer
                
                self.optimizer.zero_grad()
                
                # Update metrics
                epoch_metrics['grad_norm'].append(grad_norm.item())
                epoch_metrics['learning_rate'].append(self.scheduler.get_last_lr()[0])
                
                if step % self.config.log_interval == 0:
                    self.log_step(epoch, step, loss.item(), metrics, epoch_metrics)
            
            epoch_loss += loss.item()
            
            # Save checkpoint if needed
            global_step = epoch * len(train_loader) + step
            if global_step % self.config.save_steps == 0:
                self.save_checkpoint(epoch, global_step, metrics)
                
        return epoch_loss / len(train_loader), epoch_metrics

    def training_step(self, batch):
        """Single training step"""
        source_ids = batch['src']
        target_ids = batch['tgt']
        
        outputs = self.model(
            input_ids=source_ids,
            decoder_input_ids=target_ids[:, :-1],
            labels=target_ids[:, 1:],
            return_dict=True,
            output_attentions=True  # Ensure attention weights are returned
        )
        
        # Calculate metrics
        metrics = self.calculate_step_metrics(outputs, target_ids)
        
        return outputs.loss, metrics

    def calculate_step_metrics(self, outputs, target_ids):
        """Calculate various metrics for the current step"""
        with torch.no_grad():
            metrics = {
                'attention_entropy': self.calculate_attention_entropy(outputs.attentions),
                'accuracy': self.calculate_accuracy(outputs.logits, target_ids),
                'perplexity': torch.exp(outputs.loss).item()
            }
        return metrics

    def calculate_attention_entropy(self, attention_maps):
        """Calculate entropy of attention distributions"""
        if attention_maps is None:
            return 0.0
            
        entropy = 0
        for attn in attention_maps:
            probs = torch.softmax(attn, dim=-1)
            entropy -= torch.mean(torch.sum(probs * torch.log(probs + 1e-9), dim=-1))
        return entropy.item() / len(attention_maps)

    def calculate_accuracy(self, logits, labels):
        """Calculate token-level accuracy"""
        predictions = torch.argmax(logits, dim=-1)
        mask = labels != self.tokenizer.pad_token_id
        correct = (predictions == labels) & mask
        return torch.sum(correct).item() / torch.sum(mask).item()

    def evaluate(self, valid_loader):
        """Evaluate on the validation set"""
        self.model.eval()
        val_loss = 0
        bleu = BLEU()
        predictions = []
        references = []
        
        with torch.no_grad():
            for batch in valid_loader:
                # Move batch to the correct device
                for key in batch:
                    batch[key] = batch[key].to(self.device)
                
                loss, metrics = self.training_step(batch)
                val_loss += loss.item()
                
                # Collect predictions and references for BLEU
                preds = self.model.generate(batch['src'])
                decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
                decoded_labels = self.tokenizer.batch_decode(batch['tgt'], skip_special_tokens=True)
                
                predictions.extend([self.tokenizer.tokenize(pred) for pred in decoded_preds])
                references.extend([[self.tokenizer.tokenize(ref)] for ref in decoded_labels])
        
        avg_val_loss = val_loss / len(valid_loader)
        bleu_score = bleu.corpus_score(predictions, references).score
        
        return avg_val_loss, bleu_score

    def save_checkpoint(self, epoch, step, metrics):
        """Save model checkpoint"""
        checkpoint_path = self.dirs['checkpoints'] / f"checkpoint_epoch_{epoch}_step_{step}.pt"
        torch.save({
            'epoch': epoch,
            'step': step,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'metrics': metrics
        }, checkpoint_path)
        self.logger.info(f"Saved checkpoint to {checkpoint_path}")

    def log_step(self, epoch, step, loss, metrics, epoch_metrics):
        """Log training step information"""
        self.logger.info(f"Epoch {epoch} Step {step}: Loss {loss:.4f}, "
                         f"Attention Entropy: {metrics['attention_entropy']:.4f}, "
                         f"Accuracy: {metrics['accuracy']:.4f}, "
                         f"Perplexity: {metrics['perplexity']:.4f}")
        
        if self.local_rank == 0:
            self.writer.add_scalar('Loss/train', loss, epoch * self.steps_per_epoch + step)
            self.writer.add_scalar('Attention Entropy', metrics['attention_entropy'], epoch * self.steps_per_epoch + step)
            self.writer.add_scalar('Accuracy', metrics['accuracy'], epoch * self.steps_per_epoch + step)
            self.writer.add_scalar('Perplexity', metrics['perplexity'], epoch * self.steps_per_epoch + step)
            self.writer.add_scalar('Learning Rate', epoch_metrics['learning_rate'][-1], epoch * self.steps_per_epoch + step)
            self.writer.add_scalar('Grad Norm', epoch_metrics['grad_norm'][-1], epoch * self.steps_per_epoch + step)

    def train(self, train_loader, valid_loader=None):
        """Main training loop"""
        self.steps_per_epoch = len(train_loader) // self.config.grad_accum_steps
        self.scheduler = self.get_scheduler()
        
        for epoch in range(self.config.num_epochs):
            epoch_loss, epoch_metrics = self.train_epoch(train_loader, epoch)
            
            if self.local_rank == 0:
                self.logger.info(f"Epoch {epoch}: Average Training Loss {epoch_loss:.4f}")
                self.writer.add_scalar('Loss/train_epoch', epoch_loss, epoch)
            
            # Validate after each epoch if validation set is provided
            if valid_loader is not None:
                val_loss, bleu_score = self.evaluate(valid_loader)
                
                if self.local_rank == 0:
                    self.logger.info(f"Epoch {epoch}: Validation Loss {val_loss:.4f}, BLEU Score: {bleu_score:.4f}")
                    self.writer.add_scalar('Loss/val_epoch', val_loss, epoch)
                    self.writer.add_scalar('BLEU', bleu_score, epoch)
            
            # Save model checkpoint after each epoch
            if self.local_rank == 0:
                self.save_checkpoint(epoch, (epoch + 1) * self.steps_per_epoch, epoch_metrics)
        
        if self.local_rank == 0:
            self.writer.close()
            wandb.finish()


In [None]:
if __name__ == "__main__":

    # Initialize the pipeline
    pipeline = DataPipeline(data_dict, tokenizer)

    # Create DataLoaders
    # Set is_distributed=True if you are doing distributed training
    train_loader, val_loader, test_loader = pipeline.create_dataloaders(is_distributed=False)

    # Initializing model
    model = Transformer(**model_params)
    tokenizer = PreTrainedTokenizerFast.from_pretrained("/kaggle/input/tokenizer")
    # Initializaing Training Loop
    config = TrainingConfig()
    trainer = MTTrainer(config, model, tokenizer)
    trainer.train(train_loader, val_loader)

In [None]:
!pip install sacrebleu

## Config Class

In [None]:
from dataclasses import dataclass
from typing import Optional, Dict, List, Union
import torch

@dataclass
class ModelConfig:
    """Transformer model configuration"""
    vocab_size: int = 32000
    d_model: int = 1024
    num_heads: int = 16
    num_layers: int = 24
    d_ff: int = 4096
    dropout: float = 0.1
    max_sequence_length: int = 512
    activation: str = "gelu"
    layer_norm_eps: float = 1e-5
    pre_norm: bool = True  # Pre-norm vs Post-norm
    attention_dropout: float = 0.1
    position_embedding_type: str = "rotary"  # rotary, alibi, or learned

@dataclass
class TrainingConfig:
    """Training hyperparameters and settings"""
    # Basic training params
    num_epochs: int = 100
    train_batch_size: int = 128
    eval_batch_size: int = 64
    gradient_accumulation_steps: int = 32
    max_grad_norm: float = 1.0
    
    # Optimization
    optimizer: str = "adamw"  # adamw, adafactor, lion
    learning_rate: float = 1e-4
    min_learning_rate: float = 1e-5
    weight_decay: float = 0.01
    adam_beta1: float = 0.9
    adam_beta2: float = 0.998
    adam_epsilon: float = 1e-8
    
    # Learning rate schedule
    lr_scheduler: str = "cosine_with_warmup"  # linear, cosine, polynomial
    warmup_steps: int = 2000
    warmup_ratio: float = 0.01
    
    # Regularization
    dropout: float = 0.1
    label_smoothing: float = 0.1
    gradient_checkpointing: bool = True
    
    # Mixed precision training
    fp16: bool = True
    bf16: bool = False  # bfloat16 support
    fp16_opt_level: str = "O2"
    
    # Distributed training
    distributed_strategy: str = "ddp"  # ddp, deepspeed, fsdp
    ddp_find_unused_parameters: bool = False
    
    # Checkpointing
    save_strategy: str = "epoch"  # steps, epoch
    save_steps: int = 500
    save_total_limit: int = 5
    save_safetensors: bool = True
    
    # Evaluation
    evaluation_strategy: str = "steps"
    eval_steps: int = 500
    eval_delay: int = 0
    eval_timeout: int = 3600
    
    # Early stopping
    early_stopping_patience: int = 5
    early_stopping_threshold: float = 0.01
    
    # Logging
    logging_strategy: str = "steps"
    logging_steps: int = 100
    log_level: str = "info"
    log_on_each_node: bool = True
    
    # Memory optimization
    gradient_checkpointing: bool = True
    torch_compile: bool = True  # PyTorch 2.0 compile
    flash_attention: bool = True  # Use flash attention if available
    
    # Tokenizer settings
    max_length: int = 512
    padding: str = "max_length"
    truncation: bool = True

@dataclass
class DataConfig:
    """Data processing configuration"""
    train_file: str = "train.json"
    validation_file: str = "validation.json"
    test_file: str = "test.json"
    source_lang: str = "en"
    target_lang: str = "fr"
    
    # Data processing
    preprocessing_num_workers: int = 8
    overwrite_cache: bool = False
    max_train_samples: Optional[int] = None
    max_eval_samples: Optional[int] = None
    max_test_samples: Optional[int] = None
    
    # Augmentation and preprocessing
    do_lowercase: bool = False
    remove_punctuation: bool = False
    character_coverage: float = 0.9995
    temperature_sampling: bool = True
    temperature: float = 5.0
    
    # Tokenizer settings
    tokenizer_type: str = "sentencepiece"
    vocab_size: int = 32000
    special_tokens: Dict[str, str] = None

@dataclass
class MetricsConfig:
    """Evaluation metrics configuration"""
    metrics: List[str] = None
    bleu_type: str = "sacrebleu"
    meteor: bool = True
    rouge: bool = True
    bleurt: bool = True
    comet: bool = True
    ter: bool = True
    chrf: bool = True
    
    # Confidence estimation
    use_confidence_estimation: bool = True
    confidence_threshold: float = 0.8
    
    # Quality estimation
    use_quality_estimation: bool = True
    quality_estimation_model: str = "OpenKiwi"

@dataclass
class InferenceConfig:
    """Inference configuration"""
    beam_size: int = 5
    length_penalty: float = 0.6
    repetition_penalty: float = 1.0
    no_repeat_ngram_size: int = 3
    top_k: int = 50
    top_p: float = 0.9
    temperature: float = 1.0
    diverse_beam_groups: int = 4
    diverse_beam_strength: float = 0.5
    
    # Constrained decoding
    force_words_ids: Optional[List[List[int]]] = None
    suppress_tokens: Optional[List[int]] = None
    
    # Length control
    min_length: int = 0
    max_length: int = 512
    
    # Batch processing
    batch_size: int = 32
    num_workers: int = 4

def create_training_config(
    model_name: str,
    source_lang: str,
    target_lang: str,
    base_dir: str
) -> Dict[str, Union[ModelConfig, TrainingConfig, DataConfig, MetricsConfig, InferenceConfig]]:
    """Create a complete training configuration"""
    
    # Initialize default configurations
    model_config = ModelConfig()
    training_config = TrainingConfig()
    data_config = DataConfig(
        source_lang=source_lang,
        target_lang=target_lang
    )
    metrics_config = MetricsConfig(
        metrics=[
            "bleu",
            "meteor",
            "ter",
            "chrf",
            "comet",
            "bleurt",
            "rouge"
        ]
    )
    inference_config = InferenceConfig()
    
    # Customization based on model size
    if "large" in model_name.lower():
        model_config.d_model = 1024
        model_config.num_heads = 16
        model_config.num_layers = 24
        training_config.train_batch_size = 64
    elif "base" in model_name.lower():
        model_config.d_model = 768
        model_config.num_heads = 12
        model_config.num_layers = 12
        training_config.train_batch_size = 128
    
    # Create paths
    data_config.train_file = f"{base_dir}/data/train.json"
    data_config.validation_file = f"{base_dir}/data/validation.json"
    data_config.test_file = f"{base_dir}/data/test.json"
    
    return {
        "model": model_config,
        "training": training_config,
        "data": data_config,
        "metrics": metrics_config,
        "inference": inference_config
    }

# Example usage:
"""
config = create_training_config(
    model_name="mt-large",
    source_lang="en",
    target_lang="fr",
    base_dir="/path/to/project"
)
"""