# SBERT Encoder w/ MLP MoE for Multilabel Film Genre Classification

Short test poses a unique challenge for many NLP models due to limited context length while retaining the difficulties of the nuances of structured languages. This model aims to use several methods to address each of these concerns individually. First, a trained autoencoder, SBERT, will be used to create a new hidden representation for the short text input sequences. Then for the diverse classification tasks several expert models will be trained to selectively specialize in specific genre embeddings. Simple MLP models are sufficient for mapping the smaller embeddings to output classes while keeping the full model simple to train.

In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
import torch.nn as nn
import torch.nn.functional as F

import lightning.pytorch as pl
from lightning.pytorch.callbacks import BaseFinetuning
from torchmetrics.classification import MultilabelF1Score
from sentence_transformers import SentenceTransformer

## Models

In [2]:
class Expert(nn.Module):
    """
    A simple Feed-Forward Network acting as a single 'Expert'.
    """
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)

class TopKRouter(nn.Module):
    """
    Gating Network that selects the top-k experts for each input.
    """
    def __init__(self, input_dim, num_experts, top_k=2):
        super().__init__()
        self.gate = nn.Linear(input_dim, num_experts)
        self.top_k = top_k

    def forward(self, x):
        logits = self.gate(x)
        top_k_vals, top_k_indices = torch.topk(logits, self.top_k, dim=1)
        router_probs = F.softmax(top_k_vals, dim=1)
        return router_probs, top_k_indices, logits

class MoEClassifier(nn.Module):
    """
    The Mixture of Experts Classification Head.
    """
    def __init__(self, input_dim, num_classes, num_experts=4, top_k=2, expert_hidden_dim=128):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        
        self.router = TopKRouter(input_dim, num_experts, top_k=top_k)
        self.experts = nn.ModuleList([
            Expert(input_dim, expert_hidden_dim, num_classes) 
            for _ in range(num_experts)
        ])

    def forward(self, x):
        batch_size = x.size(0)
        router_probs, expert_indices, router_logits = self.router(x)
        
        final_output = torch.zeros(batch_size, self.experts[0].net[-1].out_features).to(x.device)
        
        for k in range(self.top_k):
            selected_experts = expert_indices[:, k]
            gate_weight = router_probs[:, k].unsqueeze(1)
            
            for expert_idx in range(self.num_experts):
                mask = (selected_experts == expert_idx)
                if mask.any():
                    expert_input = x[mask]
                    expert_output = self.experts[expert_idx](expert_input)
                    final_output[mask] += gate_weight[mask] * expert_output
                    
        return final_output, router_logits

class SBERT_MoE_Model(nn.Module):
    """
    Full End-to-End Model: SBERT Encoder + MoE Classifier
    """
    def __init__(self, model_name='all-MiniLM-L6-v2', num_classes=5, num_experts=8, top_k=2):
        super().__init__()
        # 1. Initialize SBERT
        self.sbert = SentenceTransformer(model_name)
        
        # 2. Capture device choice
        target_device = self.sbert.device 
        
        # 3. Initialize MoE head
        embedding_dim = self.sbert.get_sentence_embedding_dimension()
        self.moe_head = MoEClassifier(
            input_dim=embedding_dim,
            num_classes=num_classes,
            num_experts=num_experts,
            top_k=top_k
        )
        
        # 4. Force MoE head to same device
        self.moe_head.to(target_device)

    def forward(self, text_input):
        # Encode and Detach to treat embeddings as fixed features initially
        features = self.sbert.encode(text_input, convert_to_tensor=True)
        
        # Important: Clone and detach to avoid "Inference Tensor" errors during backprop
        # If unfreezing later, the optimizer handles the graph connection, 
        # but for the initial forward pass logic, this is safe.
        features = features.clone().detach() 
        
        logits, router_logits = self.moe_head(features)
        return logits, router_logits

## Training

In [3]:
class MoE_LightningModule(pl.LightningModule):
    def __init__(self, model, num_classes, num_experts, learning_rate=1e-3, aux_loss_weight=0.1):
        super().__init__()
        self.save_hyperparameters(ignore=['model'])
        
        # Organize modules
        self.backbone = model.sbert
        self.head = model.moe_head
        self.learning_rate = learning_rate
        self.num_experts = num_experts
        self.aux_loss_weight = aux_loss_weight
        
        # Multi-Label Loss & Metrics
        self.criterion = nn.BCEWithLogitsLoss()
        self.val_f1 = MultilabelF1Score(num_labels=num_classes, threshold=0.5, average='micro')

    def forward(self, x):
        # 1. Tokenize the raw text
        # This converts ["Hello world"] into {'input_ids': ..., 'attention_mask': ...}
        features = self.backbone.tokenize(x)
        
        # 2. Move inputs to the correct device (GPU/MPS)
        # LightningModule provides self.device
        features = {key: value.to(self.device) for key, value in features.items()}
        
        # 3. Pass through SBERT backbone
        # calling .forward() or __call__() allows gradients to flow (unlike .encode())
        out = self.backbone(features)
        
        # 4. Extract the sentence embedding
        embeddings = out['sentence_embedding']
        
        # 5. Pass through MoE Head
        return self.head(embeddings)

    def _compute_load_balancing_loss(self, router_logits):
        probs = F.softmax(router_logits, dim=1)
        mean_probs = probs.mean(dim=0)
        aux_loss = (mean_probs ** 2).sum() * self.num_experts
        return aux_loss

    def training_step(self, batch, batch_idx):
        texts, targets = batch
        logits, router_logits = self(texts)
        
        cls_loss = self.criterion(logits, targets)
        aux_loss = self._compute_load_balancing_loss(router_logits)
        
        total_loss = cls_loss + (self.aux_loss_weight * aux_loss)
        
        self.log("train_loss", total_loss)
        return total_loss

    def validation_step(self, batch, batch_idx):
        texts, targets = batch
        
        # Forward pass (we don't need router logits for validation metrics)
        logits, _ = self(texts)
        
        # Calculate Loss
        val_loss = self.criterion(logits, targets)
        
        # Update F1 Score
        self.val_f1(logits, targets)
        
        # Log metrics so the Scheduler can find 'val_loss'
        self.log("val_loss", val_loss, prog_bar=True)
        self.log("val_f1", self.val_f1, prog_bar=True)

    # def configure_optimizers(self):
    #     # Filter ensures we don't crash trying to optimize frozen params
    #     return torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate)
    
    def configure_optimizers(self):
        # 1. Initialize Optimizer with ONLY the Head (Backbone added later by Callback)
        optimizer = torch.optim.AdamW(
            self.head.parameters(), 
            lr=1e-3, 
            weight_decay=0.01
        )
        
        # 2. Use ReduceLROnPlateau (Safe for dynamic unfreezing)
        # It waits for 'val_loss' to stop improving, then lowers LR
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.1,
            patience=2
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss", # Required for ReduceLROnPlateau
                "interval": "epoch",
                "frequency": 1
            }
        }
        
class ProductionFinetuning(BaseFinetuning):
    def __init__(self, unfreeze_at_epoch=2, backbone_lr=2e-5):
        super().__init__()
        self._unfreeze_at_epoch = unfreeze_at_epoch
        self._backbone_lr = backbone_lr

    def freeze_before_training(self, pl_module):
        # Start with SBERT frozen
        self.freeze(pl_module.backbone)

    def finetune_function(self, pl_module, current_epoch, optimizer):
        # When we hit the target epoch...
        if current_epoch == self._unfreeze_at_epoch:
            print(f"\n‚ùÑÔ∏è -> üî• Unfreezing SBERT Backbone at Epoch {current_epoch}")
            
            # Unfreeze AND add to the optimizer with the Low LR
            self.unfreeze_and_add_param_group(
                modules=pl_module.backbone,
                optimizer=optimizer,
                lr=self._backbone_lr 
            )

## Data

In [4]:
from torch.utils.data import Dataset


class TextDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels 
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        return self.texts[idx], self.labels[idx]

def map_logits_to_labels(logits, class_names, threshold=0.5):
    probs = torch.sigmoid(logits).cpu()
    predictions_mask = (probs > threshold).numpy()
    batch_results = []
    for is_active in predictions_mask:
        active_labels = [class_names[idx] for idx, val in enumerate(is_active) if val]
        batch_results.append(active_labels)
    return batch_results, probs

## Demo

In [5]:
from torch.utils.data import DataLoader


CLASS_NAMES = ["Finance", "Sports", "Tech", "Urgent"]

texts = [
    "Inflation data released this morning caused immediate volatility across global markets. The S&P 500 dropped 2% within minutes of the opening bell as traders reacted to the Federal Reserve's signaling of potentially higher interest rates for the remainder of Q4.",
    "The new smart-jersey utilizes embedded bio-sensors to track player fatigue in real-time during the match. Coaches can now monitor heart rate variability and sprint speed directly from the sidelines, allowing for data-driven substitution decisions in the fourth quarter.",
    "After months of negotiation, the star quarterback has agreed to a record-breaking 50 million dollar extension. This deal makes him the highest-paid player in league history and significantly impacts the team's salary cap for the upcoming trading season.",
    "Benchmark tests for the new M3 processor show a 15% drop in multi-core performance compared to the previous generation. Thermal throttling appears to be the primary bottleneck, as the chip reaches 90 degrees Celsius under sustained heavy workloads like video rendering.",
    "CRITICAL ALERT: Primary database cluster US-East-1 is currently unresponsive due to a cascading failure in the load balancer. Response times have spiked to 5000ms. Immediate engineering intervention is required to prevent a total service outage for enterprise customers.",
    "Emergency board meeting required: The Q3 projections show a cash flow deficit that will impact payroll by next Friday. We need to approve the emergency line of credit immediately to ensure operations continue without interruption.",
]

# Multi-hot labels (Float Tensors)
labels = torch.tensor([
    [1.0, 0.0, 0.0, 0.0], # Finance
    [0.0, 1.0, 1.0, 0.0], # Tech + Sports
    [1.0, 1.0, 0.0, 0.0], # Finance + Sports
    [0.0, 0.0, 1.0, 0.0], # Tech
    [0.0, 0.0, 1.0, 1.0], # Tech + Urgent
    [1.0, 0.0, 0.0, 1.0], # Finance + Urgent
])

dataset = TextDataset(texts, labels)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# --- Initialize Models ---
num_classes = len(CLASS_NAMES)
backbone = SBERT_MoE_Model(num_classes=num_classes, num_experts=4, top_k=2)

# Wrap in Lightning
pl_module = MoE_LightningModule(
    model=backbone, 
    num_classes=num_classes, 
    num_experts=4,
    aux_loss_weight=0.1
)

# Callback to unfreeze SBERT after 1 epoch
finetune_cb = ProductionFinetuning(unfreeze_at_epoch=3, backbone_lr=2e-5)

# --- Train ---
print("--- Starting Training ---")
trainer = pl.Trainer(
    max_epochs=10,
    accelerator="auto",
    callbacks=[finetune_cb],
    enable_progress_bar=True,
    log_every_n_steps=1 
)

# trainer.fit(pl_module, dataloader, dataloader)
trainer.fit(
    pl_module, 
    train_dataloaders=dataloader, 
    val_dataloaders=dataloader # Using same data for demo purposes
)

# --- Inference ---
print("\n--- Running Inference ---")
pl_module.eval()

test_texts = [
    "The tech company bought a sports team.", 
    "Urgent: The bank is collapsing."
]

with torch.no_grad():
    # Run forward pass
    logits, _ = pl_module(test_texts)
    
    # Map outputs
    predicted_labels, probs = map_logits_to_labels(logits, CLASS_NAMES, threshold=0.5)
    
    for text, lbls, raw_probs in zip(test_texts, predicted_labels, probs):
        print(f"\nInput: {text}")
        print(f"Predicted Labels: {lbls}")
        print(f"Raw Probabilities: {[round(p, 2) for p in raw_probs.tolist()]}")

üí° Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
/home/dodogama/anaconda3/envs/nlp-py310/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision(

--- Starting Training ---


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/dodogama/anaconda3/envs/nlp-py310/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:484: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/home/dodogama/anaconda3/envs/nlp-py310/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/home/dodogama/anaconda3/envs/nlp-py310/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


‚ùÑÔ∏è -> üî• Unfreezing SBERT Backbone at Epoch 3


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.



--- Running Inference ---

Input: The tech company bought a sports team.
Predicted Labels: ['Finance', 'Sports', 'Tech']
Raw Probabilities: [0.5, 0.5, 0.5, 0.49]

Input: Urgent: The bank is collapsing.
Predicted Labels: ['Finance', 'Urgent']
Raw Probabilities: [0.5, 0.41, 0.49, 0.53]


## Movie Genre Execution

In [6]:
import torch
import pandas as pd
from pathlib import Path
from torch.utils.data import DataLoader, Dataset


class IMDBDataset(Dataset):
    def __init__(self, data_dir_path, filename, class_names, text_col='description', label_col='genre_list'):
        """
        Args:
            data_dir_path (Path or str): Relative path to data directory.
            filename (str): The CSV filename (e.g., 'train.csv').
            class_names (list): List of valid classes (ORDER MATTERS).
            text_col (str): Column name for text.
            label_col (str): Column name for labels.
        """
        # 1. Setup Path safely
        self.data_path = Path(data_dir_path) / filename
        
        if not self.data_path.exists():
            raise FileNotFoundError(f"File not found at: {self.data_path.resolve()}")
            
        print(f"Loading data from {self.data_path.name}...")
        self.df = pd.read_csv(self.data_path)
        
        # 2. Process Text (Handle NaNs, ensure strings)
        self.texts = self.df[text_col].fillna("").astype(str).tolist()
        
        # 3. Process Labels (Multi-hot Encoding)
        # Create map: "Action" -> 0, "Drama" -> 1, etc.
        self.class_to_idx = {cls: i for i, cls in enumerate(class_names)}
        self.num_classes = len(class_names)
        self.labels = []
        
        # Track unseen genres for safety warning
        unseen_genres = set()
        
        for genre_str in self.df[label_col]:
            # Initialize zero vector [0.0, 0.0, ...]
            label_vec = torch.zeros(self.num_classes, dtype=torch.float)
            
            if pd.notna(genre_str):
                # Split "Action, Drama" -> ["Action", "Drama"]
                # .strip() removes whitespace around words
                current_genres = [g.strip() for g in str(genre_str).split(',')]
                
                for genre in current_genres:
                    if genre in self.class_to_idx:
                        idx = self.class_to_idx[genre]
                        label_vec[idx] = 1.0
                    else:
                        # Track genres that don't match our class list
                        unseen_genres.add(genre)
            
            self.labels.append(label_vec)

        # 4. Warning System
        # If this is the test set, and it has weird genres not in train, warn the user.
        if unseen_genres:
            print(f"‚ö†Ô∏è  WARNING in {filename}: Found {len(unseen_genres)} genres not in the provided class list.")
            print(f"   Examples of ignored genres: {list(unseen_genres)[:5]}")
            print(f"   (These were ignored during label creation)")

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx], self.labels[idx]

    @staticmethod
    def discover_classes(data_dir_path, filename, label_col='genre_list'):
        """
        Static utility to scan a CSV and return sorted unique class names.
        Use this ONCE on your TRAINING set only.
        """
        path = Path(data_dir_path) / filename
        if not path.exists():
             raise FileNotFoundError(f"File not found at: {path.resolve()}")

        df = pd.read_csv(path)
        
        # Split by comma, explode list to rows, strip whitespace, find unique
        genres = df[label_col].dropna().astype(str).str.split(',').explode().str.strip().unique()
        
        return sorted(list(genres))


DATA_DIR = Path('../data/imdb_arh_trimmed')
CLASS_NAMES = IMDBDataset.discover_classes(DATA_DIR, 'imdb_arh_train.csv')
NUM_WORKERS = 4

tr_ds = IMDBDataset(data_dir_path=DATA_DIR, filename='imdb_arh_train.csv', class_names=CLASS_NAMES)
va_ds = IMDBDataset(data_dir_path=DATA_DIR, filename='imdb_arh_val.csv', class_names=CLASS_NAMES)
te_ds = IMDBDataset(data_dir_path=DATA_DIR, filename='imdb_arh_test.csv', class_names=CLASS_NAMES)

tr_loader = DataLoader(tr_ds, batch_size=32, num_workers=NUM_WORKERS, shuffle=True)
va_loader = DataLoader(va_ds, batch_size=64, num_workers=NUM_WORKERS, shuffle=False)
te_loader = DataLoader(te_ds, batch_size=64, num_workers=NUM_WORKERS, shuffle=False)

print("--- Sanity Checks ---")
# 1. Check if classes were discovered
print(f"Num classes: {len(CLASS_NAMES)}")
print(f"First 5 classes: {CLASS_NAMES[:5]}")

# 2. Check Dataset lengths
print(f"Train size: {len(tr_ds)}")
print(f"Val size:   {len(va_ds)}")

# 3. Inspect a single sample
sample_text, sample_label = tr_ds[0]

print(f"\nSample Text Type: {type(sample_text)}") # Should be <class 'str'>
print(f"Sample Label Type: {type(sample_label)}") # Should be <class 'torch.Tensor'>
print(f"Sample Label Shape: {sample_label.shape}") # Should be torch.Size([num_classes])
print(f"Sample Label Dtype: {sample_label.dtype}") # Should be torch.float32

print("\n--- DataLoader Batch Check ---")

# Get one batch from the training loader
batch_texts, batch_labels = next(iter(tr_loader))

print(f"Batch Text Length: {len(batch_texts)}") # Should match batch_size (32)
print(f"Batch Label Shape: {batch_labels.shape}") # Should be [32, num_classes]

# Check if shuffling works (Compare first item of two different iterator calls)
batch_texts_2, _ = next(iter(tr_loader))
if batch_texts[0] != batch_texts_2[0]:
    print("Shuffle is working (First elements differ).")
else:
    print("Warning: Shuffle might not be working or dataset is very small.")

Loading data from imdb_arh_train.csv...
Loading data from imdb_arh_val.csv...
Loading data from imdb_arh_test.csv...
--- Sanity Checks ---
Num classes: 26
First 5 classes: ['Action', 'Adult', 'Adventure', 'Animation', 'Biography']
Train size: 67772
Val size:   14523

Sample Text Type: <class 'str'>
Sample Label Type: <class 'torch.Tensor'>
Sample Label Shape: torch.Size([26])
Sample Label Dtype: torch.float32

--- DataLoader Batch Check ---
Batch Text Length: 32
Batch Label Shape: torch.Size([32, 26])
Shuffle is working (First elements differ).


In [None]:
# --- Initialize Models ---
num_classes = len(CLASS_NAMES)
backbone = SBERT_MoE_Model(num_classes=num_classes, num_experts=4, top_k=2)

# Wrap in Lightning
pl_module = MoE_LightningModule(
    model=backbone, 
    num_classes=num_classes, 
    num_experts=10,
    aux_loss_weight=0.1
)

# Callback to unfreeze SBERT after 1 epoch
finetune_cb = ProductionFinetuning(unfreeze_at_epoch=3, backbone_lr=2e-5)

# --- Train ---
print("--- Starting Training ---")
trainer = pl.Trainer(
    max_epochs=20,
    accelerator="auto",
    callbacks=[finetune_cb],
    enable_progress_bar=True,
    log_every_n_steps=1 
)

trainer.fit(
    pl_module, 
    train_dataloaders=tr_loader, 
    val_dataloaders=va_loader,
)

üí° Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                | Params | Mode 
----------------------------------------------------------
0 | backbone  | SentenceTransformer | 22.7 M | train
1 | head      | MoEClassifier       | 212 K  | train
2 | criterion | BCEWithLogitsLoss   | 0      | train
3 | val_f1    | MultilabelF1Score   | 0      | train
----------------------------------------------------------
212 K     Trainable params
22.7 M    Non-trainable params
22.9 M    Total params
91.701    Total estimated model params size (MB)
34        Modules in train mode
120       Modules in eval mode


--- Starting Training ---


Sanity Checking: |          | 0/? [00:00<?, ?it/s]



Training: |          | 0/? [00:00<?, ?it/s]

In [None]:
print("\n--- Running Inference on Test Loader ---")
pl_module.eval()

# Optional: Move to CPU for inference loop if your GPU memory is tight, 
# though keeping it on device is faster if you have space.
# pl_module.cpu() 

with torch.no_grad():
    # enumerate to keep track of counts, limit to first 5 batches for clean output
    for batch_idx, batch in enumerate(te_loader):
        if batch_idx >= 5: break  # Stop after 5 batches to avoid spamming console

        # Unpack batch (Dataset returns: text_string, label_tensor)
        texts, true_label_tensors = batch
        
        # Run forward pass
        # Note: inputs to SBERT model are usually just the list of strings
        logits, _ = pl_module(texts)
        
        # Map outputs to class names
        predicted_labels_batch, probs_batch = map_logits_to_labels(logits, CLASS_NAMES, threshold=0.5)
        
        # Loop through items in the batch (handles batch_size=1 or >1)
        for i, (text, pred_lbls, probs) in enumerate(zip(texts, predicted_labels_batch, probs_batch)):
            
            # --- Helper: Decode the actual Ground Truth for comparison ---
            true_indices = (true_label_tensors[i] == 1.0).nonzero(as_tuple=False).squeeze()
            # Handle scalar/0-dim tensors if only 1 class exists
            if true_indices.ndim == 0 and true_indices.numel() == 1:
                true_indices = [true_indices.item()]
            elif true_indices.ndim > 0:
                true_indices = true_indices.tolist()
            else: 
                true_indices = [] # No labels
                
            actual_lbls = [CLASS_NAMES[idx] for idx in true_indices]
            # -------------------------------------------------------------

            print(f"\n[Batch {batch_idx} - Sample {i}]")
            print(f"Input Text:    {text[:80]}...") # Truncate for readability
            print(f"PREDICTED:     {pred_lbls}")
            print(f"GROUND TRUTH:  {actual_lbls}")
            
            # Print only probabilities that are somewhat significant
            significant_probs = {CLASS_NAMES[j]: round(p, 3) for j, p in enumerate(probs.tolist()) if p > 0.1}
            print(f"High Probs:    {significant_probs}")