# Saving and Loading Models

In this bite-sized notebook, we'll go over how to save and load models. In general, the process is the same as for any PyTorch module.

In [1]:
import math
import torch
import qpytorch
from matplotlib import pyplot as plt


/bin/sh: brew: command not found



## Saving a Simple Model

First, we define a QEP Model that we'd like to save. The model used below is the same as the model from our
<a href="../01_Exact_QEPs/Simple_QEP_Regression.ipynb">Simple QEP Regression</a> tutorial.

In [2]:
train_x = torch.linspace(0, 1, 100)
train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2

In [4]:
# We will use the simplest form of QEP model, exact inference
POWER = 1.0
class ExactQEPModel(qpytorch.models.ExactQEP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactQEPModel, self).__init__(train_x, train_y, likelihood)
        self.power = torch.tensor(POWER)
        self.mean_module = qpytorch.means.ConstantMean()
        self.covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernel())
    
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return qpytorch.distributions.MultivariateQExponential(mean_x, covar_x, power=self.power)

# initialize likelihood and model
likelihood = qpytorch.likelihoods.QExponentialLikelihood(power=torch.tensor(POWER))
model = ExactQEPModel(train_x, train_y, likelihood)

### Change Model State

To demonstrate model saving, we change the hyperparameters from the default values below. For more information on what is happening here, see our tutorial notebook on <a href="Hyperparameters.ipynb">Initializing Hyperparameters</a>.

In [5]:
model.covar_module.outputscale = 1.2
model.covar_module.base_kernel.lengthscale = 2.2

### Getting Model State

To get the full state of a GPyTorch model, simply call `state_dict` as you would on any PyTorch model. Note that the state dict contains **raw** parameter values. This is because these are the actual `torch.nn.Parameters` that are learned in GPyTorch. Again see our notebook on hyperparamters for more information on this.

In [6]:
model.state_dict()

OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),
             ('likelihood.noise_covar.raw_noise_constraint.lower_bound',
              tensor(1.0000e-04)),
             ('likelihood.noise_covar.raw_noise_constraint.upper_bound',
              tensor(inf)),
             ('mean_module.raw_constant', tensor(0.)),
             ('covar_module.raw_outputscale', tensor(0.8416)),
             ('covar_module.base_kernel.raw_lengthscale', tensor([[2.0826]])),
             ('covar_module.base_kernel.raw_lengthscale_constraint.lower_bound',
              tensor(0.)),
             ('covar_module.base_kernel.raw_lengthscale_constraint.upper_bound',
              tensor(inf)),
             ('covar_module.raw_outputscale_constraint.lower_bound',
              tensor(0.)),
             ('covar_module.raw_outputscale_constraint.upper_bound',
              tensor(inf))])

### Saving Model State

The state dictionary above represents all traininable parameters for the model. Therefore, we can save this to a file as follows:

In [7]:
torch.save(model.state_dict(), 'model_state.pth')

### Loading Model State

Next, we load this state in to a new model and demonstrate that the parameters were updated correctly.

In [8]:
state_dict = torch.load('model_state.pth')
model = ExactQEPModel(train_x, train_y, likelihood)  # Create a new QEP model

model.load_state_dict(state_dict)

<All keys matched successfully>

In [9]:
model.state_dict()

OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),
             ('likelihood.noise_covar.raw_noise_constraint.lower_bound',
              tensor(1.0000e-04)),
             ('likelihood.noise_covar.raw_noise_constraint.upper_bound',
              tensor(inf)),
             ('mean_module.raw_constant', tensor(0.)),
             ('covar_module.raw_outputscale', tensor(0.8416)),
             ('covar_module.base_kernel.raw_lengthscale', tensor([[2.0826]])),
             ('covar_module.base_kernel.raw_lengthscale_constraint.lower_bound',
              tensor(0.)),
             ('covar_module.base_kernel.raw_lengthscale_constraint.upper_bound',
              tensor(inf)),
             ('covar_module.raw_outputscale_constraint.lower_bound',
              tensor(0.)),
             ('covar_module.raw_outputscale_constraint.upper_bound',
              tensor(inf))])

## A More Complex Example

Next we demonstrate this same principle on a more complex exact QEP where we have a simple feed forward neural network feature extractor as part of the model.


In [11]:
class QEPWithNNFeatureExtractor(qpytorch.models.ExactQEP):
    def __init__(self, train_x, train_y, likelihood):
        super(QEPWithNNFeatureExtractor, self).__init__(train_x, train_y, likelihood)
        self.power = torch.tensor(POWER)
        self.mean_module = qpytorch.means.ConstantMean()
        self.covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernel())
        
        self.feature_extractor = torch.nn.Sequential(
            torch.nn.Linear(1, 2),
            torch.nn.BatchNorm1d(2),
            torch.nn.ReLU(),
            torch.nn.Linear(2, 2),
            torch.nn.BatchNorm1d(2),
            torch.nn.ReLU(),
        )
    
    def forward(self, x):
        x = self.feature_extractor(x)
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return qpytorch.distributions.MultivariateQExponential(mean_x, covar_x, power=self.power)

# initialize likelihood and model
likelihood = qpytorch.likelihoods.QExponentialLikelihood(power=torch.tensor(POWER))
model = QEPWithNNFeatureExtractor(train_x, train_y, likelihood)

### Getting Model State

In the next cell, we once again print the model state via `model.state_dict()`. As you can see, the state is substantially more complex, as the model now includes our neural network parameters. Nevertheless, saving and loading is straight forward.

In [12]:
model.state_dict()

OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),
             ('likelihood.noise_covar.raw_noise_constraint.lower_bound',
              tensor(1.0000e-04)),
             ('likelihood.noise_covar.raw_noise_constraint.upper_bound',
              tensor(inf)),
             ('mean_module.raw_constant', tensor(0.)),
             ('covar_module.raw_outputscale', tensor(0.)),
             ('covar_module.base_kernel.raw_lengthscale', tensor([[0.]])),
             ('covar_module.base_kernel.raw_lengthscale_constraint.lower_bound',
              tensor(0.)),
             ('covar_module.base_kernel.raw_lengthscale_constraint.upper_bound',
              tensor(inf)),
             ('covar_module.raw_outputscale_constraint.lower_bound',
              tensor(0.)),
             ('covar_module.raw_outputscale_constraint.upper_bound',
              tensor(inf)),
             ('feature_extractor.0.weight',
              tensor([[-0.1177],
                      [ 0.6034]])),
             (

In [13]:
torch.save(model.state_dict(), 'my_qep_with_nn_model.pth')
state_dict = torch.load('my_qep_with_nn_model.pth')
model = QEPWithNNFeatureExtractor(train_x, train_y, likelihood)
model.load_state_dict(state_dict)

<All keys matched successfully>

In [14]:
model.state_dict()

OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),
             ('likelihood.noise_covar.raw_noise_constraint.lower_bound',
              tensor(1.0000e-04)),
             ('likelihood.noise_covar.raw_noise_constraint.upper_bound',
              tensor(inf)),
             ('mean_module.raw_constant', tensor(0.)),
             ('covar_module.raw_outputscale', tensor(0.)),
             ('covar_module.base_kernel.raw_lengthscale', tensor([[0.]])),
             ('covar_module.base_kernel.raw_lengthscale_constraint.lower_bound',
              tensor(0.)),
             ('covar_module.base_kernel.raw_lengthscale_constraint.upper_bound',
              tensor(inf)),
             ('covar_module.raw_outputscale_constraint.lower_bound',
              tensor(0.)),
             ('covar_module.raw_outputscale_constraint.upper_bound',
              tensor(inf)),
             ('feature_extractor.0.weight',
              tensor([[-0.1177],
                      [ 0.6034]])),
             (