The decoupled samplers have unexpected behavior with this implementation. The goal here
 is to compare it in detail with the original implementation, and debug the problem.

The bug has been identified and fixed.

In [1]:
import torch
from botorch.models import SingleTaskGP
%matplotlib
import matplotlib.pyplot as plt
from gp_sampling import decoupled_sampler
import pandas as pd
import numpy as np


Using matplotlib backend: TkAgg


In [2]:
# generate the train data
dim = 1
train_X = torch.arange(0, 1, 0.1).reshape(-1, 1)
train_Y = train_X * torch.sin(20*train_X)

# Initialize the GP model
model = SingleTaskGP(train_X, train_Y)

# Set GP hyper-parameters
# model.covar_module.outputscale = 0.1  # This is a different scale
model.covar_module.base_kernel.lengthscale = 0.1
model.likelihood.noise = 5e-3

# test points
test_X = torch.arange(0, 1, 0.01).reshape(-1, 1)

# clear previous plots
plt.close()

# posterior summary
with torch.no_grad():
    post_mean = model.posterior(test_X).mean.reshape(-1)
    post_std = torch.sqrt(model.posterior(test_X).variance).reshape(-1)
plt.plot(test_X.reshape(-1), post_mean, label="post_mean", color="green")
plt.fill_between(
    test_X.reshape(-1),
    post_mean - 2 * post_std,
    post_mean + 2 * post_std,
    alpha=0.5
)


<matplotlib.collections.PolyCollection at 0x7fb3dc2ada60>

These two code blocks add samples from exact posterior and the decoupled sampler
respectively.

In [3]:
# plot several posterior samples
num_samples = 50
sample_shape = torch.Size([num_samples])
with torch.no_grad():
    exact_samples = model.posterior(test_X).rsample(sample_shape=sample_shape).squeeze(-1)
plt.plot(
    test_X.reshape(-1),
    exact_samples.t(),
    label="exact",
    color="blue"
)


[<matplotlib.lines.Line2D at 0x7fb3dfd4e430>,
 <matplotlib.lines.Line2D at 0x7fb3dfd4eaf0>,
 <matplotlib.lines.Line2D at 0x7fb3dfd4ebe0>,
 <matplotlib.lines.Line2D at 0x7fb4a0492550>,
 <matplotlib.lines.Line2D at 0x7fb4a048c490>,
 <matplotlib.lines.Line2D at 0x7fb4a048c0d0>,
 <matplotlib.lines.Line2D at 0x7fb4a048cd30>,
 <matplotlib.lines.Line2D at 0x7fb4a048c8e0>,
 <matplotlib.lines.Line2D at 0x7fb4a048cac0>,
 <matplotlib.lines.Line2D at 0x7fb3dfd69280>,
 <matplotlib.lines.Line2D at 0x7fb3dfd69220>,
 <matplotlib.lines.Line2D at 0x7fb3dfd690d0>,
 <matplotlib.lines.Line2D at 0x7fb3dfd69100>,
 <matplotlib.lines.Line2D at 0x7fb3dc2add90>,
 <matplotlib.lines.Line2D at 0x7fb3dc2addf0>,
 <matplotlib.lines.Line2D at 0x7fb3dc2ade80>,
 <matplotlib.lines.Line2D at 0x7fb3dc2adf40>,
 <matplotlib.lines.Line2D at 0x7fb3dc2ad970>,
 <matplotlib.lines.Line2D at 0x7fb4a04864f0>,
 <matplotlib.lines.Line2D at 0x7fb4a0486370>,
 <matplotlib.lines.Line2D at 0x7fb4a04864c0>,
 <matplotlib.lines.Line2D at 0x7fb

In [4]:
# initialize the decoupled sampler
ds = decoupled_sampler(
    model=model,
    sample_shape=sample_shape,
    num_basis=256
)

# sample from decoupled
with torch.no_grad():
    ds_samples = ds(test_X)

# plot the samples
plt.plot(
    test_X.reshape(-1),
    ds_samples.squeeze(-1).t(),
    label="decoupled",
    color="red"
)


[<matplotlib.lines.Line2D at 0x7fb3dc238dc0>,
 <matplotlib.lines.Line2D at 0x7fb3dc238be0>,
 <matplotlib.lines.Line2D at 0x7fb3dc238ee0>,
 <matplotlib.lines.Line2D at 0x7fb3dc238fa0>,
 <matplotlib.lines.Line2D at 0x7fb3dc1d0040>,
 <matplotlib.lines.Line2D at 0x7fb3dc1d0100>,
 <matplotlib.lines.Line2D at 0x7fb3dc1d01c0>,
 <matplotlib.lines.Line2D at 0x7fb3dc1d0280>,
 <matplotlib.lines.Line2D at 0x7fb3dc1d0340>,
 <matplotlib.lines.Line2D at 0x7fb3dc1d0400>,
 <matplotlib.lines.Line2D at 0x7fb3dc1d04c0>,
 <matplotlib.lines.Line2D at 0x7fb3dc1d0580>,
 <matplotlib.lines.Line2D at 0x7fb3dc1d0640>,
 <matplotlib.lines.Line2D at 0x7fb3dc1d0700>,
 <matplotlib.lines.Line2D at 0x7fb3dc1d07c0>,
 <matplotlib.lines.Line2D at 0x7fb3dc1d0880>,
 <matplotlib.lines.Line2D at 0x7fb3dc1d0940>,
 <matplotlib.lines.Line2D at 0x7fb3dc1d0a00>,
 <matplotlib.lines.Line2D at 0x7fb3dc1d0ac0>,
 <matplotlib.lines.Line2D at 0x7fb3dc1d0b80>,
 <matplotlib.lines.Line2D at 0x7fb3dc1d0c40>,
 <matplotlib.lines.Line2D at 0x7fb

In [5]:
#plot
plt.grid(True)
plt.legend()
plt.title("torch")
plt.show()

We will now extend this and compare the MAE between the empirical mean and variance
predicted by the decoupler sampler along with the exact posterior samples and the true
posterior statistics.

In [6]:
# generate the data
sample_count = 1000
sample_shape = torch.Size([sample_count])
ds = decoupled_sampler(
    model=model,
    sample_shape=sample_shape,
    num_basis=1024,
)

with torch.no_grad():
    if dim == 1:
        test_X = torch.arange(0, 1, 0.01).reshape(-1, 1)
    else:
        test_X = torch.rand(100, dim)
    ds_samples = ds(test_X).reshape(sample_count, 100).detach()
    exact_samples = model.posterior(test_X).rsample(
        sample_shape=sample_shape).reshape(sample_count, 100).detach()
    post_var = model.posterior(test_X).variance
    post_mean = model.posterior(test_X).mean

df = pd.DataFrame()
df["ds_var"] = torch.var(ds_samples, dim=0).numpy()
df["exact_var"] = torch.var(exact_samples, dim=0).numpy()
df["true_var"] = post_var.numpy()
df["ds_mean"] = torch.mean(ds_samples, dim=0)
df["exact_mean"] = torch.mean(exact_samples, dim=0)
df["true_mean"] = post_mean.numpy()

In [7]:
# plot
plt.close("all")
df.plot()
plt.grid(True)
plt.show()

In [8]:
# report MAE
mae_ds_var = np.abs(df["ds_var"] - df["true_var"]).mean()
mae_exact_var = np.abs(df["exact_var"] - df["true_var"]).mean()
mae_ds_mean = np.abs(df["ds_mean"] - df["true_mean"]).mean()
mae_exact_mean = np.abs(df["exact_mean"] - df["true_mean"]).mean()

print(f"MAE Variance: ds {mae_ds_var}, exact {mae_exact_var}")
print(f"MAE Mean: ds {mae_ds_mean}, exact {mae_exact_mean}")

MAE Variance: ds 0.004140158649533987, exact 0.0015273370081558824
MAE Mean: ds 0.002719623502343893, exact 0.005040072835981846
