# Save and load the model 
In this unit we will look at how to persist model state with saving, loading and running model predictions.

In [4]:
import torch
import torch.onnx as onnx
import torchvision.models as models

# Saving and loading model weights 
PyTorch models store the learned parameters in an internal state dictionary, called state_dict. These can be persisted via the torch.save method:

In [5]:
model = models.vgg16(pretrained = True)
torch.save(model.state_dict(), 'data/model_weights.pth')

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


  0%|          | 0.00/528M [00:00<?, ?B/s]

FileNotFoundError: ignored

To load model weights, you need to create an instance of the same model first, and then load the parameters using the load_state_dict() method.

In [None]:
model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights
model.load_state_dict(torch.load('data/model_weights.pth'))
model.eval()

# Saving and loading models with shapes 
When loading model weights, we needed to instantiate the model class first, because the class defines the structure of a network. We might want to save the structure of this class together with the model, in which case we can pass model (and not model.state_dict()) to the saving function:



In [None]:
torch.save(model, 'data/vgg_model.pth')

# We can then load the model like this:

In [None]:
model = torch.load('data/vgg_model.pth')

# Export the model to ONNX

In [None]:
input_image = torch.zeros((1,3,224,224))
onnx.export(model, input_image, 'data/model.onnx')