LCEGP shows strange behavior with the PCS. In theory, adding more training samples
should make the model more accurate, and lead to increased PCS. Although this does not
translate to the Bayesian PCS we use directly, we would still expect to see similar
behavior. This does not always happen with LCEGP, which makes me suspect that it may
have to do with the model fitting.

The goal of this notebook is to analyze the stability of fitting LCEGP using
`fit_gpytorch_model`, and if found to not be stable, come up with better ways of
fitting these models.

One important question is to understand the effects of the initial embedding on the
resulting fitted embedding. I suspect the MLL to be very non-convext in the embedding
and this being the possible reason why we observe these inconsistencies.

### Implemented `custom_fit_gpytorch_model` to get around this issue.

In [1]:
import torch
from botorch import fit_gpytorch_model
from gpytorch import ExactMarginalLogLikelihood
from torch import Tensor

from contextual_rs.lce_gp import LCEGP


Let's train several models on same data and see how the results compare.
Starting with the purely categorical setting.

In [2]:
def test_function(X: Tensor) -> Tensor:
    sine = torch.sin(X)
    linear = X * 0.05
    noise = torch.randn_like(X) * 0.25
    return sine + linear + noise


num_alternatives = 5
num_train = 10
train_X = torch.tensor(
    range(num_alternatives), dtype=torch.float
).repeat(num_train).view(-1, 1)
train_Y = test_function(train_X)

all_alternatives = train_X[:num_alternatives].clone()

num_models = 5
emb_dim = 2
pre_train_embs = torch.zeros(num_models, num_alternatives, emb_dim)
post_train_embs = torch.zeros(num_models, num_alternatives, emb_dim)
post_train_mean = torch.zeros(num_models, num_alternatives)
post_train_covar = torch.zeros(num_models, num_alternatives, num_alternatives)
mll_vals = torch.zeros(num_models)
for i in range(num_models):
    model = LCEGP(
        train_X,
        train_Y,
        categorical_cols=[0],
        embs_dim_list=[emb_dim],
    )
    pre_train_embs[i] = model.emb_layers[0].weight.detach().clone()
    mll = ExactMarginalLogLikelihood(model.likelihood, model)
    fit_gpytorch_model(mll)
    post_train_embs[i] = model.emb_layers[0].weight.detach()
    post_train_covar[i] = model.posterior(all_alternatives).mvn.covariance_matrix.detach()
    post_train_mean[i] = model.posterior(all_alternatives).mvn.mean.detach()
    mll_vals[i] = mll(model(train_X), train_Y.squeeze())



In [3]:
pre_train_embs

tensor([[[-0.3565,  1.5856],
         [-0.9723,  0.9848],
         [-0.3298, -1.0075],
         [ 0.1390, -0.3247],
         [-0.3806,  0.2045]],

        [[-0.2174, -1.1883],
         [-2.0870, -0.7230],
         [ 0.7595,  1.5391],
         [ 0.5531, -0.5249],
         [ 0.2064, -1.0177]],

        [[-0.8223,  0.1493],
         [-0.1904, -0.5873],
         [ 0.8683,  0.3399],
         [ 0.9867, -0.5936],
         [ 0.3958,  1.2221]],

        [[-0.7546,  0.3082],
         [-0.0671, -1.1361],
         [-0.0031,  0.5324],
         [-1.5114,  0.1558],
         [ 1.1944, -1.4808]],

        [[ 0.7357, -0.9264],
         [-0.1551, -0.4375],
         [-0.3579,  2.2555],
         [ 0.6632,  1.8686],
         [ 0.7138, -0.3822]]])

In [4]:
post_train_embs

tensor([[[-0.0660,  0.1782],
         [-0.9904,  0.4270],
         [-1.0597,  0.5120],
         [-0.2218,  0.0702],
         [ 0.4376,  0.2553]],

        [[-0.0799, -0.2900],
         [-0.8112,  0.5199],
         [-1.1827,  0.7715],
         [-0.1585, -0.2056],
         [ 1.4468, -2.7107]],

        [[-0.1267,  0.6344],
         [ 0.2378,  0.1385],
         [ 0.2541,  0.1161],
         [ 1.3652, -1.6436],
         [-0.4923,  1.2851]],

        [[-1.1714,  0.2884],
         [-0.5324, -0.2406],
         [-0.5187, -0.2507],
         [-0.9312,  0.2107],
         [ 2.0118, -1.6284]],

        [[ 0.4180,  0.3470],
         [-0.3673,  0.8043],
         [-0.3925,  0.8191],
         [ 0.3346,  0.3955],
         [ 1.6068,  0.0122]]])

In [5]:
post_train_covar

tensor([[[ 3.6470e-03, -5.1081e-05,  3.1012e-04,  3.0254e-03,  1.4051e-03],
         [-5.1081e-05,  4.0724e-03,  3.8931e-03,  4.7308e-04, -2.3741e-04],
         [ 3.1012e-04,  3.8931e-03,  4.1052e-03,  2.3544e-05, -1.4371e-04],
         [ 3.0254e-03,  4.7308e-04,  2.3663e-05,  4.9225e-03, -2.0510e-04],
         [ 1.4051e-03, -2.3735e-04, -1.4371e-04, -2.0504e-04,  7.2671e-03]],

        [[ 3.5106e-03,  1.8239e-04, -2.5272e-05,  3.5221e-03,  7.6634e-04],
         [ 1.8245e-04,  3.6743e-03,  3.8066e-03,  4.3255e-04, -2.5892e-04],
         [-2.5272e-05,  3.8068e-03,  3.9825e-03,  1.9294e-04, -1.3566e-04],
         [ 3.5221e-03,  4.3255e-04,  1.9288e-04,  3.6368e-03,  1.8078e-04],
         [ 7.6646e-04, -2.5904e-04, -1.3572e-04,  1.8066e-04,  7.2054e-03]],

        [[ 6.6063e-03,  5.9146e-04,  9.4950e-05, -1.5658e-04,  9.4891e-04],
         [ 5.9140e-04,  3.8016e-03,  3.8866e-03,  1.9610e-05, -2.7776e-04],
         [ 9.4950e-05,  3.8866e-03,  4.0587e-03,  1.0818e-04, -1.1563e-04],
        

In [6]:
post_train_mean

tensor([[ 0.0337,  0.9436,  0.9352,  0.2040, -0.5264],
        [ 0.0760,  0.9292,  0.9505,  0.1760, -0.5417],
        [ 0.0741,  0.9307,  0.9472,  0.1814, -0.5421],
        [ 0.0910,  0.9344,  0.9485,  0.1685, -0.5505],
        [ 0.0756,  0.9298,  0.9503,  0.1756, -0.5405]])

In [7]:
mll_vals

tensor([-0.2153, -0.2085, -0.2156, -0.2120, -0.2085], grad_fn=<CopySlices>)

We indeed observe some significant differences in model fits even when we use the same
training data.