# Model Saving

PyTorch provides several options for model saving. Here we will look at both the native `torch.load` function (which saves models as pickle files) and the Open Neural Network Exchange (ONNX) format, which provides a standardized and more secure format for saving and loading models.

In [2]:
import torch
import torchvision.models as models

## Saving Model Weights Only

To save model weights, we can use the `torch.save` function, and save the model's `.state_dict()` member. This saves the models weights themselves as a pickle file.

In [3]:
model = models.vgg16(weights="IMAGENET1K_V1")
torch.save(model.state_dict(), "models/model_weights.pth")

Once the model weights are saved off, we can load them in with `.load_state_dict`.

In [4]:
model = (
    models.vgg16()
)  # we do not specify ``weights``, i.e. create untrained model
model.load_state_dict(
    torch.load("models/model_weights.pth", weights_only=True)
)
model.eval()

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

## Saving Model Architecture and Weights

If we want to save the *entire* model (architecture and weights), we can save the model itself, rather than the model's `model.state_dict` only.

In [5]:
torch.save(model, "models/model.pth")

When we load models saved in this way, we set `weights_only=False` to let `torch` know we want to load the architecture and the weights.

In [6]:
model = (torch.load("models/model.pth", weights_only=False),)

## Open Neural Network Exchange (ONNX)

When we save models off for prototyping, we can use the pickle format, but when we save them off for deployment to production, we need to use a more rigorous format.

Pickle comes with challenges both in interoperability and security.

Note the following warning that comes from `torch.load`:

```text
/tmp/ipykernel_11004/1873844565.py:27: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  model.load_state_dict(torch.load('models/mnist_cnn.pt'))
```

Pickle files are specific to the Python ecosystem, which makes them difficult to use outside of Python, and are also loaded by executing code, which makes them a vulnerability.

As an alternative, we can use ONNX, which is quickly becoming the standard cross-environment model specification format.

You can read more about ONNX [here](https://onnx.ai/) and read about ONNX in PyTorch [here](https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html).

Here we stick with our MNIST example and show how to save and load our model with ONNX, using ONNX runtime to load it in a way that is not dependent on PyTorch.

### Imports

In [7]:
# ONNX
import torch.onnx

# Model building
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

# Data loading
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import v2

# Basics
import matplotlib.pyplot as plt
import numpy as np

# Seed and device configuration
torch.manual_seed(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load Data

In [8]:
transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
dataset = datasets.MNIST(
    root="example_data", download=True, transform=transform
)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, test_size]
)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

### Define a Model

Here we define a model (we don't train it since we just want to show how to save and load it).

In [9]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        y = F.log_softmax(x, dim=1)
        return y


model = SimpleCNN()
model.load_state_dict(torch.load("models/mnist_cnn.pt", weights_only=True))

<All keys matched successfully>

## Save the Model

Here we set the model to evaluation mode and export it to ONNX.

To export to ONNX, we pass:
* the model itself,
* a dummy input so the model's export can be verified,
* a path to store it,
* a flag indicating whether to story the trained parameters or just the architecture,
* other information including names for the model's inputs and outputs.

In [10]:
batch_size = 64
model.eval()

# Input to the model
x = torch.randn(batch_size, 1, 28, 28, requires_grad=True)
torch_out = model(x)

# Export the model
torch.onnx.export(
    model,  # model being run
    x,  # model input (or a tuple for multiple inputs)
    "models/super_resolution.onnx",  # where to save the model (can be a file or file-like object)
    export_params=True,  # store the trained parameter weights inside the model file
    opset_version=10,  # the ONNX version to export the model to
    do_constant_folding=True,  # whether to execute constant folding for optimization
    input_names=["input"],  # the model's input names
    output_names=["output"],  # the model's output names
    dynamic_axes={
        "input": {0: "batch_size"},  # variable length axes
        "output": {0: "batch_size"},
    },
)

## Load the Model

Here we load the model, and use ONNX runtime to run inference on it. This inference has no dependency on PyTorch, is faster than PyTorch, and is guaranteed to provide the same result across runtime environments.

Even within the Python ecosystem, loading models with ONNX runtime within our production environment provides separation between our training and inference code, and minimizes dependencies in our inference environment which makes our software supply chain easier to manage.

In [11]:
import onnx

onnx_model = onnx.load("models/super_resolution.onnx")
onnx.checker.check_model(onnx_model)

In [12]:
import onnxruntime

ort_session = onnxruntime.InferenceSession(
    "models/super_resolution.onnx", providers=["CPUExecutionProvider"]
)


def to_numpy(tensor):
    return (
        tensor.detach().cpu().numpy()
        if tensor.requires_grad
        else tensor.cpu().numpy()
    )


# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(
    to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05
)

print(
    "Exported model has been tested with ONNXRuntime, and the result looks good!"
)

Exported model has been tested with ONNXRuntime, and the result looks good!


In [18]:
import time

x = torch.randn(batch_size, 1, 28, 28, requires_grad=True)

start = time.time()
torch_out = model(x)
end = time.time()
d_pt = end - start
print(f"Inference of Pytorch model used {d_pt} seconds")

ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
start = time.time()
ort_outs = ort_session.run(None, ort_inputs)
end = time.time()
d_onnx = end - start
print(f"Inference of ONNX model used {d_onnx} seconds")

print(f"Improvement = {d_pt/d_onnx:.2}x")

Inference of Pytorch model used 0.0158383846282959 seconds
Inference of ONNX model used 0.0056188106536865234 seconds
Improvement = 2.8x
