In [9]:
import torch
import os

checkpoint_path = '../saved_models_and_logs/pruning_unstructured_iterative/resnet50_pruned_90_iterative_l1_ft.pth'
if os.path.exists(checkpoint_path):
	checkpoint = torch.load(checkpoint_path, map_location='cpu')
	print(checkpoint.keys())
	# This will show you the top-level keys in the saved dictionary,
	# e.g., dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss'])
else:
	print(f"File not found: {checkpoint_path}")

odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layer1.0.conv1.weight', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn1.num_batches_tracked', 'layer1.0.conv2.weight', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn2.num_batches_tracked', 'layer1.0.conv3.weight', 'layer1.0.bn3.weight', 'layer1.0.bn3.bias', 'layer1.0.bn3.running_mean', 'layer1.0.bn3.running_var', 'layer1.0.bn3.num_batches_tracked', 'layer1.0.downsample.0.weight', 'layer1.0.downsample.1.weight', 'layer1.0.downsample.1.bias', 'layer1.0.downsample.1.running_mean', 'layer1.0.downsample.1.running_var', 'layer1.0.downsample.1.num_batches_tracked', 'layer1.1.conv1.weight', 'layer1.1.bn1.weight', 'layer1.1.bn1.bias', 'layer1.1.bn1.running_mean', 'layer1.1.bn1.running_var', 'layer1.1.bn1.num_batches_tracked', 'layer1.1.conv2.we

  checkpoint = torch.load(checkpoint_path, map_location='cpu')


In [11]:
import torch
import os
import glob

MODEL_STATE_DICT_KEY = 'model_state_dict' # Based on your dict_keys
TARGET_MODEL_FILENAME = "model_final.pth" # The desired final lean model name

def process_single_checkpoint_dir(dir_path):
    """
    Loads a checkpoint from dir_path, extracts the model_state_dict,
    saves it as TARGET_MODEL_FILENAME, and optionally deletes the original.
    """
    if not os.path.isdir(dir_path):
        print(f"Error: Directory not found: {dir_path}")
        return False

    print(f"--- Processing directory: {dir_path} ---")

    # Find existing model/checkpoint files (.pth or .pt)
    potential_checkpoints = glob.glob(os.path.join(dir_path, "*.pth")) + \
                            glob.glob(os.path.join(dir_path, "*.pt"))

    original_checkpoint_path = None

    if not potential_checkpoints:
        print(f"  Warning: No .pth or .pt file found in {dir_path}. Skipping.")
        return False

    # Logic to select the checkpoint to process:
    # If model_final.pth exists and is a checkpoint, process it.
    # Otherwise, if only one other .pth/.pt file, process it.
    # If multiple, ask user or pick one (for now, let's try to be smart or ask)
    target_path_if_exists = os.path.join(dir_path, TARGET_MODEL_FILENAME)

    if len(potential_checkpoints) == 1:
        original_checkpoint_path = potential_checkpoints[0]
        print(f"  Found one model file: {os.path.basename(original_checkpoint_path)}")
    elif target_path_if_exists in potential_checkpoints:
        original_checkpoint_path = target_path_if_exists
        print(f"  Found existing '{TARGET_MODEL_FILENAME}', will attempt to process it.")
    else:
        # Filter out the target filename if it exists as a non-checkpoint (e.g., already processed)
        candidates = [p for p in potential_checkpoints if os.path.basename(p) != TARGET_MODEL_FILENAME]
        if len(candidates) == 1:
            original_checkpoint_path = candidates[0]
            print(f"  Found one candidate checkpoint: {os.path.basename(original_checkpoint_path)}")
        elif len(candidates) > 1:
            print(f"  Multiple potential checkpoint files found (excluding '{TARGET_MODEL_FILENAME}'):")
            for i, p_path in enumerate(candidates):
                print(f"    {i}: {os.path.basename(p_path)}")
            while True:
                try:
                    choice = int(input(f"  Please enter the number of the checkpoint to process: "))
                    if 0 <= choice < len(candidates):
                        original_checkpoint_path = candidates[choice]
                        break
                    else:
                        print("  Invalid choice.")
                except ValueError:
                    print("  Invalid input. Please enter a number.")
        elif not candidates and target_path_if_exists in potential_checkpoints:
             # This case means only model_final.pth exists, and it was found in potential_checkpoints
             original_checkpoint_path = target_path_if_exists
             print(f"  Only '{TARGET_MODEL_FILENAME}' found. Will attempt to process it (in case it's a full checkpoint).")
        else: # No clear candidate
            print(f"  Warning: Could not automatically determine a unique checkpoint file to process among: {[os.path.basename(p) for p in potential_checkpoints]}.")
            print(f"  Skipping this directory. Please ensure there's a clear checkpoint file or clean up existing files.")
            return False


    if not original_checkpoint_path:
        print(f"  No checkpoint selected or found for processing. Skipping.")
        return False

    print(f"  Selected checkpoint for processing: {os.path.basename(original_checkpoint_path)}")

    try:
        checkpoint = torch.load(original_checkpoint_path, map_location='cpu')

        if not isinstance(checkpoint, dict) or MODEL_STATE_DICT_KEY not in checkpoint:
            # Check if it's already the target file and just a state_dict (already processed)
            if os.path.basename(original_checkpoint_path) == TARGET_MODEL_FILENAME and \
               not isinstance(checkpoint, dict) and \
               isinstance(checkpoint, (dict, torch.nn.Module)): # state_dict is a dict, or it could be a whole model
                print(f"  '{TARGET_MODEL_FILENAME}' seems to be already a lean model/state_dict. No action needed for this file.")
                return True # Indicate success as it's already in desired state
            else:
                print(f"  Error: Loaded file '{os.path.basename(original_checkpoint_path)}' is not a recognized checkpoint dictionary (missing '{MODEL_STATE_DICT_KEY}' or not a dict).")
                print(f"  Type of loaded data: {type(checkpoint)}")
                return False

        model_state_dict = checkpoint[MODEL_STATE_DICT_KEY]
        new_model_path = os.path.join(dir_path, TARGET_MODEL_FILENAME)

        print(f"  Saving extracted model_state_dict to: {new_model_path}")
        torch.save(model_state_dict, new_model_path)
        print(f"  Successfully saved lean model to {new_model_path}")

        # --- Decision to delete the original large checkpoint ---
        if os.path.abspath(original_checkpoint_path) != os.path.abspath(new_model_path):
            user_choice = input(f"  Delete original large checkpoint '{os.path.basename(original_checkpoint_path)}'? (yes/no): ").strip().lower()
            if user_choice == 'yes':
                os.remove(original_checkpoint_path)
                print(f"  Deleted original large checkpoint: {os.path.basename(original_checkpoint_path)}")
            else:
                print(f"  Original checkpoint '{os.path.basename(original_checkpoint_path)}' was NOT deleted.")
        elif os.path.abspath(original_checkpoint_path) == os.path.abspath(new_model_path) and isinstance(checkpoint, dict):
             print(f"  Overwrote '{TARGET_MODEL_FILENAME}' with its lean version.")
        return True

    except Exception as e:
        print(f"  Error processing {original_checkpoint_path}: {e}")
        return False

# --- How to use ---
if __name__ == "__main__":
    # **IMPORTANT**: Replace this with the actual path to the directory you want to process
    # Example paths (use only one at a time by uncommenting):

    # path_to_process = "saved_models_and_logs/pruning_unstructured_iterative/resnet50_prune_unstruct_it_l1_stage1_sp50_ft"
    # path_to_process = "saved_models_and_logs/pruning_unstructured_iterative/resnet50_prune_unstruct_it_l1_stage2_sp75_ft"
    # path_to_process = "saved_models_and_logs/pruning_unstructured_iterative/resnet50_prune_unstruct_it_l1_stage3_sp90_ft"

    # --- SET THE PATH HERE ---
    path_to_process = '../saved_models_and_logs/pruning_unstructured_iterative/resnet50_prune_unstruct_it_l1_stage3_sp90_ft' # <--- *** PASTE YOUR DIRECTORY PATH HERE ***

    if not path_to_process:
        print("Please set the 'path_to_process' variable in the script.")
    else:
        success = process_single_checkpoint_dir(path_to_process)
        if success:
            print(f"\nSuccessfully processed: {path_to_process}")
        else:
            print(f"\nProcessing encountered issues for: {path_to_process}")

    print("\n--- Script finished ---")

--- Processing directory: ../saved_models_and_logs/pruning_unstructured_iterative/resnet50_prune_unstruct_it_l1_stage3_sp90_ft ---
  Found one model file: model_final.pth
  Selected checkpoint for processing: model_final.pth


  checkpoint = torch.load(original_checkpoint_path, map_location='cpu')


  Saving extracted model_state_dict to: ../saved_models_and_logs/pruning_unstructured_iterative/resnet50_prune_unstruct_it_l1_stage3_sp90_ft\model_final.pth
  Successfully saved lean model to ../saved_models_and_logs/pruning_unstructured_iterative/resnet50_prune_unstruct_it_l1_stage3_sp90_ft\model_final.pth
  Overwrote 'model_final.pth' with its lean version.

Successfully processed: ../saved_models_and_logs/pruning_unstructured_iterative/resnet50_prune_unstruct_it_l1_stage3_sp90_ft

--- Script finished ---


In [2]:
# ============== Notebook Cell 1: Imports (No custom model definition needed here now) ==============
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import os
import glob
from collections import OrderedDict # To inspect keys easily
import shutil # For safely moving the original file
from torchvision import models # <--- IMPORT TORCHVISION MODELS

print(f"PyTorch version: {torch.__version__}")

PyTorch version: 2.5.1


In [18]:
# ============== Notebook Cell 2: Configuration and Helper Function ==============
MODEL_STATE_DICT_KEY_IN_CHECKPOINT = 'model_state_dict'
TARGET_LEAN_MODEL_FILENAME = "model_final.pth" # This will be the name of the cleaned model

def clean_pruned_model_in_directory(directory_path, model_architecture_instance_factory, num_classes_for_model): # num_classes_for_model is passed here
    """
    Processes a checkpoint in the given directory to make pruning permanent.
    Saves a lean model and backs up the original.
    model_architecture_instance_factory is a function that returns a new model instance e.g. lambda actual_num_classes: models.resnet50(num_classes=actual_num_classes)
    """
    if not os.path.isdir(directory_path):
        print(f"❌ Error: Directory not found: {directory_path}")
        return False
    print(f"\nProcessing Directory: {directory_path}")
    potential_files = glob.glob(os.path.join(directory_path, "*.pth")) + \
                      glob.glob(os.path.join(directory_path, "*.pt"))
    checkpoint_to_process_path = None
    if not potential_files:
        print(f"  🟡 Warning: No .pth or .pt files found in {directory_path}.")
        return False

    if len(potential_files) == 1:
        checkpoint_to_process_path = potential_files[0]
    else:
        print("  Multiple files found. Please select the checkpoint to process:")
        candidates = []
        for i, f_path in enumerate(potential_files):
            fname = os.path.basename(f_path)
            fsize_mb = os.path.getsize(f_path) / (1024 * 1024)
            print(f"    {i}: {fname} ({fsize_mb:.2f} MB)")
            candidates.append(f_path)
        while True:
            try:
                choice = int(input("  Enter the number of the file to process: "))
                if 0 <= choice < len(candidates):
                    checkpoint_to_process_path = candidates[choice]
                    break
                else: print("    Invalid choice.")
            except ValueError: print("    Invalid input.")
    if not checkpoint_to_process_path:
        print("  No file selected."); return False

    original_filename_to_backup = os.path.basename(checkpoint_to_process_path)
    print(f"  Selected for processing: {original_filename_to_backup}")
    lean_model_save_path = os.path.join(directory_path, TARGET_LEAN_MODEL_FILENAME)
    backup_file_path = os.path.join(directory_path, f"{os.path.splitext(original_filename_to_backup)[0]}_full_checkpoint_backup{os.path.splitext(original_filename_to_backup)[1]}")

    if os.path.exists(backup_file_path) and os.path.basename(checkpoint_to_process_path) != TARGET_LEAN_MODEL_FILENAME:
        print(f"  🟡 Backup file '{os.path.basename(backup_file_path)}' already exists for '{original_filename_to_backup}'.")
        if input(f"     Reprocess '{original_filename_to_backup}' anyway? (yes/no): ").strip().lower() != 'yes':
            print("     Skipping reprocessing."); return True
    try:
        print(f"  Loading checkpoint: {checkpoint_to_process_path}...")
        full_checkpoint_data = torch.load(checkpoint_to_process_path, map_location='cpu') # Load to CPU
        model_state_dict_from_checkpoint = None
        if isinstance(full_checkpoint_data, dict) and MODEL_STATE_DICT_KEY_IN_CHECKPOINT in full_checkpoint_data:
            model_state_dict_from_checkpoint = full_checkpoint_data[MODEL_STATE_DICT_KEY_IN_CHECKPOINT]
        elif isinstance(full_checkpoint_data, (OrderedDict, dict)): # If it's already a state_dict
            model_state_dict_from_checkpoint = full_checkpoint_data
        else:
            print(f"  ❌ Error: Loaded file type not recognized: {type(full_checkpoint_data)}"); return False

        # Create a new, clean model instance for this operation
        # Call the factory with the num_classes_for_model value positionally
        current_model_instance = model_architecture_instance_factory(num_classes_for_model) # <--- CORRECTED CALL
        print(f"  Created a new model instance (e.g., resnet50 with num_classes={num_classes_for_model}).")


        print("  Loading state_dict into model instance (strict=False)...")
        cleaned_sd_for_loading = OrderedDict()
        has_module_prefix = any(k.startswith('module.') for k in model_state_dict_from_checkpoint.keys())
        if has_module_prefix:
            print("  Detected 'module.' prefix in state_dict keys, removing it...")
            for k, v in model_state_dict_from_checkpoint.items():
                name = k[7:] if k.startswith('module.') else k
                cleaned_sd_for_loading[name] = v
        else:
            cleaned_sd_for_loading = model_state_dict_from_checkpoint
        
        current_model_instance.load_state_dict(cleaned_sd_for_loading, strict=False)


        print("  Making pruning permanent by applying prune.remove()...")
        for module_name, module_obj in current_model_instance.named_modules():
            for param_name_to_check in ['weight', 'bias']:
                try:
                    if hasattr(module_obj, param_name_to_check + "_orig"):
                        prune.remove(module_obj, param_name_to_check)
                except AttributeError:
                    pass
                except Exception as e_prune:
                    pass
        print("  Pruning removal process completed.")
        
        final_cleaned_state_dict = current_model_instance.state_dict()
        print(f"  Cleaned state_dict has {len(final_cleaned_state_dict.keys())} keys.")

        if os.path.exists(checkpoint_to_process_path):
            if checkpoint_to_process_path != backup_file_path:
                print(f"  Backing up original file '{os.path.basename(checkpoint_to_process_path)}' to '{os.path.basename(backup_file_path)}'...")
                shutil.move(checkpoint_to_process_path, backup_file_path)
        else:
            print(f"  Warning: Original file '{os.path.basename(checkpoint_to_process_path)}' not found for backup.")


        print(f"  Saving lean model to: {lean_model_save_path}")
        torch.save(final_cleaned_state_dict, lean_model_save_path)
        new_size_mb = os.path.getsize(lean_model_save_path) / (1024 * 1024)
        print(f"  ✅ Successfully saved lean model: {os.path.basename(lean_model_save_path)} ({new_size_mb:.2f} MB)")
        return True
    except Exception as e:
        print(f"  ❌ Error during processing of {checkpoint_to_process_path}: {e}")
        import traceback; traceback.print_exc()
        if os.path.exists(backup_file_path) and not os.path.exists(checkpoint_to_process_path):
            try: shutil.move(backup_file_path, checkpoint_to_process_path); print(f"  Restored '{os.path.basename(checkpoint_to_process_path)}'.")
            except Exception as e_restore: print(f"  Could not restore: {e_restore}")
        return False

In [4]:
# ============== Notebook Cell 3: Define Model Instantiation and Parameters (Run Once) ==============
# --- !!! ACTION REQUIRED: PART 1 - SET NUMBER OF CLASSES !!! ---
# This NUM_CLASSES should match the output dimension of the final fully connected layer
# of your ResNet50 models that were pruned.
# For full ImageNet, this is 1000. If you fine-tuned on a dataset with a
# different number of classes (e.g., ImageNet-mini if it has fewer, or CIFAR-10/100),
# set this value correctly.
NUM_CLASSES_FOR_PRUNED_MODELS = 1000  # <--- ADJUST IF YOUR PRUNED MODELS HAVE A DIFFERENT NUMBER OF CLASSES

# This lambda function will create a new instance of the standard ResNet50.
# We pass `weights=None` because we are going to load our own pruned weights.
# The `num_classes` will be passed from the `NUM_CLASSES_FOR_PRUNED_MODELS` variable.
resnet50_factory = lambda nc: models.resnet50(weights=None, num_classes=nc)

print(f"Will use standard torchvision.models.resnet50 with num_classes={NUM_CLASSES_FOR_PRUNED_MODELS} as the architecture.")



Will use standard torchvision.models.resnet50 with num_classes=1000 as the architecture.


In [21]:
# ============== Notebook Cell 4: Process a Single Directory (Run for each directory) ==============
# --- !!! ACTION REQUIRED: PART 2 - SET THE DIRECTORY PATH !!! ---
# Set `directory_to_process` to the full path of one of your
# "pruning_unstructured_iterative/..." subdirectories that contains a large checkpoint.

# Use a raw string (r"...") for Windows paths
directory_to_process = r"C:\Uni\deep_model_optimization\saved_models_and_logs\pruning_unstructured_iterative\resnet50_prune_unstruct_it_l1_stage3_sp90_ft" # <--- Path to the FOLDER

# OR use forward slashes:
# directory_to_process = "C:/Uni/deep_model_optimization/saved_models_and_logs/pruning_unstructured_iterative/resnet50_prune_unstruct_it_l1_stage1_sp50_ft" # <--- Path to the FOLDER


if not directory_to_process:
    print("⚠️ Please set 'directory_to_process' in Cell 4 with the path to the target directory.")
elif not os.path.exists(directory_to_process): # This checks if the DIRECTORY exists
    print(f"❌ Error: The DIRECTORY '{directory_to_process}' does not exist. Please check the path.")
elif not os.path.isdir(directory_to_process): # Add this check
    print(f"❌ Error: The path '{directory_to_process}' is not a DIRECTORY. Please provide the path to the folder containing the model file.")
else:
    print(f"\nAttempting to process directory: {directory_to_process}")
    success = clean_pruned_model_in_directory(directory_to_process, resnet50_factory, NUM_CLASSES_FOR_PRUNED_MODELS)
    if success:
        print(f"\n✅ Processing finished for {directory_to_process}")
    else:
        print(f"\n❌ Processing encountered issues for {directory_to_process}")


Attempting to process directory: C:\Uni\deep_model_optimization\saved_models_and_logs\pruning_unstructured_iterative\resnet50_prune_unstruct_it_l1_stage3_sp90_ft

Processing Directory: C:\Uni\deep_model_optimization\saved_models_and_logs\pruning_unstructured_iterative\resnet50_prune_unstruct_it_l1_stage3_sp90_ft
  Selected for processing: model_final.pth
  Loading checkpoint: C:\Uni\deep_model_optimization\saved_models_and_logs\pruning_unstructured_iterative\resnet50_prune_unstruct_it_l1_stage3_sp90_ft\model_final.pth...


  full_checkpoint_data = torch.load(checkpoint_to_process_path, map_location='cpu') # Load to CPU


  Created a new model instance (e.g., resnet50 with num_classes=1000).
  Loading state_dict into model instance (strict=False)...
  Making pruning permanent by applying prune.remove()...
  Pruning removal process completed.
  Cleaned state_dict has 320 keys.
  Backing up original file 'model_final.pth' to 'model_final_full_checkpoint_backup.pth'...
  Saving lean model to: C:\Uni\deep_model_optimization\saved_models_and_logs\pruning_unstructured_iterative\resnet50_prune_unstruct_it_l1_stage3_sp90_ft\model_final.pth
  ✅ Successfully saved lean model: model_final.pth (97.79 MB)

✅ Processing finished for C:\Uni\deep_model_optimization\saved_models_and_logs\pruning_unstructured_iterative\resnet50_prune_unstruct_it_l1_stage3_sp90_ft


In [22]:
# Quick check cell
import torch
from torchvision import models # For the model class

lean_model_path = r"C:\Uni\deep_model_optimization\saved_models_and_logs\pruning_unstructured_iterative\resnet50_prune_unstruct_it_l1_stage1_sp50_ft\model_final.pth"
num_classes = 1000 # Should match NUM_CLASSES_FOR_PRUNED_MODELS

model_check = models.resnet50(weights=None, num_classes=num_classes)
try:
    state_dict_check = torch.load(lean_model_path, map_location='cpu')
    model_check.load_state_dict(state_dict_check)
    print(f"Successfully loaded the new lean model: {lean_model_path}")
    print(f"Number of keys in its state_dict: {len(state_dict_check.keys())}")
    # print("First 10 keys:", list(state_dict_check.keys())[:10])
except Exception as e:
    print(f"Error loading or checking the new lean model: {e}")

Successfully loaded the new lean model: C:\Uni\deep_model_optimization\saved_models_and_logs\pruning_unstructured_iterative\resnet50_prune_unstruct_it_l1_stage1_sp50_ft\model_final.pth
Number of keys in its state_dict: 320


  state_dict_check = torch.load(lean_model_path, map_location='cpu')
