# EPASS Implementation - RESUME Training

This notebook **resumes** the EPASS training from a previously saved checkpoint (`epass_tinyimagenet_best.pth`).

**Goal:** Continue training for an additional number of epochs.

**Assumptions:**
*   The `epass_tinyimagenet_best.pth` file exists in the current directory and contains the state dictionary of the best student model from the previous run.
*   The configuration parameters (architecture, batch sizes, learning rate, etc.) are kept consistent with the initial run.
*   The data has already been downloaded and the validation set reorganized.

## 1. Setup and Imports

In [None]:
# Import dependencies
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime as dt
import os
import random
import math
import copy
import zipfile
from tqdm.notebook import tqdm
import gc # Garbage collector

import torch
from torch import optim, nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision.utils import make_grid
from torchvision import models, datasets
from torchvision import transforms as T

# Ignore warnings
import warnings
warnings.filterwarnings('ignore')

# For reproducibility (keep consistent if possible, though data shuffling will differ)
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## 2. Configuration (MUST match initial run, except for epochs)

In [None]:
# Device Configuration
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(f"Using device: {device}")

# --- Parameters from the initial run (KEEP THESE THE SAME) ---
DATA_DIR = 'tiny-imagenet-200'
NUM_CLASSES = 200
IMG_SIZE = 64 
INPUT_SIZE = 64
labeled_ratio = 0.1  
batch_size = 128      
mu = 3               
labeled_bs = batch_size // (mu + 1)
unlabeled_bs = batch_size - labeled_bs
confidence_threshold = 0.95 
num_projectors = 3   
projection_dim = 128 
learning_rate = 0.03 # Initial LR (scheduler will adjust)
momentum = 0.9
weight_decay = 5e-4
ema_decay = 0.999    
lambda_u = 1.0       
lambda_c = 0.1       
contrastive_temp = 0.1 

# --- Parameters for Resumed Run ---
start_epoch = 35 # Epoch number the previous run finished at
additional_epochs = 35 # How many MORE epochs to run
total_epochs = start_epoch + additional_epochs # The final target epoch number
checkpoint_path = 'epass_tinyimagenet_best.pth' # Path to saved model
previous_best_accuracy = 37.16 # Best accuracy from the previous run

print(f"Resuming training from epoch {start_epoch}")
print(f"Running for {additional_epochs} additional epochs (up to epoch {total_epochs})")
print(f"Loading checkpoint from: {checkpoint_path}")
print(f"Previous best accuracy: {previous_best_accuracy}%")

# Print derived batch sizes
print(f"Total Batch Size: {batch_size}")
print(f"Labeled Batch Size: {labeled_bs}")
print(f"Unlabeled Batch Size: {unlabeled_bs}")

## 3. Data Download and Preparation
(Should already be done, but run to confirm paths)

In [None]:
# Ensure data directory exists
if not os.path.exists(DATA_DIR):
     raise FileNotFoundError(f"Data directory '{DATA_DIR}' not found. Please run the initial data download cell first.")
else:
    print("Tiny ImageNet directory exists.")

# Define training and validation data paths
TRAIN_DIR = os.path.join(DATA_DIR, 'train') 
VALID_DIR = os.path.join(DATA_DIR, 'val')
reorganized_val_dir = os.path.join(VALID_DIR, 'organized')

# Check if validation set is reorganized
if not os.path.exists(reorganized_val_dir):
    # Attempt reorganization if needed (copy logic from previous notebook)
    print("Reorganizing validation set...")
    # ... [Include the validation set reorganization code here if necessary] ...
    # For brevity, assuming it's already done. If not, paste the code from the previous notebook.
    # Ensure VALID_DIR_LOADER is set correctly after reorganization.
    VALID_DIR_LOADER = reorganized_val_dir # Adjust if needed
    if not os.path.exists(VALID_DIR_LOADER):
         raise FileNotFoundError("Validation directory reorganized path does not exist. Reorganization might have failed.")
    print("Validation set reorganized.")
else:
    print("Using previously reorganized validation set:", reorganized_val_dir)
    VALID_DIR_LOADER = reorganized_val_dir

# --- Load Class Names --- 
# (Same as before)

## 4. Augmentations and Datasets
(Recreate datasets and dataloaders)

In [None]:
# ImageNet normalization values
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

# Weak Augmentation 
transform_weak = T.Compose([
    T.RandomResizedCrop(INPUT_SIZE, scale=(0.2, 1.0)),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    normalize
])

# Strong Augmentation
transform_strong = T.Compose([
    T.RandomResizedCrop(INPUT_SIZE, scale=(0.2, 1.0)),
    T.RandomHorizontalFlip(),
    T.RandAugment(num_ops=2, magnitude=10),
    T.ToTensor(),
    normalize
])

# Standard transform for validation
transform_val = T.Compose([
    T.ToTensor(),
    normalize
])

# --- Custom Datasets --- (Same classes as before)
class TinyImageNetLabeled(Dataset):
    def __init__(self, root, indices, transform):
        temp_dataset = datasets.ImageFolder(root)
        self.base_path = root
        self.samples = [(temp_dataset.samples[i][0], temp_dataset.targets[i]) for i in indices]
        self.transform = transform
        self.loader = temp_dataset.loader
        del temp_dataset
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        path, target = self.samples[idx]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        return img, target

class TinyImageNetUnlabeled(Dataset):
    def __init__(self, root, indices, transform_weak, transform_strong):
        temp_dataset = datasets.ImageFolder(root)
        self.base_path = root
        self.samples = [temp_dataset.samples[i][0] for i in indices]
        self.transform_weak = transform_weak
        self.transform_strong = transform_strong
        self.loader = temp_dataset.loader
        del temp_dataset
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        path = self.samples[idx]
        img = self.loader(path)
        img_weak = self.transform_weak(img)
        img_strong = self.transform_strong(img)
        return img_weak, img_strong

# --- Recreate Labeled/Unlabeled Split --- 
# NOTE: The split will be different due to shuffling unless the indices were saved.
# Using the same *procedure* ensures the dataset sizes are correct.
print("Recreating train/validation datasets...")
base_train_dataset_info = datasets.ImageFolder(TRAIN_DIR)
val_dataset = datasets.ImageFolder(VALID_DIR_LOADER, transform=transform_val)
num_train_samples = len(base_train_dataset_info)
num_labeled_samples = int(labeled_ratio * num_train_samples)
print(f"Total training samples: {num_train_samples}")
print(f"Using {num_labeled_samples} labeled samples ({labeled_ratio*100}%)")

targets = np.array(base_train_dataset_info.targets)
labeled_indices = []
unlabeled_indices = []
samples_per_class = num_labeled_samples // NUM_CLASSES
if samples_per_class == 0:
    raise ValueError(f"labeled_ratio ({labeled_ratio}) is too small for {NUM_CLASSES} classes.")
print(f"Aiming for {samples_per_class} labeled samples per class.")

np.random.seed(seed) # Ensure same split if seed is same
for i in range(NUM_CLASSES):
    class_indices = np.where(targets == i)[0]
    np.random.shuffle(class_indices)
    labeled_indices.extend(class_indices[:samples_per_class])
    unlabeled_indices.extend(class_indices[samples_per_class:])

num_labeled_actual = len(labeled_indices)
num_unlabeled_actual = len(unlabeled_indices)
print(f"Actual labeled samples: {num_labeled_actual}")
print(f"Actual unlabeled samples: {num_unlabeled_actual}")

print("Creating labeled/unlabeled datasets...")
labeled_dataset = TinyImageNetLabeled(TRAIN_DIR, labeled_indices, transform=transform_weak)
unlabeled_dataset = TinyImageNetUnlabeled(TRAIN_DIR, unlabeled_indices, 
                                          transform_weak=transform_weak, 
                                          transform_strong=transform_strong)
del base_train_dataset_info
gc.collect()

# --- Create DataLoaders ---
print("Creating DataLoaders...")
num_iterations = len(unlabeled_dataset) // unlabeled_bs
if num_iterations == 0:
    raise ValueError(f"Unlabeled batch size ({unlabeled_bs}) too large.")

dataloader_num_workers = 2 if use_cuda else 0
labeled_loader = DataLoader(
    labeled_dataset, batch_size=labeled_bs, shuffle=True,
    num_workers=dataloader_num_workers, pin_memory=use_cuda, drop_last=True)
unlabeled_loader = DataLoader(
    unlabeled_dataset, batch_size=unlabeled_bs, shuffle=True,
    num_workers=dataloader_num_workers, pin_memory=use_cuda, drop_last=True)
val_loader = DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False,
    num_workers=dataloader_num_workers, pin_memory=use_cuda)

print(f"Labeled loader batches per epoch: {len(labeled_loader)}")
print(f"Unlabeled loader batches per epoch: {len(unlabeled_loader)}")
print(f"Validation loader batches: {len(val_loader)}")
print(f"Number of iterations per epoch: {num_iterations}")

## 5. Model Definition (Encoder + EPASS Projector + Classifier)
(Identical definition as before)

In [None]:
class MLPProjector(nn.Module):
    def __init__(self, input_dim, hidden_dim=512, output_dim=128):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim)
        )
    def forward(self, x):
        return self.block(x)

def create_modified_resnet18(pretrained=False): # Set pretrained=False initially
    model = models.resnet18(pretrained=False)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    # We will modify the final fc layer inside EPASS_Model
    return model

class EPASS_Model(nn.Module):
    def __init__(self, backbone, num_classes, num_projectors, projection_dim):
        super().__init__()
        self.backbone = backbone
        if hasattr(backbone, 'fc'):
             self.feature_dim = backbone.fc.in_features
             self.backbone.fc = nn.Identity()
        else:
             self.feature_dim = 512
        print(f"Backbone feature dimension: {self.feature_dim}")
        self.classifier = nn.Linear(self.feature_dim, num_classes)
        self.num_projectors = num_projectors
        self.projectors = nn.ModuleList([
            MLPProjector(self.feature_dim, output_dim=projection_dim)
            for _ in range(num_projectors)
        ])
    def forward(self, x, return_features=False, return_projection=False):
        features = self.backbone(x)
        logits = self.classifier(features)
        if return_projection:
            projected_embeddings = []
            for projector in self.projectors:
                projected_embeddings.append(projector(features))
            ensembled_embedding = torch.mean(torch.stack(projected_embeddings, dim=0), dim=0)
            ensembled_embedding = F.normalize(ensembled_embedding, dim=1)
            if return_features:
                return logits, features, ensembled_embedding
            else:
                return logits, ensembled_embedding
        else:
            if return_features:
                return logits, features
            else:
                return logits

## 6. Model Instantiation and Loading State

In [None]:
print("Creating student and teacher models...")
backbone_student = create_modified_resnet18(pretrained=False) # Arch only
student_model = EPASS_Model(backbone_student, NUM_CLASSES, num_projectors, projection_dim).to(device)

backbone_teacher = create_modified_resnet18(pretrained=False) # Arch only
teacher_model = EPASS_Model(backbone_teacher, NUM_CLASSES, num_projectors, projection_dim).to(device)

# --- Load the saved state dictionary --- 
if os.path.exists(checkpoint_path):
    print(f"Loading state from {checkpoint_path}...")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    # Load student state
    student_model.load_state_dict(checkpoint) # Assumes saved file is just the state_dict
    print("Student model state loaded.")
    
    # Initialize teacher with the loaded student state
    print("Initializing teacher model from loaded student state...")
    for param_teacher, param_student in zip(teacher_model.parameters(), student_model.parameters()):
        param_teacher.data.copy_(param_student.data)
        param_teacher.requires_grad = False
    print("Teacher model initialized.")

else:
    raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}. Cannot resume training.")

print("Models ready for resumed training.")

## 7. Loss Functions, Optimizer, and Scheduler (Re-initialize)

In [None]:
# Loss Functions (Same as before)
criterion_s = nn.CrossEntropyLoss()
criterion_u = nn.CrossEntropyLoss(reduction='none')
def contrastive_loss(emb_student_strong, emb_teacher_weak, temperature):
    sim = torch.sum(emb_student_strong * emb_teacher_weak, dim=1)
    loss = -sim / temperature
    return loss.mean()

# Optimizer (Linked to the student model with loaded weights)
optimizer = optim.SGD(student_model.parameters(), 
                      lr=learning_rate, 
                      momentum=momentum, 
                      weight_decay=weight_decay,
                      nesterov=True)

# Scheduler 
# IMPORTANT: T_max should be based on the *total* number of epochs (original + additional)
total_train_steps = num_iterations * total_epochs
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_train_steps)

# --- Advance the scheduler --- 
# We need to step the scheduler forward to where it would have been at the start_epoch
steps_to_advance = start_epoch * num_iterations
print(f"Advancing scheduler by {steps_to_advance} steps...")
# Temporarily set the initial LR in the optimizer to avoid warnings during fast-forwarding
initial_lr = optimizer.param_groups[0]['lr']
for _ in tqdm(range(steps_to_advance), desc="Fast-forwarding scheduler"):
    scheduler.step()

# Restore the original LR in the optimizer (scheduler.step() modifies it)
# The *next* scheduler.step() in the training loop will apply the correct decayed LR.
# optimizer.param_groups[0]['lr'] = initial_lr 
# Correction: No, scheduler.step() correctly sets the LR for the *next* step.
# We want the LR that scheduler arrived at after advancing.
current_lr = optimizer.param_groups[0]['lr']
print(f"Scheduler advanced. Current LR for next step: {current_lr:.1e}")

# EMA Update Function (Same as before)
@torch.no_grad()
def update_ema_variables(model, ema_model, alpha, global_step):
    alpha = min(1 - 1 / (global_step + 1), alpha)
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)

## 8. Training Loop Function (Same as before)

In [None]:
def train_one_epoch(student_model, teacher_model, labeled_loader, unlabeled_loader, optimizer, scheduler, epoch, num_iterations, total_epochs):
    student_model.train()
    teacher_model.eval()

    total_loss_s, total_loss_u, total_loss_c, total_loss = 0.0, 0.0, 0.0, 0.0
    mask_ratio_acc = 0.0
    labeled_processed, unlabeled_processed = 0, 0

    labeled_iter = iter(labeled_loader)
    unlabeled_iter = iter(unlabeled_loader)
    
    # Use total_epochs in the progress bar description
    pbar = tqdm(range(num_iterations), desc=f"Epoch {epoch+1}/{total_epochs}") 

    for i in pbar:
        try:
            images_l, targets_l = next(labeled_iter)
        except StopIteration:
            labeled_iter = iter(labeled_loader)
            images_l, targets_l = next(labeled_iter)

        try:
            images_uw, images_us = next(unlabeled_iter)
        except StopIteration:
             print("Warning: Unlabeled loader exhausted unexpectedly.")
             break 
        
        images_l, targets_l = images_l.to(device), targets_l.to(device)
        images_uw, images_us = images_uw.to(device), images_us.to(device)

        bs_l = images_l.shape[0]
        bs_u = images_uw.shape[0]

        logits_l = student_model(images_l)
        loss_s = criterion_s(logits_l, targets_l)

        logits_us, emb_us = student_model(images_us, return_projection=True)
        
        with torch.no_grad():
            logits_uw, emb_uw = teacher_model(images_uw, return_projection=True)

        pseudo_label = torch.softmax(logits_uw, dim=-1)
        max_probs, targets_u = torch.max(pseudo_label, dim=-1)
        mask = max_probs.ge(confidence_threshold).float()
        loss_u_all = criterion_u(logits_us, targets_u)
        loss_u = (torch.sum(loss_u_all * mask) / (mask.sum() + 1e-6))
        mask_ratio = mask.mean().item()

        loss_c = contrastive_loss(emb_us, emb_uw, contrastive_temp)
        
        loss_u_item = 0.0 if torch.isnan(loss_u) or torch.isinf(loss_u) else loss_u.item()
        total_batch_loss = loss_s + lambda_u * loss_u + lambda_c * loss_c if loss_u_item != 0.0 else loss_s + lambda_c * loss_c

        optimizer.zero_grad()
        total_batch_loss.backward()
        optimizer.step()
        scheduler.step() 

        global_step = epoch * num_iterations + i # Global step relative to start of training (epoch 0)
        update_ema_variables(student_model, teacher_model, ema_decay, global_step)

        loss_s_item = loss_s.item()
        loss_c_item = loss_c.item()
        total_loss_s += loss_s_item * bs_l
        total_loss_u += loss_u_item * bs_u
        total_loss_c += loss_c_item * bs_u
        total_loss += total_batch_loss.item() 
        mask_ratio_acc += mask_ratio
        labeled_processed += bs_l
        unlabeled_processed += bs_u

        pbar.set_postfix({
            'Ls': f'{loss_s_item:.4f}', 'Lu': f'{loss_u_item:.4f}', 'Lc': f'{loss_c_item:.4f}',
            'Mask': f'{mask_ratio:.2f}', 'LR': f'{optimizer.param_groups[0]["lr"]:.1e}'
        })

    avg_loss_s = total_loss_s / labeled_processed if labeled_processed > 0 else 0
    avg_loss_u = total_loss_u / unlabeled_processed if unlabeled_processed > 0 else 0
    avg_loss_c = total_loss_c / unlabeled_processed if unlabeled_processed > 0 else 0
    avg_loss = (total_loss_s + total_loss_u + total_loss_c) / (labeled_processed + unlabeled_processed) if (labeled_processed + unlabeled_processed) > 0 else 0
    avg_mask_ratio = mask_ratio_acc / num_iterations if num_iterations > 0 else 0

    print(f"Epoch {epoch+1}/{total_epochs} Summary: Avg Loss: {avg_loss:.4f}, Ls: {avg_loss_s:.4f}, Lu: {avg_loss_u:.4f}, Lc: {avg_loss_c:.4f}, Mask Ratio: {avg_mask_ratio:.4f}")
    return avg_loss, avg_mask_ratio

## 9. Evaluation Function (Same as before)

In [None]:
def evaluate(model, val_loader):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        pbar = tqdm(val_loader, desc="Evaluating")
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images, return_projection=False)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            pbar.set_postfix({'Val Acc': f'{(100 * correct / total):.2f}%'})

    avg_loss = total_loss / total if total > 0 else 0
    accuracy = 100 * correct / total if total > 0 else 0
    print(f'Validation Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')
    return avg_loss, accuracy

## 10. Run Resumed Training and Evaluation

In [None]:
train_losses = []
val_losses = []
val_accuracies = []
mask_ratios = []
best_val_accuracy = previous_best_accuracy # Initialize with previous best
best_epoch = start_epoch # Assume the loaded model is the best initially

start_time_resume = dt.now()
print(f"Resuming training at {start_time_resume}...")

if use_cuda:
    torch.cuda.empty_cache()
gc.collect()

# The loop now starts from start_epoch and goes up to total_epochs
for epoch in range(start_epoch, total_epochs):
    # Train
    # Pass total_epochs to the training function for correct global step calculation
    train_loss, mask_ratio = train_one_epoch(student_model, teacher_model, labeled_loader, unlabeled_loader, 
                                             optimizer, scheduler, epoch, num_iterations, total_epochs) 
    train_losses.append(train_loss) # Store losses for this resumed session
    mask_ratios.append(mask_ratio)

    # Evaluate
    val_loss, val_accuracy = evaluate(student_model, val_loader)
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)

    # Save best model (only if current accuracy exceeds the overall best)
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        best_epoch = epoch + 1 # Record the epoch number (1-based)
        torch.save(student_model.state_dict(), checkpoint_path) # Overwrite the previous best
        print(f"*** Best model saved with accuracy: {best_val_accuracy:.2f}% (Epoch {best_epoch}) ***")
        
    # Optional: Clear cache
    if use_cuda and (epoch + 1) % 5 == 0:
        torch.cuda.empty_cache()
        gc.collect()

end_time_resume = dt.now()
print(f"\nResumed training finished at {end_time_resume}. Duration: {end_time_resume - start_time_resume}")
print(f"Overall Best Validation Accuracy: {best_val_accuracy:.2f}% (achieved at Epoch {best_epoch})")

## 11. Plotting Results (for the resumed part)

In [None]:
# Plotting for the resumed epochs only
epochs_resumed_range = range(start_epoch + 1, total_epochs + 1)

plt.figure(figsize=(18, 5))

plt.subplot(1, 3, 1)
plt.plot(epochs_resumed_range, train_losses, label='Avg Training Loss (Resumed)')
plt.plot(epochs_resumed_range, val_losses, label='Validation Loss (Resumed)')
plt.title(f'Losses (Epochs {start_epoch+1}-{total_epochs})')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 2)
plt.plot(epochs_resumed_range, val_accuracies, label='Validation Accuracy (Resumed)', color='green')
plt.axhline(y=previous_best_accuracy, color='r', linestyle='--', label=f'Previous Best ({previous_best_accuracy:.2f}%)')
plt.title(f'Validation Accuracy (Epochs {start_epoch+1}-{total_epochs})')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 3)
plt.plot(epochs_resumed_range, mask_ratios, label='Mask Ratio (Resumed)', color='orange')
plt.title(f'Mask Ratio (Epochs {start_epoch+1}-{total_epochs})')
plt.xlabel('Epoch')
plt.ylabel('Ratio')
plt.ylim(0, 1)
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()