In the classical RS setting of only categorical variables, there doesn't seem to be any
advantage to using custom fit. This notebook considers the contextual case, where some
other experiments suggest that custom fit may be helping somewhat significantly.

Below, we use a version of custom_fit that returns additional output for analysis.

In [4]:
from typing import Tuple

from gpytorch import ExactMarginalLogLikelihood
from torch import Tensor
from contextual_rs.lce_gp import LCEGP


import math
from collections import Callable
from copy import deepcopy
from typing import Any

import torch
from botorch.fit import _set_transformed_inputs
from botorch.optim.fit import fit_gpytorch_scipy
from botorch.optim.utils import sample_all_priors
from contextual_rs.models.lce_gp import LCEGP
from gpytorch.mlls import MarginalLogLikelihood


def custom_fit_gpytorch_model(
    mll: MarginalLogLikelihood, optimizer: Callable = fit_gpytorch_scipy, **kwargs: Any
) -> Tuple[MarginalLogLikelihood, int, float]:
    r"""
    This is a modified version of BoTorch `fit_gpytorch_model`. `fit_gpytorch_model`
    has some inconsistent behavior in fitting the embedding weights in LCEGP.
    The idea here is to get around this issue by aiming for a global fit.

    Args:
        mll: The marginal log-likelihood of the model. To be maximized.
        optimizer: The optimizer for optimizing the mll starting from an
            initialization of model parameters.
        **kwargs: Optional arguments.

    Returns:
        The optimized mll.
    """
    assert isinstance(mll.model, LCEGP), "Only supports LCEGP!"
    num_retries = kwargs.pop("num_retries", 1)
    mll.train()
    original_state_dict = deepcopy(mll.model.state_dict())
    retry = 0
    state_dict_list = list()
    mll_values = torch.zeros(num_retries)
    max_error_tries = kwargs.pop("max_error_tries", 10)
    randn_factor = kwargs.pop("randn_factor", 0.1)
    error_count = 0
    while retry < num_retries:
        if retry > 0:  # use normal initial conditions on first try
            mll.model.load_state_dict(original_state_dict)
            # randomize the embedding as well, reinitializing here.
            # two alternatives for initialization, specified by passing randn_factor
            for i, emb_layer in enumerate(mll.model.emb_layers):
                if randn_factor == 0:
                    new_emb = torch.nn.Embedding(
                        emb_layer.num_embeddings,
                        emb_layer.embedding_dim,
                        max_norm=emb_layer.max_norm,
                    ).to(emb_layer.weight)
                    mll.model.emb_layers[i] = new_emb
                else:
                    new_weight = torch.randn_like(emb_layer.weight) * randn_factor
                    emb_layer.weight = torch.nn.Parameter(
                        new_weight, requires_grad=True
                    )
            sample_all_priors(mll.model)
        mll, info_dict = optimizer(mll, track_iterations=False, **kwargs)
        opt_val = info_dict["fopt"]
        if math.isnan(opt_val):
            if error_count < max_error_tries:
                error_count += 1
                continue
            else:
                state_dict_list.append(original_state_dict)
                mll_values[retry] = float("-inf")
                retry += 1
                continue

        # record the fitted model and the corresponding mll value
        state_dict_list.append(deepcopy(mll.model.state_dict()))
        mll_values[retry] = -opt_val  # negate to get mll value
        retry += 1

    # pick the best among all trained models
    best_idx = mll_values.argmax()
    best_params = state_dict_list[best_idx]
    mll.model.load_state_dict(best_params)
    _set_transformed_inputs(mll=mll)
    return mll.eval(), best_idx, mll_values[best_idx] - mll_values[0]


def test_func(X: Tensor) -> Tensor:
    assert X.dim() == 2
    context_dim = X.shape[-1] - 1
    part_1 = X[:, 0].view(-1, 1) * X[:, 1:].sum(dim=-1, keepdim=True)



def fit_on_random_model(
    randn_factor: float,
    fit_tries: int,
    num_arms: int,
    context_dim: int,
    num_train: int,
    seed: int,
) -> Tuple[int, Tensor]:
    torch.manual_seed(seed)
    ckwargs = {"dtype": torch.double, "device": "cpu"}
    train_X = torch.cat(
        [
            torch.randint(0, num_arms, (num_train, 1), **ckwargs),
            torch.rand(num_train, context_dim, **ckwargs)
        ], dim=-1
    )
    train_Y = torch.randn()

    model = LCEGP(train_X, train_Y, categorical_cols=[0])
    mll = ExactMarginalLogLikelihood(model.likelihood, model)
    fitted_mll, best_idx, improvement = custom_fit_gpytorch_model(
        mll, num_retries=fit_tries, randn_factor=randn_factor
    )
    return best_idx, improvement


def multi_fit(
    replications: int,
    **kwargs,
) -> Tuple[float, Tensor]:
    r"""
    Runs fit_on_random_model multiple times and returns the fraction of time the
    best_idx is non-zero and the average improvement over simple fit.
    """
    best_idcs = torch.zeros(replications)
    improvements = torch.zeros(replications)
    for seed in range(replications):
        best_idcs[seed], improvements[seed] = fit_on_random_model(**kwargs, seed=seed)
    return best_idcs.bool().to(torch.float).mean(), improvements.mean()

Now that the setup is ready, let's run some alternatives under several settings.

Alternatives to consider: 0 (default initialization), 1.0, 0.5, 0.1, 0.05

In [5]:
alternatives = [0, 1.0, 0.5, 0.1, 0.05]

Alternative 0, fraction 0.25, improvement 0.004703551530838013
Alternative 1.0, fraction 0.25, improvement 0.004703551530838013
Alternative 0.5, fraction 0.25, improvement 0.004703551530838013
Alternative 0.1, fraction 0.25, improvement 0.004703551530838013
Alternative 0.05, fraction 0.25, improvement 0.004703551530838013


In [None]:
kwargs = {
    "fit_tries": 10,
    "num_alternatives": 5,
    "num_train": 10,
    "replications": 20,
}

for randn_factor in alternatives:
    fraction, improvement = multi_fit(**kwargs, randn_factor=randn_factor)
    print(f"Alternative {randn_factor}, fraction {fraction}, improvement {improvement}")


In [6]:
kwargs = {
    "fit_tries": 10,
    "num_alternatives": 10,
    "num_train": 10,
    "replications": 20,
}

for randn_factor in alternatives:
    fraction, improvement = multi_fit(**kwargs, randn_factor=randn_factor)
    print(f"Alternative {randn_factor}, fraction {fraction}, improvement {improvement}")


Alternative 0, fraction 0.699999988079071, improvement 0.007884478196501732
Alternative 1.0, fraction 0.699999988079071, improvement 0.007884478196501732
Alternative 0.5, fraction 0.699999988079071, improvement 0.007884478196501732
Alternative 0.1, fraction 0.699999988079071, improvement 0.007884478196501732
Alternative 0.05, fraction 0.699999988079071, improvement 0.007884478196501732


In [7]:
kwargs = {
    "fit_tries": 10,
    "num_alternatives": 20,
    "num_train": 5,
    "replications": 20,
}

for randn_factor in alternatives:
    fraction, improvement = multi_fit(**kwargs, randn_factor=randn_factor)
    print(f"Alternative {randn_factor}, fraction {fraction}, improvement {improvement}")


Alternative 0, fraction 0.8500000238418579, improvement 0.009538346901535988
Alternative 1.0, fraction 0.8999999761581421, improvement 0.012168830260634422
Alternative 0.5, fraction 0.949999988079071, improvement 0.012226665392518044
Alternative 0.1, fraction 0.949999988079071, improvement 0.012226665392518044
Alternative 0.05, fraction 0.949999988079071, improvement 0.012226665392518044


In [8]:
kwargs = {
    "fit_tries": 10,
    "num_alternatives": 40,
    "num_train": 5,
    "replications": 20,
}

for randn_factor in alternatives:
    fraction, improvement = multi_fit(**kwargs, randn_factor=randn_factor)
    print(f"Alternative {randn_factor}, fraction {fraction}, improvement {improvement}")

Alternative 0, fraction 0.6499999761581421, improvement 0.005162292625755072
Alternative 1.0, fraction 0.949999988079071, improvement 0.010405289940536022
Alternative 0.5, fraction 1.0, improvement 0.011099308729171753
Alternative 0.1, fraction 1.0, improvement 0.011149173602461815
Alternative 0.05, fraction 1.0, improvement 0.011149173602461815


In [9]:
kwargs = {
    "fit_tries": 10,
    "num_alternatives": 10,
    "num_train": 30,
    "replications": 20,
}

for randn_factor in alternatives:
    fraction, improvement = multi_fit(**kwargs, randn_factor=randn_factor)
    print(f"Alternative {randn_factor}, fraction {fraction}, improvement {improvement}")

Alternative 0, fraction 0.6499999761581421, improvement 0.0036636709701269865
Alternative 1.0, fraction 0.6499999761581421, improvement 0.0036636709701269865
Alternative 0.5, fraction 0.6499999761581421, improvement 0.0036636709701269865
Alternative 0.1, fraction 0.6499999761581421, improvement 0.0036636709701269865
Alternative 0.05, fraction 0.6499999761581421, improvement 0.0036636709701269865
