# Lab5B - Saving and Loading Models

In the process of training the model, you may stop the training temporarily and resume it later. You may also want to save the best model which may not be the model generated in the last iteration. More importantly, after completion of training, you want to deploy your model to the field. All this requires you to save and load the model.

#### Objectives:
In this practical, students learn how to:
1. Save and loading models 
2. Resume previous training

#### References:
1. [Saving and loading models](https://pytorch.org/tutorials/beginner/saving_loading_models.html)

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
cd "./gdrive/MyDrive/UCCD3074_Labs/UCCD3074_Lab5"

Import the required library.

In [None]:
# imports
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

from torch.utils.data import DataLoader

import torch.optim as optim
import os

from cifar10 import CIFAR10

In [None]:
if not os.path.exists("./models"):
    os.mkdir("models")

# 1. Introduction

When it comes to saving and loading models, there are three core functions to be familiar with:

1. **`torch.save`**
<br> Saves a serialized object to disk. This function uses Python’s pickle utility for serialization. Models, tensors, and dictionaries of all kinds of objects can be saved using this function.
2. **`torch.load`** 
<br> Uses pickle’s unpickling facilities to deserialize pickled object files to memory. This function also facilitates the device to load the data into (see Saving & Loading Model Across Devices).
3. **`torch.nn.Module.load_state_dict`**
<br> Loads a model’s parameter dictionary using a deserialized state_dict. 

#### What is a `state_dict()`?

* Each <u>model</u> has a `state_dict`. The model state_dict is simply a Python dictionary object that maps each layer to its parameter tensors stored in `model.parameters()`. `state_dict` stores the following tensors:
  * learnable parameters (convolutional layers, linear layers, etc.)
  * registered buffers (batchnorm's running mean).

* The <u>optimizer object</u> (`torch.optim`) also have a `state_dict`, which contains information about 
  * the optimizer's state
  * the hyperparameters used.

Because `state_dict` objects are Python dictionaries, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to PyTorch models and optimizers.

First, let's build our model.

In [None]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1)   
        self.bn1 = nn.BatchNorm2d(8)
        
        self.conv2 = nn.Conv2d(8, 16, 3)  
        self.bn2   = nn.BatchNorm2d(16)
        
        self.fc1 = nn.Linear(16*30*30, 256) 
        self.fc2 = nn.Linear(256, 10)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        
        x = x.view(x.size(0), -1) # flat
        x = self.fc1(x)
        x = self.fc2(x)
        
        return x

The following shows the `state_dict` of the model. Note that `state_dict` stores not only the *parameters* (weight and bias) of the trainable layers but also the *running mean* of the batch norm layer.

In [None]:
model = Net()

In [None]:
#... your code here ...

The following code shows the `state_dict` of the optimizer. It stores the *hyperparameter* settings (e.g., `lr`, `momentum`, `dampening`, `weight_decay`, `nesterov`) as well as the *optimizer* states (`params`)

In [None]:
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

In [None]:
#... your code here ...

---

## 1.1 Saving & Loading Model Parameters Only

**`torch.save(model.state_dict(), PATH)`**

When saving a model for inference, it is only necessary to save the trained model’s learned parameters. We do not need to save the network structure itself. To do that, use the command `torch.save()`. A common PyTorch convention is to save models using either a `.pt` or `.pth` file extension.


In [None]:
#... your code here ...

model.**`load_state_dict(torch.load(PATH))`**

To load the model parameters, use the model's function `load_state_dict()`. `load_state_dict()` takes a dictionary object, NOT a path to a saved object. So, you must deserialize the saved state_dict first (`torch.load(PATH)`) before you pass it to the `load_state_dict()` function. 

In [None]:
#... your code here ...

---
## 1.2 Saving the Entire Model

The previous method only saves the model *parameters* but not the *network* itself. As a result, the saved parameters must be accompanied by the *model class*, i.e., the class `Net`, so that we can create the *network* first before loading the parameters. Because of this, your code can break in various ways when used in other projects or after refactors.


**`torch.save(model, PATH)`**

You may save the whole model and use it for inference by providing `model` rather than `model.state_dict()` as the argument for `torch.save`. This eliminates the need to attach the model class together with your saved model file.

In [None]:
#... your code here ...

**model = `torch.load(PATH)`**

When we load, we load both the network and the model. There is no need for us to create the model first: `new_model2 = Net()`.

In [None]:
#... your code here ...

**Caution**: 

* If you are doing inference, remember that you must call `model.eval()` to set *dropout* and *batch normalization* layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results. 

* If you wish to resuming training, call `model.train()` to ensure these layers are in training mode.

---
## 1.3 Saving the Model Parameters and Optimizer State

It is common to train your model in multiple session where you stop the training temporarily and resume it only at a later day. To do this you need to save **checkpoints**. 

When saving a checkpoint, to be used for either inference or resuming training, you must save more than just the model’s state_dict. It is important to also save:
1. optimizer's state_dict 
2. model's state_dict 
3. current epoch number
4. training loss
5. others

Assume the following as the current state of training.

In [None]:
epoch = 0
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
loss = np.inf

To save multiple components, you can organize them into a dictionary and use `torch.save()` to serialize the dictionary. A common PyTorch convention is to save these checkpoints using the `.tar` file extension.

In [None]:
#... your code here ...

In [None]:
#... your code here ...

First, load the *network's parameters* and *optimizer's state*. For the *optimizer*, the learning rate (`lr`) is a compulsory argument. It will be overwritten when we load the saved optimizer's state.

In [None]:
#... your code here ...

Since you wish to resuming training, remember to call `model.train()` to ensure that that the dropout and batch normalization layers are in training mode.

In [None]:
#... your code here ...

Now you are ready to resume your training.

---
# 2. Example

## Load the dataset
We will use the CIFAR10 dataset for example

In [None]:
# transform the model
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# dataset
trainset = CIFAR10(train=True,  transform=transform, num_samples=10000)
validset  = CIFAR10(train=False,  transform=transform, num_samples=2000)

# dataloader]
trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
validloader  = DataLoader(validset, batch_size=128, shuffle=True, num_workers=2)

## Define training function

First, we define our training model. To allow the model to resume training, we do the following:
1. Define the `model` and `optimizer` outside the `train` function
2. Save the model at the end of each epoch (`line 56` to `62`)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def train(model, optimizer, start_epoch=0, max_epochs=10):
    
    # compute loss 3 times in each epoch
    loss_iterations = int(np.ceil(len(trainloader)/3))
    
    # transfer model to GPU
    model = model.to(device)
    
    # set the optimizer. Use SGD with momentum
    
    # set to training mode
    model.train()
        
    # train the network
    #... your code here ... 

        running_loss = 0
        running_count = 0

        for i, (inputs, labels) in enumerate(trainloader):

            # Clear all the gradient to 0
            optimizer.zero_grad()

            # transfer data to GPU
            inputs = inputs.to(device)
            labels = labels.to(device)

            # forward propagation to get h
            outs = model(inputs)

            # compute loss 
            loss = F.cross_entropy(outs, labels)

            # backpropagation to get gradients of all parameters
            loss.backward()

            # update parameters
            optimizer.step()

            # get the loss
            running_loss += loss.item()
            running_count += 1

             # display the averaged loss value 
            if i % loss_iterations == loss_iterations-1 or i == len(trainloader) - 1:    
                # compute training loss
                train_loss = running_loss / running_count
                running_loss = 0. 
                running_count = 0.
               
                print(f'[Epoch {e+1:2d} Iter {i+1:5d}/{len(trainloader)}]: train_loss = {train_loss:.4f}')       
            
        
        # save the model 
        ... your code here ...

## Train model 

Train the model for 2 epochs

In [None]:
lr=0.01; momentum=0.9

model = Net()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

train(model, optimizer, max_epochs=2)

## Resume training

Resume training and train for another 2 epochs. To do that, we get the load the *previous* model's and optimizer's `state_dict`, the last epoch and training loss value.

In [None]:
# define a new model
#... your code here ...

# define a new optimizer
#... your code here ...

# load the checkpoint file
#... your code here ...

# resume training
print(f'Resuming previous epoch. Last run epoch: {previous_epoch+1}, last run loss: {previous_loss:.4f}')
#... your code here ...

<center> --- END OF LAB --- </center>