# EPASS for Semi-Supervised Learning (Text Classification)

This notebook implements the EPASS approach for Semi-Supervised Learning on a text classification task, adapting concepts from the paper (originally for images) to text data. It incorporates:
- BERT as the base encoder.
- Ensemble Projectors for generating embeddings.
- An Exponential Moving Average (EMA) teacher model.
- A Memory Bank for contrastive learning.
- SSL loss components: Supervised (Ls), Unsupervised Consistency (Lu), and Contrastive (Lc).
- Techniques to mitigate CUDA OOM errors: Automatic Mixed Precision (AMP) and Gradient Accumulation.

## 1. Imports and Environment Setup

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel
from torch.utils.data import Dataset, DataLoader, Subset
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm # For progress bars
import traceback
import os

# Set environment variable for synchronous CUDA execution (debugging)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

## 2. Configuration Parameters

In [None]:
# Model Params
PRETRAINED_MODEL = "bert-base-uncased"
BERT_OUTPUT_DIM = 768
PROJECTOR_DIM = 128 # As mentioned in paper
EMBEDDING_DIM = 128 # Output dim of projector MLP's final layer
NUM_PROJECTORS = 3 # As used in paper's ablation
NUM_CLASSES = -1 # Initialize as -1, will be set during data loading

# SSL Params (Example values, tune as needed)
LABELED_RATIO = 0.1 # Use 10% of training data as labeled
CONFIDENCE_THRESHOLD = 0.95 # Tau (τ) for pseudo-labeling
TEMPERATURE = 0.1 # T for contrastive loss
MOMENTUM = 0.999 # m for EMA teacher update
LAMBDA_U = 1.0 # Weight for unsupervised classification loss (Lu)
LAMBDA_C = 1.0 # Weight for contrastive loss (Lc)
MEM_BANK_SIZE = 4096 # K, size of the memory bank (adjust based on resources/dataset)

# Training Params
MAX_LEN = 128
BATCH_SIZE = 4 # <<< REDUCED for OOM issues
UNLABELED_BATCH_SIZE_MULTIPLIER = 2 # mu << Use a fixed multiplier
LR = 1e-5 # Learning rate for BERT fine-tuning (often lower)
EPOCHS = 3 # Adjust as needed
GRADIENT_ACCUMULATION_STEPS = 4 # <<< ADDED for OOM issues
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {DEVICE}")

## 3. Load Tokenizer and Base Model

In [None]:
print("Loading tokenizer...")
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL)
print("Loading base BERT model...")
# We load the base model later inside the main execution block to ensure fresh instances
# bert_model = BertModel.from_pretrained(PRETRAINED_MODEL)
print("Tokenizer loaded. Base models will be loaded later.")

## 4. Data Loading Function (Corrected for Header)

In [None]:
def load_data_csv(train_path, test_path):
    """Loads train/test CSVs, explicitly using the first row as the header."""
    global NUM_CLASSES # Ensure we modify the global variable
    print(f"Attempting to load train data from: {train_path} (assuming header in row 0)")
    try:
        # Explicitly use the first row as header
        train_df = pd.read_csv(train_path, header=0, low_memory=False)
        print("Train data columns loaded:", train_df.columns.tolist())

        # Check for required columns
        if 'label' not in train_df.columns or ('content' not in train_df.columns and 'title' not in train_df.columns):
             raise ValueError("Train CSV must contain 'label' and ('content' or 'title') columns in the header.")

        # Robustly convert 'label' to numeric, drop rows that fail (like the header itself)
        train_df['label'] = pd.to_numeric(train_df['label'], errors='coerce') # Coerce errors to NaN
        initial_train_count = len(train_df)
        train_df.dropna(subset=['label'], inplace=True) # Drop rows where label conversion failed
        print(f"Dropped {initial_train_count - len(train_df)} train rows due to non-numeric labels.")
        if len(train_df) == 0: raise ValueError("No valid numeric labels found in train data after cleaning.")
        train_df['label'] = train_df['label'].astype(int) - 1 # Convert to 0-based index

    except Exception as e:
        print(f"Error loading or processing train data: {e}")
        raise # Re-raise the exception after printing

    print(f"Attempting to load test data from: {test_path} (assuming header in row 0)")
    try:
        test_df = pd.read_csv(test_path, header=0, low_memory=False)
        print("Test data columns loaded:", test_df.columns.tolist())

        if 'label' not in test_df.columns or ('content' not in test_df.columns and 'title' not in test_df.columns):
             raise ValueError("Test CSV must contain 'label' and ('content' or 'title') columns in the header.")

        test_df['label'] = pd.to_numeric(test_df['label'], errors='coerce')
        initial_test_count = len(test_df)
        test_df.dropna(subset=['label'], inplace=True)
        print(f"Dropped {initial_test_count - len(test_df)} test rows due to non-numeric labels.")
        if len(test_df) == 0: raise ValueError("No valid numeric labels found in test data after cleaning.")
        test_df['label'] = test_df['label'].astype(int) - 1

    except Exception as e:
        print(f"Error loading or processing test data: {e}")
        raise

    # Ensure 'content' column exists, fallback to 'title' if necessary
    if 'content' not in train_df.columns and 'title' in train_df.columns:
        print("Warning: 'content' column missing in train_df, using 'title' instead.")
        train_df['content'] = train_df['title']
    if 'content' not in test_df.columns and 'title' in test_df.columns:
        print("Warning: 'content' column missing in test_df, using 'title' instead.")
        test_df['content'] = test_df['title']

    # Final check for required columns after potential fallback
    if 'content' not in train_df.columns or 'label' not in train_df.columns:
        raise ValueError("Train DataFrame must contain 'label' and 'content' (derived from 'title' if needed) columns.")
    if 'content' not in test_df.columns or 'label' not in test_df.columns:
         raise ValueError("Test DataFrame must contain 'label' and 'content' (derived from 'title' if needed) columns.")


    # --- Crucial: Verify and Set NUM_CLASSES --- 
    train_min_label = train_df['label'].min()
    train_max_label = train_df['label'].max()
    test_min_label = test_df['label'].min()
    test_max_label = test_df['label'].max()

    print(f"Train labels after processing: Min={train_min_label}, Max={train_max_label}")
    print(f"Test labels after processing: Min={test_min_label}, Max={test_max_label}")

    if train_min_label < 0 or test_min_label < 0:
         print("Warning: Negative labels detected after processing. Check label conversion logic.")
         # Optionally raise an error here if negative labels are strictly invalid
         # raise ValueError("Negative labels found after processing.")
         
    # Set NUM_CLASSES based on the highest label index + 1 found in EITHER train or test
    # This handles cases where test set might have labels not seen in the small labeled training split
    determined_num_classes = max(train_max_label, test_max_label) + 1
    if determined_num_classes <= 0:
         raise ValueError("Could not determine a valid number of classes (max label < 0).")
         
    NUM_CLASSES = determined_num_classes
    print(f"Data loaded successfully. Num classes set to: {NUM_CLASSES} (based on max label {NUM_CLASSES - 1})")
    return train_df, test_df

## 5. Dataset Class

In [None]:
class SSLTextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len, is_labeled=True):
        self.texts = texts
        self.labels = labels # Will be -1 or similar for unlabeled
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.is_labeled = is_labeled

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, index):
        # Handle potential NaN or non-string data robustly
        text = str(self.texts[index]) if pd.notna(self.texts[index]) else ""
        label = self.labels[index]

        encoding = self.tokenizer(text, truncation=True, padding="max_length",
                                   max_length=self.max_len, return_tensors="pt")

        return {
            'text': text, 
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'label': torch.tensor(label, dtype=torch.long),
            'is_labeled': torch.tensor(self.is_labeled, dtype=torch.bool)
        }

## 6. EPASS Model Architecture

In [None]:
class EPASSModel(nn.Module):
    def __init__(self, base_model, num_classes, bert_output_dim, projector_dim, embedding_dim, num_projectors):
        super(EPASSModel, self).__init__()
        self.encoder = base_model
        self.num_projectors = num_projectors

        if num_classes <= 0:
            raise ValueError(f"EPASSModel received invalid num_classes: {num_classes}")

        self.projectors = nn.ModuleList([
            nn.Sequential(
                nn.Linear(bert_output_dim, projector_dim),
                nn.ReLU(),
                nn.Linear(projector_dim, embedding_dim)
            ) for _ in range(num_projectors)
        ])
        self.classifier = nn.Linear(bert_output_dim, num_classes)

    def forward(self, input_ids, attention_mask, return_features=False):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        features = outputs.last_hidden_state[:, 0, :] 
        embeddings = [proj(features) for proj in self.projectors]
        stacked_embeddings = torch.stack(embeddings, dim=0)
        ensemble_embedding = torch.mean(stacked_embeddings, dim=0)
        normalized_embedding = F.normalize(ensemble_embedding, p=2, dim=1)
        logits = self.classifier(features)

        if return_features:
            return logits, normalized_embedding, features 
        else:
            return logits, normalized_embedding

## 7. EMA Update Function

In [None]:
@torch.no_grad()
def update_ema_variables(model, ema_model, alpha, global_step):
    alpha = min(1 - 1 / (global_step + 1), alpha)
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)

## 8. Memory Bank Class

In [None]:
class MemoryBank:
    def __init__(self, size, embedding_dim, num_classes, device):
        self.size = size
        self.device = device
        if num_classes <= 0:
            raise ValueError(f"MemoryBank received invalid num_classes: {num_classes}")
        self.embeddings = torch.randn(size, embedding_dim).to(device)
        self.embeddings = F.normalize(self.embeddings, dim=1)
        self.labels = torch.randint(0, num_classes, (size,)).to(device) 
        self.ptr = 0
        print(f"Memory Bank initialized with size {size}, embedding_dim {embedding_dim}, num_classes {num_classes}")

    @torch.no_grad()
    def update(self, embeddings, labels):
        batch_size = embeddings.size(0)
        if batch_size == 0: return 
        embeddings = F.normalize(embeddings, dim=1)
        indices = torch.arange(self.ptr, self.ptr + batch_size).fmod(self.size).long()
        indices = indices[indices < self.size]
        valid_batch_size = len(indices)
        if valid_batch_size < batch_size:
           # print(f"Warning: Updating memory bank with only {valid_batch_size}/{batch_size} elements due to index wrapping.")
           embeddings = embeddings[:valid_batch_size]
           labels = labels[:valid_batch_size]
        if valid_batch_size > 0:
           self.embeddings.index_copy_(0, indices, embeddings.detach())
           self.labels.index_copy_(0, indices, labels.detach())
           self.ptr = (self.ptr + valid_batch_size) % self.size

    def get_all(self):
        return self.embeddings, self.labels

## 9. Loss Calculation Function (with Checks)

In [None]:
ce_loss_fn = nn.CrossEntropyLoss()
ce_loss_fn_reduction_none = nn.CrossEntropyLoss(reduction='none')

def calculate_losses(student_model, ema_model, memory_bank, labeled_batch, unlabeled_batch, criterion_ce, criterion_ce_noreduction, temp, threshold, lambda_u, lambda_c, device):
    """Calculates Ls, Lu, Lc and the combined loss with label range checks."""
    Ls = torch.tensor(0.0, device=device, requires_grad=True) # Ensure Ls requires grad if it's the only term
    Lu = torch.tensor(0.0).to(device)
    Lc = torch.tensor(0.0).to(device)

    # 1. Supervised Loss (Ls)
    if labeled_batch is not None and len(labeled_batch['input_ids']) > 0:
        l_input_ids = labeled_batch['input_ids'].to(device)
        l_attn_mask = labeled_batch['attention_mask'].to(device)
        l_labels = labeled_batch['label'].to(device)

        # Check labeled data labels
        if l_labels.min() < 0 or l_labels.max() >= NUM_CLASSES:
             print(f"!!! Invalid label detected in labeled batch: min={l_labels.min()}, max={l_labels.max()}, NUM_CLASSES={NUM_CLASSES}")
             # Ls remains 0.0 with requires_grad=True
        else:
            student_logits_l, _ = student_model(l_input_ids, l_attn_mask)
            Ls = criterion_ce(student_logits_l, l_labels)
            if torch.isnan(Ls) or torch.isinf(Ls):
                 print("NaN/Inf Ls detected!")
                 Ls = torch.tensor(0.0, device=device, requires_grad=True) # Reset if bad

    # 2. Unsupervised Losses (Lu and Lc)
    if unlabeled_batch is not None and len(unlabeled_batch['input_ids']) > 0:
        u_input_ids = unlabeled_batch['input_ids'].to(device)
        u_attn_mask = unlabeled_batch['attention_mask'].to(device)
        batch_size_u = u_input_ids.size(0)

        with torch.no_grad():
            ema_logits_weak, ema_embed_weak = ema_model(u_input_ids, u_attn_mask)
            ema_probs_weak = torch.softmax(ema_logits_weak, dim=1)
            pseudo_labels = torch.argmax(ema_probs_weak, dim=1)
            max_probs = torch.max(ema_probs_weak, dim=1)[0]
            mask = max_probs.ge(threshold).float()

            # Check pseudo-labels
            if (pseudo_labels.min() < 0 or pseudo_labels.max() >= NUM_CLASSES):
                 print(f"!!! Invalid pseudo_label detected OUTSIDE forward pass: min={pseudo_labels.min()}, max={pseudo_labels.max()}, NUM_CLASSES={NUM_CLASSES}")
                 # We might invalidate the mask for these specific indices later if needed
                 # Or rely on the check within Lu/Lc calculation

        student_logits_strong, student_embed_strong = student_model(u_input_ids, u_attn_mask)

        # 2a. Lu
        if mask.sum() > 0:
            confident_indices_bool = mask.bool()
            valid_pseudo_labels = pseudo_labels[confident_indices_bool]
            if valid_pseudo_labels.min() < 0 or valid_pseudo_labels.max() >= NUM_CLASSES:
                 print(f"!!! Invalid pseudo_label detected *among confident samples* for Lu: min={valid_pseudo_labels.min()}, max={valid_pseudo_labels.max()}, NUM_CLASSES={NUM_CLASSES}")
                 # Set Lu to 0 for this batch to avoid crash
                 Lu = torch.tensor(0.0).to(device)
            else:
                # Apply loss only to confident samples using the mask
                loss_u_per_sample = criterion_ce_noreduction(student_logits_strong, pseudo_labels)
                Lu = (loss_u_per_sample * mask).mean()
                if torch.isnan(Lu) or torch.isinf(Lu):
                     print("NaN/Inf Lu detected!")
                     Lu = torch.tensor(0.0).to(device)


        # 2b. Lc
        mem_embeddings, mem_labels = memory_bank.get_all()
        if mem_embeddings is not None and mem_embeddings.size(0) > 0 and mask.sum() > 0:
            confident_indices_bool = mask.bool()
            valid_pseudo_labels_for_lc = pseudo_labels[confident_indices_bool]
            valid_student_embed_strong = student_embed_strong[confident_indices_bool]
            current_batch_size_u_confident = valid_student_embed_strong.size(0)

            if valid_pseudo_labels_for_lc.min() < 0 or valid_pseudo_labels_for_lc.max() >= NUM_CLASSES:
                print(f"!!! Invalid pseudo_label detected *among confident samples* for Lc: min={valid_pseudo_labels_for_lc.min()}, max={valid_pseudo_labels_for_lc.max()}, NUM_CLASSES={NUM_CLASSES}")
                Lc = torch.tensor(0.0).to(device) 
            elif mem_labels.min() < 0 or mem_labels.max() >= NUM_CLASSES:
                 print(f"!!! Invalid label detected in memory bank for Lc: min={mem_labels.min()}, max={mem_labels.max()}, NUM_CLASSES={NUM_CLASSES}")
                 Lc = torch.tensor(0.0).to(device)
            else:
                # Calculate Lc only for confident samples
                sim_matrix = torch.mm(valid_student_embed_strong, mem_embeddings.t()) / temp 
                labels_matrix = valid_pseudo_labels_for_lc.unsqueeze(1).expand(-1, memory_bank.size)
                mem_labels_matrix = mem_labels.unsqueeze(0).expand(current_batch_size_u_confident, -1)
                positive_mask = (labels_matrix == mem_labels_matrix).float().to(device)
                
                log_probs_contrastive = F.log_softmax(sim_matrix, dim=1)
                num_positives = positive_mask.sum(dim=1, keepdim=True)
                target_dist = positive_mask / (num_positives + 1e-9)
                
                Lc_per_sample = -(target_dist * log_probs_contrastive).sum(dim=1)
                Lc = Lc_per_sample.mean() # Average over the confident samples
                if torch.isnan(Lc) or torch.isinf(Lc):
                     print("NaN/Inf Lc detected!")
                     Lc = torch.tensor(0.0).to(device)

        # Update memory bank 
        confident_indices = mask.bool()
        if confident_indices.sum() > 0:
           labels_to_update = pseudo_labels[confident_indices]
           if labels_to_update.min() >= 0 and labels_to_update.max() < NUM_CLASSES:
                memory_bank.update(ema_embed_weak[confident_indices], labels_to_update)
           else:
                print(f"!!! Skipping memory bank update due to invalid pseudo-labels: min={labels_to_update.min()}, max={labels_to_update.max()}")

    # Ensure loss requires grad if Ls was skipped but Lu/Lc were calculated
    if not labeled_batch and (unlabeled_batch is not None and len(unlabeled_batch['input_ids']) > 0):
        loss = lambda_u * Lu + lambda_c * Lc
        # Manually set requires_grad if needed (though backward should handle it)
        if loss.requires_grad is False and (student_model.training): # Check if student model has grads enabled
             # Find a parameter in the student model that requires grad
             grad_param = next((p for p in student_model.parameters() if p.requires_grad), None)
             if grad_param is not None:
                 loss = loss + 0.0 * grad_param.sum() # Add zero contribution from a grad-requiring param
             else:
                 print("Warning: Could not make unsupervised loss require gradients.")
    elif labeled_batch is not None and len(labeled_batch['input_ids']) > 0:
        loss = Ls + lambda_u * Lu + lambda_c * Lc
    else:
        loss = torch.tensor(0.0, device=device, requires_grad=True) # Default requires grad if no data
        
    # Final check for safety
    if torch.isnan(loss) or torch.isinf(loss):
        print("NaN/Inf combined loss detected!")
        return torch.tensor(0.0, device=device, requires_grad=True), Ls, Lu, Lc # Return zero loss requiring grad

    return loss, Ls, Lu, Lc

## 10. Training Loop Function (AMP + Accumulation)

In [None]:
def train_ssl_model(student_model, ema_model, memory_bank, labeled_loader, unlabeled_loader, optimizer, criterion_ce, criterion_ce_noreduction, device, epochs, momentum, temp, threshold, lambda_u, lambda_c, accumulation_steps=1):

    use_amp = torch.cuda.is_available()
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    student_model.to(device)
    ema_model.to(device)
    ema_model.eval()

    # Handle potential None loaders or empty datasets
    unlabeled_iter = iter(unlabeled_loader) if unlabeled_loader and len(unlabeled_loader.dataset) > 0 else None
    num_labeled_batches = len(labeled_loader) if labeled_loader and len(labeled_loader.dataset) > 0 else 0

    global_step = 0
    print(f"Starting training with AMP={'Enabled' if use_amp else 'Disabled'} and gradient accumulation (steps={accumulation_steps})...")
    for epoch in range(epochs):
        total_loss, total_Ls, total_Lu, total_Lc = 0, 0, 0, 0
        num_optimizer_steps_in_epoch = 0
        num_batches_processed = 0

        if num_labeled_batches == 0 and unlabeled_loader is None:
             print(f"Epoch {epoch+1}: No data to process. Skipping.")
             continue

        optimizer.zero_grad() # Zero gradients at the start of epoch accumulation cycle

        # Determine the iterator to drive the epoch
        if num_labeled_batches > 0:
            loop_iterator = labeled_loader
            loop_length = num_labeled_batches
        elif unlabeled_loader:
            loop_iterator = range(len(unlabeled_loader)) # Use range if only unlabeled
            loop_length = len(unlabeled_loader)
        else:
            loop_iterator = range(0) # Empty loop
            loop_length = 0
            
        progress_bar = tqdm(loop_iterator, desc=f"Epoch {epoch+1}/{epochs}", total=loop_length)

        for i, data_batch in enumerate(progress_bar):
            student_model.train()

            # Assign labeled/unlabeled based on loop driver
            current_labeled_batch = data_batch if num_labeled_batches > 0 else None
            current_unlabeled_batch = None
            if unlabeled_iter:
                try:
                    current_unlabeled_batch = next(unlabeled_iter)
                except StopIteration:
                    if unlabeled_loader is None: continue
                    unlabeled_iter = iter(unlabeled_loader)
                    current_unlabeled_batch = next(unlabeled_iter)
            
            # If loop is driven by unlabeled data, labeled batch is None
            if num_labeled_batches == 0:
                 current_labeled_batch = None
                 # We need to fetch the unlabeled batch if the loop is driven by range
                 if not isinstance(data_batch, dict) and current_unlabeled_batch is None: # Check if data_batch is just an index
                      if unlabeled_iter: # Try fetching again if somehow missed
                           try:
                               current_unlabeled_batch = next(unlabeled_iter)
                           except StopIteration:
                               print("Warning: Unlabeled iterator ended unexpectedly.")
                               continue
                      else:
                           print("Warning: No unlabeled iterator available.")
                           continue

            if current_labeled_batch is None and current_unlabeled_batch is None:
                print("Warning: Both labeled and unlabeled batches are None in loop. Skipping step.")
                continue

            with torch.cuda.amp.autocast(enabled=use_amp):
                loss, Ls, Lu, Lc = calculate_losses(
                    student_model, ema_model, memory_bank,
                    current_labeled_batch, current_unlabeled_batch,
                    ce_loss_fn, ce_loss_fn_reduction_none,
                    TEMPERATURE, CONFIDENCE_THRESHOLD, LAMBDA_U, LAMBDA_C, DEVICE
                )
                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"NaN/Inf loss detected at step {global_step}, epoch {epoch+1}, batch {i}. Skipping update.")
                    if (i + 1) % accumulation_steps == 0 or (i + 1) == loop_length:
                         optimizer.zero_grad()
                    continue

                loss = loss / accumulation_steps

            scaler.scale(loss).backward()
            num_batches_processed += 1
            total_loss += loss.item() * accumulation_steps
            total_Ls += Ls.item() if current_labeled_batch is not None else 0
            total_Lu += Lu.item() if current_unlabeled_batch is not None else 0
            total_Lc += Lc.item() if current_unlabeled_batch is not None else 0

            if (i + 1) % accumulation_steps == 0 or (i + 1) == loop_length:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                update_ema_variables(student_model, ema_model, MOMENTUM, global_step)
                num_optimizer_steps_in_epoch += 1

            global_step += 1

            if isinstance(progress_bar, tqdm):
                 progress_bar.set_postfix({
                     'Loss': f"{loss.item() * accumulation_steps:.4f}",
                     'Ls': f"{Ls.item():.4f}",
                     'Lu': f"{Lu.item():.4f}",
                     'Lc': f"{Lc.item():.4f}"
                 })

        # --- Epoch Summary ---
        if num_optimizer_steps_in_epoch > 0:
             avg_loss = total_loss / num_optimizer_steps_in_epoch
             avg_Ls = total_Ls / num_batches_processed if num_batches_processed > 0 else 0
             avg_Lu = total_Lu / num_batches_processed if num_batches_processed > 0 else 0
             avg_Lc = total_Lc / num_batches_processed if num_batches_processed > 0 else 0
             print(f"Epoch {epoch+1} Avg Loss: {avg_loss:.4f} | Avg Ls: {avg_Ls:.4f} | Avg Lu: {avg_Lu:.4f} | Avg Lc: {avg_Lc:.4f}")
        else:
             print(f"Epoch {epoch+1} completed (no optimizer steps performed).")

    return "EPASS SSL Training Completed"

## 11. Evaluation Function

In [None]:
def evaluate_model(model, dataloader, device):
    model.eval()
    total_correct = 0
    total_samples = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            # Ensure model is on the correct device
            model.to(device)
            # Use autocast for evaluation if using AMP during training, though usually not needed
            # with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            logits, _ = model(input_ids, attention_mask)

            _, predicted = torch.max(logits, dim=1)

            total_correct += (predicted == labels).sum().item()
            total_samples += labels.size(0)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = total_correct / total_samples if total_samples > 0 else 0
    return accuracy, all_preds, all_labels

## 12. Main Execution Block

In [None]:
train_path = '/kaggle/input/dbpedia-ontology/train.csv'
test_path = '/kaggle/input/dbpedia-ontology/test.csv'
accuracy_percent = 0.0 # Initialize in case of errors before evaluation

try:
    # --- Load Data ---
    train_df_full, test_df = load_data_csv(train_path, test_path)

    if NUM_CLASSES <= 0:
        raise ValueError("Number of classes not determined correctly during data loading.")

    # --- Data Splitting ---
    labels_for_split = train_df_full['label'].values
    indices = np.arange(len(train_df_full))
    n_labeled = max(1, int(len(train_df_full) * LABELED_RATIO))
    unique_labels, counts = np.unique(labels_for_split[labels_for_split >= 0], return_counts=True)
    can_stratify = len(unique_labels) == NUM_CLASSES and np.all(counts >= 2)

    if can_stratify:
         print("Stratifying labeled/unlabeled split...")
         labeled_idx, unlabeled_idx = train_test_split(
             indices, train_size=n_labeled, stratify=labels_for_split, random_state=42
         )
    else:
         print("Warning: Cannot stratify. Performing random split...")
         labeled_idx, unlabeled_idx = train_test_split(
             indices, train_size=n_labeled, random_state=42
         )

    labeled_df = train_df_full.iloc[labeled_idx].reset_index(drop=True)
    unlabeled_df = train_df_full.iloc[unlabeled_idx].reset_index(drop=True)

    print(f"Total training samples: {len(train_df_full)}")
    print(f"Labeled samples: {len(labeled_df)}")
    print(f"Unlabeled samples: {len(unlabeled_df)}")
    print(f"Test samples: {len(test_df)}")
    print(f"Number of classes: {NUM_CLASSES}")

    # --- Create Datasets and Dataloaders ---
    print("Creating datasets...")
    labeled_dataset = SSLTextDataset(labeled_df['content'].tolist(), labeled_df['label'].tolist(), tokenizer, MAX_LEN, is_labeled=True)
    unlabeled_dataset = SSLTextDataset(unlabeled_df['content'].tolist(), [-1]*len(unlabeled_df), tokenizer, MAX_LEN, is_labeled=False)
    test_dataset = SSLTextDataset(test_df['content'].tolist(), test_df['label'].tolist(), tokenizer, MAX_LEN, is_labeled=True)

    effective_labeled_batch_size = min(BATCH_SIZE, len(labeled_dataset)) if len(labeled_dataset) > 0 else 0
    if effective_labeled_batch_size == 0 and len(labeled_dataset) > 0:
         effective_labeled_batch_size = 1
         print("Warning: Calculated labeled batch size is 0, setting to 1.")
    elif len(labeled_dataset) == 0:
         print("Warning: Labeled dataset is empty. Ls will be 0. Setting labeled loader to None.")
         effective_labeled_batch_size = 0

    effective_unlabeled_batch_size = 0
    if len(unlabeled_dataset) > 0:
         if effective_labeled_batch_size > 0:
              unlabeled_multiplier = max(1, round(len(unlabeled_df) / len(labeled_df))) # Use rounding
         else: 
              unlabeled_multiplier = UNLABELED_BATCH_SIZE_MULTIPLIER
         effective_unlabeled_batch_size = min(BATCH_SIZE * unlabeled_multiplier, len(unlabeled_dataset))
         if effective_unlabeled_batch_size == 0:
              effective_unlabeled_batch_size = BATCH_SIZE # Fallback if dataset is very small
    else:
        print("Warning: Unlabeled dataset is empty. Lu, Lc will be 0.")

    print(f"Using Labeled Batch Size: {effective_labeled_batch_size}")
    print(f"Using Unlabeled Batch Size: {effective_unlabeled_batch_size}")

    # Create loaders only if batch size is > 0
    labeled_loader = DataLoader(labeled_dataset, batch_size=effective_labeled_batch_size, shuffle=True, drop_last=True) if effective_labeled_batch_size > 0 else None
    unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=effective_unlabeled_batch_size, shuffle=True, drop_last=True) if effective_unlabeled_batch_size > 0 else None
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE*2, shuffle=False)
    print("Dataloaders created.")

    # --- Initialize Models, Optimizer, Memory Bank ---
    print("Initializing models...")
    # Reload base models to ensure they are not on device yet
    base_bert_model_student = BertModel.from_pretrained(PRETRAINED_MODEL)
    student_model = EPASSModel(base_bert_model_student, NUM_CLASSES, BERT_OUTPUT_DIM, PROJECTOR_DIM, EMBEDDING_DIM, NUM_PROJECTORS).to(DEVICE)
    
    ema_base_bert_model = BertModel.from_pretrained(PRETRAINED_MODEL)
    ema_model = EPASSModel(ema_base_bert_model, NUM_CLASSES, BERT_OUTPUT_DIM, PROJECTOR_DIM, EMBEDDING_DIM, NUM_PROJECTORS).to(DEVICE)

    for param_q, param_k in zip(student_model.parameters(), ema_model.parameters()):
        param_k.data.copy_(param_q.data)
        param_k.requires_grad = False

    optimizer = optim.AdamW(student_model.parameters(), lr=LR)
    criterion_ce = nn.CrossEntropyLoss().to(DEVICE)
    criterion_ce_noreduction = nn.CrossEntropyLoss(reduction='none').to(DEVICE)

    # <<< Ensure Memory Bank is initialized AFTER NUM_CLASSES is determined >>>
    memory_bank = MemoryBank(MEM_BANK_SIZE, EMBEDDING_DIM, NUM_CLASSES, DEVICE)
    print("Models and optimizer initialized.")

    # --- Start Training ---
    print("Starting EPASS SSL Training...")
    epass_results = train_ssl_model(
        student_model, ema_model, memory_bank,
        labeled_loader, unlabeled_loader,
        optimizer, criterion_ce, criterion_ce_noreduction,
        DEVICE, EPOCHS, MOMENTUM, TEMPERATURE, CONFIDENCE_THRESHOLD,
        LAMBDA_U, LAMBDA_C, accumulation_steps=GRADIENT_ACCUMULATION_STEPS
    )
    print(epass_results)

    # --- Evaluate ---
    print("Evaluating EMA model on test set...")
    accuracy, _, _ = evaluate_model(ema_model, test_loader, DEVICE)
    accuracy_percent = accuracy * 100
    print(f"Final EMA Model Test Accuracy: {accuracy_percent:.2f}%")

    # --- Save Model ---
    torch.save(ema_model.state_dict(), "/kaggle/working/epass_ema_model_state_dict.pth")
    print("EMA Model state dict saved.")

except Exception as main_e:
    print(f"An error occurred during main execution: {main_e}")
    traceback.print_exc()
finally:
    # Clean up CUDA cache if needed
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

## 13. Display Final Accuracy

In [None]:
print(f"Final EMA Model Test Accuracy: {accuracy_percent:.2f}%")