The decoupled samplers have unexpected behavior with this implementation. The goal here
 is to compare it in detail with the original implementation, and debug the problem.

The bug has been identified and fixed.

In [1]:
import torch
from botorch.models import SingleTaskGP
%matplotlib
import matplotlib.pyplot as plt
from gp_sampling import decoupled_sampler


Using matplotlib backend: TkAgg


In [2]:
# generate the train data
train_X = torch.arange(0, 1, 0.1).reshape(-1, 1)
train_Y = train_X * torch.sin(20*train_X)

# Initialize the GP model
model = SingleTaskGP(train_X, train_Y)

# Set GP hyper-parameters
# model.covar_module.outputscale = 0.1  # This is a different scale
model.covar_module.base_kernel.lengthscale = 0.1
model.likelihood.noise = 5e-3

# test points
test_X = torch.arange(0, 1, 0.01).reshape(-1, 1)

# clear previous plots
plt.close()

# posterior summary
with torch.no_grad():
    post_mean = model.posterior(test_X).mean.reshape(-1)
    post_std = torch.sqrt(model.posterior(test_X).variance).reshape(-1)
plt.plot(test_X.reshape(-1), post_mean, label="post_mean", color="green")
plt.fill_between(
    test_X.reshape(-1),
    post_mean - 2 * post_std,
    post_mean + 2 * post_std,
    alpha=0.5
)

# plot several posterior samples
num_samples = 50
sample_shape = torch.Size([num_samples])
with torch.no_grad():
    exact_samples = model.posterior(test_X).rsample(sample_shape=sample_shape).squeeze(-1)
plt.plot(
    test_X.reshape(-1),
    exact_samples.t(),
    label="exact",
    color="blue"
)


[<matplotlib.lines.Line2D at 0x7f2300570b80>,
 <matplotlib.lines.Line2D at 0x7f2300570bb0>,
 <matplotlib.lines.Line2D at 0x7f2300570af0>,
 <matplotlib.lines.Line2D at 0x7f23005702e0>,
 <matplotlib.lines.Line2D at 0x7f2300570dc0>,
 <matplotlib.lines.Line2D at 0x7f2300570eb0>,
 <matplotlib.lines.Line2D at 0x7f2300570f10>,
 <matplotlib.lines.Line2D at 0x7f2300570fd0>,
 <matplotlib.lines.Line2D at 0x7f23005890d0>,
 <matplotlib.lines.Line2D at 0x7f2300589190>,
 <matplotlib.lines.Line2D at 0x7f2300589250>,
 <matplotlib.lines.Line2D at 0x7f2300589310>,
 <matplotlib.lines.Line2D at 0x7f23005893d0>,
 <matplotlib.lines.Line2D at 0x7f2300589490>,
 <matplotlib.lines.Line2D at 0x7f2300589550>,
 <matplotlib.lines.Line2D at 0x7f2300589610>,
 <matplotlib.lines.Line2D at 0x7f23005896d0>,
 <matplotlib.lines.Line2D at 0x7f2300589790>,
 <matplotlib.lines.Line2D at 0x7f2300589850>,
 <matplotlib.lines.Line2D at 0x7f2300589910>,
 <matplotlib.lines.Line2D at 0x7f23005899d0>,
 <matplotlib.lines.Line2D at 0x7f2

In [3]:
# initialize the decoupled sampler
ds = decoupled_sampler(
    model=model,
    sample_shape=sample_shape,
    num_basis=256
)

# sample from decoupled
with torch.no_grad():
    ds_samples = ds(test_X)

# plot the samples
plt.plot(
    test_X.reshape(-1),
    ds_samples.squeeze(-1).t(),
    label="decoupled",
    color="red"
)


[<matplotlib.lines.Line2D at 0x7f2304023820>,
 <matplotlib.lines.Line2D at 0x7f2304023880>,
 <matplotlib.lines.Line2D at 0x7f2304023cd0>,
 <matplotlib.lines.Line2D at 0x7f2304023d90>,
 <matplotlib.lines.Line2D at 0x7f2304023b80>,
 <matplotlib.lines.Line2D at 0x7f2304023be0>,
 <matplotlib.lines.Line2D at 0x7f2304023af0>,
 <matplotlib.lines.Line2D at 0x7f2304023040>,
 <matplotlib.lines.Line2D at 0x7f23004f2670>,
 <matplotlib.lines.Line2D at 0x7f23004f2850>,
 <matplotlib.lines.Line2D at 0x7f23004f27c0>,
 <matplotlib.lines.Line2D at 0x7f23004f2940>,
 <matplotlib.lines.Line2D at 0x7f23004f27f0>,
 <matplotlib.lines.Line2D at 0x7f23004f29a0>,
 <matplotlib.lines.Line2D at 0x7f23004f2a60>,
 <matplotlib.lines.Line2D at 0x7f23004f2b20>,
 <matplotlib.lines.Line2D at 0x7f23004f2be0>,
 <matplotlib.lines.Line2D at 0x7f23004f2ca0>,
 <matplotlib.lines.Line2D at 0x7f23004f2d60>,
 <matplotlib.lines.Line2D at 0x7f23004f2eb0>,
 <matplotlib.lines.Line2D at 0x7f23004f2fd0>,
 <matplotlib.lines.Line2D at 0x7f2

In [4]:
#plot
plt.grid(True)
plt.legend()
plt.title("torch")
plt.show()