# SBERT Encoder w/ MLP MoE for Multiclass Film Genre Classification

Short test poses a unique challenge for many NLP models due to limited context length while retaining the difficulties of the nuances of structured languages. This model aims to use several methods to address each of these concerns individually. First, a trained autoencoder, SBERT, will be used to create a new hidden representation for the short text input sequences. Then for the diverse classification tasks several expert models will be trained to selectively specialize in specific genre embeddings. Simple MLP models are sufficient for mapping the smaller embeddings to output classes while keeping the full model simple to train.

In [4]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
import torch.nn as nn
import torch.nn.functional as F

import lightning.pytorch as pl
from lightning.pytorch.callbacks import BaseFinetuning
from torchmetrics.classification import MultilabelF1Score
from sentence_transformers import SentenceTransformer

## Models

In [5]:
class Expert(nn.Module):
    """
    A simple Feed-Forward Network acting as a single 'Expert'.
    """
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)

class TopKRouter(nn.Module):
    """
    Gating Network that selects the top-k experts for each input.
    """
    def __init__(self, input_dim, num_experts, top_k=2):
        super().__init__()
        self.gate = nn.Linear(input_dim, num_experts)
        self.top_k = top_k

    def forward(self, x):
        logits = self.gate(x)
        top_k_vals, top_k_indices = torch.topk(logits, self.top_k, dim=1)
        router_probs = F.softmax(top_k_vals, dim=1)
        return router_probs, top_k_indices, logits

class MoEClassifier(nn.Module):
    """
    The Mixture of Experts Classification Head.
    """
    def __init__(self, input_dim, num_classes, num_experts=4, top_k=2, expert_hidden_dim=128):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        
        self.router = TopKRouter(input_dim, num_experts, top_k=top_k)
        self.experts = nn.ModuleList([
            Expert(input_dim, expert_hidden_dim, num_classes) 
            for _ in range(num_experts)
        ])

    def forward(self, x):
        batch_size = x.size(0)
        router_probs, expert_indices, router_logits = self.router(x)
        
        final_output = torch.zeros(batch_size, self.experts[0].net[-1].out_features).to(x.device)
        
        for k in range(self.top_k):
            selected_experts = expert_indices[:, k]
            gate_weight = router_probs[:, k].unsqueeze(1)
            
            for expert_idx in range(self.num_experts):
                mask = (selected_experts == expert_idx)
                if mask.any():
                    expert_input = x[mask]
                    expert_output = self.experts[expert_idx](expert_input)
                    final_output[mask] += gate_weight[mask] * expert_output
                    
        return final_output, router_logits

class SBERT_MoE_Model(nn.Module):
    """
    Full End-to-End Model: SBERT Encoder + MoE Classifier
    """
    def __init__(self, model_name='all-MiniLM-L6-v2', num_classes=5, num_experts=8, top_k=2):
        super().__init__()
        # 1. Initialize SBERT
        self.sbert = SentenceTransformer(model_name)
        
        # 2. Capture device choice
        target_device = self.sbert.device 
        
        # 3. Initialize MoE head
        embedding_dim = self.sbert.get_sentence_embedding_dimension()
        self.moe_head = MoEClassifier(
            input_dim=embedding_dim,
            num_classes=num_classes,
            num_experts=num_experts,
            top_k=top_k
        )
        
        # 4. Force MoE head to same device
        self.moe_head.to(target_device)

    def forward(self, text_input):
        # Encode and Detach to treat embeddings as fixed features initially
        features = self.sbert.encode(text_input, convert_to_tensor=True)
        
        # Important: Clone and detach to avoid "Inference Tensor" errors during backprop
        # If unfreezing later, the optimizer handles the graph connection, 
        # but for the initial forward pass logic, this is safe.
        features = features.clone().detach() 
        
        logits, router_logits = self.moe_head(features)
        return logits, router_logits

## Training

In [6]:
from torchmetrics.classification import MulticlassAccuracy

class MoE_LightningModule(pl.LightningModule):
    def __init__(self, model, num_classes, num_experts, learning_rate=1e-3, aux_loss_weight=0.01):
        super().__init__()
        self.save_hyperparameters(ignore=['model'])
        
        self.backbone = model.sbert
        self.head = model.moe_head
        self.learning_rate = learning_rate
        
        # --- FIX 1: Store these for the loss calculation ---
        self.num_experts = num_experts
        self.aux_loss_weight = aux_loss_weight
        
        # Loss & Metric
        self.criterion = nn.CrossEntropyLoss()
        self.val_acc = MulticlassAccuracy(num_classes=num_classes, average='micro')

    def forward(self, x):
        # Standard forward pass
        features = self.backbone.tokenize(x)
        features = {k: v.to(self.device) for k, v in features.items()}
        out = self.backbone(features)
        embeddings = out['sentence_embedding']
        return self.head(embeddings)

    # --- FIX 2: Re-add the Load Balancing Logic ---
    def _compute_load_balancing_loss(self, router_logits):
        """
        Encourages the router to send equal traffic to all experts.
        """
        # 1. Get probability of selecting each expert
        probs = F.softmax(router_logits, dim=1) # [batch_size, num_experts]
        
        # 2. Mean probability per expert across the batch
        mean_probs = probs.mean(dim=0) # [num_experts]
        
        # 3. Variance-like penalty (sum of squares)
        # If traffic is perfectly even, this value is minimized.
        aux_loss = (mean_probs ** 2).sum() * self.num_experts
        
        return aux_loss

    def training_step(self, batch, batch_idx):
        texts, targets = batch
        logits, router_logits = self(texts)
        
        # 1. Main Classification Loss
        cls_loss = self.criterion(logits, targets)
        
        # --- FIX 3: Add Aux Loss to Total Loss ---
        aux_loss = self._compute_load_balancing_loss(router_logits)
        
        total_loss = cls_loss + (self.aux_loss_weight * aux_loss)
        
        # Log both for debugging
        self.log("train_loss", total_loss)
        self.log("train_cls_loss", cls_loss)
        self.log("train_aux_loss", aux_loss)
        
        return total_loss

    def validation_step(self, batch, batch_idx):
        texts, targets = batch
        logits, _ = self(texts)
        loss = self.criterion(logits, targets)
        self.val_acc(logits, targets)
        
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.val_acc, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

## Movie Genre Execution

In [7]:
import torch
import pandas as pd
from pathlib import Path
from torch.utils.data import DataLoader, Dataset

class IMDBDataset(Dataset):
    def __init__(self, data_dir_path, filename, class_names, text_col='description', label_col='csv_genre'):
        """
        Args:
            label_col (str): Now defaults to 'csv_genre' (single label)
        """
        self.data_path = Path(data_dir_path) / filename
        
        if not self.data_path.exists():
            raise FileNotFoundError(f"File not found at: {self.data_path.resolve()}")
            
        print(f"Loading data from {self.data_path.name}...")
        self.df = pd.read_csv(self.data_path)
        
        # 1. Process Text
        self.texts = self.df[text_col].fillna("").astype(str).tolist()
        
        # 2. Process Labels (Single Integer Encoding)
        self.class_to_idx = {cls: i for i, cls in enumerate(class_names)}
        self.labels = []
        
        unseen_genres = set()
        
        # Iterate directly over the column
        for genre_raw in self.df[label_col]:
            # Clean whitespace
            genre_str = str(genre_raw).strip()
            
            if genre_str in self.class_to_idx:
                # Store just the INTEGER index (e.g., 5)
                self.labels.append(self.class_to_idx[genre_str])
            else:
                # Handle unknown classes (optional: map to -1 or a generic 'Other')
                # For now, we'll just warn and append a dummy index (e.g., 0) 
                # or raise an error depending on your preference.
                unseen_genres.add(genre_str)
                self.labels.append(0) # Defaulting to class 0 (Risky, better to filter data first)

        if unseen_genres:
            print(f"⚠️  WARNING: Found {len(unseen_genres)} labels not in class list.")
            print(f"   Examples: {list(unseen_genres)[:5]}")

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        # RETURN TYPE CHANGE:
        # Torch expects LongTensor (int64) for CrossEntropyLoss targets
        # We do NOT wrap it in a list. Just the scalar tensor.
        label_tensor = torch.tensor(self.labels[idx], dtype=torch.long)
        
        return self.texts[idx], label_tensor

    @staticmethod
    def discover_classes(data_dir_path, filename, label_col='csv_genre'):
        """
        Scans for unique SINGLE categories.
        """
        path = Path(data_dir_path) / filename
        df = pd.read_csv(path)
        
        # No splitting or exploding needed for single-label
        genres = df[label_col].dropna().astype(str).str.strip().unique()
        
        return sorted(list(genres))


DATA_DIR = Path('../data/imdb_arh_trimmed')
CLASS_NAMES = IMDBDataset.discover_classes(DATA_DIR, 'imdb_arh_train.csv')
NUM_WORKERS = 4

tr_ds = IMDBDataset(data_dir_path=DATA_DIR, filename='imdb_arh_train.csv', class_names=CLASS_NAMES)
va_ds = IMDBDataset(data_dir_path=DATA_DIR, filename='imdb_arh_val.csv', class_names=CLASS_NAMES)
te_ds = IMDBDataset(data_dir_path=DATA_DIR, filename='imdb_arh_test.csv', class_names=CLASS_NAMES)

tr_loader = DataLoader(tr_ds, batch_size=32, num_workers=NUM_WORKERS, shuffle=True)
va_loader = DataLoader(va_ds, batch_size=64, num_workers=NUM_WORKERS, shuffle=False)
te_loader = DataLoader(te_ds, batch_size=64, num_workers=NUM_WORKERS, shuffle=False)

print("--- Sanity Checks ---")
# 1. Check if classes were discovered
print(f"Num classes: {len(CLASS_NAMES)}")
print(f"First 5 classes: {CLASS_NAMES[:5]}")

# 2. Check Dataset lengths
print(f"Train size: {len(tr_ds)}")
print(f"Val size:   {len(va_ds)}")

# 3. Inspect a single sample
sample_text, sample_label = tr_ds[0]

print(f"\nSample Text Type: {type(sample_text)}") # Should be <class 'str'>
print(f"Sample Label Type: {type(sample_label)}") # Should be <class 'torch.Tensor'>
print(f"Sample Label Shape: {sample_label.shape}") # Should be torch.Size([num_classes])
print(f"Sample Label Dtype: {sample_label.dtype}") # Should be torch.float32

print("\n--- DataLoader Batch Check ---")

# Get one batch from the training loader
batch_texts, batch_labels = next(iter(tr_loader))

print(f"Batch Text Length: {len(batch_texts)}") # Should match batch_size (32)
print(f"Batch Label Shape: {batch_labels.shape}") # Should be [32, num_classes]

# Check if shuffling works (Compare first item of two different iterator calls)
batch_texts_2, _ = next(iter(tr_loader))
if batch_texts[0] != batch_texts_2[0]:
    print("Shuffle is working (First elements differ).")
else:
    print("Warning: Shuffle might not be working or dataset is very small.")

Loading data from imdb_arh_train.csv...
Loading data from imdb_arh_val.csv...
Loading data from imdb_arh_test.csv...
--- Sanity Checks ---
Num classes: 3
First 5 classes: ['action', 'horror', 'romance']
Train size: 67772
Val size:   14523

Sample Text Type: <class 'str'>
Sample Label Type: <class 'torch.Tensor'>
Sample Label Shape: torch.Size([])
Sample Label Dtype: torch.int64

--- DataLoader Batch Check ---
Batch Text Length: 32
Batch Label Shape: torch.Size([32])
Shuffle is working (First elements differ).


In [None]:
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from dotenv import load_dotenv
import wandb

load_dotenv()
wandb.login(key=os.getenv("WANDB_API_KEY"))

# Initialize Logger
wandb_logger = WandbLogger(
    project="MovieGenreMulticlassMoE", 
    name="testrun",
    log_model=False,
)
wandb_logger.experiment.config.update({"class_names": CLASS_NAMES})

# --- Initialize Models ---
num_experts = 10
num_classes = len(CLASS_NAMES)
backbone = SBERT_MoE_Model(
    num_classes=num_classes, 
    num_experts=num_experts, 
    top_k=2
)

# Wrap in Lightning
pl_module = MoE_LightningModule(
    model=backbone, 
    num_classes=num_classes, 
    num_experts=num_experts,
)

# Callback to stop training early
early_stop_cb = EarlyStopping(
    monitor="val_loss",  # Watch the validation loss
    min_delta=0.00,      # Minimum change to qualify as an improvement
    patience=3,          # Stop if no improvement for 3 epochs in a row
    verbose=True,
    mode="min"
)

# Callback to save best model based on validation loss
checkpoint_cb = ModelCheckpoint(
    dirpath="checkpoints/",
    filename="moe-film-{epoch:02d}-{val_loss:.2f}",
    monitor="val_loss",
    mode="min",
    save_top_k=1,
    save_last=True
)

# --- Train ---
print("--- Starting Training ---")
trainer = pl.Trainer(
    max_epochs=20,
    accelerator="auto",
    callbacks=[early_stop_cb, checkpoint_cb],
    enable_progress_bar=True,
    log_every_n_steps=1,
    logger=wandb_logger
)


trainer.fit(
    pl_module, 
    train_dataloaders=tr_loader, 
    val_dataloaders=va_loader,
)

wandb.finish()

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/dodogama/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mconradli90[0m ([33mconradli90-duke-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/dodogama/anaconda3/envs/nlp-py310/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:751: Checkpoint directory /home/dodogama/code/FilmGenreClassification/01-bert-moe/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                | Params | Mode 
----------------------------------------------------------
0 | backbone  | SentenceTransformer | 22.7 M | train
1 | head      | MoEClassifier       | 500 K  | train
2 | criterion | CrossEntropyLoss    | 0      | train
3 |

--- Starting Training ---


Sanity Checking: |          | 0/? [00:00<?, ?it/s]



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: 1.093


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 1.092


Validation: |          | 0/? [00:00<?, ?it/s]

### Inference

In [None]:
import torch
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassF1Score, MulticlassPrecision, MulticlassRecall, MulticlassAccuracy


def map_logits_to_labels_multiclass(logits, class_names):
    """
    Adapted for Multi-Class: Picks the single highest probability class.
    """
    # 1. Softmax to get probabilities that sum to 1
    probs = torch.softmax(logits, dim=1).cpu()
    
    # 2. Argmax to get the index of the winner
    pred_indices = torch.argmax(probs, dim=1)
    
    # 3. Map indices to names
    batch_results = [class_names[idx.item()] for idx in pred_indices]
    
    return batch_results, probs

# A. Define Path
CHECKPOINT_PATH = "checkpoints/moe-film-epoch=04-val_loss=0.14.ckpt" # Update this!

# B. Re-initialize Architecture (Ensure num_classes matches your new dataset)
num_experts = 10
num_classes = len(CLASS_NAMES)
backbone_skeleton = SBERT_MoE_Model(
    num_classes=num_classes, 
    num_experts=num_experts, 
    top_k=2
)

# C. Load Weights
print(f"Loading model from: {CHECKPOINT_PATH}")
# Note: If you had to do the manual state_dict filtering discussed earlier, 
# load that specific model object here instead of using load_from_checkpoint.
loaded_model = MoE_LightningModule.load_from_checkpoint(
    CHECKPOINT_PATH,
    model=backbone_skeleton,
    num_classes=num_classes,
    num_experts=num_experts,
)

# D. Prepare
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loaded_model.to(device)
loaded_model.eval()

# ==========================================
# 2. INITIALIZE GLOBAL METRICS (MULTICLASS)
# ==========================================
metrics = MetricCollection({
    'accuracy': MulticlassAccuracy(num_classes=num_classes, average='micro'),
    'f1': MulticlassF1Score(num_classes=num_classes, average='micro'),
    'precision': MulticlassPrecision(num_classes=num_classes, average='micro'),
    'recall': MulticlassRecall(num_classes=num_classes, average='micro')
}).to(device)

# ==========================================
# 3. INFERENCE LOOP
# ==========================================
print("\n--- Running Inference on Full Test Loader (Multi-Class) ---")

with torch.no_grad():
    for batch_idx, batch in enumerate(te_loader):
        
        # Unpack: Labels are now 1D Integer Tensors (e.g., [5, 2, 0...])
        texts, true_label_indices = batch
        
        # Move to device
        true_label_indices = true_label_indices.to(device)
        
        # Forward Pass
        logits, _ = loaded_model(texts)
        
        # --- A. ACCUMULATE METRICS ---
        # Torchmetrics handles Logits vs Int Indices automatically
        metrics.update(logits, true_label_indices)

        # --- B. VISUALIZE ---
        if batch_idx < 5: 
            # Map outputs
            predicted_labels_batch, probs_batch = map_logits_to_labels_multiclass(logits, CLASS_NAMES)
            
            for i, (text, pred_lbl, probs) in enumerate(zip(texts, predicted_labels_batch, probs_batch)):
                
                # Decode Ground Truth (Simpler now: just look up the index)
                true_idx = true_label_indices[i].item()
                actual_lbl = CLASS_NAMES[true_idx]

                # Print
                print(f"\n[Batch {batch_idx} - Sample {i}]")
                print(f"Input Text:    {text[:80]}...") 
                print(f"PREDICTED:     {pred_lbl}")
                print(f"GROUND TRUTH:  {actual_lbl}")
                
                # Show "High Probs" to see if the model was confused
                # (e.g., showing if it was 51% Action vs 49% Adventure)
                significant_probs = {CLASS_NAMES[j]: round(p, 3) for j, p in enumerate(probs.tolist()) if p > 0.1}
                print(f"High Probs:    {significant_probs}")

# ==========================================
# 4. RESULTS
# ==========================================
print("\n" + "="*30)
print("FINAL EVALUATION REPORT (MULTICLASS)")
print("="*30)

final_results = metrics.compute()

for metric_name, value in final_results.items():
    print(f"Global {metric_name.capitalize()}: {value.item():.4f}")

metrics.reset()