__Saving and loading models__

1. [Load tools](#Load-tools)
1. [Saving and loading models](#Saving-and-loading-models)
1. [What is a state_dict?](#What-is-a-state_dict?)
1. [Saving & loading model for inference](#Saving-&-loading-model-for-inference)
1. [Saving & loading a general checkpoint](#Saving-&-loading-a-general-checkpoint)

# Load tools

<a id = 'Load-tools'></a>

In [2]:
# Standard libary and settings
import os
import sys
import warnings

warnings.simplefilter("ignore")
from IPython.core.display import display, HTML

display(HTML("<style>.container { width:95% !important; }</style>"))

# Data extensions and settings
import numpy as np

np.set_printoptions(threshold=np.inf, suppress=True)
import pandas as pd

pd.set_option("display.max_rows", 500)
pd.set_option("display.max_columns", 500)
pd.options.display.float_format = "{:,.6f}".format

# import PyTorch
import torch
from torch.utils.data import Dataset, DataLoader
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.jit import script, trace
import torchvision
import torchvision.transforms as transforms

# Visualization extensions and settings
import seaborn as sns
import matplotlib.pyplot as plt

# Magic functions
%matplotlib inline

# Saving and loading models

There are three core functions:

1. torch.save - saves a serialzied object to disk. This uses pickle for serialization. Models, Tensors and dictionaries can be saved using this function.

2. torch.load - Uses pickle to deserialize pickled objects.

3. torch.nn.Module.load_state_dict - Loads a model's parameter dciontary.


<a id = 'Saving-and-loading-models'></a>

# What is a state_dict?

The torch.nn.Module learnable parameters are contained in the model's parameters (accessed via model.parameters(). A state_dict is a simple Python dictionary that maps each model layer to its parameter tensor. Only layers with learnable parameters have entries in the model's state_dict.

Optimizer objects (torch.optim) also have a state_dict that contains information about the optimizer's state and the hyperparameters used.

Since the state_dicts are just python dictionaries, they can easily be saved, restored, updated and changed.

<a id = 'What-is-a-state_dict?'></a>

In [26]:
#
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = TheModelClass()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("model's state_dict:\n")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("\noptimizer's state_dict:\n")
for var_tensor in optimizer.state_dict():
    print(var_tensor, "\t", optimizer.state_dict()[var_tensor])

model's state_dict:

conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])

optimizer's state_dict:

state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2106751666000, 2106751665640, 2106755451640, 2106751664920, 2106751665208, 2106751666072, 2106751665856, 2106751665928, 2106751665280, 2106755268680]}]


# Saving & loading model for inference

A common PyTorch convention is to save models using either a .pt or .pth file extension. When saving a model for inference, only the model's learned parameters are needed. In order to run the model for inference based on the saved parameters, model.evalu() needs to be called in order to set any dropout and batch normalization layers to evaluation mode.

The load_state_dict() method takes a dictionary, not a path. The saved state_dict needs to be deserialized before it is passed to load_state_dict().


<a id = 'Saving-&-loading-model-for-inference'></a>

In [27]:
# save
torch.save(model.state_dict(), "saved_model_params_1.pth")

In [28]:
# load
model = TheModelClass()
model.load_state_dict(torch.load("saved_model_params_1.pth"))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [29]:
# save entire model
torch.save(model, "saved_model_1.pth")

  "type " + obj.__name__ + ". It won't be checked "


In [None]:
# load entire model
model = torch.load("saved_model_1.pth")
model.eval()

# Saving & loading a general checkpoint

This process can be used to save a checkpoint for inference and/or resuming training. In this case, it is important to save the optimizer's state_dict in addition to the model's state_dict. Other objects to save may include the last epoch completed, the latest recorded training loss, any torch.nn.Embedding layers, and so on.

Saving multiple components can be accomplished by passing the objects into torch.save() using a dictionary. This type of check point is typically stored as a .tar file.

```python
torch.save({'epoch' : epoch
           ,'model_State_dict' : model.state_dict()
           ,'optimizer_state_dict' : optimizer.state_dict()
           ,'loss' : loss
           }, 'saved_ckpt_1.tar')
```

To load the saved items, the model and optimizer need to be initialized, then the dictionary can be restored using torch.load()

```python
model = TheModelClass()
optimizer = TheOptimizerClass()

checkpoint = torch.load('saved_ckpt_1.tar')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
```

For inference mode, run model.eval() as before. To resume training, simply call model.train() to put all layers in training mode.

```python
model.eval()
# or
model.train()
```

<a id = 'Saving-&-loading-a-general-checkpoint'></a>