# DenseNet121 Model Preprocessing

This notebook preprocesses the DenseNet121 model state dict to make it compatible with torchvision's DenseNet121 implementation.

In [9]:
import torch
import torchvision.models as models
from collections import OrderedDict
from pathlib import Path

# Set random seeds for reproducibility
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    
# Ensure Models directory exists
Path('Models').mkdir(exist_ok=True)

## Load Original Model

Load the original state dict and examine its structure.

In [10]:
# Load original state dict
original_path = Path('Models/DenseNet121.pt')
if not original_path.exists():
    raise FileNotFoundError(f"Original model file not found at: {original_path}")

print(f"Loading model from: {original_path}")
state_dict = torch.load(original_path)

# Print original keys structure
print("\nOriginal state dict structure:")
for key in list(state_dict.keys())[:5]:
    print(f"  {key}")
print("  ...")

# Print total number of parameters
print(f"\nTotal parameters: {sum(v.numel() for v in state_dict.values()):,}")

Loading model from: Models/DenseNet121.pt

Original state dict structure:
  backbone.0.conv0.weight
  backbone.0.norm0.weight
  backbone.0.norm0.bias
  backbone.0.norm0.running_mean
  backbone.0.norm0.running_var
  ...

Total parameters: 7,037,625


## Create New Model

Initialize a fresh DenseNet121 model and examine its expected structure.

In [11]:
# Initialize new DenseNet121
print("Initializing new DenseNet121...")
new_model = models.densenet121(weights=None)

# Print expected keys structure
print("\nExpected model structure:")
model_state = new_model.state_dict()
for key in list(model_state.keys())[:5]:
    print(f"  {key}")
print("  ...")

# Print total parameters
print(f"\nTotal parameters: {sum(v.numel() for v in model_state.values()):,}")

Initializing new DenseNet121...

Expected model structure:
  features.conv0.weight
  features.norm0.weight
  features.norm0.bias
  features.norm0.running_mean
  features.norm0.running_var
  ...

Total parameters: 8,062,625


## Preprocess State Dict

Convert the original state dict to match the expected structure.

In [12]:
# Create new state dict with correct mapping
print("Creating new state dict with correct mapping...")
new_state_dict = OrderedDict()

# Track parameter count to verify mapping
param_count = 0

for k, v in state_dict.items():
    if k.startswith('backbone.0.'):
        # Remove backbone.0. prefix and map to features
        new_key = k.replace('backbone.0.', 'features.')
        new_state_dict[new_key] = v
        param_count += v.numel()

print("\nNew state dict structure:")
for key in list(new_state_dict.keys())[:5]:
    print(f"  {key}")
print("  ...")

print(f"\nMapped parameters: {param_count:,}")

# Verify shapes match
print("\nVerifying tensor shapes:")
for key in list(new_state_dict.keys())[:5]:
    if key in model_state:
        print(f"  {key}: {new_state_dict[key].shape} -> {model_state[key].shape}")

Creating new state dict with correct mapping...

New state dict structure:
  features.conv0.weight
  features.norm0.weight
  features.norm0.bias
  features.norm0.running_mean
  features.norm0.running_var
  ...

Mapped parameters: 7,037,625

Verifying tensor shapes:
  features.conv0.weight: torch.Size([64, 3, 7, 7]) -> torch.Size([64, 3, 7, 7])
  features.norm0.weight: torch.Size([64]) -> torch.Size([64])
  features.norm0.bias: torch.Size([64]) -> torch.Size([64])
  features.norm0.running_mean: torch.Size([64]) -> torch.Size([64])
  features.norm0.running_var: torch.Size([64]) -> torch.Size([64])


## Validate and Save

Test loading the preprocessed state dict and save if successful.

In [20]:
# Try loading the state dict
print("Attempting to load preprocessed state dict...")
try:
    missing_keys, unexpected_keys = new_model.load_state_dict(new_state_dict, strict=False)
    
    print(f"\nMissing keys: {len(missing_keys)}")
    if missing_keys:
        print("First few missing keys:")
        for key in missing_keys[:5]:
            print(f"  {key}")
    
    print(f"\nUnexpected keys: {len(unexpected_keys)}")
    if unexpected_keys:
        print("First few unexpected keys:")
        for key in unexpected_keys[:5]:
            print(f"  {key}")
    
    if not missing_keys and not unexpected_keys:
        print("\nState dict loaded successfully!")
        
    # Save preprocessed state dict
    output_path = 'DenseNet121_processed.pt'
    torch.save(new_state_dict, output_path)
    print(f"\nSaved preprocessed model to: {output_path}")
    
except Exception as e:
    print(f"Error processing state dict: {str(e)}")
    raise

Attempting to load preprocessed state dict...

Missing keys: 2
First few missing keys:
  classifier.weight
  classifier.bias

Unexpected keys: 0

Saved preprocessed model to: DenseNet121_processed.pt


## Verify Saved Model

Load the saved preprocessed model to verify it works correctly.

In [21]:
# Try loading the saved preprocessed model
try:
    output_path = Path('DenseNet121_processed.pt')
    if not output_path.exists():
        raise FileNotFoundError(f"Processed model file not found at: {output_path}")
        
    # Load the saved state dict
    print("Loading saved preprocessed model...")
    processed_state_dict = torch.load(output_path)
    
    # Create a fresh model and load the state dict
    test_model = models.densenet121(weights=None)
    test_model.load_state_dict(processed_state_dict, strict=False)
    
    print("Successfully loaded preprocessed model!")
    print(f"Model parameters: {sum(p.numel() for p in test_model.parameters()):,}")
    
except Exception as e:
    print(f"Error verifying saved model: {str(e)}")
    raise

Loading saved preprocessed model...
Successfully loaded preprocessed model!
Model parameters: 7,978,856
