Cell 1: Imports and Global Configuration (Simplified for Accuracy Evaluation)

In [1]:
import os
import json
import pandas as pd
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import time # Still used by some loading/evaluation functions, though not for timing inference directly
from pathlib import Path
import gc
import torch_pruning as tp # Needed for structured pruning reconstruction
import re
import traceback # Keep for detailed error messages

print("--- Notebook Setup: Imports completed ---")

# --- Configuration ---
ROOT_DIR = "saved_models_and_logs" # Parent directory of your experiments
OUTPUT_CSV_ACCURACY = "model_accuracy_summary.csv" # New CSV for accuracy results
DEFAULT_NUM_CLASSES = 1000
FIXED_NUM_CLASSES = 1000 # For model reconstruction consistency in structured pruning

# --- Uniform Evaluation Configuration ---
VALIDATION_DATA_PATH = "imagenet-mini/val" # MAKE SURE THIS PATH IS CORRECT
BATCH_SIZE_EVAL = 32
NUM_WORKERS_EVAL = 0 # Set to 0 for Windows or if issues, >0 for Linux if beneficial
MAX_EVAL_BATCHES = 125 # Max batches for accuracy evaluation (set to float('inf') for all)

# --- Device and Input Tensors ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# INPUT_TENSOR_CPU is needed for torch_pruning's example_inputs during reconstruction
INPUT_TENSOR_CPU = torch.randn(1, 3, 224, 224)

# List of models known to be unstable on GPU for evaluation (e.g. certain JIT quantized models)
# Accuracy for these will be evaluated on CPU.
GPU_UNSTABLE_QUANTIZED_MODELS = [
    "resnet18pretrained_distilled_quant_ptq_int8_perchannel_post",
    "resnet18pretrained_distilled_quant_ptq_int8_pertensor_post",
    "resnet18pretrained_distilled_quant_qat_int8_epochs8",
    "resnet50_quant_ptq_int8_perchannel_post",
    "resnet50_quant_ptq_int8_pertensor_post",
    "resnet50_quant_qat_int8_epochs8",
]

# --- DataFrame to store results ---
accuracy_results_df = pd.DataFrame()
current_eval_experiment_id_nb = "" # For logging within evaluate_model_accuracies

--- Notebook Setup: Imports completed ---
Using device: cuda


Cell 2: Core Helper Functions (Path, Pruning Reconstruction Logic)

In [2]:
# --- Helper: Image Transforms ---
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
eval_transforms = transforms.Compose([
    transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize,
])

# --- Helper: Model File ---
def get_model_file_path_nb(experiment_path_str):
    experiment_path = Path(experiment_path_str)
    specific_model_file = experiment_path / "model_final.pth"
    if specific_model_file.exists():
        return str(specific_model_file)
    pth_files = list(experiment_path.glob("*.pth"))
    if pth_files:
        for common_name in ["model_quantized.pth"]:
            for p_file in pth_files:
                if p_file.name == common_name: return str(p_file)
        for p_file in pth_files:
            if "baseline_ft_imagenetmini_final.pth" in p_file.name: return str(p_file)
        return str(pth_files[0])
    return None

# --- Model Definition and Pruning Application (FROM SCRIPT 1, needed for structured pruning reconstruction) ---
def get_base_resnet50_model_for_reconstruction_nb():
    model = models.resnet50(weights=None, num_classes=FIXED_NUM_CLASSES)
    return model

def apply_structured_pruning_to_model_for_reconstruction_nb(
    model_to_prune, example_inputs, target_pruning_rate_per_layer, device_obj
):
    model_to_prune.to(device_obj)
    example_inputs = example_inputs.to(device_obj)
    ignored_layers = []
    for name, m in model_to_prune.named_modules():
        if isinstance(m, nn.Linear) and m.out_features == FIXED_NUM_CLASSES:
            ignored_layers.append(m)
    try:
        importance = tp.importance.MagnitudeImportance(p=1) # L1 norm
        pruner = tp.pruner.MagnitudePruner(
            model=model_to_prune, example_inputs=example_inputs, importance=importance,
            iterative_steps=1, pruning_ratio=target_pruning_rate_per_layer,
            global_pruning=False, ignored_layers=ignored_layers,
        )
        pruner.step()
    except Exception as e_prune:
        print(f"      ERROR during tp.pruner.MagnitudePruner step (rate {target_pruning_rate_per_layer}): {e_prune}")
        return None # Indicate failure
    return model_to_prune

def get_pruning_config_from_log_for_reconstruction_nb(log_file_path_str):
    log_file_path = Path(log_file_path_str)
    if not log_file_path.exists(): return None
    try:
        with open(log_file_path, 'r') as f: log_data = json.load(f)
        if 'config_details' in log_data and 'target_filter_pruning_rate_per_layer' in log_data['config_details']:
            rate = log_data['config_details']['target_filter_pruning_rate_per_layer']
            if rate is not None: return {'type': 'one-shot', 'rate': float(rate)}
        if 'config_details' in log_data and 'applied_step_rate_for_this_stage' in log_data['config_details']:
            rate = log_data['config_details']['applied_step_rate_for_this_stage']
            if rate is not None: return {'type': 'iterative_step', 'rate': float(rate)}
        if 'config_details' in log_data and 'target_overall_sparsity_approx_for_this_stage' in log_data['config_details']:
            rate = log_data['config_details']['target_overall_sparsity_approx_for_this_stage']
            if rate is not None: return {'type': 'iterative_step', 'rate': float(rate)}
    except Exception as e:
        print(f"    Error processing log {log_file_path} for pruning config: {e}")
    return None

def _reconstruct_model_arch_and_load_weights_nb(model_path_str, device_obj, pruning_config, exp_id_for_log=""):
    if not pruning_config: return None
    reconstructed_model = get_base_resnet50_model_for_reconstruction_nb()
    reconstructed_model.to(device_obj)
    # INPUT_TENSOR_CPU is defined globally
    example_inputs_local = INPUT_TENSOR_CPU.to(device_obj)
    try:
        if pruning_config['type'] == 'one-shot':
            rate = pruning_config['rate']
            reconstructed_model = apply_structured_pruning_to_model_for_reconstruction_nb(
                reconstructed_model, example_inputs_local, rate, device_obj)
        elif pruning_config['type'] == 'iterative':
            step_rates = pruning_config.get('step_rates', [])
            if not step_rates: return None
            current_arch_model = reconstructed_model
            for i, step_rate in enumerate(step_rates):
                current_arch_model = apply_structured_pruning_to_model_for_reconstruction_nb(
                    current_arch_model, example_inputs_local, step_rate, device_obj)
                if current_arch_model is None: return None
            reconstructed_model = current_arch_model
        else: return None
        if reconstructed_model is None: return None

        state_dict = torch.load(model_path_str, map_location=device_obj, weights_only=True)
        if all(key.startswith('module.') for key in state_dict.keys()):
            state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
        if 'model' in state_dict and isinstance(state_dict['model'], dict): state_dict = state_dict['model']
        elif 'state_dict' in state_dict and isinstance(state_dict['state_dict'], dict): state_dict = state_dict['state_dict']
        reconstructed_model.load_state_dict(state_dict)
        reconstructed_model.eval()
        return reconstructed_model
    except Exception as e:
        print(f"    ERROR in _reconstruct_model_arch_and_load_weights_nb for {model_path_str} ({exp_id_for_log}): {e}")
        return None

print("--- Helper functions for model path and reconstruction defined ---")

--- Helper functions for model path and reconstruction defined ---


Cell 3: Central Model Loading Function

In [3]:
def load_model_for_experiment_nb(exp_info, all_experiments_df, target_device_str='cpu'):
    model_path = exp_info.get('Model_File_Path')
    base_arch = exp_info.get('Base_Model_Arch')
    num_classes = exp_info.get('Num_Classes', DEFAULT_NUM_CLASSES)
    is_structured = exp_info.get('Is_Structured_Pruning', False)
    exp_id = exp_info.get('Experiment_ID', 'Unknown_Exp')

    if not model_path or not os.path.exists(model_path) or os.path.getsize(model_path) == 0:
        print(f"      ERROR ({exp_id}): Model file invalid: {model_path}")
        return None

    device_to_load_on = torch.device(target_device_str)
    loaded_model = None
    
    try:
        loaded_model = torch.jit.load(model_path, map_location=device_to_load_on)
        loaded_model.eval()
        return loaded_model
    except Exception:
        pass

    if is_structured:
        pruning_config_for_reconstruction = None
        base_exp_name_iter = exp_info.get('Base_Exp_Name_Iterative')
        stage_num_iter = exp_info.get('Stage_Num_Iterative')

        if base_exp_name_iter and stage_num_iter is not None:
            cumulative_step_rates = []
            relevant_stages_info = all_experiments_df[
                (all_experiments_df['Base_Exp_Name_Iterative'] == base_exp_name_iter) &
                (all_experiments_df['Stage_Num_Iterative'] <= stage_num_iter) &
                (all_experiments_df['Stage_Num_Iterative'].notna())
            ].sort_values(by='Stage_Num_Iterative')
            for _, stage_row in relevant_stages_info.iterrows():
                stage_log_path = stage_row.get('Log_Path')
                stage_log_pruning_info = get_pruning_config_from_log_for_reconstruction_nb(stage_log_path)
                if stage_log_pruning_info and stage_log_pruning_info.get('type') == 'iterative_step':
                    cumulative_step_rates.append(stage_log_pruning_info['rate'])
                else:
                    cumulative_step_rates = [] 
                    break 
            if cumulative_step_rates:
                pruning_config_for_reconstruction = {'type': 'iterative', 'step_rates': cumulative_step_rates}
        else: 
            log_path_current_exp = exp_info.get('Log_Path')
            one_shot_pruning_info = get_pruning_config_from_log_for_reconstruction_nb(log_path_current_exp)
            if one_shot_pruning_info and one_shot_pruning_info.get('type') == 'one-shot':
                pruning_config_for_reconstruction = one_shot_pruning_info
            elif one_shot_pruning_info and one_shot_pruning_info.get('type') == 'iterative_step':
                 pruning_config_for_reconstruction = {'type': 'one-shot', 'rate': one_shot_pruning_info['rate']}

        if pruning_config_for_reconstruction:
            reconstructed = _reconstruct_model_arch_and_load_weights_nb(
                model_path, device_to_load_on, pruning_config_for_reconstruction, exp_id)
            if reconstructed:
                reconstructed.eval()
                return reconstructed
            else:
                 print(f"      WARNING ({exp_id}): Failed structured reconstruction. Fallback attempt...")
        else:
             print(f"      WARNING ({exp_id}): No valid pruning_config for structured. Fallback attempt...")

    try:
        _raw_loaded_content = torch.load(model_path, map_location=device_to_load_on, weights_only=False) # False needed if model saved directly
        if isinstance(_raw_loaded_content, torch.nn.Module):
            loaded_model = _raw_loaded_content
        elif isinstance(_raw_loaded_content, dict):
            if base_arch == "ResNet18": model_instance = models.resnet18(weights=None, num_classes=num_classes)
            elif base_arch == "ResNet50": model_instance = models.resnet50(weights=None, num_classes=num_classes)
            else:
                print(f"      ERROR ({exp_id}): Unknown base_arch '{base_arch}' for fallback state_dict load.")
                return None
            state_dict = _raw_loaded_content
            if any(k.startswith('module.') for k in state_dict.keys()):
                state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
            if 'model' in state_dict and isinstance(state_dict['model'], dict): state_dict = state_dict['model']
            elif 'state_dict' in state_dict and isinstance(state_dict['state_dict'], dict): state_dict = state_dict['state_dict']
            model_instance.load_state_dict(state_dict)
            loaded_model = model_instance
        else: return None
        
        if loaded_model:
            loaded_model.eval()
            return loaded_model.to(device_to_load_on)
    except Exception as e_load:
        print(f"      ERROR ({exp_id}): During fallback model loading: {str(e_load).splitlines()[0]}")
        return None
        
    print(f"      ERROR ({exp_id}): Model could not be loaded by any method.")
    return None

print("--- Central model loader defined ---")

--- Central model loader defined ---


Cell 4: Experiment Discovery and DataFrame Initialization

In [4]:
def discover_experiments_nb():
    print(f"--- Discovering experiments in: {ROOT_DIR} ---")
    discovered_experiments = []
    if not os.path.exists(ROOT_DIR):
        print(f"ERROR: ROOT_DIR '{ROOT_DIR}' does not exist!")
        return pd.DataFrame()

    # Process baselines first
    for cat_name_outer in os.listdir(ROOT_DIR):
        cat_path_outer = os.path.join(ROOT_DIR, cat_name_outer)
        if os.path.isdir(cat_path_outer) and ("baseline" in cat_name_outer.lower()):
            exp_name = cat_name_outer
            exp_path = cat_path_outer
            base_arch = "ResNet18" if "resnet18" in exp_name.lower() else "ResNet50" if "resnet50" in exp_name.lower() else "Unknown"
            model_file = get_model_file_path_nb(exp_path)
            log_path = os.path.join(exp_path, "log.json")
            num_classes, config_details = DEFAULT_NUM_CLASSES, {}
            if os.path.exists(log_path):
                try:
                    with open(log_path, 'r') as f: log_data_temp = json.load(f)
                    config_details = log_data_temp.get('config_details', {})
                    num_classes = config_details.get('num_classes', DEFAULT_NUM_CLASSES)
                except Exception as e: print(f"  Warn: Log parsing error for baseline {exp_name}: {e}")

            exp_data = {
                "Experiment_ID": exp_name, "Experiment_Path": exp_path, "Log_Path": log_path,
                "Model_File_Path": model_file, "Base_Model_Arch": base_arch,
                "Is_Structured_Pruning": False, "Base_Exp_Name_Iterative": None, "Stage_Num_Iterative": None,
                "Num_Classes": num_classes,
                "Config_Details_From_Log": config_details # Keep for num_classes, other potential debug
                # No need for Training_Summary_From_Log or Original_Eval_Metrics_From_Log for accuracy only
            }
            discovered_experiments.append(exp_data)

    # Process other experiments
    for cat_name in os.listdir(ROOT_DIR):
        cat_path = os.path.join(ROOT_DIR, cat_name)
        if not os.path.isdir(cat_path) or "baseline" in cat_name.lower(): continue
        is_cat_structured = "pruning_structured" in cat_name.lower()

        for exp_name in os.listdir(cat_path):
            exp_path_str = os.path.join(cat_path, exp_name)
            if not os.path.isdir(exp_path_str): continue

            base_arch = "ResNet18" if "resnet18" in exp_name.lower() else "ResNet50"
            if cat_name == "combined_distilled_quantized" and "resnet18" in exp_name.lower(): base_arch = "ResNet18"

            model_file = get_model_file_path_nb(exp_path_str)
            log_path_str_current = os.path.join(exp_path_str, "log.json")
            
            current_exp_is_structured = is_cat_structured or \
                                       "prune_struct_it" in exp_name.lower() or \
                                       "prune_struct_os" in exp_name.lower() or \
                                       "structured_l1_filter" in exp_name.lower()
            
            base_exp_name_iterative, stage_num_iterative = None, None
            if current_exp_is_structured and ("iterative" in cat_name.lower() or "it" in exp_name.lower() or "_stage" in exp_name.lower()):
                match = re.search(r"(.+?)(?:_|-)(?:stage|s)(\d+)", exp_name.lower())
                if match:
                    base_exp_name_iterative = match.group(1)
                    stage_num_iterative = int(match.group(2))

            num_classes, config_details = DEFAULT_NUM_CLASSES, {}
            if os.path.exists(log_path_str_current):
                try:
                    with open(log_path_str_current, 'r') as f: log_data_temp = json.load(f)
                    config_details = log_data_temp.get('config_details', {})
                    num_classes = config_details.get('num_classes', DEFAULT_NUM_CLASSES)
                    if 'student_config' in config_details and isinstance(config_details['student_config'], dict):
                        num_classes = config_details['student_config'].get('num_classes', num_classes)
                except Exception as e: print(f"  Warn: Log parsing error for {exp_name}: {e}")

            exp_data = {
                "Experiment_ID": exp_name, "Experiment_Path": exp_path_str, "Log_Path": log_path_str_current,
                "Model_File_Path": model_file, "Base_Model_Arch": base_arch,
                "Is_Structured_Pruning": current_exp_is_structured,
                "Base_Exp_Name_Iterative": base_exp_name_iterative, "Stage_Num_Iterative": stage_num_iterative,      
                "Num_Classes": num_classes,
                "Config_Details_From_Log": config_details # Keep for num_classes and reconstruction logic
            }
            discovered_experiments.append(exp_data)

    df = pd.DataFrame(discovered_experiments)
    if not df.empty:
        if 'Stage_Num_Iterative' in df.columns:
             df['Stage_Num_Iterative'] = pd.to_numeric(df['Stage_Num_Iterative'], errors='coerce')
        df = df.set_index("Experiment_ID", drop=False)
    print(f"--- Discovery finished. Found {len(df)} experiments. ---")
    return df

accuracy_results_df = discover_experiments_nb()
if not accuracy_results_df.empty:
    print("\n--- Discovered Experiments Sample (columns relevant for loading/accuracy): ---")
    display_cols = ['Experiment_ID', 'Base_Model_Arch', 'Is_Structured_Pruning', 'Model_File_Path', 'Num_Classes']
    # Add iterative info if present
    if 'Base_Exp_Name_Iterative' in accuracy_results_df.columns: display_cols.append('Base_Exp_Name_Iterative')
    if 'Stage_Num_Iterative' in accuracy_results_df.columns: display_cols.append('Stage_Num_Iterative')
    display(accuracy_results_df[display_cols].head())
else:
    print("No experiments discovered. Check ROOT_DIR and folder structure.")

--- Discovering experiments in: saved_models_and_logs ---
--- Discovery finished. Found 25 experiments. ---

--- Discovered Experiments Sample (columns relevant for loading/accuracy): ---


Unnamed: 0_level_0,Experiment_ID,Base_Model_Arch,Is_Structured_Pruning,Model_File_Path,Num_Classes,Base_Exp_Name_Iterative,Stage_Num_Iterative
Experiment_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
resnet18_baseline,resnet18_baseline,ResNet18,False,saved_models_and_logs\resnet18_baseline\resnet...,1000,,
resnet50_baseline,resnet50_baseline,ResNet50,False,saved_models_and_logs\resnet50_baseline\resnet...,1000,,
resnet18pretrained_distilled_quant_kmeans_256clusters_post,resnet18pretrained_distilled_quant_kmeans_256c...,ResNet18,False,saved_models_and_logs\combined_distilled_quant...,1000,,
resnet18pretrained_distilled_quant_ptq_int8_perchannel_post,resnet18pretrained_distilled_quant_ptq_int8_pe...,ResNet18,False,saved_models_and_logs\combined_distilled_quant...,1000,,
resnet18pretrained_distilled_quant_ptq_int8_pertensor_post,resnet18pretrained_distilled_quant_ptq_int8_pe...,ResNet18,False,saved_models_and_logs\combined_distilled_quant...,1000,,


Cell 5: Calculate Top-1 and Top-5 Accuracy

In [5]:
@torch.no_grad()
def evaluate_model_accuracies_nb(model, device_str_eval, num_classes_eval, max_batches_to_eval, exp_id_for_log):
    global current_eval_experiment_id_nb
    current_eval_experiment_id_nb = exp_id_for_log

    if not os.path.exists(VALIDATION_DATA_PATH):
        print(f"      ERROR ({current_eval_experiment_id_nb}): Val data path not found: {VALIDATION_DATA_PATH}")
        return "N/A (Val Data Missing)", "N/A (Val Data Missing)"
    try:
        val_dataset = ImageFolder(VALIDATION_DATA_PATH, eval_transforms)
        if len(val_dataset.classes) != num_classes_eval and num_classes_eval != FIXED_NUM_CLASSES:
            # This warning is a bit noisy if model is 1000-class (ImageNet default) and dataset is also (ImageNet-mini has 1000 classes)
            # More relevant if num_classes_eval is e.g. 100 for a fine-tuned model and val_dataset is full ImageNet.
            # For now, assume FIXED_NUM_CLASSES (1000) is the standard for this dataset if model also has 1000 classes.
            pass # print(f"      INFO ({current_eval_experiment_id_nb}): Dataset classes ({len(val_dataset.classes)}) vs Model classes ({num_classes_eval}). Check consistency.")
        if len(val_dataset) == 0:
            print(f"      WARNING ({current_eval_experiment_id_nb}): Validation dataset at '{VALIDATION_DATA_PATH}' is empty.")
            return 0.0, 0.0
        
        current_num_workers = NUM_WORKERS_EVAL if DEVICE.type == 'cuda' else 0
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE_EVAL, shuffle=False,
                                num_workers=current_num_workers, pin_memory=(True if device_str_eval=='cuda' else False))
    except Exception as e:
        print(f"      ERROR ({current_eval_experiment_id_nb}): Could not load validation data: {e}")
        return f"N/A (Val Data Load Error)", f"N/A (Val Data Load Error)"

    device_obj_eval = torch.device(device_str_eval)
    model.to(device_obj_eval)
    model.eval()
    correct_top1 = 0
    correct_top5 = 0
    total = 0
    batches_processed = 0
    
    # print(f"      INFO ({current_eval_experiment_id_nb}): Evaluating on {device_str_eval} for max {max_batches_to_eval} batches.")

    for images, labels in val_loader:
        try:
            images, labels = images.to(device_obj_eval), labels.to(device_obj_eval)
            outputs = model(images)
            
            # Determine k for top-k, ensuring k is not greater than the number of actual classes in the output
            num_model_output_classes = outputs.size(1)
            top_k_val = min(5, num_model_output_classes) # Max k is 5, or fewer if model has fewer classes
            if top_k_val < 1: top_k_val = 1 # Should not happen with valid models

            _, pred = outputs.topk(top_k_val, 1, True, True)
            pred = pred.t()
            correct = pred.eq(labels.view(1, -1).expand_as(pred))

            correct_top1 += correct[:1].reshape(-1).float().sum(0, keepdim=True).item()
            if top_k_val >= 5: # Only sum up to 5 if available
                 correct_top5 += correct[:5].reshape(-1).float().sum(0, keepdim=True).item()
            elif top_k_val > 0 : # if top_k_val is < 5, top_5 is effectively top_k_val
                 correct_top5 += correct[:top_k_val].reshape(-1).float().sum(0, keepdim=True).item()
            
            total += labels.size(0)
            batches_processed += 1
            if batches_processed >= max_batches_to_eval:
                break
        except Exception as e_batch:
            print(f"      ERROR ({current_eval_experiment_id_nb}) during batch {batches_processed} eval: {e_batch}")
            return "N/A (Batch Eval Error)", "N/A (Batch Eval Error)"

    accuracy_top1 = (correct_top1 / total) * 100.0 if total > 0 else 0.0
    accuracy_top5 = (correct_top5 / total) * 100.0 if total > 0 else 0.0
    # print(f"      INFO ({current_eval_experiment_id_nb}): Top-1 Acc = {accuracy_top1:.2f}%, Top-5 Acc = {accuracy_top5:.2f}%")
    return accuracy_top1, accuracy_top5

def calculate_and_store_accuracies(df_to_update):
    if df_to_update.empty:
        print("Experiment DataFrame is empty. Run discovery cell first.")
        return
    if not os.path.exists(VALIDATION_DATA_PATH):
        print(f"ERROR: Validation data path '{VALIDATION_DATA_PATH}' not found. Cannot calculate accuracy.")
        df_to_update['Top1_Accuracy_Percent'] = "N/A (Val Data Missing)"
        df_to_update['Top5_Accuracy_Percent'] = "N/A (Val Data Missing)"
        return

    print("\n--- Calculating Model Accuracies (Top-1 and Top-5) ---")
    top1_accuracies = {}
    top5_accuracies = {}

    for exp_id, row in df_to_update.iterrows():
        print(f"  Processing Accuracy for: {exp_id}")
        
        is_gpu_unstable = exp_id in GPU_UNSTABLE_QUANTIZED_MODELS
        eval_device_str = 'cpu' if is_gpu_unstable else DEVICE.type
        if is_gpu_unstable and DEVICE.type == 'cuda': 
            print(f"      INFO ({exp_id}): Known GPU unstable. Forcing CPU evaluation.")

        model_obj = load_model_for_experiment_nb(row, df_to_update, target_device_str=eval_device_str)
        
        current_top1_acc = "N/A (Load Error)"
        current_top5_acc = "N/A (Load Error)"
        if model_obj:
            num_classes_for_eval = row.get('Num_Classes', DEFAULT_NUM_CLASSES)
            current_top1_acc, current_top5_acc = evaluate_model_accuracies_nb(
                model_obj, eval_device_str, num_classes_for_eval, MAX_EVAL_BATCHES, exp_id
            )
            del model_obj
            if DEVICE.type == 'cuda': torch.cuda.empty_cache()
            gc.collect()
        
        top1_accuracies[exp_id] = current_top1_acc
        top5_accuracies[exp_id] = current_top5_acc

    df_to_update['Top1_Accuracy_Percent'] = pd.Series(top1_accuracies)
    df_to_update['Top5_Accuracy_Percent'] = pd.Series(top5_accuracies)
    print("--- Accuracies calculated and stored. ---")
    display(df_to_update[['Experiment_ID', 'Top1_Accuracy_Percent', 'Top5_Accuracy_Percent']].head())

if not accuracy_results_df.empty:
    calculate_and_store_accuracies(accuracy_results_df)
else:
    print("Skipping accuracy calculation as no experiments were discovered.")


--- Calculating Model Accuracies (Top-1 and Top-5) ---
  Processing Accuracy for: resnet18_baseline
  Processing Accuracy for: resnet50_baseline
  Processing Accuracy for: resnet18pretrained_distilled_quant_kmeans_256clusters_post
  Processing Accuracy for: resnet18pretrained_distilled_quant_ptq_int8_perchannel_post
      INFO (resnet18pretrained_distilled_quant_ptq_int8_perchannel_post): Known GPU unstable. Forcing CPU evaluation.
  Processing Accuracy for: resnet18pretrained_distilled_quant_ptq_int8_pertensor_post
      INFO (resnet18pretrained_distilled_quant_ptq_int8_pertensor_post): Known GPU unstable. Forcing CPU evaluation.
  Processing Accuracy for: resnet18pretrained_distilled_quant_qat_int8_epochs8
      INFO (resnet18pretrained_distilled_quant_qat_int8_epochs8): Known GPU unstable. Forcing CPU evaluation.
  Processing Accuracy for: resnet50_to_resnet18pretrained_kd
  Processing Accuracy for: resnet50_to_resnet18scratch_kd
  Processing Accuracy for: resnet50_prune_nm24_ft
  

Unnamed: 0_level_0,Experiment_ID,Top1_Accuracy_Percent,Top5_Accuracy_Percent
Experiment_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
resnet18_baseline,resnet18_baseline,50.089217,76.319144
resnet50_baseline,resnet50_baseline,64.950293,87.484068
resnet18pretrained_distilled_quant_kmeans_256clusters_post,resnet18pretrained_distilled_quant_kmeans_256c...,53.683406,80.321183
resnet18pretrained_distilled_quant_ptq_int8_perchannel_post,resnet18pretrained_distilled_quant_ptq_int8_pe...,50.879429,78.078002
resnet18pretrained_distilled_quant_ptq_int8_pertensor_post,resnet18pretrained_distilled_quant_ptq_int8_pe...,53.301045,80.270201


Cell 6: Final Review and Save to CSV

In [6]:
if not accuracy_results_df.empty:
    print("\n--- Final Accuracy DataFrame Review (First 5 rows) ---")
    
    # Define columns for the final CSV output
    output_columns = ['Experiment_ID', 'Top1_Accuracy_Percent', 'Top5_Accuracy_Percent']
    
    # Ensure these columns exist, coercing to numeric for consistent display/saving
    for col in ['Top1_Accuracy_Percent', 'Top5_Accuracy_Percent']:
        if col in accuracy_results_df.columns:
            accuracy_results_df[col] = pd.to_numeric(accuracy_results_df[col], errors='coerce')
        else:
            accuracy_results_df[col] = pd.NA # Add column if missing

    # Set float_format for display and CSV saving
    pd.options.display.float_format = '{:.4f}'.format

    # Display the relevant columns
    display(accuracy_results_df[output_columns].head())
    
    # Save to CSV
    try:
        # Select only the desired columns for saving
        accuracy_results_df_to_save = accuracy_results_df[output_columns]
        accuracy_results_df_to_save.to_csv(OUTPUT_CSV_ACCURACY, index=False, lineterminator='\n', float_format='%.5f')
        print(f"\n--- Accuracy summary saved to {OUTPUT_CSV_ACCURACY} ---")
        print(f"Total experiments processed: {len(accuracy_results_df_to_save)}")
    except Exception as e_csv:
        print(f"Error saving CSV: {e_csv}")
else:
    print("Accuracy DataFrame is empty. Nothing to save.")

print("\n--- Notebook processing for accuracy finished ---")


--- Final Accuracy DataFrame Review (First 5 rows) ---


Unnamed: 0_level_0,Experiment_ID,Top1_Accuracy_Percent,Top5_Accuracy_Percent
Experiment_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
resnet18_baseline,resnet18_baseline,50.0892,76.3191
resnet50_baseline,resnet50_baseline,64.9503,87.4841
resnet18pretrained_distilled_quant_kmeans_256clusters_post,resnet18pretrained_distilled_quant_kmeans_256c...,53.6834,80.3212
resnet18pretrained_distilled_quant_ptq_int8_perchannel_post,resnet18pretrained_distilled_quant_ptq_int8_pe...,50.8794,78.078
resnet18pretrained_distilled_quant_ptq_int8_pertensor_post,resnet18pretrained_distilled_quant_ptq_int8_pe...,53.301,80.2702



--- Accuracy summary saved to model_accuracy_summary.csv ---
Total experiments processed: 25

--- Notebook processing for accuracy finished ---
