In [None]:
import torch
import torchvision.models as models
from collections import OrderedDict
from pathlib import Path

# Set random seeds for reproducibility
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

## Load Original Model

Load the original state dict and examine its structure.

In [1]:
# Load original state dict
original_path = 'Models/DenseNet121.pt'
state_dict = torch.load(original_path)

# Print original keys structure
print("Original state dict structure:")
for key in list(state_dict.keys())[:5]:
    print(f"  {key}")
print("  ...")

NameError: name 'torch' is not defined

## Create New Model

Initialize a fresh DenseNet121 model and examine its expected structure.

In [None]:
# Initialize new DenseNet121
new_model = models.densenet121(weights=None)

# Print expected keys structure
print("Expected model structure:")
for key in list(new_model.state_dict().keys())[:5]:
    print(f"  {key}")
print("  ...")

Expected model structure:
  features.conv0.weight
  features.norm0.weight
  features.norm0.bias
  features.norm0.running_mean
  features.norm0.running_var
  ...


## Preprocess State Dict

Convert the original state dict to match the expected structure.

In [None]:
# Create new state dict with correct mapping
new_state_dict = OrderedDict()

for k, v in state_dict.items():
    if k.startswith('backbone.0.'):
        # Remove backbone.0. prefix and map to features
        new_key = k.replace('backbone.0.', 'features.')
        new_state_dict[new_key] = v

print("\nNew state dict structure:")
for key in list(new_state_dict.keys())[:5]:
    print(f"  {key}")
print("  ...")

# Verify shapes match
print("\nVerifying tensor shapes:")
model_state = new_model.state_dict()
for key in list(new_state_dict.keys())[:5]:
    if key in model_state:
        print(f"  {key}: {new_state_dict[key].shape} -> {model_state[key].shape}")

## Validate and Save

Test loading the preprocessed state dict and save if successful.

In [None]:
# Try loading the state dict
try:
    print("Attempting to load preprocessed state dict...")
    missing_keys, unexpected_keys = new_model.load_state_dict(new_state_dict, strict=False)
    
    print(f"\nMissing keys: {len(missing_keys)}")
    if missing_keys:
        print("First few missing keys:")
        for key in missing_keys[:5]:
            print(f"  {key}")
    
    print(f"\nUnexpected keys: {len(unexpected_keys)}")
    if unexpected_keys:
        print("First few unexpected keys:")
        for key in unexpected_keys[:5]:
            print(f"  {key}")
    
    if not missing_keys and not unexpected_keys:
        print("\nState dict loaded successfully!")
        
    # Save preprocessed state dict
    output_path = 'Models/DenseNet121_processed.pt'
    torch.save(new_state_dict, output_path)
    print(f"\nSaved preprocessed model to: {output_path}")
    
except Exception as e:
    print(f"Error loading state dict: {str(e)}")

## Verify Saved Model

Load the saved preprocessed model to verify it works correctly.

In [None]:
# Try loading the saved preprocessed model
try:
    # Load the saved state dict
    print("Loading saved preprocessed model...")
    processed_state_dict = torch.load('Models/DenseNet121_processed.pt')
    
    # Create a fresh model and load the state dict
    test_model = models.densenet121(weights=None)
    test_model.load_state_dict(processed_state_dict, strict=False)
    
    print("Successfully loaded preprocessed model!")
    print(f"Model parameters: {sum(p.numel() for p in test_model.parameters()):,}")
    
except Exception as e:
    print(f"Error verifying saved model: {str(e)}")