In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import pandas as pd
import numpy as np
from tokenizers import Tokenizer
import json
from torch.nn.functional import scaled_dot_product_attention  # For flash attention

# =======================
# Load CONFIG
# =======================
CONFIG = {
    "d_model": 512,
    "nhead": 8,
    "num_layers": 6,
    "dim_feedforward": 2048,
    "batch_size": 16,
    "learning_rate": 2e-5,
    "epochs": 24,
    "max_seq_len": 768,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "gradient_accumulation_steps": 4,
    "mixed_precision": True,
    "flash_attention": True,
    "checkpoint_interval": 2,
    "checkpoint_dir": "model_FT_checkpoints",
    "train_ratio": 0.9,
    "tokenizer_path": "movie_review_tokenizer.json",
    "token_ids_path": "padded_token_ids_FT.pt",
    "attention_mask_path": "padded_attention_masks_FT.pt",
    "sentiment_labels_path": "sentiment_labels.pt",
    "mlm_checkpoint_path": "MLM_checkpoint_epoch_24.pt"
}

# Ensure checkpoint directory exists
import os
os.makedirs(CONFIG["checkpoint_dir"], exist_ok=True)

# =======================
# 1. Dataset Loading (Updated for your tokenizer)
# =======================
class SentimentDataset(Dataset):
    def __init__(self, token_ids, attention_mask, labels):
        self.token_ids = token_ids
        self.attention_mask = attention_mask
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            "input_ids": self.token_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "labels": self.labels[idx].clone().detach().long(),
        }

# Load pre-tokenized data (update paths as needed)
token_ids = torch.load(CONFIG["token_ids_path"]).to(CONFIG["device"])
attention_mask = torch.load(CONFIG["attention_mask_path"]).to(CONFIG["device"])
labels = torch.load(CONFIG["sentiment_labels_path"])  

# Split into train/val
train_size = int(len(labels) * CONFIG["train_ratio"])
train_dataset = SentimentDataset(token_ids[:train_size], attention_mask[:train_size], labels[:train_size])
val_dataset = SentimentDataset(token_ids[train_size:], attention_mask[train_size:], labels[train_size:])
train_loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG["batch_size"])

# =======================
# 2. Hybrid Head (Adjusted for d_model=512)
# =======================
class HybridClassificationHead(nn.Module):
    def __init__(self, hidden_size=CONFIG["d_model"], num_classes=2):
        super().__init__()
        self.cnn = nn.Conv1d(hidden_size, hidden_size, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.AdaptiveMaxPool1d(1)
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, num_classes),
        )

    def forward(self, token_embeddings):
        mean_pool = token_embeddings.mean(dim=1)  # [B, H]
        cnn_out = self.relu(self.cnn(token_embeddings.transpose(1, 2)))  # [B, H, T]
        max_pool = self.pool(cnn_out).squeeze(-1)  # [B, H]
        concat = torch.cat([mean_pool, max_pool], dim=-1)  # [B, 2*H]
        return self.classifier(concat)

# =======================
# 3. Custom MLM Model (Updated for your architecture)
# =======================
class CustomMLM(nn.Module):
    def __init__(self, config):
        super().__init__()
        # Load tokenizer if vocab_size not in config
        if "vocab_size" not in config:
            tokenizer = Tokenizer.from_file(config["tokenizer_path"])
            config["vocab_size"] = tokenizer.get_vocab_size()
            
        self.embedding = nn.Embedding(config["vocab_size"], config["d_model"])
        self.pos_encoder = nn.Embedding(config["max_seq_len"], config["d_model"])
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config["d_model"],
            nhead=config["nhead"],
            dim_feedforward=config["dim_feedforward"],
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=config["num_layers"])
        self.config = config

    def forward(self, input_ids, attention_mask):
        seq_len = input_ids.size(1)
        positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        embeddings = self.embedding(input_ids) + self.pos_encoder(positions)
        
        if CONFIG["flash_attention"]:
            embeddings = self.encoder(embeddings, mask=None, src_key_padding_mask=~attention_mask.bool())
        else:
            embeddings = self.encoder(embeddings, src_key_padding_mask=~attention_mask.bool())
        return embeddings

# Initialize MLM
mlm_model = CustomMLM(CONFIG).to(CONFIG["device"])

# Load checkpoint properly
checkpoint = torch.load(CONFIG["mlm_checkpoint_path"], map_location=CONFIG["device"])
state_dict = checkpoint.get("model_state", checkpoint)  # Handles both formats

# Key renaming function
def fix_key_names(state_dict):
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith("transformer.layers"):
            new_key = k.replace("transformer.layers", "encoder.layers")
            new_state_dict[new_key] = v
        elif k == "pos_encoder.pe":
            new_state_dict["pos_encoder.weight"] = v  # Positional encoding
        elif k in ["fc.weight", "fc.bias"]:
            continue  # Skip classifier weights if not needed
        else:
            new_state_dict[k] = v
    return new_state_dict

# Load with renamed keys
mlm_model.load_state_dict(fix_key_names(state_dict), strict=False)

# =======================
# 4. Model Wrapper (Unfreeze last 2 layers)
# =======================
class SentimentClassifier(nn.Module):
    def __init__(self, mlm_model):
        super().__init__()
        self.mlm = mlm_model
        self.head = HybridClassificationHead()

        # Freeze all except last 2 layers and classifier
        for param in self.mlm.parameters():
            param.requires_grad = False
        for layer in self.mlm.encoder.layers[-2:]:
            for param in layer.parameters():
                param.requires_grad = True

    def forward(self, input_ids, attention_mask):
        embeddings = self.mlm(input_ids, attention_mask)
        return self.head(embeddings)

model = SentimentClassifier(mlm_model).to(CONFIG["device"])

# =======================
# 5. Training Setup (Optimizer, AMP, Gradient Clipping)
# =======================
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=CONFIG["learning_rate"])
criterion = nn.CrossEntropyLoss()
scaler = torch.amp.GradScaler(enabled=CONFIG["mixed_precision"])
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=100,
    num_training_steps=len(train_loader) * CONFIG["epochs"] // CONFIG["gradient_accumulation_steps"],
)

# =======================
# 6. Training Loop (Track Validation Accuracy)
# =======================
best_val_acc = 0.0

for epoch in range(CONFIG["epochs"]):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']} [Train]")

    for step, batch in enumerate(progress_bar):
        input_ids = batch["input_ids"].to(CONFIG["device"])
        attention_mask = batch["attention_mask"].to(CONFIG["device"])
        labels = batch["labels"].to(CONFIG["device"])

        with torch.amp.autocast(device_type=CONFIG["device"], enabled=CONFIG["mixed_precision"]):
            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)
            loss = loss / CONFIG["gradient_accumulation_steps"]

        scaler.scale(loss).backward()

        if (step + 1) % CONFIG["gradient_accumulation_steps"] == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad()

        total_loss += loss.item()
        progress_bar.set_postfix({"loss": loss.item()})

    # Validation
    model.eval()
    val_preds, val_labels = [], []
    val_loss = 0.0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']} [Val]"):
            input_ids = batch["input_ids"].to(CONFIG["device"])
            attention_mask = batch["attention_mask"].to(CONFIG["device"])
            labels = batch["labels"].to(CONFIG["device"])

            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)
            val_loss += loss.item()

            preds = torch.argmax(logits, dim=-1)
            val_preds.extend(preds.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())

    val_acc = accuracy_score(val_labels, val_preds)
    print(f"Epoch {epoch+1} — Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {val_acc:.4f}")

    # Save best model and periodic checkpoints
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), f"{CONFIG['checkpoint_dir']}/best_model.pt")
    if (epoch + 1) % CONFIG["checkpoint_interval"] == 0:
        torch.save(model.state_dict(), f"{CONFIG['checkpoint_dir']}/epoch_{epoch+1}.pt")



In [None]:
# Test evaluation

def evaluate_model(model, test_loader):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(CONFIG["device"])
            attention_mask = batch["attention_mask"].to(CONFIG["device"])
            labels = batch["labels"].to(CONFIG["device"])
            
            logits = model(input_ids, attention_mask)
            preds = torch.argmax(logits, dim=-1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')
    
    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1
    }

# Assuming you have a test dataset (modify as needed)
test_dataset = SentimentDataset(
    torch.load("padded_token_ids_test.pt"),  # Replace with your test data paths
    torch.load("padded_attention_masks_test.pt"),
    torch.load("sentiment_labels_test.pt")
)
test_loader = DataLoader(test_dataset, batch_size=CONFIG["batch_size"])

# Load your trained model (assuming you have the best checkpoint)
model_path = "epoch_18.pt"
model = SentimentClassifier(mlm_model).to(CONFIG["device"])
model.load_state_dict(torch.load(model_path, map_location=torch.device(CONFIG["device"])))
model.eval()

# Evaluate
metrics = evaluate_model(model, test_loader)
print("\nTest Metrics:")
print(f"Accuracy: {metrics['accuracy']:.4f}")
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall: {metrics['recall']:.4f}")
print(f"F1 Score: {metrics['f1']:.4f}")