This notebook demonstrates the randomness in the embedding fit when using
`fit_gpytorch_model` to train the model. The randomness is due to random
initialization of the embedding weights.

This variance in the embedding results in significant variance in the posterior
mean and covariance.

Below, we fit 5 instances of LCEMGP on the same training data to demonstrate this.

In [1]:
import torch
from botorch import fit_gpytorch_model
from botorch.models.contextual_multioutput import LCEMGP
from gpytorch import ExactMarginalLogLikelihood
from torch import Tensor

def test_function(X: Tensor) -> Tensor:
    sine = torch.sin(X)
    linear = X * 0.05
    noise = torch.randn_like(X) * 0.25
    return (sine + linear + noise).sum(dim=-1, keepdim=True)


num_alternatives = 5
num_train = 5
train_X_cat = torch.tensor(
    range(num_alternatives), dtype=torch.float
).repeat(num_train).view(-1, 1)
train_X = torch.cat(
    [train_X_cat, torch.rand_like(train_X_cat)], dim=-1
)
train_Y = test_function(train_X)

num_models = 5
emb_dim = 2
pre_train_embs = torch.zeros(num_models, num_alternatives, emb_dim)
post_train_embs = torch.zeros(num_models, num_alternatives, emb_dim)
post_train_mean = torch.zeros(num_models, num_alternatives)
post_train_covar = torch.zeros(num_models, num_alternatives, num_alternatives)
for i in range(num_models):
    model = LCEMGP(
        train_X,
        train_Y,
        task_feature=0,
        embs_dim_list=[emb_dim],
    )
    pre_train_embs[i] = model.emb_layers[0].weight.detach()
    mll = ExactMarginalLogLikelihood(model.likelihood, model)
    fit_gpytorch_model(mll)
    post_train_embs[i] = model.emb_layers[0].weight.detach()
    posterior_mvn = model.posterior(torch.ones(1, 1) * 0.5).mvn
    post_train_covar[i] = posterior_mvn.covariance_matrix.detach()
    post_train_mean[i] = posterior_mvn.mean.detach()

The random initialization of the embedding (pre-train)

In [2]:
pre_train_embs

tensor([[[-1.9477, -0.9121],
         [-0.8366, -1.3235],
         [ 1.3058,  1.9682],
         [ 1.1083,  0.3799],
         [ 0.0707,  1.8554]],

        [[-1.0116, -0.7351],
         [ 0.8207, -0.4530],
         [-1.0873,  0.3601],
         [ 0.3244,  1.5122],
         [-1.0946, -1.1705]],

        [[ 0.0567,  0.1575],
         [-0.3365,  0.0523],
         [ 0.3112, -0.5188],
         [ 0.7687,  1.1599],
         [-0.2045,  0.8756]],

        [[ 1.6935,  0.0215],
         [ 0.9371,  0.1164],
         [-1.4184, -1.5826],
         [ 1.0568, -0.1078],
         [-0.2043, -2.5470]],

        [[ 1.5344,  0.5198],
         [-1.5981,  2.2358],
         [-0.4685, -0.0651],
         [-1.1122, -0.9644],
         [ 0.5034,  1.9493]]])

The fitted embedding

In [3]:
post_train_embs

tensor([[[-0.5859, -0.1187],
         [-0.6807, -1.7474],
         [ 1.6677,  1.5944],
         [ 1.9523,  0.0944],
         [-2.6530,  2.1454]],

        [[-0.1066,  0.1133],
         [ 0.8881,  1.0305],
         [ 0.9482,  1.0858],
         [ 0.1192,  0.3120],
         [-3.8973, -3.0280]],

        [[ 0.1004,  0.2506],
         [ 0.4040, -0.4959],
         [ 0.4061, -0.5003],
         [ 0.1939,  0.0218],
         [-0.5089,  2.4504]],

        [[ 0.9000, -0.4527],
         [ 0.8141,  0.2153],
         [-1.8824, -0.9803],
         [ 0.8719, -0.2278],
         [ 1.3612, -2.6539]],

        [[ 1.3934,  0.5379],
         [-1.5487,  0.4246],
         [-1.1437,  0.3254],
         [-0.4918,  0.5832],
         [ 0.6499,  1.8043]]])

Posterior covariance (evaluated at X=0.5)

In [4]:
post_train_covar

tensor([[[ 6.3067e-02,  1.3132e-02,  1.0758e-03, -1.5726e-04,  2.0660e-02],
         [ 1.3132e-02,  3.6272e-02,  5.6101e-04,  1.9624e-04,  9.3919e-04],
         [ 1.0758e-03,  5.6101e-04,  6.3291e-02,  1.9678e-02,  2.2500e-04],
         [-1.5726e-04,  1.9624e-04,  1.9678e-02,  5.8053e-02, -3.3465e-06],
         [ 2.0660e-02,  9.3922e-04,  2.2499e-04, -3.3460e-06,  4.7394e-02]],

        [[ 6.1665e-02,  4.4771e-03,  4.5211e-03,  4.5334e-02,  1.4315e-02],
         [ 4.4771e-03,  2.4659e-02,  2.4630e-02,  9.7251e-03, -7.5728e-05],
         [ 4.5210e-03,  2.4631e-02,  2.4651e-02,  9.7865e-03, -7.2449e-05],
         [ 4.5335e-02,  9.7250e-03,  9.7866e-03,  4.5851e-02,  5.1273e-03],
         [ 1.4315e-02, -7.5802e-05, -7.2479e-05,  5.1274e-03,  5.2308e-02]],

        [[ 6.0611e-02,  4.4552e-03,  4.2512e-03,  4.5031e-02,  1.5132e-02],
         [ 4.4552e-03,  2.4687e-02,  2.4786e-02,  9.8573e-03, -6.7830e-05],
         [ 4.2512e-03,  2.4786e-02,  2.4898e-02,  9.4204e-03, -4.2126e-05],
        

Posterior mean (evaluated at X=0.5)

In [5]:
post_train_mean

tensor([[ 0.5603,  1.4518,  1.4640,  0.7985, -0.1114],
        [ 0.3183,  1.4759,  1.4752,  0.6964, -0.1422],
        [ 0.3207,  1.4744,  1.4787,  0.7013, -0.1402],
        [ 0.3730,  1.4577,  1.6185,  0.7435, -0.1573],
        [ 0.4247,  1.4884,  1.4833,  0.7414, -0.1271]])