## Read a .pth File in PyTorch

In [1]:
import torch

# Path to your .pth file
path = './checkpoints/miniimagenet/wrn/pt_map_bpa/run_5shot_augFalse_metrics/max_acc.pth'

# Load the checkpoint
checkpoint = torch.load(path, map_location='cpu')

# Show all keys in the state_dict
print("Keys in the checkpoint:")
for k in checkpoint.keys():
    print(k)

# Optional: Inspect one layer's weights
first_key = list(checkpoint.keys())[0]
print(f"\nFirst key: {first_key}")
print("Weights:", checkpoint[first_key])


Keys in the checkpoint:
conv1.weight
block1.layer.0.bn1.weight
block1.layer.0.bn1.bias
block1.layer.0.bn1.running_mean
block1.layer.0.bn1.running_var
block1.layer.0.bn1.num_batches_tracked
block1.layer.0.conv1.weight
block1.layer.0.bn2.weight
block1.layer.0.bn2.bias
block1.layer.0.bn2.running_mean
block1.layer.0.bn2.running_var
block1.layer.0.bn2.num_batches_tracked
block1.layer.0.conv2.weight
block1.layer.0.convShortcut.weight
block1.layer.1.bn1.weight
block1.layer.1.bn1.bias
block1.layer.1.bn1.running_mean
block1.layer.1.bn1.running_var
block1.layer.1.bn1.num_batches_tracked
block1.layer.1.conv1.weight
block1.layer.1.bn2.weight
block1.layer.1.bn2.bias
block1.layer.1.bn2.running_mean
block1.layer.1.bn2.running_var
block1.layer.1.bn2.num_batches_tracked
block1.layer.1.conv2.weight
block1.layer.2.bn1.weight
block1.layer.2.bn1.bias
block1.layer.2.bn1.running_mean
block1.layer.2.bn1.running_var
block1.layer.2.bn1.num_batches_tracked
block1.layer.2.conv1.weight
block1.layer.2.bn2.weight
bl

## All Layers and Their Shapes

In [2]:
for key, value in checkpoint.items():
    print(f"{key} -> shape: {value.shape}")

conv1.weight -> shape: torch.Size([16, 3, 3, 3])
block1.layer.0.bn1.weight -> shape: torch.Size([16])
block1.layer.0.bn1.bias -> shape: torch.Size([16])
block1.layer.0.bn1.running_mean -> shape: torch.Size([16])
block1.layer.0.bn1.running_var -> shape: torch.Size([16])
block1.layer.0.bn1.num_batches_tracked -> shape: torch.Size([])
block1.layer.0.conv1.weight -> shape: torch.Size([160, 16, 3, 3])
block1.layer.0.bn2.weight -> shape: torch.Size([160])
block1.layer.0.bn2.bias -> shape: torch.Size([160])
block1.layer.0.bn2.running_mean -> shape: torch.Size([160])
block1.layer.0.bn2.running_var -> shape: torch.Size([160])
block1.layer.0.bn2.num_batches_tracked -> shape: torch.Size([])
block1.layer.0.conv2.weight -> shape: torch.Size([160, 160, 3, 3])
block1.layer.0.convShortcut.weight -> shape: torch.Size([160, 16, 1, 1])
block1.layer.1.bn1.weight -> shape: torch.Size([160])
block1.layer.1.bn1.bias -> shape: torch.Size([160])
block1.layer.1.bn1.running_mean -> shape: torch.Size([160])
block

## How many parameters?

In [3]:
import torch

# Load checkpoint
path = './checkpoints/miniimagenet/wrn/pt_map_bpa/run_5shot_augFalse_metrics/max_acc.pth'
checkpoint = torch.load(path, map_location='cpu')

# Sometimes weights are under 'state_dict', sometimes not
state_dict = checkpoint.get('state_dict', checkpoint)

# Count total parameters
total_params = sum(v.numel() for v in state_dict.values())
print(f"Total parameters: {total_params:,}")


Total parameters: 36,490,761
