# Storing and Loading Models

https://pytorch.org/tutorials/beginner/saving_loading_models.html

## Init, helpers, utils, ...

In [1]:
%matplotlib inline

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
from pprint import pprint
import matplotlib.pyplot as plt
import numpy as np
from IPython.core.debugger import set_trace

# `state_dict()`

## `nn.Module.state_dict()`
`nn.Module` contain state dict, that maps each layer to the learnable parameters.

In [4]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [5]:
model = Net()

In [6]:
model.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[-0.0214, -0.0616, -0.0166, -0.0885,  0.0492],
                        [-0.0103,  0.0270,  0.0554, -0.0256, -0.0176],
                        [-0.0469, -0.1120,  0.1011, -0.0898,  0.0791],
                        [-0.0761,  0.0606,  0.0247,  0.0163, -0.0270],
                        [-0.0633,  0.0729, -0.0075,  0.0278, -0.0547]],
              
                       [[-0.0757, -0.0462, -0.0290,  0.0887,  0.0608],
                        [-0.0527,  0.0375,  0.0631, -0.0958, -0.0470],
                        [-0.0900, -0.0750, -0.0244,  0.1147,  0.0856],
                        [ 0.0461,  0.0735, -0.0351,  0.0280, -0.0574],
                        [ 0.0125, -0.0666,  0.0314, -0.0634,  0.0962]],
              
                       [[-0.0227, -0.0003,  0.0399,  0.1068,  0.0646],
                        [ 0.0405, -0.0156, -0.0906, -0.0675,  0.0462],
                        [ 0.0656,  0.0022, -0.0086, -0.0271, -0.0277],
               

In [7]:
def state_dict_info(obj):
    print(f"{'layer':25} shape")
    print("===================================================")
    for k,v in obj.state_dict().items():
        try:
            print(f"{k:25} {v.shape}")
        except AttributeError:
            print(f"{k:25} {v}")

In [8]:
state_dict_info(model)

layer                     shape
conv1.weight              torch.Size([6, 3, 5, 5])
conv1.bias                torch.Size([6])
conv2.weight              torch.Size([16, 6, 5, 5])
conv2.bias                torch.Size([16])
fc1.weight                torch.Size([120, 400])
fc1.bias                  torch.Size([120])
fc2.weight                torch.Size([84, 120])
fc2.bias                  torch.Size([84])
fc3.weight                torch.Size([10, 84])
fc3.bias                  torch.Size([10])


## `nn.Optimizer`

Optimizers also have a a `state_dict`.

In [9]:
optimizer = optim.Adadelta(model.parameters())

In [10]:
state_dict_info(optimizer)

layer                     shape
state                     {}
param_groups              [{'lr': 1.0, 'rho': 0.9, 'eps': 1e-06, 'weight_decay': 0, 'params': [140517483101424, 140517483101352, 140517483101280, 140517483100488, 140517483099840, 140517483822032, 140517483821240, 140517483822968, 140517483821528, 140519793752320]}]


In [11]:
optimizer.state_dict()["state"]

{}

In [12]:
optimizer.state_dict()["param_groups"]

[{'lr': 1.0,
  'rho': 0.9,
  'eps': 1e-06,
  'weight_decay': 0,
  'params': [140517483101424,
   140517483101352,
   140517483101280,
   140517483100488,
   140517483099840,
   140517483822032,
   140517483821240,
   140517483822968,
   140517483821528,
   140519793752320]}]

## Storing and loading `state_dict`

In [13]:
model_file = "model_state_dict.pt"
torch.save(model.state_dict(), model_file)

In [14]:
model = Net()
model.load_state_dict(torch.load(model_file))

<All keys matched successfully>

## Storing and loading the full model

In [15]:
model_file = "model_123.pt"
torch.save(model, model_file)

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


In [16]:
# Only works if code for `Net` is available right now
model = torch.load(model_file)

# Example Checkpointing
You can store model, optimizer and arbitrary information and reload it.

Example:
```python
torch.save(
    {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'loss': loss,
    },
    PATH,
)
```

# Exercise
- Find out what is going to be in the `state` variable of the `state_dict` of an optimizer.
- Write your own checkpoint functionality.