In [None]:
import os
import glob
import torch
from torch.utils.data import DataLoader

# ===== User must define or import their model class and validation dataset =====
# from your_model_file import ModelClass, val_dataset

# Example:
# model = ModelClass()
# val_dataset = YourValidationDataset(...)

# Create DataLoader for validation data
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Evaluation function: compute classification accuracy
def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            # assumes classification: take the highest logit
            _, preds = torch.max(outputs, dim=1)
            correct += (preds == targets).sum().item()
            total += targets.size(0)
    return correct / total if total > 0 else 0

# Directory containing checkpoint files
checkpoint_dir = '/mnt/data/checkpoints'
# Automatically find all .pth files in the folder
checkpoint_paths = sorted(glob.glob(os.path.join(checkpoint_dir, '*.pth')))

if not checkpoint_paths:
    raise FileNotFoundError(f"No .pth files found in {checkpoint_dir}")

# Main evaluation loop
if __name__ == '__main__':
    # Initialize model and move to device
    model = ModelClass().to(device)

    best_acc = 0.0
    best_ckpt = None

    for ckpt_path in checkpoint_paths:
        # Load checkpoint
        checkpoint = torch.load(ckpt_path, map_location=device)
        # If checkpoint contains a state_dict key, adjust accordingly:
        state_dict = checkpoint.get('state_dict', checkpoint)
        model.load_state_dict(state_dict)

        # Evaluate accuracy
        acc = evaluate(model, val_loader, device)
        print(f"Checkpoint {os.path.basename(ckpt_path)}: Accuracy = {acc:.4f}")

        # Track best
        if acc > best_acc:
            best_acc = acc
            best_ckpt = ckpt_path

    print(f"\nBest checkpoint: {os.path.basename(best_ckpt)} with Accuracy = {best_acc:.4f}")
