# View PyTorch Checkpoint File

This notebook loads and displays information about the VAE checkpoint file.

In [4]:
import torch
import os
from pathlib import Path

# Path to the checkpoint file
checkpoint_path = Path('../checkpoint/vae_epoch_5.pt')

# Check if file exists
if not checkpoint_path.exists():
    print(f"File not found: {checkpoint_path}")
else:
    print(f"Loading checkpoint from: {checkpoint_path}")

Loading checkpoint from: ..\checkpoint\vae_epoch_5.pt


In [5]:
# Load the checkpoint
try:
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    print("Checkpoint loaded successfully!")
except Exception as e:
    print(f"Error loading checkpoint: {e}")

Checkpoint loaded successfully!


In [6]:
# Display the checkpoint structure
def explore_checkpoint(checkpoint):
    if isinstance(checkpoint, dict):
        print("Checkpoint contains the following keys:")
        for key in checkpoint:
            print(f"- {key}")
            
            # If it's a state_dict, show its structure
            if key == 'state_dict' or key == 'model_state_dict' or 'state_dict' in key:
                print("\nModel state dictionary structure:")
                for param_name, param in checkpoint[key].items():
                    if hasattr(param, 'shape'):
                        print(f"  {param_name}: {param.shape}")
                    else:
                        print(f"  {param_name}: {type(param)}")
    else:
        print(f"Checkpoint is not a dictionary. Type: {type(checkpoint)}")

# Run the exploration function
if 'checkpoint' in locals():
    explore_checkpoint(checkpoint)

Checkpoint contains the following keys:
- epoch
- model_state_dict

Model state dictionary structure:
  encoder.conv_in.conv.weight: torch.Size([128, 3, 3, 3, 3])
  encoder.conv_in.conv.bias: torch.Size([128])
  encoder.down.0.block.0.norm1.weight: torch.Size([128])
  encoder.down.0.block.0.norm1.bias: torch.Size([128])
  encoder.down.0.block.0.conv1.conv.weight: torch.Size([128, 128, 3, 3, 3])
  encoder.down.0.block.0.conv1.conv.bias: torch.Size([128])
  encoder.down.0.block.0.norm2.weight: torch.Size([128])
  encoder.down.0.block.0.norm2.bias: torch.Size([128])
  encoder.down.0.block.0.conv2.conv.weight: torch.Size([128, 128, 3, 3, 3])
  encoder.down.0.block.0.conv2.conv.bias: torch.Size([128])
  encoder.down.0.block.1.norm1.weight: torch.Size([128])
  encoder.down.0.block.1.norm1.bias: torch.Size([128])
  encoder.down.0.block.1.conv1.conv.weight: torch.Size([128, 128, 3, 3, 3])
  encoder.down.0.block.1.conv1.conv.bias: torch.Size([128])
  encoder.down.0.block.1.norm2.weight: torch.S

In [None]:
# Additional detailed exploration of specific parts (if needed)
# For example, if you want to see specific model parameters

if 'checkpoint' in locals():
    # Extract model parameters if they exist
    if isinstance(checkpoint, dict):
        # Try common state dict keys
        state_dict_key = None
        for key in ['state_dict', 'model_state_dict', 'model']:
            if key in checkpoint:
                state_dict_key = key
                break
                
        if state_dict_key:
            print(f"First few parameters from '{state_dict_key}':")
            for i, (name, param) in enumerate(list(checkpoint[state_dict_key].items())[:5]):
                if hasattr(param, 'shape'):
                    print(f"Parameter {i}: {name}, Shape: {param.shape}")
                    # Print a sample of values from the tensor
                    flat_tensor = param.flatten()
                    if len(flat_tensor) > 0:
                        sample_size = min(5, len(flat_tensor))
                        print(f"Sample values: {flat_tensor[:sample_size]}")
                    print()