In [1]:
import os
import pandas as pd
import json
import numpy as np
import re
import unicodedata
import string
import torch
import copy
# from datasets import Dataset
from collections import Counter
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
%matplotlib inline

from torch.utils.data import Dataset
import h5py
from pathlib import Path
import gzip
import pickle
import unicodedata
from torch.cuda.amp import GradScaler, autocast

torch.manual_seed(42)
torch.cuda.manual_seed(42)
import warnings
warnings.filterwarnings("ignore")

In [2]:
class NepaliTokenizer:
    def __init__(self):
        # Nepali-specific character ranges and rules
        self.NEPALI_DEVANAGARI_RANGE = (0x0900, 0x097F)
        
        # Punctuation and special characters to handle
        self.NEPALI_PUNCTUATION = r'।॥,\.;:!?\(\)\[\]\{\}'
        
        # Common Nepali suffixes and postpositions to potentially separate
        self.NEPALI_SUFFIXES = [
            'ले', 'को', 'का', 'की', 'के', 
            'मा', 'बाट', 'सँग', 'देखि', 
            'सम्म', 'पछि', 'अघि',
            'हरू'  # Plural marker
        ]
    
    def is_nepali_character(self, char):
        """
        Check if a character is in the Devanagari script range used for Nepali
        
        Args:
        - char: Single character to check
        
        Returns:
        - Boolean indicating if character is Nepali
        """
        if not char:
            return False
        
        # Get the Unicode code point of the character
        code_point = ord(char)
        
        # Check if it falls within Devanagari range
        return (self.NEPALI_DEVANAGARI_RANGE[0] <= code_point <= self.NEPALI_DEVANAGARI_RANGE[1])
    
    def normalize_nepali_text(self, text):
        """
        Normalize Nepali text
        
        Args:
        - text: Input text to normalize
        
        Returns:
        - Normalized text
        """
        
        # Normalize Unicode decomposition
        text = unicodedata.normalize('NFC', text)
        
        # Replace multiple spaces with single space
        text = re.sub(r'\s+', ' ', text).strip()
        
        return text
    
    def tokenize(self, text):
        """
        Advanced Nepali tokenization method
        
        Args:
        - text: Input text to tokenize
        
        Returns:
        - List of tokens
        """
        # Normalize the text first
        text = self.normalize_nepali_text(text)
        
        # Tokenization strategy
        tokens = []
        
        # Current token being built
        current_token = []
        
        # Iterate through characters
        for i, char in enumerate(text):
            # Check if character is Nepali, space, or punctuation
            if self.is_nepali_character(char):
                current_token.append(char)
            elif char.isspace():
                # If we have a current token, add it
                if current_token:
                    tokens.append(''.join(current_token))
                    current_token = []
            elif char in self.NEPALI_PUNCTUATION:
                # Add current token if exists
                if current_token:
                    tokens.append(''.join(current_token))
                    current_token = []
                
                # Add punctuation as separate token
                tokens.append(char)
            else:
                # Non-Nepali characters (like digits, Latin script)
                if current_token:
                    tokens.append(''.join(current_token))
                    current_token = []
                tokens.append(char)
        
        # Add last token if exists
        if current_token:
            tokens.append(''.join(current_token))
        
        # Suffix and postposition handling
        final_tokens = []
        for token in tokens:
            # Check for tokens that can be further split
            if self.is_nepali_character(token[-1]):
                # Check for known suffixes
                for suffix in self.NEPALI_SUFFIXES:
                    if token.endswith(suffix):
                        base_word = token[:-len(suffix)]
                        if base_word:
                            final_tokens.append(base_word)
                            final_tokens.append(suffix)
                            break
                else:
                    final_tokens.append(token)
            else:
                final_tokens.append(token)
        
        return final_tokens

In [3]:
class CustomBERTTokenizer:
    def __init__(self, 
                 max_vocab_size=30000, 
                 max_length=512, 
                 mask_probability=0.15):
        """
        Custom BERT-style tokenizer
        
        Args:
        - max_vocab_size: Maximum number of tokens in vocabulary
        - max_length: Maximum sequence length
        - mask_probability: Probability of masking a token
        """
        self.max_length = max_length
        self.mask_probability = mask_probability
        
        # Special tokens
        self.special_tokens = {
            '[PAD]': 0,
            '[UNK]': 1,
            '[CLS]': 2,
            '[SEP]': 3,
            '[MASK]': 4
        }
        self.special_token_ids = {token: idx for token, idx in self.special_tokens.items()}
        
        # Vocabulary will be built dynamically
        self.vocab = self.special_tokens.copy()
        self.reverse_vocab = {v: k for k, v in self.vocab.items()}
        
        # Keep track of token frequencies
        self.token_freq = Counter()
        
        # Tokenization parameters
        self.max_vocab_size = max_vocab_size
        self.tokenizer = NepaliTokenizer()

    def _tokenize(self, text):
        """
        Basic tokenization method
        
        Args:
        - text: Input text to tokenize
        
        Returns:
        - List of tokens
        """

        tokens = self.tokenizer.tokenize(text)
        
        return tokens

    def build_vocab(self, texts):
        """
        Build vocabulary from corpus
        
        Args:
        - texts: List of texts to build vocabulary from
        """
        # Tokenize all texts
        all_tokens = []
        for text in texts:
            tokens = self._tokenize(text)
            all_tokens.extend(tokens)
            self.token_freq.update(tokens)
        
        # Sort tokens by frequency
        sorted_tokens = sorted(
            self.token_freq.items(), 
            key=lambda x: x[1], 
            reverse=True
        )
        
        # Add most frequent tokens to vocabulary
        next_idx = max(self.special_token_ids.values()) + 1
        for token, _ in sorted_tokens:
            if token not in self.vocab:
                self.vocab[token] = next_idx
                self.reverse_vocab[next_idx] = token
                next_idx += 1
                
                # Stop if we reach max vocab size
                if len(self.vocab) >= self.max_vocab_size:
                    break
    
    def encode(self, text):
        """
        Encode text to token ids
        
        Args:
        - text: Input text
        
        Returns:
        - List of token ids
        """
        
        # Tokenize
        tokens = self._tokenize(text)
    
        # Convert to ids, using [UNK] for out-of-vocab tokens
        token_ids = [
            self.vocab.get(token, self.special_token_ids['[UNK]'])
            for token in tokens
        ]
    
        # Add special tokens
        token_ids = [self.special_token_ids['[CLS]']] + \
                    token_ids + \
                    [self.special_token_ids['[SEP]']]
    
        # Truncate or pad to max_length
        token_ids = token_ids[:self.max_length]
        token_ids += [self.special_token_ids['[PAD]']] * (self.max_length - len(token_ids))
    
        return token_ids

    def mask_tokens(self, input_ids):
        """
        Apply token masking
        
        Args:
        - input_ids: Original token sequence
        
        Returns:
        - masked_input_ids: Input with some tokens masked
        - mask_labels: Original tokens before masking
        """
        # Ensure input_ids is a torch tensor
        if not isinstance(input_ids, torch.Tensor):
            input_ids = torch.tensor(input_ids)
        
        # Create a copy of input_ids
        masked_input_ids = input_ids.clone()
        
        # Create mask for tokens to be masked (excluding special tokens)
        mask = torch.bernoulli(torch.full(masked_input_ids.shape, self.mask_probability)).bool()
        mask &= (masked_input_ids != self.special_token_ids['[CLS]']) & \
               (masked_input_ids != self.special_token_ids['[SEP]']) & \
               (masked_input_ids != self.special_token_ids['[PAD]'])
        
        # If no tokens are masked, randomly mask at least one
        if not mask.any():
            # Randomly select a non-special token to mask
            non_special_mask = (masked_input_ids != self.special_token_ids['[CLS]']) & \
                               (masked_input_ids != self.special_token_ids['[SEP]']) & \
                               (masked_input_ids != self.special_token_ids['[PAD]'])
            if non_special_mask.any():
                random_index = torch.multinomial(non_special_mask.float(), 1)[0]
                mask[random_index] = True
        
        # Create labels for masked tokens
        mask_labels = torch.zeros_like(masked_input_ids)
        mask_labels[mask] = masked_input_ids[mask]
        
        # 80% of masked tokens are replaced with [MASK]
        mask_mask = mask & (torch.rand_like(masked_input_ids.float()) < 0.8)
        masked_input_ids[mask_mask] = self.special_token_ids['[MASK]']
        
        # 10% of masked tokens are replaced with random tokens
        if mask.any():
            random_tokens = torch.randint_like(
                masked_input_ids, 
                0, 
                len(self.vocab)
            )
            random_mask = mask & (torch.rand_like(masked_input_ids.float()) < 0.1)
            masked_input_ids[random_mask] = random_tokens[random_mask]
        
        return masked_input_ids, mask_labels

    def prepare_bert_pretraining_data(self, df, text_column):
        """
        Prepare BERT pretraining data from a DataFrame
        
        Args:
        - df: Input DataFrame
        - text_column: Name of the text column
        
        Returns:
        - Tuple of tensors for pretraining, now with masked_tokens
        """
        # First, build vocabulary
        # self.build_vocab(df[text_column])
        
        # Prepare lists to store data
        input_sequences = []
        segment_ids = []
        masked_tokens = []
        
        # Iterate through the DataFrame
        for i in range(len(df)):
            try:
                text1 = df[text_column].iloc[i]
                
                input_ids1 = self.encode(text1)
                
                # Combine texts with segment ids
                combined_input_ids = input_ids1
                segment_ids_tensor = torch.zeros(len(combined_input_ids), dtype=torch.long)
                segment_ids_tensor[len(input_ids1):] = 1

                #AFTER GPU ON
                masked_input_ids, mask_label = self.mask_tokens(combined_input_ids)
                # Append to lists
                input_sequences.append(torch.tensor(combined_input_ids).clone().detach())
                segment_ids.append(segment_ids_tensor.clone().detach())
                masked_tokens.append(mask_label.clone().detach())

            except Exception as e:
                print(f"Error processing row {i}: {e}")
                continue
        
        # Convert to tensors
        input_sequences = torch.stack(input_sequences)
        segment_ids = torch.stack(segment_ids)
        masked_tokens = torch.stack(masked_tokens)
        
        return input_sequences, segment_ids, masked_tokens

In [4]:
class PretrainingDataset(Dataset):
    def __init__(self, input_sequences, segment_ids, masked_tokens):
        self.input_sequences = input_sequences
        self.segment_ids = segment_ids
        self.masked_tokens = masked_tokens
    def __len__(self):
        return len(self.input_sequences)
    
    def __getitem__(self, idx):
        return {
            'input_sequences':torch.tensor(self.input_sequences[idx], dtype=torch.long),
            'segment_ids':torch.tensor(self.segment_ids[idx], dtype=torch.long),
            'masked_tokens':torch.tensor(self.masked_tokens[idx],  dtype=torch.long)
         }

In [5]:
def build_total_vocabulary(file_path, tokenizer, chunk_size=1000, max_chunks=None):
    """
    Build vocabulary from entire dataset before processing chunks
    
    Args:
        file_path: Path to TSV file
        tokenizer: CustomBERTTokenizer instance
        chunk_size: Number of rows to process at once
        max_chunks: Maximum number of chunks to process (None for all)
    """
    # Create chunk iterator for vocabulary building
    chunks = pd.read_csv(
        file_path, 
        sep='\t', 
        header=None, 
        names=['text'], 
        chunksize=chunk_size,
        encoding='utf-8'
    )
    
    print("Building vocabulary from all chunks...")
    for chunk_idx, chunk in enumerate(tqdm(chunks, desc="Processing chunks for vocabulary")):
        if max_chunks and chunk_idx >= max_chunks:
            break
            
        try:
            # Build vocabulary from this chunk
            tokenizer.build_vocab(chunk['text'])
            
        except Exception as e:
            print(f"Error processing chunk {chunk_idx} for vocabulary: {e}")
            continue
    
    print(f"Built vocabulary with {len(tokenizer.vocab)} tokens")
    return tokenizer

def save_vocabulary(tokenizer, output_path):
    """
    Save tokenizer vocabulary to file
    """
    vocab_path = Path(output_path).parent / "vocabulary.json"
    
    vocab_data = {
        'vocab': tokenizer.vocab,
        'special_tokens': tokenizer.special_tokens,
        'max_vocab_size': tokenizer.max_vocab_size,
        'max_length': tokenizer.max_length,
        'mask_probability': tokenizer.mask_probability
    }
    
    with open(vocab_path, 'w', encoding='utf-8') as f:
        json.dump(vocab_data, f, ensure_ascii=False, indent=2)
    
    return vocab_path

def load_vocabulary(vocab_path):
    """
    Load saved vocabulary and initialize tokenizer
    """
    with open(vocab_path, 'r', encoding='utf-8') as f:
        vocab_data = json.load(f)
    
    tokenizer = CustomBERTTokenizer(
        max_vocab_size=vocab_data['max_vocab_size'],
        max_length=vocab_data['max_length'],
        mask_probability=vocab_data['mask_probability']
    )
    
    tokenizer.vocab = vocab_data['vocab']
    tokenizer.special_tokens = vocab_data['special_tokens']
    tokenizer.reverse_vocab = {v: k for k, v in tokenizer.vocab.items()}
    
    return tokenizer

In [6]:
# class ChunkDataset(Dataset):
#     """Dataset that loads chunks on-demand"""
    
#     def __init__(self, metadata_path):
#         """
#         Initialize dataset using metadata file
        
#         Args:
#             metadata_path: Path to metadata JSON file
#         """
#         with open(metadata_path) as f:
#             self.metadata = json.load(f)
        
#         self.chunks = self.metadata['chunks']
#         # self.format = self.metadata['format']
        
#         # Load vocabulary
#         self.tokenizer = load_vocabulary(self.metadata['vocab_path'])
        
#         # Calculate cumulative sizes for chunk lookup
#         self.cumulative_sizes = np.cumsum([chunk['num_sequences'] for chunk in self.chunks])
        
#     def __len__(self):
#         return self.cumulative_sizes[-1]
    
#     def _find_chunk(self, idx):
#         """Find which chunk contains the given index"""
#         chunk_idx = np.searchsorted(self.cumulative_sizes, idx, side='right')
#         local_idx = idx - (self.cumulative_sizes[chunk_idx-1] if chunk_idx > 0 else 0)
#         return chunk_idx, local_idx
    
#     def __getitem__(self, idx):
#         chunk_idx, local_idx = self._find_chunk(idx)
#         chunk_path = self.chunks[chunk_idx]['path']
        
#         # Load appropriate slice from chunk file
#         if self.metadata['format'] == 'h5':
#             with h5py.File(chunk_path, 'r') as f:
#                 return {
#                     'input_sequences': torch.tensor(f['input_sequences'][local_idx], dtype=torch.long),
#                     'segment_ids': torch.tensor(f['segment_ids'][local_idx], dtype=torch.long),
#                     'masked_tokens': torch.tensor(f['masked_tokens'][local_idx], dtype=torch.long)
#                 }
#         else:  # npz
#             data = np.load(chunk_path)
#             return {
#                 'input_sequences': torch.tensor(data['input_sequences'][local_idx], dtype=torch.long),
#                 'segment_ids': torch.tensor(data['segment_ids'][local_idx], dtype=torch.long),
#                 'masked_tokens': torch.tensor(data['masked_tokens'][local_idx], dtype=torch.long)
#             }
    # def split(self, train_ratio=0.8, valid_ratio=0.1, seed=42):
    #     """
    #     Split dataset into train, validation, and test sets
        
    #     Args:
    #         train_ratio: Proportion of data for training
    #         valid_ratio: Proportion of data for validation 
    #         seed: Random seed for reproducibility
        
    #     Returns:
    #         Tuple of (train_dataset, valid_dataset, test_dataset)
    #     """
    #     # Total dataset size
    #     total_size = len(self)
        
    #     # Set random seed
    #     np.random.seed(seed)
        
    #     # Generate random indices
    #     indices = np.random.permutation(total_size)
        
    #     # Calculate split points
    #     train_end = int(total_size * train_ratio)
    #     valid_end = train_end + int(total_size * valid_ratio)
        
    #     # Create train split datasets
    #     train_indices = indices[:train_end]
    #     valid_indices = indices[train_end:valid_end]
    #     test_indices = indices[valid_end:]
        
    #     def create_split_dataset(subset_indices):
    #         """Create a subset dataset from given indices"""
    #         split_dataset = copy.deepcopy(self)
    #         split_dataset.indices = subset_indices
    #         split_dataset.__getitem__ = lambda idx: self.__getitem__(subset_indices[idx])
    #         split_dataset.__len__ = lambda: len(subset_indices)
    #         return split_dataset
        
    #     return (
    #         create_split_dataset(train_indices),
    #         create_split_dataset(valid_indices),
    #         create_split_dataset(test_indices)
    #     )

class ChunkDataset(Dataset):
    """Dataset that loads chunks on-demand with support for multiple indices and slicing"""
    
    def __init__(self, metadata_path):
        """
        Initialize dataset using metadata file
        
        Args:
            metadata_path: Path to metadata JSON file
        """
        with open(metadata_path) as f:
            self.metadata = json.load(f)
        
        self.chunks = self.metadata['chunks']
        
        # Load vocabulary
        self.tokenizer = load_vocabulary(self.metadata['vocab_path'])
        
        # Calculate cumulative sizes for chunk lookup
        self.cumulative_sizes = np.cumsum([chunk['num_sequences'] for chunk in self.chunks])
        
    def __len__(self):
        return self.cumulative_sizes[-1]
    
    def _find_chunk(self, idx):
        """Find which chunk contains the given index"""
        chunk_idx = np.searchsorted(self.cumulative_sizes, idx, side='right')
        local_idx = idx - (self.cumulative_sizes[chunk_idx-1] if chunk_idx > 0 else 0)
        return chunk_idx, local_idx
    
    def __getitem__(self, key):
        # Handle slicing
        if isinstance(key, slice):
            start = key.start or 0
            stop = key.stop or len(self)
            step = key.step or 1
            
            # Convert slice to list of indices
            indices = range(start, stop, step)
            return [self[idx] for idx in indices]
        
        # Handle single index
        if isinstance(key, int):
            # Normalize negative indices
            if key < 0:
                key += len(self)
            
            # Validate index
            if key < 0 or key >= len(self):
                raise IndexError("Index out of range")
            
            chunk_idx, local_idx = self._find_chunk(key)
            chunk_path = self.chunks[chunk_idx]['path']
            
            # Load appropriate slice from chunk file
            if self.metadata['format'] == 'h5':
                with h5py.File(chunk_path, 'r') as f:
                    return {
                        'input_sequences': torch.tensor(f['input_sequences'][local_idx], dtype=torch.long),
                        'segment_ids': torch.tensor(f['segment_ids'][local_idx], dtype=torch.long),
                        'masked_tokens': torch.tensor(f['masked_tokens'][local_idx], dtype=torch.long)
                    }
            else:  # npz
                data = np.load(chunk_path)
                return {
                    'input_sequences': torch.tensor(data['input_sequences'][local_idx], dtype=torch.long),
                    'segment_ids': torch.tensor(data['segment_ids'][local_idx], dtype=torch.long),
                    'masked_tokens': torch.tensor(data['masked_tokens'][local_idx], dtype=torch.long)
                }
        
        # Handle list/array of indices
        if isinstance(key, (list, np.ndarray)):
            return [self[idx] for idx in key]
        
        raise TypeError("Invalid argument type")

    
    def split(self, train_ratio=0.8, valid_ratio=0.1, seed=42):
        """
        Split dataset into train, validation, and test sets
        
        Args:
            train_ratio: Proportion of data for training
            valid_ratio: Proportion of data for validation 
            seed: Random seed for reproducibility
        
        Returns:
            Tuple of (train_dataset, valid_dataset, test_dataset)
        """
        # Total dataset size
        total_size = len(self)
        
        # Set random seed
        np.random.seed(seed)
        
        # Generate random indices
        indices = np.random.permutation(total_size)
        
        # Calculate split points
        train_end = int(total_size * train_ratio)
        valid_end = train_end + int(total_size * valid_ratio)
        
        # Create train split datasets
        train_indices = indices[:train_end]
        valid_indices = indices[train_end:valid_end]
        test_indices = indices[valid_end:]
        
        def create_split_dataset(subset_indices):
            """Create a subset dataset from given indices"""
            split_dataset = copy.deepcopy(self)
            split_dataset.indices = subset_indices
            split_dataset.__getitem__ = lambda idx: self.__getitem__(subset_indices[idx])
            split_dataset.__len__ = lambda: len(subset_indices)
            return split_dataset
        
        return (
            create_split_dataset(train_indices),
            create_split_dataset(valid_indices),
            create_split_dataset(test_indices)
        )

In [7]:
dataset = ChunkDataset('/home/ubuntu/dataset/processed_chunks_metadata.json')

# Split into train, validation, and test sets
train_dataset, valid_dataset, test_dataset = dataset.split(
    train_ratio=0.8,  # 80% training
    valid_ratio=0.1,  # 10% validation
    seed=42  # For reproducibility
)
# dataloader = DataLoader(dataset,batch_size=32,shuffle=True)
# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=64,num_workers=16, shuffle=True,pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=64,num_workers=16,pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=64,num_workers=16,pin_memory=True)

In [8]:
# train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True,pin_memory=True)
# test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True,pin_memory=True)
# valid_dataloader = DataLoader(valid_dataset,batch_size=8,shuffle=True,pin_memory=True)

# Model Architecture

In [9]:
class BERTEmbedding(nn.Module):
    def __init__(self, vocab_size, n_segments, max_len, embed_dim, dropout):
        super().__init__()
        self.tok_embed = nn.Embedding(vocab_size, embed_dim)  # Token embedding
        self.seg_embed = nn.Embedding(n_segments, embed_dim)  # Segment embedding
        self.pos_embed = nn.Embedding(max_len, embed_dim)     # Positional embedding
        self.drop = nn.Dropout(dropout)
        self.max_len = max_len  # Store max length for positional embedding

    def forward(self, seq, seg):
        # Dynamically generate position indices on the same device as `seq`
        pos_inp = torch.arange(seq.size(1), device=seq.device).unsqueeze(0).expand_as(seq)
        embed_val = self.tok_embed(seq) + self.seg_embed(seg) + self.pos_embed(pos_inp)
        embed_val = self.drop(embed_val)
        return embed_val



In [10]:
class BERT(nn.Module):
    def __init__(self,
                 vocab_size,
                 n_segments,
                 max_len,
                 embed_dim,
                 n_layers,
                 attn_heads,
                 dropout):
        super().__init__()
        self.embedding = BERTEmbedding(vocab_size, n_segments, max_len, embed_dim, dropout)
        self.encoder_layer = nn.TransformerEncoderLayer(embed_dim, attn_heads, embed_dim*4)
        self.encoder_block = nn.TransformerEncoder(self.encoder_layer, n_layers)
    def forward(self, seq, seg):
        out = self.embedding(seq, seg)
        out = self.encoder_block(out)
        return out

In [11]:
class BERTPretrainingModel(nn.Module):
    def __init__(self, bert_model, vocab_size):
        super().__init__()
        self.bert = bert_model
        # MLM head
        self.mlm_head = nn.Linear(
            bert_model.embedding.tok_embed.embedding_dim, vocab_size)

    def forward(self, seq, seg):
        # Get BERT embeddings
        bert_output = self.bert(seq, seg)  # [batch_size, seq_len, embed_dim]

        # MLM prediction for all tokens
        mlm_prediction = self.mlm_head(bert_output)  # [batch_size, seq_len, vocab_size]

        return mlm_prediction

# Training

In [12]:
# def train_bert(model, train_dataloader, optimizer, mlm_criterion, device):
#     model.train()
#     total_train_loss = 0

#     for batch in tqdm(train_dataloader):
#         seq = batch['input_sequences'].to(device)  # [batch_size, seq_len]
#         seg = batch['segment_ids'].to(device)      # [batch_size, seq_len]
#         masked_tokens = batch['masked_tokens'].to(device)  # [batch_size, seq_len]

#         # Zero gradients
#         optimizer.zero_grad()

#         # Forward pass
#         mlm_predictions = model(seq, seg)  # [batch_size, seq_len, vocab_size]

#         # Flatten predictions and targets
#         mlm_predictions = mlm_predictions.view(-1, mlm_predictions.size(-1))  # [batch_size * seq_len, vocab_size]
#         masked_tokens = masked_tokens.view(-1)  # [batch_size * seq_len]

#         # Compute MLM loss
#         mlm_loss = mlm_criterion(mlm_predictions, masked_tokens)

#         # Backward pass
#         mlm_loss.backward()

#         # Optimizer step
#         optimizer.step()

#         # Accumulate loss
#         total_train_loss += mlm_loss.item()

#     return total_train_loss / len(train_dataloader)

from torch.cuda.amp import GradScaler, autocast

scaler = GradScaler()

def train_bert(model, train_dataloader, optimizer, mlm_criterion, device, train_batch_loss, patience=50):
    model.train()
    total_train_loss = 0
    flag = False
    train_bar = tqdm(train_dataloader)
    
    best_loss = float("inf")
    batches_no_improve = 0  

    for batch in train_bar:
        seq = batch['input_sequences'].to(device)
        seg = batch['segment_ids'].to(device)
        masked_tokens = batch['masked_tokens'].to(device)

        optimizer.zero_grad()

        with autocast():
            mlm_predictions = model(seq, seg)
            mlm_predictions = mlm_predictions.view(-1, mlm_predictions.size(-1))
            masked_tokens = masked_tokens.view(-1)
            mlm_loss = mlm_criterion(mlm_predictions, masked_tokens)

        train_batch_loss.append(mlm_loss.item())
        scaler.scale(mlm_loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_bar.set_postfix({"loss":mlm_loss.item()})
        total_train_loss += mlm_loss.item()

        if mlm_loss.item() < best_loss:
            best_loss = mlm_loss.item()
            batches_no_improve = 0  
        else:
            batches_no_improve += 1
        
        if batches_no_improve >= patience:
            flag = True  

    return total_train_loss / len(train_dataloader),flag


In [13]:
def test_bert(model, test_dataloader, mlm_criterion, device,valid_batch_loss):
    model.eval()
    total_test_loss = 0
    test_bar = tqdm(test_dataloader)
    
    for batch in test_bar:
        seq = batch['input_sequences'].to(device)
        seg = batch['segment_ids'].to(device)
        masked_tokens = batch['masked_tokens'].to(device)

        with autocast():
            mlm_predictions = model(seq, seg)
            mlm_predictions = mlm_predictions.view(-1, mlm_predictions.size(-1))
            masked_tokens = masked_tokens.view(-1)
            mlm_loss = mlm_criterion(mlm_predictions, masked_tokens)

        valid_batch_loss.append(mlm_loss.item())
        total_test_loss += mlm_loss.item()
        test_bar.set_postfix({"loss": mlm_loss.item()})

    return total_test_loss / len(test_dataloader)

In [14]:
# Checkpointing 
scaler = GradScaler()
checkpoint_path = "checkpoint.pth"

def save_checkpoint(model, optimizer, epoch, path=checkpoint_path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scaler_state_dict': scaler.state_dict()
    }, path)

def load_checkpoint(model, optimizer, path=checkpoint_path):
    if os.path.exists(path):
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scaler.load_state_dict(checkpoint['scaler_state_dict'])
        return checkpoint['epoch'] + 1  # Resume from next epoch
    return 0  # Start fresh if no checkpoint

In [15]:
# Hyperparameters
vocab_size = 30000 
n_segments = 2
max_len = 512
embed_dim = 768
n_layers = 6
attn_heads = 6
dropout = 0.1

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize BERT model
# bert_base = BERT(
#     vocab_size=vocab_size, 
#     n_segments=n_segments, 
#     max_len=max_len, 
#     embed_dim=embed_dim, 
#     n_layers=n_layers, 
#     attn_heads=attn_heads, 
#     dropout=dropout
# ).to(device)

bert_base = torch.compile(BERT(
    vocab_size=vocab_size, 
    n_segments=n_segments, 
    max_len=max_len, 
    embed_dim=embed_dim, 
    n_layers=n_layers, 
    attn_heads=attn_heads, 
    dropout=dropout
).to(device))

# Wrap BERT in pretraining model
# model = BERTPretrainingModel(bert_base, vocab_size).to(device)

model = torch.compile(BERTPretrainingModel(bert_base, vocab_size).to(device))

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Loss functions
mlm_criterion = nn.CrossEntropyLoss(ignore_index=0)



num_epochs = 5
train_loss = []
valid_loss = []
train_batch_losses = []
valid_batch_losses = []

start_epoch = load_checkpoint(model, optimizer)


for epoch in range(start_epoch, num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    
    train_mlm_loss,flag = train_bert(model, train_dataloader, optimizer, mlm_criterion, device,train_batch_losses,patience=50)
    print("Train MLM Loss: {train_mlm_loss}")
    train_loss.append(train_mlm_loss)
    
    valid_mlm_loss = test_bert(model, valid_dataloader, mlm_criterion, device,valid_batch_losses)
    print("Test MLM Loss: {valid_mlm_loss}")
    valid_loss.append(valid_mlm_loss)
    
    save_checkpoint(model, optimizer, epoch)

    if flag == True:
        print(f"Epoch:{epoch} early stopping triggered")
    
# Save the model
torch.save(model, 'model.pth')
torch.save(model.state_dict(), 'model_state_dict.pth')


Epoch 1/10


  0%|          | 0/69913 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
test_mlm_loss = test_bert(
    model,
    test_dataloader,
    mlm_criterion,
    device
)
print('Valid MLM Loss: {test_mlm_loss}')

In [None]:
with open('train_loss.txt', 'w') as traintxt:
    traintxt.write(str(train_loss))

with open('valid_loss.txt', 'w') as testtxt:
    testtxt.write(str(valid_loss))


In [None]:
with open('train_batch_loss.txt', 'w') as traintxt:
    traintxt.write(str(train_batch_losses))

with open('valid_batch_loss.txt', 'w') as testtxt:
    testtxt.write(str(valid_batch_losses))

In [None]:
plt.plot(train_batch_losses,label='train_loss',color='blue')
plt.plot(valid_batch_losses,label='valid_loss',color='red')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('Iteration vs Loss')
plt.legend()
plt.show()