In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from tokenizers import Tokenizer
from torch.optim import AdamW
import numpy as np
import math
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import os

# ======================
# 0. Configuration
# ======================
CONFIG = {
    # Model Architecture
    "d_model": 512,
    "nhead": 8,
    "num_layers": 6,
    "dim_feedforward": 2048,
    
    # Training Parameters
    "batch_size": 16,
    "learning_rate": 2e-5,
    "epochs": 24,
    "mask_prob": 0.15,
    
    # Data Handling
    "max_seq_len": 768,
    "train_ratio": 0.9,
    "tokenizer_path": "models/movie_review_tokenizer.json",
    "token_ids_path": "data/IMDb/mlm_dataset/padded_token_ids.pt",
    "attention_mask_path": "data/IMDb/mlm_dataset/padded_attention_masks.pt",
    
    # GPU Optimization
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "gradient_accumulation_steps": 4,
    "mixed_precision": True,
    "flash_attention": True,
    "checkpoint_interval": 2,
    "checkpoint_dir": "model_checkpoints"
}

print("Active device:", CONFIG["device"])
os.makedirs(CONFIG["checkpoint_dir"], exist_ok=True)

# Enable FlashAttention if available
if CONFIG["flash_attention"] and torch.cuda.is_available():
    torch.backends.cuda.enable_flash_sdp(True)
    torch.backends.cudnn.benchmark = True

# ======================
# 1. Load Tokenizer
# ======================
print("Loading tokenizer...")
tokenizer = Tokenizer.from_file(CONFIG["tokenizer_path"])

SPECIAL_TOKENS = {
    "pad": "[PAD]",
    "mask": "[MASK]",
    "cls": "[CLS]",
    "sep": "[SEP]"
}

SPECIAL_IDS = {name: tokenizer.token_to_id(tok) for name, tok in SPECIAL_TOKENS.items()}
assert all(v is not None for v in SPECIAL_IDS.values()), "Missing special tokens!"

# ======================
# 2. Dataset and Dynamic Masking
# ======================
class ReviewDataset(Dataset):
    def __init__(self, token_ids, attention_masks):
        self.encodings = {
            'input_ids': token_ids,
            'attention_mask': attention_masks
        }
    
    def __len__(self):
        return len(self.encodings['input_ids'])
    
    def __getitem__(self, idx):
        return {
            'input_ids': torch.tensor(self.encodings['input_ids'][idx]),
            'attention_mask': torch.tensor(self.encodings['attention_mask'][idx])
        }

class DynamicMaskingCollator:
    def __init__(self, tokenizer, mask_prob=0.15):
        self.tokenizer = tokenizer
        self.mask_prob = mask_prob
        self.special_ids = SPECIAL_IDS

    def __call__(self, batch):
        inputs = torch.stack([item['input_ids'] for item in batch])
        attention_masks = torch.stack([item['attention_mask'] for item in batch])
        labels = inputs.clone()
        pad_mask = (inputs == self.special_ids["pad"])

        # Dynamic masking
        mask = torch.rand(inputs.shape) < self.mask_prob
        for token_id in self.special_ids.values():
            mask &= (inputs != token_id)

        labels[~mask] = -100  # Ignore non-masked tokens

        # Apply masking (80% [MASK], 10% random, 10% original)
        mask_indices = mask.nonzero(as_tuple=True)
        rand_vals = torch.rand(mask.sum())
        
        # 80% [MASK]
        mask_80 = (rand_vals < 0.8)
        inputs[mask_indices[0][mask_80], mask_indices[1][mask_80]] = self.special_ids["mask"]
        
        # 10% random token
        mask_10 = (rand_vals >= 0.8) & (rand_vals < 0.9)
        random_tokens = torch.randint(0, len(self.tokenizer.get_vocab()), (mask_10.sum(),))
        inputs[mask_indices[0][mask_10], mask_indices[1][mask_10]] = random_tokens

        return {
            'input_ids': inputs,
            'attention_mask': attention_masks,
            'labels': labels,
            'pad_mask': pad_mask
        }

# Load data
token_ids = torch.load(CONFIG["token_ids_path"]).numpy()
attention_masks = torch.load(CONFIG["attention_mask_path"]).numpy()
full_dataset = ReviewDataset(token_ids, attention_masks)

# Split dataset
train_size = int(CONFIG["train_ratio"] * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Create DataLoaders with dynamic masking
mask_collator = DynamicMaskingCollator(tokenizer, mask_prob=CONFIG["mask_prob"])
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG["batch_size"],
    shuffle=True,
    collate_fn=mask_collator,
    drop_last=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG["batch_size"],
    collate_fn=mask_collator
)

# ======================
# 3. Model Architecture
# ======================
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=768):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)  # Even positions
        pe[:, 1::2] = torch.cos(position * div_term)  # Odd positions
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        return x + self.pe[:x.size(1), :]

class MLMTransformer(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, CONFIG["d_model"])
        self.pos_encoder = PositionalEncoding(CONFIG["d_model"])
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=CONFIG["d_model"],
            nhead=CONFIG["nhead"],
            dim_feedforward=CONFIG["dim_feedforward"],
            batch_first=True,
            dropout=0.1
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=CONFIG["num_layers"])
        self.fc = nn.Linear(CONFIG["d_model"], vocab_size)

    def forward(self, x, mask=None):
        x = self.embedding(x)
        x = self.pos_encoder(x)
        x = self.transformer(x, src_key_padding_mask=mask)
        return self.fc(x)

# ======================
# 4. Training Setup
# ======================
model = MLMTransformer(len(tokenizer.get_vocab())).to(CONFIG["device"])
optimizer = AdamW(model.parameters(), lr=CONFIG["learning_rate"])
criterion = nn.CrossEntropyLoss()
scaler = torch.amp.GradScaler(enabled=CONFIG["mixed_precision"])
autocast = torch.amp.autocast(CONFIG["device"], dtype=torch.float16, enabled=CONFIG["mixed_precision"])

def compute_metrics(preds, labels):
    mask = labels != -100
    if mask.sum() == 0:
        return 0.0
    return accuracy_score(labels[mask].cpu(), preds[mask].cpu())

def save_checkpoint(epoch):
    torch.save({
        'epoch': epoch,
        'model_state': model.state_dict(),
        'optimizer_state': optimizer.state_dict()
    }, f"{CONFIG['checkpoint_dir']}/checkpoint_epoch_{epoch}.pt")

# ======================
# 5. Training Loop
# ======================
train_losses = []
val_losses = []
accuracies = []

for epoch in range(CONFIG["epochs"]):
    # Training
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]")
    
    for step, batch in enumerate(progress_bar):
        inputs = batch['input_ids'].to(CONFIG["device"])
        labels = batch['labels'].to(CONFIG["device"])
        pad_mask = batch['pad_mask'].to(CONFIG["device"])
        
        with autocast:
            outputs = model(inputs, mask=pad_mask)
            loss = criterion(outputs.view(-1, len(tokenizer.get_vocab())), labels.view(-1))
            loss = loss / CONFIG["gradient_accumulation_steps"]
        
        scaler.scale(loss).backward()
        
        if (step + 1) % CONFIG["gradient_accumulation_steps"] == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        total_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item()})
    
    # Validation
    model.eval()
    val_loss = 0
    val_acc = 0
    val_progress = tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]")
    
    with torch.no_grad():
        for batch in val_progress:
            inputs = batch['input_ids'].to(CONFIG["device"])
            labels = batch['labels'].to(CONFIG["device"])
            pad_mask = batch['pad_mask'].to(CONFIG["device"])
            
            outputs = model(inputs, mask=pad_mask)
            loss = criterion(outputs.view(-1, len(tokenizer.get_vocab())), labels.view(-1))
            preds = torch.argmax(outputs, dim=-1)
            acc = compute_metrics(preds, labels)
            
            val_loss += loss.item()
            val_acc += acc
            val_progress.set_postfix({'val_loss': loss.item(), 'acc': acc})
    
    # Save checkpoint
    if (epoch + 1) % CONFIG["checkpoint_interval"] == 0:
        save_checkpoint(epoch + 1)
    
    # Log metrics
    avg_train_loss = total_loss / len(train_loader)
    avg_val_loss = val_loss / len(val_loader)
    avg_val_acc = val_acc / len(val_loader)
    
    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    accuracies.append(avg_val_acc)
    
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"Train Loss: {avg_train_loss:.4f}")
    print(f"Val Loss: {avg_val_loss:.4f}")
    print(f"Val Accuracy: {avg_val_acc:.4f}\n")

print("Training complete!")