# JointMatch Reimagined: Tiny ImageNet SSL (Kaggle Optimized)

This notebook implements the JointMatch semi-supervised learning algorithm for image classification on the Tiny ImageNet dataset using a custom PyTorch training loop, optimized for Kaggle.

**Core Idea:** Leverage two models teaching each other with adaptive thresholds, without relying on external training frameworks like Ignite.

**References:**
*   **JointMatch Paper:** Zou, H. P., & Caragea, C. (2023). JointMatch: A Unified Approach for Diverse and Collaborative Pseudo-Labeling to Semi-Supervised Text Classification. *EMNLP 2023*.

## 1. Setup & Installs

In [None]:
# Install necessary packages (EfficientNet and dataset downloader)
print("Installing dependencies...")
!pip install -q opendatasets efficientnet-pytorch tqdm
print("Dependencies installed.")

# --- Verification ---
try:
    import torch
    import torchvision
    print(f"PyTorch Version: {torch.__version__}")
    print(f"Torchvision Version: {torchvision.__version__}")
    print(f"CUDA Available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA Version detected by PyTorch: {torch.version.cuda}")
        print(f"Device Name: {torch.cuda.get_device_name(0)}")
except ImportError as e:
    print(f"Error importing torch/torchvision: {e}. Ensure PyTorch is installed.")
except Exception as e:
     print(f"An error occurred during verification: {e}")

## 2. Imports

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime as dt
import os
import random
import zipfile
import shutil
import warnings
import copy
from math import floor, ceil
from sklearn.model_selection import train_test_split # Still useful for potential splits
from itertools import cycle
import time
import glob
from tqdm.notebook import tqdm # Progress bar

import torch
from torch import optim, nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset, ConcatDataset
from torchvision.utils import make_grid
from torchvision import models, datasets
from torchvision import transforms as T
from efficientnet_pytorch import EfficientNet

warnings.filterwarnings('ignore')

## 3. Configuration

In [None]:
# Basic Setup
SEED = 42
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(f"Using device: {device}")

# --- Data Parameters ---
DATA_DIR_KAGGLE = '/kaggle/input/tiny-imagenet-200/tiny-imagenet-200'
DATA_DIR_MANUAL = './tiny-imagenet-200'
MANUAL_DOWNLOAD = True # Set True to download, False if using Kaggle dataset input
DATA_DIR = DATA_DIR_MANUAL if MANUAL_DOWNLOAD else DATA_DIR_KAGGLE

NUM_CLASSES = 200
IMG_SIZE = 64
VAL_IMAGES_PER_CLASS = 50

# --- Semi-Supervised Learning Parameters ---
NUM_LABELED_PER_CLASS = 10
TOTAL_TRAIN_IMAGES_PER_CLASS = 500

# --- JointMatch Specific Hyperparameters ---
ema_decay = 0.999
base_threshold = 0.95
disagreement_weight = 0.7
unlabeled_loss_weight = 1.0

# --- Training Hyperparameters ---
batch_size = 64
unlabeled_ratio = 7
labeled_batch_size = batch_size // (unlabeled_ratio + 1)
unlabeled_batch_size = batch_size - labeled_batch_size
print(f"Total Batch Size: {batch_size}")
print(f"  Labeled Batch Size: {labeled_batch_size}")
print(f"  Unlabeled Batch Size: {unlabeled_batch_size}")

lr = 3e-4
num_epochs = 50 # Increased epochs for better convergence potential
num_labeled_total = NUM_LABELED_PER_CLASS * NUM_CLASSES
# Define steps per epoch based on labeled data
steps_per_epoch = ceil(num_labeled_total / labeled_batch_size)
print(f"Steps per epoch (based on labeled data): {steps_per_epoch}")
gradient_accumulation_steps = 1 # Set > 1 for larger effective batch size if memory constrained

# --- Saving/Loading Parameters ---
KAGGLE_WORKING_DIR = "/kaggle/working/"
output_dir = os.path.join(KAGGLE_WORKING_DIR, "jointmatch_output")
checkpoint_dir = os.path.join(output_dir, "checkpoints")
best_model_path = os.path.join(output_dir, "best_model.pth")
save_every_epochs = 5 # How often to save periodic checkpoints

# --- Helper Functions ---
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if use_cuda:
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        # torch.backends.cudnn.deterministic = True # Can slow down training
        torch.backends.cudnn.benchmark = False

set_seed(SEED)

# Create output directories
os.makedirs(output_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)

## 4. Data Loading & Preprocessing

In [None]:
# --- Download and Extract Data (if MANUAL_DOWNLOAD is True) ---
if MANUAL_DOWNLOAD:
    zip_file_path = './tiny-imagenet-200.zip'
    if not os.path.exists(DATA_DIR):
        if not os.path.exists(zip_file_path):
            print("Downloading Tiny ImageNet...")
            !wget -q http://cs231n.stanford.edu/tiny-imagenet-200.zip -O {zip_file_path}
            print("Download complete.")
        else:
            print(f"Zip file {zip_file_path} already exists.")

        print("Extracting Tiny ImageNet...")
        try:
            with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
                zip_ref.extractall('.')
            print("Data extracted.")
        except zipfile.BadZipFile:
             print(f"Error: {zip_file_path} is corrupted or not a zip file. Please re-download.")
             raise
        except Exception as e:
            print(f"An error occurred during extraction: {e}")
            raise
    else:
        print(f"Tiny ImageNet directory '{DATA_DIR}' already exists.")
elif not os.path.exists(DATA_DIR):
     raise FileNotFoundError(f"Kaggle input directory '{DATA_DIR}' not found. Check dataset path.")
else:
    print(f"Using Tiny ImageNet data from Kaggle input: {DATA_DIR}")

TRAIN_DIR = os.path.join(DATA_DIR, 'train')
VAL_DIR = os.path.join(DATA_DIR, 'val')

# --- Organize Validation Folder (Important for ImageFolder) --- 
val_img_dir = os.path.join(VAL_DIR, 'images')
annotations_file = os.path.join(VAL_DIR, 'val_annotations.txt')
organized_val_dir_exists = False
if os.path.isdir(VAL_DIR):
    if not os.path.exists(val_img_dir) and not os.path.exists(annotations_file):
        subdirs = [d for d in os.listdir(VAL_DIR) if os.path.isdir(os.path.join(VAL_DIR, d))]
        if len(subdirs) == NUM_CLASSES:
            organized_val_dir_exists = True
            print("Validation folder appears already organized.")

if not organized_val_dir_exists and os.path.exists(val_img_dir) and os.path.exists(annotations_file):
    print("Organizing validation folder...")
    try:
        val_data = pd.read_csv(annotations_file, sep='\t', header=None, names=['File', 'Class', 'X', 'Y', 'H', 'W'])
        for index, row in tqdm(val_data.iterrows(), total=len(val_data), desc="Organizing Val"):
            img_class = row['Class']
            img_file = row['File']
            class_dir = os.path.join(VAL_DIR, img_class)
            os.makedirs(class_dir, exist_ok=True)
            source_path = os.path.join(val_img_dir, img_file)
            dest_path = os.path.join(class_dir, img_file)
            if os.path.exists(source_path):
                shutil.move(source_path, dest_path)
        if os.path.exists(val_img_dir) and not os.listdir(val_img_dir):
            os.rmdir(val_img_dir)
        os.remove(annotations_file)
        print("Validation folder organized.")
    except Exception as e:
        print(f"Error organizing validation folder: {e}")
elif not organized_val_dir_exists:
    print("Warning: Validation folder not organized and source files missing. Cannot organize.")

# --- Define Data Augmentations ---
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

weak_transform = T.Compose([
    T.RandomResizedCrop(IMG_SIZE, scale=(0.2, 1.0)),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=mean, std=std)
])

strong_transform = T.Compose([
    T.RandomResizedCrop(IMG_SIZE, scale=(0.2, 1.0)),
    T.RandomHorizontalFlip(),
    T.RandAugment(num_ops=2, magnitude=10),
    T.ToTensor(),
    T.Normalize(mean=mean, std=std)
])

val_transform = T.Compose([
    T.Resize(IMG_SIZE + 8),
    T.CenterCrop(IMG_SIZE),
    T.ToTensor(),
    T.Normalize(mean=mean, std=std)
])

# --- Create Datasets ---
try:
    full_train_dataset = datasets.ImageFolder(TRAIN_DIR)
except FileNotFoundError:
    print(f"ERROR: Training directory '{TRAIN_DIR}' not found.")
    raise
except Exception as e:
    print(f"Error loading training dataset: {e}")
    raise

# Split training data into labeled and unlabeled sets
targets = np.array(full_train_dataset.targets)
labeled_indices = []
unlabeled_indices = []
for i in range(NUM_CLASSES):
    class_indices = np.where(targets == i)[0]
    if len(class_indices) < NUM_LABELED_PER_CLASS:
        print(f"Warning: Class {i} has {len(class_indices)} samples < {NUM_LABELED_PER_CLASS}. Using all as labeled.")
        labeled_indices.extend(class_indices)
    else:
        np.random.shuffle(class_indices)
        labeled_indices.extend(class_indices[:NUM_LABELED_PER_CLASS])
        unlabeled_indices.extend(class_indices[NUM_LABELED_PER_CLASS:])
print(f"Total training samples: {len(full_train_dataset)}")
print(f"Labeled samples: {len(labeled_indices)}")
print(f"Unlabeled samples: {len(unlabeled_indices)}")

# Custom Dataset Wrappers (already defined in previous attempt, slightly adjusted)
class LabeledDataset(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __getitem__(self, index):
        x, y = self.subset[index]
        return self.transform(x), y
    def __len__(self):
        return len(self.subset)

class UnlabeledDataset(Dataset):
    def __init__(self, subset, transform_weak, transform_strong):
        self.subset = subset
        self.transform_weak = transform_weak
        self.transform_strong = transform_strong
    def __getitem__(self, index):
        x, _ = self.subset[index]
        return self.transform_weak(x), self.transform_strong(x)
    def __len__(self):
        return len(self.subset)

labeled_subset = Subset(full_train_dataset, labeled_indices)
unlabeled_subset = Subset(full_train_dataset, unlabeled_indices)

labeled_dataset = LabeledDataset(labeled_subset, weak_transform)
unlabeled_dataset = UnlabeledDataset(unlabeled_subset, weak_transform, strong_transform)

# Validation Dataset
val_loader = None
try:
    val_dataset = datasets.ImageFolder(VAL_DIR, transform=val_transform)
    if len(val_dataset) > 0:
        num_workers = 2 # Kaggle default
        pin_memory = True if use_cuda else False
        val_loader = DataLoader(val_dataset, batch_size=batch_size * 2, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
        print(f"Validation dataset loaded with {len(val_dataset)} samples.")
    else:
        print("Warning: Validation dataset is empty.")
except FileNotFoundError:
    print(f"ERROR: Validation directory '{VAL_DIR}' not found or not organized.")
except Exception as e:
    print(f"Error loading validation dataset: {e}")

# --- Create Training DataLoaders ---
num_workers = 2
pin_memory = True if use_cuda else False
labeled_loader = DataLoader(labeled_dataset, batch_size=labeled_batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory, drop_last=True)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=unlabeled_batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory, drop_last=True)

# Create iterators
labeled_iter = cycle(labeled_loader)
unlabeled_iter = cycle(unlabeled_loader)
print("DataLoaders created.")

## 5. Model Definition

In [None]:
# Load two instances of the EfficientNet model
MODEL_ARCH = 'efficientnet-b0'
print(f"Loading model architecture: {MODEL_ARCH}")
try:
    model_f = EfficientNet.from_pretrained(MODEL_ARCH, num_classes=NUM_CLASSES)
    model_g = EfficientNet.from_pretrained(MODEL_ARCH, num_classes=NUM_CLASSES)
    print(f"Loaded pretrained weights for {MODEL_ARCH}")
except Exception as e:
    print(f"Error loading pretrained model: {e}")
    raise

model_f = model_f.to(device)
model_g = model_g.to(device)

## 6. Optimizer & Loss

In [None]:
# Optimizers
optimizer_f = optim.Adam(model_f.parameters(), lr=lr)
optimizer_g = optim.Adam(model_g.parameters(), lr=lr)

# Loss functions
criterion_s = nn.CrossEntropyLoss().to(device) # Supervised loss
criterion_u = nn.CrossEntropyLoss(reduction='none').to(device) # Unlabeled loss (manual reduction)

## 7. Evaluation Function

In [None]:
def evaluate(model, dataloader, criterion, device, desc="Evaluating"):
    model.eval() # Set model to evaluation mode
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    with torch.no_grad(): # Disable gradient calculations
        for inputs, labels in tqdm(dataloader, desc=desc, leave=False):
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * inputs.size(0)

            _, predicted = torch.max(outputs.data, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

    avg_loss = total_loss / total_samples if total_samples > 0 else 0
    accuracy = correct_predictions / total_samples if total_samples > 0 else 0
    return avg_loss, accuracy

## 8. Training Loop

In [None]:
print("Starting JointMatch Training...")
start_time = time.time()
best_val_acc = -1.0
history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

# Initialize EMA probabilities (Needs to be persistent across steps)
ema_p = torch.ones(NUM_CLASSES).to(device) / NUM_CLASSES

for epoch in range(num_epochs):
    epoch_start_time = time.time()
    model_f.train()
    model_g.train()

    # Track epoch metrics
    epoch_loss_s_f, epoch_loss_u_f, epoch_loss_s_g, epoch_loss_u_g = 0.0, 0.0, 0.0, 0.0
    total_batches = 0
    conf_f_count, conf_g_count, disagree_count = 0, 0, 0
    total_unlabeled_processed = 0

    # Use tqdm for the main training loop progress
    pbar = tqdm(range(steps_per_epoch), desc=f"Epoch {epoch+1}/{num_epochs}")

    for step in pbar:
        # === Get Batches ===
        try:
            inputs_l, targets_l = next(labeled_iter)
            inputs_u_w, inputs_u_s = next(unlabeled_iter)
        except StopIteration:
            # Should not happen with cycle, but reset just in case
            labeled_iter = cycle(labeled_loader)
            unlabeled_iter = cycle(unlabeled_loader)
            inputs_l, targets_l = next(labeled_iter)
            inputs_u_w, inputs_u_s = next(unlabeled_iter)

        inputs_l, targets_l = inputs_l.to(device), targets_l.to(device)
        inputs_u_w, inputs_u_s = inputs_u_w.to(device), inputs_u_s.to(device)

        # === Supervised Loss ===
        logits_l_f = model_f(inputs_l)
        logits_l_g = model_g(inputs_l)
        loss_s_f = criterion_s(logits_l_f, targets_l)
        loss_s_g = criterion_s(logits_l_g, targets_l)

        # === Unlabeled Loss Calculation ===
        with torch.no_grad():
            logits_u_w_f = model_f(inputs_u_w)
            logits_u_w_g = model_g(inputs_u_w)
            probs_u_w_f = F.softmax(logits_u_w_f, dim=1)
            probs_u_w_g = F.softmax(logits_u_w_g, dim=1)

            # Update EMA
            avg_probs_u_w = (probs_u_w_f + probs_u_w_g) / 2.0
            batch_p = avg_probs_u_w.mean(dim=0)
            ema_p = ema_decay * ema_p + (1 - ema_decay) * batch_p
            ema_p = ema_p.detach()

            # Adaptive Thresholds
            max_ema_p = torch.max(ema_p)
            normalized_ema_p = ema_p / (max_ema_p + 1e-8) # Add epsilon
            adaptive_thresholds = base_threshold * normalized_ema_p
            min_threshold = 1.0 / NUM_CLASSES
            adaptive_thresholds = torch.max(adaptive_thresholds, torch.tensor(min_threshold).to(device))

            # Pseudo-labels & Masks
            max_probs_f, pseudo_labels_hard_f = torch.max(probs_u_w_f, dim=1)
            max_probs_g, pseudo_labels_hard_g = torch.max(probs_u_w_g, dim=1)
            thresholds_f = adaptive_thresholds.gather(0, pseudo_labels_hard_f)
            thresholds_g = adaptive_thresholds.gather(0, pseudo_labels_hard_g)
            mask_f = max_probs_f.ge(thresholds_f).float()
            mask_g = max_probs_g.ge(thresholds_g).float()

            # Disagreement Weights
            disagree_mask = (pseudo_labels_hard_f != pseudo_labels_hard_g).float()
            sample_weights = disagreement_weight * disagree_mask + (1.0 - disagreement_weight) * (1.0 - disagree_mask)

            # Track stats
            num_conf_f = mask_f.sum().item()
            num_conf_g = mask_g.sum().item()
            num_disagree = disagree_mask.sum().item()
            conf_f_count += num_conf_f
            conf_g_count += num_conf_g
            disagree_count += num_disagree
            total_unlabeled_processed += inputs_u_w.size(0)

        # === Consistency Loss ===
        logits_u_s_f = model_f(inputs_u_s)
        logits_u_s_g = model_g(inputs_u_s)

        loss_u_f_all = criterion_u(logits_u_s_f, pseudo_labels_hard_g)
        loss_u_f = (loss_u_f_all * mask_g * sample_weights).sum() / (num_conf_g + 1e-8)

        loss_u_g_all = criterion_u(logits_u_s_g, pseudo_labels_hard_f)
        loss_u_g = (loss_u_g_all * mask_f * sample_weights).sum() / (num_conf_f + 1e-8)

        loss_u_f = torch.nan_to_num(loss_u_f)
        loss_u_g = torch.nan_to_num(loss_u_g)

        # === Total Loss & Backpropagation ===
        total_loss_f = loss_s_f + unlabeled_loss_weight * loss_u_f
        total_loss_g = loss_s_g + unlabeled_loss_weight * loss_u_g

        # Normalize loss for accumulation (optional, if steps vary)
        # total_loss_f = total_loss_f / gradient_accumulation_steps
        # total_loss_g = total_loss_g / gradient_accumulation_steps

        # Backward pass for F
        optimizer_f.zero_grad()
        total_loss_f.backward()
        # Clip gradients (optional)
        # torch.nn.utils.clip_grad_norm_(model_f.parameters(), max_norm=1.0)
        optimizer_f.step()

        # Backward pass for G
        optimizer_g.zero_grad()
        total_loss_g.backward()
        # Clip gradients (optional)
        # torch.nn.utils.clip_grad_norm_(model_g.parameters(), max_norm=1.0)
        optimizer_g.step()

        # --- Update Epoch Metrics --- 
        epoch_loss_s_f += loss_s_f.item()
        epoch_loss_u_f += loss_u_f.item()
        epoch_loss_s_g += loss_s_g.item()
        epoch_loss_u_g += loss_u_g.item()
        total_batches += 1

        # Update progress bar description
        avg_loss = (total_loss_f.item() + total_loss_g.item()) / 2.0
        pbar.set_postfix(loss=f"{avg_loss:.4f}", 
                         u_f=f"{loss_u_f.item():.4f}", u_g=f"{loss_u_g.item():.4f}")

    # --- End of Epoch --- 
    avg_epoch_loss_s_f = epoch_loss_s_f / total_batches
    avg_epoch_loss_u_f = epoch_loss_u_f / total_batches
    avg_epoch_loss_s_g = epoch_loss_s_g / total_batches
    avg_epoch_loss_u_g = epoch_loss_u_g / total_batches
    avg_epoch_conf_f = conf_f_count / total_unlabeled_processed if total_unlabeled_processed > 0 else 0
    avg_epoch_conf_g = conf_g_count / total_unlabeled_processed if total_unlabeled_processed > 0 else 0
    avg_epoch_disagree = disagree_count / total_unlabeled_processed if total_unlabeled_processed > 0 else 0

    # Evaluate on validation set (if available)
    val_loss, val_acc = 0.0, 0.0
    if val_loader:
        # Evaluate model_f (or could use model_g or an ensemble)
        val_loss, val_acc = evaluate(model_f, val_loader, criterion_s, device, desc="Validating")
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
    else:
        history['val_loss'].append(None)
        history['val_acc'].append(None)

    # Optionally evaluate on a subset of training data (labeled part)
    # train_loss_eval, train_acc_eval = evaluate(model_f, labeled_loader, criterion_s, device, desc="Eval Train (Labeled)")
    # history['train_loss'].append(train_loss_eval)
    # history['train_acc'].append(train_acc_eval)
    # For simplicity, we'll just log the average training batch losses
    history['train_loss'].append((avg_epoch_loss_s_f + avg_epoch_loss_s_g)/2)
    history['train_acc'].append(None) # Placeholder if not evaluating train acc separately

    epoch_duration = time.time() - epoch_start_time

    print(f"\nEpoch {epoch+1}/{num_epochs} Summary ({epoch_duration:.2f}s):")
    print(f"  Avg Train Loss: S_f={avg_epoch_loss_s_f:.4f}, U_f={avg_epoch_loss_u_f:.4f} | S_g={avg_epoch_loss_s_g:.4f}, U_g={avg_epoch_loss_u_g:.4f}")
    print(f"  Avg Unlabeled Stats: Conf_F={avg_epoch_conf_f:.3f}, Conf_G={avg_epoch_conf_g:.3f}, Disagree={avg_epoch_disagree:.3f}")
    if val_loader:
        print(f"  Validation Loss: {val_loss:.4f}, Validation Acc: {val_acc:.4f}")
    else:
        print("  Validation: Skipped (no validation data)")

    # --- Checkpointing --- 
    # Save periodic checkpoint
    if (epoch + 1) % save_every_epochs == 0:
        checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pth")
        torch.save({
            'epoch': epoch + 1,
            'model_f_state_dict': model_f.state_dict(),
            'model_g_state_dict': model_g.state_dict(),
            'optimizer_f_state_dict': optimizer_f.state_dict(),
            'optimizer_g_state_dict': optimizer_g.state_dict(),
            # Note: ema_p is NOT saved here to avoid issues, training resumes with default ema_p
            'val_loss': val_loss, # Save last validation loss for info
            'val_acc': val_acc
        }, checkpoint_path)
        print(f"  Checkpoint saved to {checkpoint_path}")

    # Save best model based on validation accuracy
    if val_loader and val_acc > best_val_acc:
        best_val_acc = val_acc
        # Save model F's state dict as the best model
        torch.save({
            'epoch': epoch + 1,
            'model_f_state_dict': model_f.state_dict(),
            'val_acc': best_val_acc
        }, best_model_path)
        print(f"  New best model saved to {best_model_path} (Val Acc: {best_val_acc:.4f})")

    print("-" * 50)

total_training_time = time.time() - start_time
print(f"\nTraining finished. Total duration: {total_training_time/60:.2f} minutes.")

## 9. Final Evaluation

In [None]:
print("\n--- Final Evaluation --- ")

if not val_loader:
    print("Skipping final evaluation (no validation loader).")
elif not os.path.exists(best_model_path):
    print(f"Best model file not found at {best_model_path}. Evaluating last state of model_f instead.")
    # Optionally evaluate the model state at the end of training
    final_loss, final_acc = evaluate(model_f, val_loader, criterion_s, device, desc="Final Eval (Last State)")
    print(f"\nFinal Results (Last Epoch Model F) - Loss: {final_loss:.4f} Acc: {final_acc:.4f}")
else:
    print(f"Loading best model from: {best_model_path}")
    try:
        checkpoint = torch.load(best_model_path, map_location=device)

        # Create a new model instance for evaluation
        eval_model = EfficientNet.from_pretrained(MODEL_ARCH, num_classes=NUM_CLASSES).to(device)

        # Load the state dict
        if 'model_f_state_dict' in checkpoint:
            eval_model.load_state_dict(checkpoint['model_f_state_dict'])
            print(f"Loaded best model state dict (Epoch {checkpoint.get('epoch', 'N/A')}, Val Acc: {checkpoint.get('val_acc', 'N/A'):.4f})")
        else:
            raise KeyError("Checkpoint does not contain 'model_f_state_dict'.")

        # Evaluate the loaded best model
        final_loss, final_acc = evaluate(eval_model, val_loader, criterion_s, device, desc="Final Eval (Best)")
        print(f"\nFinal Results (Best Model) - Loss: {final_loss:.4f} Acc: {final_acc:.4f}")

    except FileNotFoundError:
        print(f"Error: Best model file not found at {best_model_path}.")
    except Exception as e:
        print(f"Error loading or evaluating best model: {e}")
        import traceback
        traceback.print_exc()

## 10. Plotting Results (Optional)

In [None]:
# Simple plotting using matplotlib if history was collected
if history['val_acc'] and any(v is not None for v in history['val_acc']):
    epochs = range(1, num_epochs + 1)

    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, history['train_loss'], 'bo-', label='Training Loss (Avg Batch)')
    plt.plot(epochs, history['val_loss'], 'ro-', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 2, 2)
    # Skip plotting train_acc if it wasn't evaluated
    # plt.plot(epochs, history['train_acc'], 'bo-', label='Training Accuracy') 
    plt.plot(epochs, history['val_acc'], 'ro-', label='Validation Accuracy')
    plt.title('Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    # Save the plot
    plot_path = os.path.join(output_dir, "training_history.png")
    plt.savefig(plot_path)
    print(f"\nTraining history plot saved to {plot_path}")
    plt.show()
else:
    print("\nSkipping plotting (no validation data or history)." )