In this notebook, we want to see how Adam and LBFGS compare when fitting LCEGP. We also
want to play around with the number of iterations allowed to see how that affects things.

The question is how to compare the fit alternatives??
We could check the MLL values - it is a flawed metric but better than nothing.
"
Use optimizer_kwargs = {"options": {"maxiter": maxiter_}} to limit number of iterations

In [1]:
from typing import Callable
import torch
from botorch import fit_gpytorch_model
from botorch.optim.fit import fit_gpytorch_scipy, fit_gpytorch_torch
from botorch.models import SingleTaskGP
from gpytorch import ExactMarginalLogLikelihood

from contextual_rs.models.lce_gp import LCEGP

ckwargs = {"device": "cpu", "dtype": torch.float}

def fit_baseline_gp(
    num_train: int,
    dim: int,
) -> SingleTaskGP:
    train_X = torch.rand(num_train, dim, **ckwargs)
    train_Y = torch.randn(num_train, 1, **ckwargs)
    model = SingleTaskGP(train_X, train_Y)
    mll = ExactMarginalLogLikelihood(model.likelihood, model)
    fit_gpytorch_model(mll)
    return model


def fit_on_baselines(
    d_cont: int,
    num_cat: int,
    num_cont: int,
    num_baseline_train: int,
    replications: int,
    optimizer: Callable,
    optimizer_kwargs: dict,
):
    all_mlls = torch.zeros(replications, **ckwargs)
    iters = torch.zeros(replications, **ckwargs)
    for seed in range(replications):
        torch.manual_seed(seed)
        baseline = fit_baseline_gp(num_baseline_train, d_cont+1)
        cats = torch.arange(0, num_cat, **ckwargs).view(num_cat, 1)
        train_X = torch.cat(
            [
                cats.view(-1, 1, 1).expand(-1, num_cont, -1),
                torch.rand(num_cat, num_cont, d_cont, **ckwargs)
            ], dim=-1
        )
        train_X_eval = train_X.clone()
        train_X_eval[..., 0] = train_X_eval[..., 0] / num_cat
        with torch.no_grad():
            train_Y = baseline.posterior(train_X_eval).mean
        model = LCEGP(train_X, train_Y, [0])
        mll = ExactMarginalLogLikelihood(model.likelihood, model)
        mll, info_dict = optimizer(mll, track_iterations=True, **optimizer_kwargs)
        all_mlls[seed] = info_dict["fopt"]
        iters[seed] = len(info_dict["iterations"])
    return all_mlls, iters

Let's just run a simple test first.

In [2]:
lbfgs_mlls, lbfgs_iters = fit_on_baselines(
    d_cont=2,
    num_cat=6,
    num_cont=10,
    num_baseline_train=20,
    replications=3,
    optimizer=fit_gpytorch_scipy,
    optimizer_kwargs=dict()
)

adam_mlls, adam_iters = fit_on_baselines(
    d_cont=2,
    num_cat=6,
    num_cont=10,
    num_baseline_train=20,
    replications=3,
    optimizer=fit_gpytorch_torch,
    optimizer_kwargs={"options": {"disp": False}}
)

print(f"LBFGS mlls: {lbfgs_mlls}")
print(f"Adam mlls: {adam_mlls}")
print(f"LBFGS iters: {lbfgs_iters}")
print(f"Adam iters: {adam_iters}")

LBFGS mlls: tensor([26.4275, 26.1508, 26.8268])
Adam mlls: tensor([26.9916, 28.1221, 28.3788])
LBFGS iters: tensor([58., 53., 60.])
Adam iters: tensor([50., 39., 39.])


In [3]:
kwargs = {
    "d_cont": 2,
    "num_cat": 6,
    "num_cont": 10,
    "num_baseline_train": 20,
    "replications": 30,
}

In [4]:
%%time
mlls, iters = fit_on_baselines(
    optimizer=fit_gpytorch_scipy, optimizer_kwargs=dict(), **kwargs
)
print(f"LBFGS mll, avg: {mlls.mean()}, std: {mlls.std()}")
print(f"LBFGS iters, avg: {iters.mean()}, std: {iters.std()}")

LBFGS mll, avg: 26.39801025390625, std: 0.5533416867256165
LBFGS iters, avg: 79.76667022705078, std: 32.117977142333984
CPU times: user 2min 46s, sys: 595 ms, total: 2min 47s
Wall time: 35.7 s


In [5]:
%%time
mlls, iters = fit_on_baselines(
    optimizer=fit_gpytorch_torch, optimizer_kwargs={"options": {"disp": False}}, **kwargs
)
print(f"Adam mll, avg: {mlls.mean()}, std: {mlls.std()}")
print(f"Adam iters, avg: {iters.mean()}, std: {iters.std()}")

Adam mll, avg: 27.490205764770508, std: 0.6860774755477905
Adam iters, avg: 47.70000076293945, std: 14.960062026977539
CPU times: user 54.3 s, sys: 197 ms, total: 54.5 s
Wall time: 11.7 s


In [6]:
kwargs = {
    "d_cont": 2,
    "num_cat": 10,
    "num_cont": 10,
    "num_baseline_train": 30,
    "replications": 30,
}

In [7]:
%%time
mlls, iters = fit_on_baselines(
    optimizer=fit_gpytorch_scipy, optimizer_kwargs=dict(), **kwargs
)
print(f"LBFGS mll, avg: {mlls.mean()}, std: {mlls.std()}")
print(f"LBFGS iters, avg: {iters.mean()}, std: {iters.std()}")

LBFGS mll, avg: 66.84286499023438, std: 0.4117429256439209
LBFGS iters, avg: 30.233333587646484, std: 5.49409294128418
CPU times: user 1min 27s, sys: 295 ms, total: 1min 27s
Wall time: 18.8 s


In [8]:
%%time
mlls, iters = fit_on_baselines(
    optimizer=fit_gpytorch_torch, optimizer_kwargs={"options": {"disp": False}}, **kwargs
)
print(f"Adam mll, avg: {mlls.mean()}, std: {mlls.std()}")
print(f"Adam iters, avg: {iters.mean()}, std: {iters.std()}")

Adam mll, avg: 67.77068328857422, std: 1.0745059251785278
Adam iters, avg: 91.86666870117188, std: 21.090499877929688
CPU times: user 1min 31s, sys: 271 ms, total: 1min 31s
Wall time: 19.8 s


In [9]:
kwargs = {
    "d_cont": 1,
    "num_cat": 10,
    "num_cont": 20,
    "num_baseline_train": 20,
    "replications": 30,
}

In [10]:
%%time
mlls, iters = fit_on_baselines(
    optimizer=fit_gpytorch_scipy, optimizer_kwargs=dict(), **kwargs
)
print(f"LBFGS mll, avg: {mlls.mean()}, std: {mlls.std()}")
print(f"LBFGS iters, avg: {iters.mean()}, std: {iters.std()}")

LBFGS mll, avg: 17.527326583862305, std: 3.4920310974121094
LBFGS iters, avg: 15.366666793823242, std: 5.162753582000732
CPU times: user 1min 6s, sys: 1.45 s, total: 1min 7s
Wall time: 16.8 s


In [11]:
%%time
mlls, iters = fit_on_baselines(
    optimizer=fit_gpytorch_torch, optimizer_kwargs={"options": {"disp": False}}, **kwargs
)
print(f"Adam mll, avg: {mlls.mean()}, std: {mlls.std()}")
print(f"Adam iters, avg: {iters.mean()}, std: {iters.std()}")

Adam mll, avg: 17.032188415527344, std: 2.3136327266693115
Adam iters, avg: 65.5999984741211, std: 13.745469093322754
CPU times: user 1min 2s, sys: 1.29 s, total: 1min 3s
Wall time: 15.7 s


In [12]:
kwargs = {
    "d_cont": 4,
    "num_cat": 10,
    "num_cont": 20,
    "num_baseline_train": 20,
    "replications": 30,
}

In [13]:
%%time
mlls, iters = fit_on_baselines(
    optimizer=fit_gpytorch_scipy, optimizer_kwargs=dict(), **kwargs
)
print(f"LBFGS mll, avg: {mlls.mean()}, std: {mlls.std()}")
print(f"LBFGS iters, avg: {iters.mean()}, std: {iters.std()}")

LBFGS mll, avg: 35.880760192871094, std: 0.2730153799057007
LBFGS iters, avg: 47.266666412353516, std: 8.08546257019043
CPU times: user 1min 57s, sys: 2.34 s, total: 1min 59s
Wall time: 28.8 s


In [14]:
%%time
mlls, iters = fit_on_baselines(
    optimizer=fit_gpytorch_torch, optimizer_kwargs={"options": {"disp": False}}, **kwargs
)
print(f"Adam mll, avg: {mlls.mean()}, std: {mlls.std()}")
print(f"Adam iters, avg: {iters.mean()}, std: {iters.std()}")

Adam mll, avg: 35.91973876953125, std: 0.2560601532459259
Adam iters, avg: 99.69999694824219, std: 1.6431676149368286
CPU times: user 1min 40s, sys: 2.08 s, total: 1min 42s
Wall time: 24.7 s


In [15]:
kwargs = {
    "d_cont": 4,
    "num_cat": 5,
    "num_cont": 20,
    "num_baseline_train": 40,
    "replications": 30,
}

In [16]:
%%time
mlls, iters = fit_on_baselines(
    optimizer=fit_gpytorch_scipy, optimizer_kwargs=dict(), **kwargs
)
print(f"LBFGS mll, avg: {mlls.mean()}, std: {mlls.std()}")
print(f"LBFGS iters, avg: {iters.mean()}, std: {iters.std()}")

LBFGS mll, avg: 9.431946754455566, std: 1.287776231765747
LBFGS iters, avg: 166.56666564941406, std: 156.10842895507812
CPU times: user 5min 27s, sys: 1.12 s, total: 5min 29s
Wall time: 1min 10s


In [17]:
%%time
mlls, iters = fit_on_baselines(
    optimizer=fit_gpytorch_torch, optimizer_kwargs={"options": {"disp": False}}, **kwargs
)
print(f"Adam mll, avg: {mlls.mean()}, std: {mlls.std()}")
print(f"Adam iters, avg: {iters.mean()}, std: {iters.std()}")

Adam mll, avg: 11.157636642456055, std: 0.22094941139221191
Adam iters, avg: 52.766666412353516, std: 4.538595676422119
CPU times: user 1min 7s, sys: 240 ms, total: 1min 7s
Wall time: 14.6 s


In [18]:
kwargs = {
    "d_cont": 2,
    "num_cat": 5,
    "num_cont": 20,
    "num_baseline_train": 30,
    "replications": 30,
}

In [19]:
%%time
mlls, iters = fit_on_baselines(
    optimizer=fit_gpytorch_scipy, optimizer_kwargs=dict(), **kwargs
)
print(f"LBFGS mll, avg: {mlls.mean()}, std: {mlls.std()}")
print(f"LBFGS iters, avg: {iters.mean()}, std: {iters.std()}")

LBFGS mll, avg: 6.512121200561523, std: 2.445747137069702
LBFGS iters, avg: 55.0, std: 44.02742004394531
CPU times: user 2min 29s, sys: 436 ms, total: 2min 29s
Wall time: 32.4 s


In [20]:
%%time
mlls, iters = fit_on_baselines(
    optimizer=fit_gpytorch_torch, optimizer_kwargs={"options": {"disp": False}}, **kwargs
)
print(f"Adam mll, avg: {mlls.mean()}, std: {mlls.std()}")
print(f"Adam iters, avg: {iters.mean()}, std: {iters.std()}")

Adam mll, avg: 10.719143867492676, std: 0.4961346983909607
Adam iters, avg: 44.70000076293945, std: 1.0221680402755737
CPU times: user 57 s, sys: 253 ms, total: 57.2 s
Wall time: 12.4 s


In [21]:
kwargs = {
    "d_cont": 1,
    "num_cat": 5,
    "num_cont": 10,
    "num_baseline_train": 25,
    "replications": 30,
}

In [22]:
%%time
mlls, iters = fit_on_baselines(
    optimizer=fit_gpytorch_scipy, optimizer_kwargs=dict(), **kwargs
)
print(f"LBFGS mll, avg: {mlls.mean()}, std: {mlls.std()}")
print(f"LBFGS iters, avg: {iters.mean()}, std: {iters.std()}")

LBFGS mll, avg: 13.640057563781738, std: 1.6636030673980713
LBFGS iters, avg: 24.96666717529297, std: 12.890743255615234
CPU times: user 1min 15s, sys: 229 ms, total: 1min 16s
Wall time: 17.3 s


In [23]:
%%time
mlls, iters = fit_on_baselines(
    optimizer=fit_gpytorch_torch, optimizer_kwargs={"options": {"disp": False}}, **kwargs
)
print(f"Adam mll, avg: {mlls.mean()}, std: {mlls.std()}")
print(f"Adam iters, avg: {iters.mean()}, std: {iters.std()}")

Adam mll, avg: 14.872940063476562, std: 1.8696917295455933
Adam iters, avg: 56.66666793823242, std: 10.927135467529297
CPU times: user 55.3 s, sys: 123 ms, total: 55.4 s
Wall time: 12.3 s


In [25]:
%%time
mlls, iters = fit_on_baselines(
    optimizer=fit_gpytorch_torch,
    optimizer_kwargs={"options": {"disp": False, "maxiter": 50}},
    **kwargs
)
print(f"Adam mll, avg: {mlls.mean()}, std: {mlls.std()}")
print(f"Adam iters, avg: {iters.mean()}, std: {iters.std()}")

Adam mll, avg: 14.93786907196045, std: 1.8196450471878052
Adam iters, avg: 49.13333511352539, std: 1.1665846109390259
CPU times: user 48.3 s, sys: 170 ms, total: 48.5 s
Wall time: 11.1 s


In [26]:
%%time
mlls, iters = fit_on_baselines(
    optimizer=fit_gpytorch_torch,
    optimizer_kwargs={"options": {"disp": False, "maxiter": 25}},
    **kwargs
)
print(f"Adam mll, avg: {mlls.mean()}, std: {mlls.std()}")
print(f"Adam iters, avg: {iters.mean()}, std: {iters.std()}")

Adam mll, avg: 21.671274185180664, std: 0.2634010314941406
Adam iters, avg: 25.0, std: 0.0
CPU times: user 33 s, sys: 95.2 ms, total: 33.1 s
Wall time: 7.45 s
