# EPASS on Tiny ImageNet (SimMatch-based)

Demonstrates Semi-Supervised Learning using Ensemble Projectors
on the Tiny ImageNet dataset.

**Changes from previous version:**
- Increased Epochs (50 -> 300) for better SSL convergence.
- Lowered Learning Rate (0.03 -> 0.02) for stability over more epochs.
- Adjusted Contrastive Loss: Compares student(strong_aug) vs teacher(weak_aug).
- Lowered Confidence Threshold (0.95 -> 0.90) to potentially use more unlabeled data earlier.
- Added Mask Ratio monitoring.

## 1. Setup: Imports and Configuration

In [None]:
import os
import time
import math
import random
import shutil
import subprocess
import glob
from tqdm.notebook import tqdm
from collections import Counter

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, datasets, models
from sklearn.model_selection import train_test_split

# Configuration / Hyperparameters
# --- Dataset ---
DATA_DIR = './tiny-imagenet-200'
TINY_IMAGENET_URL = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip'
NUM_CLASSES = 200
# --- Training ---
SEED = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 300 # <<< INCREASED EPOCHS
BATCH_SIZE = 64 # Labeled batch size
MU = 7 # Ratio of unlabeled batch size to labeled batch size (unlabeled_bs = MU * BATCH_SIZE)
LR = 0.02 # <<< ADJUSTED LEARNING RATE
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
EMA_DECAY = 0.999 # EMA decay factor for teacher model
THRESHOLD = 0.90 # <<< ADJUSTED CONFIDENCE THRESHOLD
LAMBDA_U = 1.0 # Weight for unsupervised consistency loss
LAMBDA_C = 1.0 # Weight for contrastive loss
TEMPERATURE = 0.2 # Temperature for contrastive similarity scaling
# --- EPASS ---
NUM_PROJECTORS = 3 # Number of ensemble projectors (P in the paper)
PROJECTION_DIM = 128 # Output dimension of projectors
# --- Labeled Data ---
LABELED_RATIO = 0.2 # Fraction of training data to use as labeled (e.g., 0.1 for 10%, 0.2 for 20%)

print(f"Using Device: {DEVICE}")
print(f"Labeled Ratio: {LABELED_RATIO}")
print(f"Labeled samples per class: {int(500 * LABELED_RATIO)}")
print(f"Unlabeled batch size: {MU * BATCH_SIZE}")
print(f"Number of projectors: {NUM_PROJECTORS}")
print(f"Epochs: {EPOCHS}")
print(f"Initial LR: {LR}")
print(f"Confidence Threshold: {THRESHOLD}")

# Set random seeds for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    # Keep deterministic for reproducibility, might be slightly slower
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(SEED)

## 2. Dataset Download and Preparation

In [None]:
def download_and_extract_tiny_imagenet(url, target_dir):
    zip_file = os.path.join('.', os.path.basename(url))
    if not os.path.exists(target_dir):
        print("Tiny ImageNet dataset not found. Downloading...")
        try:
            # Use quiet flags for less verbose output during download/extract
            subprocess.run(['wget', '-q', url], check=True) 
            print("Download complete. Extracting...")
            subprocess.run(['unzip', '-q', zip_file], check=True)
            print(f"Extraction complete. Dataset in {target_dir}")
        except Exception as e:
            print(f"Error during download/extraction: {e}")
            print("Please download and extract Tiny ImageNet manually from the URL:")
            print(url)
            print(f"And place the 'tiny-imagenet-200' folder in the current directory.")
            return False
        finally:
            if os.path.exists(zip_file):
                os.remove(zip_file) # Clean up zip file
    else:
        print("Tiny ImageNet dataset found.")
    return True

# Function to create validation folder structure for ImageFolder
def create_val_folder_structure(data_dir):
    val_dir = os.path.join(data_dir, 'val')
    val_img_dir = os.path.join(val_dir, 'images')
    val_annotations_file = os.path.join(val_dir, 'val_annotations.txt')

    if not os.path.exists(val_annotations_file):
        # If images dir also doesn't exist or has subdirs, assume it's already structured
        if not os.path.exists(val_img_dir) or len([d for d in os.listdir(val_dir) if os.path.isdir(os.path.join(val_dir, d))]) > 1:
             print("Validation folder structure potentially already created or source missing.")
             return
        else:
             print(f"Validation annotations file not found at {val_annotations_file}. Cannot restructure validation set.")
             return

    # Check if restructuring is already done (presence of subdirectories in val_dir other than 'images')
    dirs_in_val = [d for d in os.listdir(val_dir) if os.path.isdir(os.path.join(val_dir, d))]
    if len(dirs_in_val) > 1: # Already has class folders besides 'images'
        print("Validation folder structure seems to be already created.")
        return
        
    # Check if the images dir actually still exists - maybe it was moved before
    if not os.path.exists(val_img_dir):
        print("Validation images folder doesn't exist, structure likely already created.")
        return

    print("Restructuring validation folder...")
    # Read annotations and create class folders
    val_data = {}
    with open(val_annotations_file, 'r') as f:
        for line in f:
            parts = line.strip().split('\t')
            if len(parts) >= 2:
                img_file, class_id = parts[0], parts[1]
                if class_id not in val_data:
                    val_data[class_id] = []
                val_data[class_id].append(img_file)

    # Create class directories and move images
    for class_id, img_files in tqdm(val_data.items(), desc="Moving val images"):
        class_dir = os.path.join(val_dir, class_id)
        os.makedirs(class_dir, exist_ok=True)
        for img_file in img_files:
            src_path = os.path.join(val_img_dir, img_file)
            dest_path = os.path.join(class_dir, img_file)
            if os.path.exists(src_path):
                shutil.move(src_path, dest_path)

    # Remove original images folder if empty
    try:
      if os.path.exists(val_img_dir) and not os.listdir(val_img_dir):
          os.rmdir(val_img_dir)
    except OSError as e:
         print(f"Error removing {val_img_dir}: {e}") 
    # Keep val_annotations.txt for reference or remove if desired
    # os.remove(val_annotations_file)
    print("Validation folder restructuring complete.")


# --- Download and Prepare ---
if download_and_extract_tiny_imagenet(TINY_IMAGENET_URL, DATA_DIR):
    create_val_folder_structure(DATA_DIR)
    TRAIN_DIR = os.path.join(DATA_DIR, 'train')
    VAL_DIR = os.path.join(DATA_DIR, 'val') # Now structured correctly
    # Check if directories exist after setup
    if not os.path.exists(TRAIN_DIR) or not os.path.exists(VAL_DIR):
        print(f"Error: Training ({TRAIN_DIR}) or Validation ({VAL_DIR}) directory not found after setup.")
        exit()
else:
    print("Exiting due to dataset issues.")
    exit()


## 3. Data Augmentations and Datasets

In [None]:
from PIL import Image # Import PIL Image

# ImageNet statistics
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

# Weak Augmentation (for supervised loss, teacher pseudo-labels, contrastive keys)
transform_weak = transforms.Compose([
    transforms.RandomResizedCrop(64, scale=(0.2, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
])

# Strong Augmentation (for unsupervised consistency loss, contrastive queries)
# Using RandAugment - parameters might need tuning
transform_strong = transforms.Compose([
    transforms.RandomResizedCrop(64, scale=(0.2, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandAugment(num_ops=2, magnitude=10), # RandAugment
    transforms.ToTensor(),
    normalize
])

# Transform for validation (only resize and normalize)
transform_val = transforms.Compose([
    transforms.Resize(70), # Slightly larger then crop
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    normalize
])


# Custom Dataset wrapper for SSL (applies two transforms)
class SSLDataset(Dataset):
    def __init__(self, base_dataset, transform_weak, transform_strong):
        self.base_dataset = base_dataset
        self.transform_weak = transform_weak
        self.transform_strong = transform_strong

    def __getitem__(self, index):
        try:
           img, target = self.base_dataset[index]
        except Exception as e:
            # print(f"Error loading item {index}: {e}") # Suppress print for cleaner logs
            # Create dummy data (might introduce noise)
            img = Image.new('RGB', (64, 64)) # Requires PIL: from PIL import Image
            target = 0 # Or a random target
        
        try:
           img_w = self.transform_weak(img)
           img_s = self.transform_strong(img)
           return img_w, img_s, target
        except Exception as transform_e:
             # print(f"Error applying transform to item {index}: {transform_e}") # Suppress print
             # Handle transform error similar to loading error
             img_w = torch.zeros((3, 64, 64))
             img_s = torch.zeros((3, 64, 64))
             target = 0
             return img_w, img_s, target 

    def __len__(self):
        return len(self.base_dataset)

# Dataset for labeled data (only weak transform needed for supervised loss)
class LabeledDataset(Dataset):
    def __init__(self, base_dataset, transform_weak):
        self.base_dataset = base_dataset
        self.transform_weak = transform_weak

    def __getitem__(self, index):
        try:
            img, target = self.base_dataset[index]
            img_w = self.transform_weak(img)
            return img_w, target
        except Exception as e:
            # print(f"Error loading/transforming labeled item {index}: {e}") # Suppress print
            # Return dummy data 
            img_w = torch.zeros((3, 64, 64))
            target = 0 
            return img_w, target

    def __len__(self):
        return len(self.base_dataset)


# --- Create Base Datasets ---
# Wrap ImageFolder to handle potential loading errors (e.g., corrupted files)
class SafeImageFolder(datasets.ImageFolder):
    def __getitem__(self, index):
        try:
            return super().__getitem__(index)
        except Exception as e:
            # print(f"Caught error in SafeImageFolder getitem {index}: {e}") # Suppress print
            # Return None and rely on DataLoader's collate_fn to handle this
            return None

# Collate function to filter out None values returned by SafeImageFolder
def safe_collate(batch):
    batch = list(filter(lambda x: x is not None, batch))
    if not batch:
         # If the whole batch is problematic, return tensors of zeros
         # This requires knowing the expected output structure
         # Assuming (img_w, target) or (img_w, img_s, target)
         # Let's assume it's for the unlabeled loader format (most complex)
         print("Warning: Entire batch failed to load. Returning zeros.") 
         bs = BATCH_SIZE # Approximate expected batch size
         dummy_w = torch.zeros((bs, 3, 64, 64))
         dummy_s = torch.zeros((bs, 3, 64, 64))
         dummy_t = torch.zeros((bs,), dtype=torch.long)
         # Check context (e.g., which loader called this) to adjust structure if needed
         # For simplicity, this assumes the SSLDataset structure. May need refinement.
         return dummy_w, dummy_s, dummy_t # Adjust based on loader
         
    return torch.utils.data.dataloader.default_collate(batch)

base_train_dataset = SafeImageFolder(TRAIN_DIR) # Basic, before SSL transforms
val_dataset = SafeImageFolder(VAL_DIR, transform=transform_val)

# --- Split Train into Labeled and Unlabeled ---
train_indices = list(range(len(base_train_dataset)))
train_targets = [s[1] for s in base_train_dataset.samples]

# Stratified split
labeled_indices, unlabeled_indices = train_test_split(
    train_indices,
    test_size=1.0 - LABELED_RATIO,
    stratify=train_targets,
    random_state=SEED
)

print(f"Total training samples: {len(base_train_dataset)}")
print(f"Labeled samples: {len(labeled_indices)}")
print(f"Unlabeled samples: {len(unlabeled_indices)}")

# Verify labeled split distribution (optional)
labeled_targets = [train_targets[i] for i in labeled_indices]
labeled_counts = Counter(labeled_targets)
print(f"Labeled samples per class (sample): {list(labeled_counts.items())[:5]}")
unlabeled_targets = [train_targets[i] for i in unlabeled_indices]
unlabeled_counts = Counter(unlabeled_targets)


# --- Create Final SSL Datasets ---
labeled_subset = Subset(base_train_dataset, labeled_indices)
unlabeled_subset = Subset(base_train_dataset, unlabeled_indices)

labeled_train_dataset = LabeledDataset(labeled_subset, transform_weak)
unlabeled_train_dataset = SSLDataset(unlabeled_subset, transform_weak, transform_strong)

# --- Data Loaders ---
num_workers = 2 # Reduce workers if memory issues arise or errors persist

labeled_loader = DataLoader(
    labeled_train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    drop_last=True, # Drop last incomplete batch
    collate_fn=safe_collate # Use safe collate
)

unlabeled_loader = DataLoader(
    unlabeled_train_dataset,
    batch_size=BATCH_SIZE * MU,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    drop_last=True, # Crucial for consistent batch sizes
    collate_fn=safe_collate # Use safe collate
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE * 2,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    collate_fn=safe_collate # Use safe collate for validation too
)

# Estimate labeled class distribution prior (for simplified DA)
labeled_class_counts = np.array([labeled_counts.get(i, 0) for i in range(NUM_CLASSES)])
# Handle cases where a class might have 0 labeled samples (avoid division by zero)
total_labeled = labeled_class_counts.sum()
if total_labeled > 0:
    p_target = torch.tensor(labeled_class_counts / total_labeled, dtype=torch.float).to(DEVICE)
else:
    # If no labeled samples, use uniform distribution as a fallback
    print("Warning: No labeled samples found for prior estimation, using uniform.")
    p_target = torch.ones(NUM_CLASSES, dtype=torch.float).to(DEVICE) / NUM_CLASSES
    
p_target = p_target.detach() # Ensure it's not part of grad computation
print("Labeled class prior (p_target shape):", p_target.shape)

## 4. Model Definition (ResNet + EPASS Head)

In [None]:
class EPASSModel(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES, num_projectors=NUM_PROJECTORS, projection_dim=PROJECTION_DIM, pretrained=True):
        super(EPASSModel, self).__init__()
        # Load ResNet18 - Use pretrained weights
        self.encoder = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
        num_ftrs = self.encoder.fc.in_features
        self.encoder.fc = nn.Identity() # Remove the original classifier

        # Classifier Head
        self.fc = nn.Linear(num_ftrs, num_classes)

        # EPASS Projector Heads
        self.projectors = nn.ModuleList()
        for _ in range(num_projectors):
            projector = nn.Sequential(
                nn.Linear(num_ftrs, num_ftrs),
                nn.BatchNorm1d(num_ftrs), # Added BatchNorm for stability
                nn.ReLU(inplace=True),
                nn.Linear(num_ftrs, projection_dim)
                # No BatchNorm on the final projection layer is common
            )
            self.projectors.append(projector)

        self.num_projectors = num_projectors

    def forward(self, x, return_features=False, return_proj_only=False):
        features = self.encoder(x)
        
        if return_proj_only:
             # Need to handle features directly in this case
             # Ensure projectors are applied correctly
             projected_features = [proj(features) for proj in self.projectors]
             avg_projection = torch.stack(projected_features, dim=0).mean(dim=0)
             return avg_projection
        
        logits = self.fc(features)
        
        if return_features:
            projected_features = [proj(features) for proj in self.projectors]
            avg_projection = torch.stack(projected_features, dim=0).mean(dim=0)
            return logits, features, avg_projection # Return raw features too if needed elsewhere
        else:
            return logits

# Function to update teacher model using EMA
@torch.no_grad()
def update_ema_variables(model, ema_model, alpha, global_step):
    # Use fixed alpha EMA_DECAY - simpler and often effective
    alpha = EMA_DECAY
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)

# --- Instantiate Student and Teacher Models ---
student_model = EPASSModel(num_classes=NUM_CLASSES, num_projectors=NUM_PROJECTORS, projection_dim=PROJECTION_DIM, pretrained=True).to(DEVICE)
teacher_model = EPASSModel(num_classes=NUM_CLASSES, num_projectors=NUM_PROJECTORS, projection_dim=PROJECTION_DIM, pretrained=True).to(DEVICE)

# Initialize teacher model to be same as student model
for param_q, param_k in zip(student_model.parameters(), teacher_model.parameters()):
    param_k.data.copy_(param_q.data)
    param_k.requires_grad = False # Teacher model doesn't need gradients

## 5. Loss Functions

In [None]:
# Supervised Loss
criterion_s = nn.CrossEntropyLoss()

# Unsupervised Loss (Consistency Regularization)
criterion_u = nn.CrossEntropyLoss(reduction='none')

# Contrastive Loss (InfoNCE style, student-strong vs teacher-weak)
def contrastive_loss(proj_q_strong, proj_k_weak_detached, temp=TEMPERATURE):
    """
    Calculates InfoNCE-style contrastive loss within the batch.
    proj_q_strong: Projections from student (strong aug query) - shape (N, dim)
    proj_k_weak_detached: Projections from teacher (weak aug key, detached) - shape (N, dim)
    temp: Temperature scaling
    """
    # Normalize projections
    proj_q = F.normalize(proj_q_strong, dim=1)
    proj_k = F.normalize(proj_k_weak_detached, dim=1) # Already detached

    # Cosine similarity: the dot product of normalized vectors is cosine similarity
    sim_matrix = torch.mm(proj_q, proj_k.T) # (N, N)
    logits = sim_matrix / temp

    # Labels: positive pairs are diagonal elements (i-th strong vs i-th weak)
    labels = torch.arange(len(proj_q), device=DEVICE) # Ensure labels are on the correct device

    loss = F.cross_entropy(logits, labels)
    return loss


## 6. Training Loop

In [None]:
optimizer = optim.SGD(student_model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

# Cosine LR Scheduler with updated T_max
steps_per_epoch = min(len(labeled_loader), len(unlabeled_loader)) 
# Handle case where loader might be empty initially due to errors
if steps_per_epoch == 0:
    print("Error: DataLoaders have zero length. Cannot calculate total_steps.")
    # Set a default or raise an error
    total_steps = 1 # Avoid division by zero, but training won't proceed correctly
else:
     total_steps = EPOCHS * steps_per_epoch
     
print(f"Scheduler total steps (T_max): {total_steps}")
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=0) 

# Trackers
train_losses_s = []
train_losses_u = []
train_losses_c = []
train_losses_total = []
val_accuracies = []
mask_ratios = [] # <<< Added tracker for mask usage
best_val_acc = 0.0
start_epoch = 0 
global_step = 0


print("Starting Training...")

for epoch in range(start_epoch, EPOCHS):
    start_time = time.time()
    student_model.train()
    teacher_model.train() # Set teacher to train mode (for BatchNorm, etc.)

    running_loss_s = 0.0
    running_loss_u = 0.0
    running_loss_c = 0.0
    running_loss_total = 0.0
    running_mask_ratio = 0.0
    actual_batches_processed = 0 # Count batches that didn't fail
    
    # Check steps_per_epoch again in case loaders are empty now
    current_steps_per_epoch = min(len(labeled_loader), len(unlabeled_loader))
    if current_steps_per_epoch == 0:
        print(f"Epoch {epoch+1}: Loaders are empty, skipping training epoch.")
        continue # Skip to next epoch or validation

    labeled_iter = iter(labeled_loader)
    unlabeled_iter = iter(unlabeled_loader)

    pbar = tqdm(range(current_steps_per_epoch), desc=f"Epoch {epoch+1}/{EPOCHS}", unit="batch", leave=False)

    for i in pbar:
        # --- Get Batches --- 
        # Handle potential StopIteration or None batches from safe_collate
        labeled_batch = None
        unlabeled_batch = None
        try:
            labeled_batch = next(labeled_iter)
            unlabeled_batch = next(unlabeled_iter)
        except StopIteration:
             # This shouldn't happen if iterating range(current_steps_per_epoch)
             # but as a safeguard:
             print(f"Warning: StopIteration caught unexpectedly at step {i}.")
             break # Exit inner loop for this epoch
        
        # Check if safe_collate returned None (whole batch failed)
        if labeled_batch is None or unlabeled_batch is None:
            print(f"Skipping step {i} due to None batch from loader.")
            continue # Skip this step
            
        img_w_l, target_l = labeled_batch
        img_w_u, img_s_u, _ = unlabeled_batch
      
        img_w_l, target_l = img_w_l.to(DEVICE), target_l.to(DEVICE)
        img_w_u, img_s_u = img_w_u.to(DEVICE), img_s_u.to(DEVICE)

        batch_size_l = img_w_l.shape[0]
        batch_size_u = img_w_u.shape[0]
        
        # Skip if batches somehow ended up empty after filtering
        if batch_size_l == 0 or batch_size_u == 0:
             # print(f"Skipping step {i} due to zero batch size after filtering.")
             continue
        
        optimizer.zero_grad()

        # --- Supervised Loss ---
        logits_l = student_model(img_w_l)
        loss_s = criterion_s(logits_l, target_l)

        # --- Unsupervised Loss ---
        with torch.no_grad():
            logits_u_w = teacher_model(img_w_u)
            probs_u_w = torch.softmax(logits_u_w, dim=1)
            max_probs, pseudo_labels = torch.max(probs_u_w, dim=1)
            mask = max_probs.ge(THRESHOLD).float()
            current_mask_ratio = mask.mean().item()

        logits_u_s = student_model(img_s_u)
        loss_u_unmasked = criterion_u(logits_u_s, pseudo_labels)
        # Apply mask and calculate mean only over masked samples
        # Multiply by mask, sum it, and divide by the sum of the mask (number of samples above threshold)
        loss_u = (loss_u_unmasked * mask).sum() / (mask.sum() + 1e-8) # Add epsilon for stability if mask sum is 0
        if mask.sum() == 0: # If no samples passed threshold, loss_u is 0
            loss_u = torch.tensor(0.0).to(DEVICE)

        # --- Contrastive Loss --- 
        _ , _, proj_s_s_avg = student_model(img_s_u, return_features=True) 
        with torch.no_grad():
             proj_t_w_avg = teacher_model(img_w_u, return_proj_only=True) 
        
        loss_c = contrastive_loss(proj_s_s_avg, proj_t_w_avg, temp=TEMPERATURE)

        # --- Total Loss ---
        # Ensure loss_u is a tensor before adding
        if not isinstance(loss_u, torch.Tensor):
            loss_u = torch.tensor(loss_u, device=DEVICE)
        total_loss = loss_s + LAMBDA_U * loss_u + LAMBDA_C * loss_c

        # --- Backward and Optimize ---
        total_loss.backward()
        optimizer.step()
        scheduler.step() 

        # --- Update EMA Teacher ---
        global_step += 1
        update_ema_variables(student_model, teacher_model, EMA_DECAY, global_step)

        # --- Record Losses ---
        running_loss_s += loss_s.item()
        # Ensure loss_u item is added correctly
        running_loss_u += loss_u.item() if isinstance(loss_u, torch.Tensor) else loss_u
        running_loss_c += loss_c.item()
        running_loss_total += total_loss.item()
        running_mask_ratio += current_mask_ratio
        actual_batches_processed += 1 # Increment count for successful batches

        pbar.set_description(f"Epoch {epoch+1}/{EPOCHS} | Ls: {loss_s.item():.2f} | Lu: {loss_u.item() if isinstance(loss_u, torch.Tensor) else loss_u:.2f} | Lc: {loss_c.item():.2f} | Mask: {current_mask_ratio:.2f}")
    
    pbar.close()
    
    # --- End of Epoch ---
    epoch_time = time.time() - start_time
    if actual_batches_processed == 0:
        print(f"Epoch [{epoch+1}/{EPOCHS}] Time: {epoch_time:.2f}s - No batches processed.")
        # Append placeholder values or skip appending for this epoch
        train_losses_s.append(float('nan'))
        train_losses_u.append(float('nan'))
        train_losses_c.append(float('nan'))
        train_losses_total.append(float('nan'))
        mask_ratios.append(float('nan')) 
    else:
        avg_loss_s = running_loss_s / actual_batches_processed
        avg_loss_u = running_loss_u / actual_batches_processed
        avg_loss_c = running_loss_c / actual_batches_processed
        avg_loss_total = running_loss_total / actual_batches_processed
        avg_mask_ratio = running_mask_ratio / actual_batches_processed 
        
        train_losses_s.append(avg_loss_s)
        train_losses_u.append(avg_loss_u)
        train_losses_c.append(avg_loss_c)
        train_losses_total.append(avg_loss_total)
        mask_ratios.append(avg_mask_ratio) 
        print(f"Epoch [{epoch+1}/{EPOCHS}] Time: {epoch_time:.2f}s")
        print(f"  Train Loss: Total={avg_loss_total:.4f} (S={avg_loss_s:.4f}, U={avg_loss_u:.4f}, C={avg_loss_c:.4f}) | Mask Ratio: {avg_mask_ratio:.2f}")
        

    # --- Validation ---
    teacher_model.eval() 
    val_correct = 0
    val_total = 0
    val_loss = 0.0
    actual_val_batches = 0
    with torch.no_grad():
        for val_batch in tqdm(val_loader, desc="Validation", leave=False):
            if val_batch is None: 
                # print("Skipping None validation batch")
                continue
            images, labels = val_batch
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            
            if images.nelement() == 0: 
                continue
                
            outputs = teacher_model(images)
            loss = criterion_s(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()
            actual_val_batches += 1
    
    if val_total == 0: 
        val_accuracy = 0
        avg_val_loss = float('inf')
        print("Warning: Validation set processing resulted in zero valid samples.")
    else:
        val_accuracy = 100 * val_correct / val_total
        # Average loss over batches that were actually processed
        avg_val_loss = val_loss / actual_val_batches if actual_val_batches > 0 else float('inf')
        
    val_accuracies.append(val_accuracy)
    print(f"  Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}%")

    # Save checkpoint
    checkpoint = {
        'epoch': epoch + 1,
        'student_state_dict': student_model.state_dict(),
        'teacher_state_dict': teacher_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'best_val_acc': best_val_acc,
        'val_accuracies': val_accuracies,
        'train_losses_s': train_losses_s,
        'train_losses_u': train_losses_u,
        'train_losses_c': train_losses_c,
        'train_losses_total': train_losses_total,
        'mask_ratios': mask_ratios,
        'global_step': global_step
    }
    torch.save(checkpoint, 'latest_checkpoint.pth')

    # Save the best performing model
    if val_accuracy > best_val_acc:
        best_val_acc = val_accuracy
        print(f"  *** New Best Validation Accuracy: {best_val_acc:.2f}% - Saving model... ***")
        torch.save(teacher_model.state_dict(), 'best_teacher_model.pth')

print("Training Finished!")
print(f"Best Validation Accuracy: {best_val_acc:.2f}%")


## 7. Visualization (Losses and Accuracy)

In [None]:
def plot_metrics(epochs_range, train_losses_s, train_losses_u, train_losses_c, train_losses_total, val_accuracies, mask_ratios, best_acc):
    # Filter out potential NaN values if epochs were skipped
    valid_indices = [i for i, acc in enumerate(val_accuracies) if not np.isnan(acc)]
    if not valid_indices:
        print("No valid data points to plot.")
        return
        
    epochs_plot = [epochs_range[i] for i in valid_indices]
    train_s_plot = [train_losses_s[i] for i in valid_indices]
    train_u_plot = [train_losses_u[i] for i in valid_indices]
    train_c_plot = [train_losses_c[i] for i in valid_indices]
    train_total_plot = [train_losses_total[i] for i in valid_indices]
    val_acc_plot = [val_accuracies[i] for i in valid_indices]
    mask_ratios_plot = [mask_ratios[i] for i in valid_indices]
    
    plt.style.use('seaborn-v0_8-whitegrid')
    fig, axes = plt.subplots(1, 3, figsize=(22, 6)) # Increased figsize for 3 plots

    # Plot Losses
    axes[0].plot(epochs_plot, train_s_plot, label='Supervised Loss (Ls)', marker='.', alpha=0.7)
    axes[0].plot(epochs_plot, train_u_plot, label='Unsupervised Loss (Lu)', marker='.', alpha=0.7)
    axes[0].plot(epochs_plot, train_c_plot, label='Contrastive Loss (Lc)', marker='.', alpha=0.7)
    axes[0].plot(epochs_plot, train_total_plot, label='Total Training Loss', marker='o', linewidth=2)
    axes[0].set_title('Training Losses per Epoch')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    # Set y-axis limit to zoom in if losses become small
    if train_total_plot:
        min_loss_display = max(0, min(filter(lambda x: not np.isnan(x), train_total_plot)) - 0.5)
        max_loss_display = min(max(filter(lambda x: not np.isnan(x), train_total_plot)) + 0.5, 10) # Cap max display
        if max_loss_display > min_loss_display: 
           axes[0].set_ylim([min_loss_display, max_loss_display])
    axes[0].legend()
    axes[0].grid(True)

    # Plot Accuracy
    axes[1].plot(epochs_plot, val_acc_plot, label='Validation Accuracy', marker='o', color='crimson')
    axes[1].axhline(y=best_acc, color='green', linestyle='--', label=f'Best Accuracy ({best_acc:.2f}%)')
    axes[1].set_title('Validation Accuracy per Epoch')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].legend()
    axes[1].grid(True)
    if val_acc_plot:
      min_acc_display = max(0, min(filter(lambda x: not np.isnan(x), val_acc_plot)) - 5) 
      max_acc_display = min(100, max(filter(lambda x: not np.isnan(x), val_acc_plot)) + 5) 
      if max_acc_display > min_acc_display:
            axes[1].set_ylim([min_acc_display, max_acc_display])
            
    # Plot Mask Ratio
    axes[2].plot(epochs_plot, mask_ratios_plot, label='Mask Ratio (Avg % Used)', marker='x', color='purple')
    axes[2].set_title('Avg. Mask Ratio per Epoch')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Ratio')
    axes[2].set_ylim([0, 1]) # Ratio is between 0 and 1
    axes[2].legend()
    axes[2].grid(True)

    plt.tight_layout()
    plt.show()

    # Interpretation for Overfitting/Underfitting:
    print("\n--- Overfitting/Underfitting Analysis ---")
    print("Observe the plots:")
    print("- Underfitting: Training Loss high, Validation Accuracy low/plateaued early.")
    print("- Good Fit: Training Loss decreases, Validation Accuracy increases and plateaus high.")
    print("- Overfitting: Training Loss decreasing, Validation Accuracy plateaued/decreasing.")
    print("- Mask Ratio: Shows the fraction of unlabeled data meeting the confidence threshold. Should ideally increase over time.")
    print("-----------------------------------------")

# --- Plot the results ---
# Check if tracking lists are populated before plotting
if EPOCHS > 0 and val_accuracies:
    full_epochs_range = range(1, len(val_accuracies) + 1)
    plot_metrics(full_epochs_range, train_losses_s, train_losses_u, train_losses_c, train_losses_total, val_accuracies, mask_ratios, best_val_acc)
else:
    print("Not enough data to plot metrics. Did training run?")
