# Storing and Loading Models

https://pytorch.org/tutorials/beginner/saving_loading_models.html

## Init, helpers, utils, ...

In [1]:
%matplotlib inline

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
from pprint import pprint
import matplotlib.pyplot as plt
import numpy as np
from IPython.core.debugger import set_trace

# `state_dict()`

## `nn.Module.state_dict()`
`nn.Module` contain state dict, that maps each layer to the learnable parameters.

In [4]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [5]:
model = Net()

In [6]:
model.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[ 0.0578, -0.0653,  0.0616,  0.0788, -0.0898],
                        [ 0.0827,  0.0199,  0.0880,  0.0648, -0.0648],
                        [-0.0666, -0.0491, -0.0371, -0.0871, -0.0411],
                        [-0.1079,  0.0122, -0.0199,  0.0675,  0.1119],
                        [ 0.0651,  0.0783, -0.1097, -0.0713,  0.1117]],
              
                       [[ 0.0551,  0.0709,  0.0744, -0.0970,  0.0294],
                        [ 0.0730,  0.0558,  0.1112,  0.0641, -0.0597],
                        [ 0.0637, -0.0153,  0.0798, -0.0198,  0.0298],
                        [-0.0264, -0.0163,  0.0942, -0.0682, -0.0921],
                        [ 0.0118, -0.0533, -0.0815, -0.1112, -0.0123]],
              
                       [[ 0.0511, -0.0618, -0.0692,  0.0200,  0.0716],
                        [ 0.0850, -0.0018, -0.0734,  0.0281, -0.0406],
                        [-0.0960,  0.0325,  0.0582, -0.0778, -0.0274],
               

In [7]:
def state_dict_info(obj):
    print(f"{'layer':25} shape")
    print("===================================================")
    for k,v in obj.state_dict().items():
        try:
            print(f"{k:25} {v.shape}")
        except AttributeError:
            print(f"{k:25} {v}")

In [8]:
state_dict_info(model)

layer                     shape
conv1.weight              torch.Size([6, 3, 5, 5])
conv1.bias                torch.Size([6])
conv2.weight              torch.Size([16, 6, 5, 5])
conv2.bias                torch.Size([16])
fc1.weight                torch.Size([120, 400])
fc1.bias                  torch.Size([120])
fc2.weight                torch.Size([84, 120])
fc2.bias                  torch.Size([84])
fc3.weight                torch.Size([10, 84])
fc3.bias                  torch.Size([10])


## `nn.Optimizer`

Optimizers also have a a `state_dict`.

In [9]:
optimizer = optim.Adadelta(model.parameters())

In [10]:
state_dict_info(optimizer)

layer                     shape
state                     {}
param_groups              [{'lr': 1.0, 'rho': 0.9, 'eps': 1e-06, 'weight_decay': 0, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]


In [11]:
optimizer.state_dict()["state"]

{}

In [12]:
optimizer.state_dict()["param_groups"]

[{'lr': 1.0,
  'rho': 0.9,
  'eps': 1e-06,
  'weight_decay': 0,
  'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]

## Storing and loading `state_dict`

In [13]:
model_file = "model_state_dict.pt"
torch.save(model.state_dict(), model_file)

In [14]:
model = Net()
model.load_state_dict(torch.load(model_file))

<All keys matched successfully>

## Storing and loading the full model

In [15]:
model_file = "model_123.pt"
torch.save(model, model_file)

In [16]:
# Only works if code for `Net` is available right now
model = torch.load(model_file)

# Example Checkpointing
You can store model, optimizer and arbitrary information and reload it.

Example:
```python
torch.save(
    {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'loss': loss,
    },
    PATH,
)
```

# Exercise
- Find out what is going to be in the `state` variable of the `state_dict` of an optimizer.
- Write your own checkpoint functionality.