## Overview

In this notebook, we compare the GPyTorch fast posteriors with the decoupled samplers
(from https://arxiv.org/abs/2002.09309), using the BoTorch based implementation
available in this repo.

This is partly based on the GPyTorch notebook "GP Regression with LOVE for Fast
Predictive Variances and Sampling".

In [18]:
from math import ceil
from typing import Any, Union, Optional, List
import torch
import gpytorch
from botorch.posteriors import GPyTorchPosterior
from gpytorch.likelihoods import FixedNoiseGaussianLikelihood
from torch import Tensor
from gp_sampling import decoupled_sampler
from botorch.test_functions import Hartmann, Ackley
from botorch.models import SingleTaskGP
from botorch import fit_gpytorch_model
from gpytorch.settings import fast_pred_var
from matplotlib import pyplot as plt

# Make plots inline
%matplotlib inline

### Loading Data

Uses `num_train` points from 6D Hartmann function, randomly drawn from [0, 1]^6

Play around with `num_train` and `num_test` to see how the runtime changes.

In [19]:
num_train = 100
num_test = 5000
dim=6
train_x = torch.rand(num_train, dim).contiguous()
function = Hartmann(dim=dim, noise_std=0.1)
# Ackley was used for debugging.
# function = Ackley(dim=dim, noise_std=1)
train_y = function(train_x).unsqueeze(-1).contiguous()
# It appears that the implementation is not fully compatible with the outcome
# transforms. Thus, we will manually standardize train_y.
standardize_mean = train_y.mean()
standardize_std = train_y.std()
train_y = (train_y - standardize_mean) / standardize_std

test_x = torch.rand(num_test, dim).contiguous()

if torch.cuda.is_available() and False:  # disable cuda for now
    train_x, train_y, test_x = train_x.cuda(), train_y.cuda(), test_x.cuda()

## The GP Model

We now define the GP model. Here, we use a stripped version of the  `SingleTaskGP` model
from BoTorch. The `StrippedGP` model removes all the built in GPyTorch context managers
 from the `SingleTaskGP` relating to posterior computations.

In [20]:
class StrippedPosterior(GPyTorchPosterior):
    def rsample(
        self,
        sample_shape: Optional[torch.Size] = None,
        base_samples: Optional[Tensor] = None,
    ) -> Tensor:
        if sample_shape is None:
            sample_shape = torch.Size([1])
        if base_samples is not None:
            if base_samples.shape[: len(sample_shape)] != sample_shape:
                raise RuntimeError("sample_shape disagrees with shape of base_samples.")
            # get base_samples to the correct shape
            base_samples = base_samples.expand(sample_shape + self.event_shape)
            # remove output dimension in single output case
            if not self._is_mt:
                base_samples = base_samples.squeeze(-1)
        samples = self.mvn.rsample(
            sample_shape=sample_shape, base_samples=base_samples
        )
        # make sure there always is an output dimension
        if not self._is_mt:
            samples = samples.unsqueeze(-1)
        return samples

class StrippedGP(SingleTaskGP):
    def posterior(
        self,
        X: Tensor,
        output_indices: Optional[List[int]] = None,
        observation_noise: Union[bool, Tensor] = False,
        **kwargs: Any,
    ) -> StrippedPosterior:
        self.eval()  # make sure model is in eval mode
        # insert a dimension for the output dimension
        if self._num_outputs > 1:
            raise NotImplementedError
        mvn = self(X)
        if observation_noise is not False:
            if torch.is_tensor(observation_noise):
                # TODO: Validate noise shape
                # make observation_noise `batch_shape x q x n`
                obs_noise = observation_noise.transpose(-1, -2)
                mvn = self.likelihood(mvn, X, noise=obs_noise)
            elif isinstance(self.likelihood, FixedNoiseGaussianLikelihood):
                # Use the mean of the previous noise values (TODO: be smarter here).
                noise = self.likelihood.noise.mean().expand(X.shape[:-1])
                mvn = self.likelihood(mvn, X, noise=noise)
            else:
                mvn = self.likelihood(mvn, X)
        posterior = StrippedPosterior(mvn=mvn)
        if hasattr(self, "outcome_transform"):
            posterior = self.outcome_transform.untransform_posterior(posterior)
        return posterior


In [21]:
model = StrippedGP(train_x, train_y)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)
_ = fit_gpytorch_model(mll)  # ignore the output

## Computing predictive variances

### Using standard computations

The next cell gets the predictive mean and covariance for the test set
with no acceleration or precomputation.

In [22]:
import time

# Clear the cache from the previous computations
model.train()
model.likelihood.train()

# Set into eval mode
model.eval()
model.likelihood.eval()

with torch.no_grad():
    start_time = time.time()
    posterior = model.posterior(test_x)
    exact_mean = posterior.mean
    exact_covar = posterior.mvn.covariance_matrix
    exact_covar_time = time.time() - start_time
    
print(f"Time to compute exact mean + covariances: {exact_covar_time:.2f}s")

Time to compute exact mean + covariances: 0.34s


### Using fast predictive variances

Next we compute predictive mean and covariances using `fast_predictive_var` context
manager.

In [23]:
# Clear the cache from the previous computations
model.train()
model.likelihood.train()

# Set into eval mode
model.eval()
model.likelihood.eval()

with torch.no_grad(), fast_pred_var():
    start_time = time.time()
    posterior = model.posterior(test_x)
    _ = posterior.mean
    _ = posterior.mvn.covariance_matrix
    fast_time_no_cache = time.time() - start_time

The above cell additionally computed the caches required to get fast predictions. From this point onwards, unless we put the model back in training mode, predictions should be extremely fast. The cell below re-runs the above code, but takes full advantage of both the mean cache and the LOVE cache for variances.

In [24]:
with torch.no_grad(), fast_pred_var():
    start_time = time.time()
    posterior = model.posterior(test_x)
    fast_mean = posterior.mean
    fast_covar = posterior.mvn.covariance_matrix
    fast_time_with_cache = time.time() - start_time

print('Time to compute mean + covariances (no cache) {:.2f}s'.format(fast_time_no_cache))
print('Time to compute mean + variances (cache): {:.2f}s'.format(fast_time_with_cache))

Time to compute mean + covariances (no cache) 0.27s
Time to compute mean + variances (cache): 0.29s


### Compute Error between Exact and Fast Variances

Finally, we compute the mean absolute error between the fast variances computed by LOVE (stored in fast_covar), and the exact variances computed previously. 

In [25]:
smae = ((exact_covar - fast_covar).abs() / exact_covar.abs()).mean()
print(f"SMAE between exact covar matrix and fast covar matrix: {smae:.6f}")

SMAE between exact covar matrix and fast covar matrix: nan


## Computing posterior samples (KISS-GP only)

With KISS-GP models, LOVE can also be used to draw fast posterior samples. (The same does not apply to exact GP models.)

Since this part here does not use KISS-GP, we cannot utilize `gpytorch.settings
.fast_pred_samples()` context manager. We will get back to KISS-GP later.

### Drawing samples the standard way (without LOVE)

We now draw samples from the posterior distribution. Without LOVE, we accomlish this by performing Cholesky on the posterior covariance matrix. This can be slow for large covariance matrices.

In [26]:
import time
num_samples = 2000

# Clear the cache from the previous computations
model.train()
model.likelihood.train()

# Set into eval mode
model.eval()
model.likelihood.eval()

with torch.no_grad():
    start_time = time.time()
    exact_samples = model.posterior(test_x).rsample(torch.Size([num_samples]))
    exact_sample_time = time.time() - start_time
    
print(f"Time to compute exact samples: {exact_sample_time:.2f}s")

Time to compute exact samples: 1.07s


### Using `fast_pred_var()`

Since we do not use KISS-GP here, we cannot use `with gpytorch.settings
.fast_pred_samples():` to speed up posterior sampling.

We will still use `with gpytorch.settings.fast_pred_var():` flag to speed up the
computation of the covariance matrix.

We can also use the `gpytorch.settings.max_root_decomposition_size(10)` setting, which
 speeds up the computations for large matrices.

In [27]:
# Clear the cache from the previous computations
model.train()
model.likelihood.train()

# Set into eval mode
model.eval()
model.likelihood.eval()

# Repeat the timing now that the cache is computed
with torch.no_grad(), gpytorch.settings.fast_pred_var():
    start_time = time.time()
    _ = model.posterior(test_x).rsample(torch.Size([num_samples]))
    fast_sample_time_no_cache = time.time() - start_time

# Repeat the timing now that the cache is computed
with torch.no_grad(), gpytorch.settings.fast_pred_var():
    start_time = time.time()
    _ = model.posterior(test_x).rsample(torch.Size([num_samples]))
    fast_sample_time_cache = time.time() - start_time

# Clear the cache from the previous computations
model.train()
model.likelihood.train()

# Set into eval mode
model.eval()
model.likelihood.eval()

with torch.no_grad(), gpytorch.settings.fast_pred_var(),\
     gpytorch.settings.max_root_decomposition_size(10):
    start_time = time.time()
    _ = model.posterior(test_x).rsample(torch.Size([num_samples]))
    small_root_time_no_cache = time.time() - start_time
    
# Repeat the timing now that the cache is computed
with torch.no_grad(), gpytorch.settings.fast_pred_var(),\
     gpytorch.settings.max_root_decomposition_size(10):
    start_time = time.time()
    love_samples = model.posterior(test_x).rsample(torch.Size([num_samples]))
    small_root_time_cache = time.time() - start_time
    
print('Time to compute fast samples (no cache) {:.2f}s'.format(fast_sample_time_no_cache))
print('Time to compute fast samples (cache) {:.2f}s'.format(fast_sample_time_cache))
print('Time to compute fast samples (small root, no cache) {:.2f}s'.format
      (small_root_time_no_cache))
print('Time to compute fast samples (small root, cache) {:.2f}s'.format
      (small_root_time_cache))


Time to compute fast samples (no cache) 0.88s
Time to compute fast samples (cache) 0.85s
Time to compute fast samples (small root, no cache) 0.29s
Time to compute fast samples (small root, cache) 0.28s


### Draw samples using the decoupled samplers

Here, we use the decoupled samplers to sample from the posterior.

The loop over `num_basis` is to see how increasing the number of basis functions affects
 the sampler construction and sampling times.

Note that the time to construct the sampler does not depend on `num_samples`. So, we
could draw additional samples along the same "sample path" without having to
reconstruct the sampler. Drawing samples from the sampler does not require much more
than some `matmul` operations.

Constructing the sampler requires the Cholesky decomposition of the `num_train x
num_train` kernel matrix, which explains increase in initialization time as `num_train`
increases. Wilson also introduces a sparse GP implementation of decoupled samplers,
which would help with this issue. The sparse GP version is not implemented here (yet).

In [28]:
mini_batch_size = 2000
# sampling from too many points at once causes memory issues. We can draw them in mini
# batches instead.
num_batches = ceil(num_test / float(mini_batch_size))

for num_basis in [64, 256, 1024, 4096]:
    start_time = time.time()
    sampler = decoupled_sampler(model, sample_shape=[num_samples], num_basis=num_basis)
    print("Initializing decoupled sampler with %d basis functions took %.2f s" %
          (num_basis, time.time() - start_time))
    start_time = time.time()
    decoupled_samples = torch.empty(num_samples, num_test, 1)
    for i in range(num_batches):
        l_idx = i * mini_batch_size
        r_idx = (i+1) * mini_batch_size if i != num_batches - 1 else num_test
        decoupled_samples[:, l_idx: r_idx] = sampler(test_x[l_idx: r_idx])
    print("Drawing decoupled samples with %d basis functions took %.2f s" %
          (num_basis, time.time() - start_time))

Initializing decoupled sampler with 64 basis functions took 0.04 s
Drawing decoupled samples with 64 basis functions took 0.06 s
Initializing decoupled sampler with 256 basis functions took 0.04 s
Drawing decoupled samples with 256 basis functions took 0.07 s
Initializing decoupled sampler with 1024 basis functions took 0.07 s
Drawing decoupled samples with 1024 basis functions took 0.17 s
Initializing decoupled sampler with 4096 basis functions took 0.14 s
Drawing decoupled samples with 4096 basis functions took 0.71 s


### Compute the empirical covariance matrices

Let's see how well the samples recover the true mean and covariance
matrix.

In [29]:
# get rid of the final dimension
exact_samples = exact_samples.squeeze(-1)
love_samples = love_samples.squeeze(-1)
decoupled_samples = decoupled_samples.squeeze(-1)

# Compute exact posterior covar
with torch.no_grad():
    start_time = time.time()
    posterior = model.posterior(test_x)
    # this same as model.posterior() except that it returns the mvn directly
    mean, covar = posterior.mean, posterior.mvn.covariance_matrix
    variance = posterior.variance
    mean, variance = mean.squeeze(-1), variance.squeeze(-1)

exact_empirical_covar = ((exact_samples - mean).t() @ (exact_samples - mean)) / num_samples
love_empirical_covar = ((love_samples - mean).t() @ (love_samples - mean)) / num_samples
decoupled_empirical_covar = ((decoupled_samples - mean).t() @ (decoupled_samples -
                                                               mean)) / num_samples

exact_empirical_error = ((exact_empirical_covar - covar).abs()).mean()
love_empirical_error = ((love_empirical_covar - covar).abs()).mean()
decoupled_empirical_error = ((decoupled_empirical_covar - covar).abs()).mean()

print(f"Empirical covariance MAE (Exact samples): {exact_empirical_error}")
print(f"Empirical covariance MAE (LOVE samples): {love_empirical_error}")
print(f"Empirical covariance MAE (Decoupled samples): {decoupled_empirical_error}")

# compare the posterior mean as well
exact_empirical_mean = exact_samples.mean(dim=0)
love_empirical_mean = love_samples.mean(dim=0)
decoupled_empirical_mean = decoupled_samples.mean(dim=0)

exact_mean_error = (exact_empirical_mean - mean).abs().mean()
love_mean_error = (love_empirical_mean - mean).abs().mean()
decoupled_mean_error = (decoupled_empirical_mean - mean).abs().mean()

print(f"Empirical mean MAE (Exact samples): {exact_mean_error}")
print(f"Empirical mean MAE (LOVE samples): {love_mean_error}")
print(f"Empirical mean MAE (Decoupled samples): {decoupled_mean_error}")

Empirical covariance MAE (Exact samples): 0.007614322006702423
Empirical covariance MAE (LOVE samples): 0.01323997974395752
Empirical covariance MAE (Decoupled samples): 0.011483925394713879
Empirical mean MAE (Exact samples): 0.009403395466506481
Empirical mean MAE (LOVE samples): 0.003239282174035907
Empirical mean MAE (Decoupled samples): 0.01099607814103365


Since the posterior variance is relatively small, we will also look at the scaled mean
absolute error.

In [30]:
exact_smae = ((exact_empirical_covar - covar).abs() / covar.abs()).mean()
love_smae = ((love_empirical_error - covar).abs() / covar.abs()).mean()
decoupled_smae = ((decoupled_empirical_error - covar).abs() / covar.abs()).mean()

print(f"Empirical covariance SMAE (Exact samples): {exact_smae}")
print(f"Empirical covariance SMAE (LOVE samples): {love_smae}")
print(f"Empirical covariance SMAE (Decoupled samples): {decoupled_smae}")

Empirical covariance SMAE (Exact samples): inf
Empirical covariance SMAE (LOVE samples): inf
Empirical covariance SMAE (Decoupled samples): inf


Since the covariances are super small, this returns either inf or some other large
number. To put things into perspective, let's check the average posterior variance at test
 points.


In [31]:
print(f"Average posterior variance: {variance.mean()}")
print("Mean absolute deviation (Exact samples): "
      f"{(exact_samples.var(dim=0) - variance).abs().mean()}")
print("Mean absolute deviation (Love samples): "
      f"{(love_samples.var(dim=0) - variance).abs().mean()}")
print("Mean absolute deviation (Decoupled samples): "
      f"{(decoupled_samples.var(dim=0) - variance).abs().mean()}")

Average posterior variance: 0.46045684814453125
Mean absolute deviation (Exact samples): 0.226485013961792
Mean absolute deviation (Love samples): 0.4081191420555115
Mean absolute deviation (Decoupled samples): 0.014375735074281693


Interestingly, the decoupled samples approximate the posterior variance even better
than the exact samples. At the same time, it does a poor job of approximating the
posterior mean compared to the alternatives.

These values are empirical variance of the samples, computed using the empirical
mean of these samples. Let's see how things change if we were to use the
empirical variances computed w.r.t. the true posterior mean.

In [32]:
print(f"Average posterior variance: {variance.mean()}")
print("Mean absolute deviation (Exact samples): "
      f"{(exact_empirical_covar.diag() - variance).abs().mean()}")
print("Mean absolute deviation (Love samples): "
      f"{(love_empirical_covar.diag() - variance).abs().mean()}")
print("Mean absolute deviation (Decoupled samples): "
      f"{(decoupled_empirical_covar.diag() - variance).abs().mean()}")

Average posterior variance: 0.46045684814453125
Mean absolute deviation (Exact samples): 0.22644329071044922
Mean absolute deviation (Love samples): 0.40812256932258606
Mean absolute deviation (Decoupled samples): 0.014377166517078876


## The KISS-GP model

We can use the KISS-GP to achieve significant speed-ups over standard GP models. It is
 not an exact comparison with the decoupled samplers (since it is a different GP model)
 , but it is still nice to have it as a benchmark.

Since the KISS-GP doesn't work well in larger dimensions, we will instead use the SKIP GP.

In [33]:
if dim <= 4:
    # this is the KISS-GP kernel. Doesn't work with dim > 4
    grid_size = gpytorch.utils.grid.choose_grid_size(train_x, 1.0)
    covar_module = gpytorch.kernels.ScaleKernel(
        gpytorch.kernels.GridInterpolationKernel(
            gpytorch.kernels.RBFKernel(), grid_size=grid_size, num_dims=dim
        )
    )
else:
    # This is the SKIP GP kernel. It is suitable for large dims.
    covar_module = gpytorch.kernels.AdditiveStructureKernel(
        gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.GridInterpolationKernel(
                gpytorch.kernels.RBFKernel(), grid_size=128, num_dims=1
            )
        ), num_dims=dim
    )

model = StrippedGP(train_x, train_y, covar_module=covar_module)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)
_ = fit_gpytorch_model(mll)  # ignore the output


As before, let's compare the time to compute the predictive mean and covariances.

In [34]:
# Clear the cache from the previous computations
model.train()
model.likelihood.train()

# Set into eval mode
model.eval()
model.likelihood.eval()

with torch.no_grad():
    start_time = time.time()
    posterior = model.posterior(test_x)
    exact_mean = posterior.mean
    exact_covar = posterior.mvn.covariance_matrix
    exact_covar_time = time.time() - start_time

print(f"Time to compute exact mean + covariances: {exact_covar_time:.2f}s")

# Clear the cache from the previous computations
model.train()
model.likelihood.train()

# Set into eval mode
model.eval()
model.likelihood.eval()

with torch.no_grad(), fast_pred_var():
    start_time = time.time()
    posterior = model.posterior(test_x)
    _ = posterior.mean
    _ = posterior.mvn.covariance_matrix
    fast_time_no_cache = time.time() - start_time

with torch.no_grad(), fast_pred_var():
    start_time = time.time()
    posterior = model.posterior(test_x)
    fast_mean = posterior.mean
    fast_covar = posterior.mvn.covariance_matrix
    fast_time_with_cache = time.time() - start_time

# Clear the cache from the previous computations
model.train()
model.likelihood.train()

# Set into eval mode
model.eval()
model.likelihood.eval()

with torch.no_grad(), fast_pred_var(), gpytorch.settings.max_root_decomposition_size(10):
    start_time = time.time()
    posterior = model.posterior(test_x)
    _ = posterior.mean
    _ = posterior.mvn.covariance_matrix
    small_root_no_cache = time.time() - start_time

with torch.no_grad(), fast_pred_var(), gpytorch.settings.max_root_decomposition_size(10):
    start_time = time.time()
    posterior = model.posterior(test_x)
    small_root_mean = posterior.mean
    small_root_covar = posterior.mvn.covariance_matrix
    small_root_with_cache = time.time() - start_time


print('Time to compute mean + covariances (no cache) {:.2f}s'.format(fast_time_no_cache))
print('Time to compute mean + variances (cache): {:.2f}s'.format(fast_time_with_cache))
print('Time to compute mean + covariances (small root, no cache) '
      '{:.2f}s'.format(small_root_no_cache))
print('Time to compute mean + variances (small root, cache): '
      '{:.2f}s'.format(small_root_with_cache))

smae = ((exact_covar - fast_covar).abs() / exact_covar.abs()).mean()
print(f"SMAE between exact covar matrix and fast covar matrix: {smae:.6f}")

smae = ((exact_covar - small_root_covar).abs() / exact_covar.abs()).mean()
print(f"SMAE between exact covar matrix and small root, fast covar matrix: {smae:.6f}")


Time to compute exact mean + covariances: 1.71s
Time to compute mean + covariances (no cache) 1.64s
Time to compute mean + variances (cache): 1.64s
Time to compute mean + covariances (small root, no cache) 1.70s
Time to compute mean + variances (small root, cache): 1.73s
SMAE between exact covar matrix and fast covar matrix: inf
SMAE between exact covar matrix and small root, fast covar matrix: inf
