This notebook is to have a rough idea of how much of an improvement there's to gain
from coming up with a Li type approximation to PCS.

Li takes advantage of IID normal structure to easily come up with an estimate based on
the volume of a sphere. If we were to attempt a similar idea, we would be dealing with
ellipses and Cholesky factors.

Need to take a more careful look into what is necessary to get such an approximation,
but let's start with comparing the time to draw a single sample vs multiple samples
using the empirical PCS code. It is possible that by avoiding sorting of many samples,
we would gain significant computational savings.

In [1]:
import resource
import time

import torch
from botorch import fit_gpytorch_model
from gpytorch import ExactMarginalLogLikelihood
from torch import Tensor
from contextual_rs.models.lce_gp import LCEGP
from contextual_rs.generalized_pcs import estimate_current_generalized_pcs

ckwargs = {"dtype": torch.double, "device": "cpu"}


def clock2():
    """
    clock2() -> (t_user,t_system)
    Return a tuple of user/system cpu times.
    """
    return resource.getrusage(resource.RUSAGE_SELF)[:2]


def sine_test(X: Tensor) -> Tensor:
    return torch.sin(X * 10.0).sum(dim=-1, keepdim=True)

# running a simple test with LCEGP
# test with LCEMGP
dim_x = 1
context_dim = 2
num_arms = 6
num_contexts = 10
num_full_train = 3
arm_set = torch.arange(0, num_arms, **ckwargs).view(-1, 1)
context_map = torch.rand(num_contexts, context_dim, **ckwargs)
train_X = (
    torch.cat(
        [
            arm_set.view(-1, 1, 1).expand(-1, num_contexts, -1),
            context_map.expand(num_arms, -1, -1),
        ],
        dim=-1,
    )
    .view(-1, context_dim + 1)
    .repeat(num_full_train, 1)
)
# construct and train the model
model = LCEGP(
    train_X, sine_test(train_X), categorical_cols=[0]
)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_model(mll)

%load_ext memory_profiler


def estimate_w_samples(
    num_samples: int, replications: int,
):
    wall_times = torch.zeros(replications, **ckwargs)
    user_cts = torch.zeros(replications, **ckwargs)
    sys_cts = torch.zeros(replications, **ckwargs)
    for seed in range(replications):
        # clean caches
        model.train()
        model.eval()
        torch.manual_seed(seed)
        # start times
        wt_start = time.time()
        ct_start = clock2()
        # run
        estimate_current_generalized_pcs(
            model=model,
            arm_set=arm_set,
            context_set=context_map,
            num_samples=num_samples,
            base_samples=None,
            func_I=lambda X: (X > 0).to(**ckwargs),
            rho=lambda X: X.mean(dim=-2),
        )
        # record times
        ct_end = clock2()
        wt_end = time.time()
        wall_times[seed] = wt_end - wt_start
        user_cts[seed] = ct_end[0] - ct_start[0]
        sys_cts[seed] = ct_end[1] - ct_start[1]
    return (
        wall_times.mean(),
        user_cts.mean(),
        sys_cts.mean(),
        user_cts.mean() + sys_cts.mean()
    )

In [2]:
for num_samples in [1, 16, 64, 256, 1024, 2**14 ]:
    print(f"Num samples {num_samples}")
    %memit times = estimate_w_samples(num_samples, 1000)
    print(
        f"Wall time {'{:.3f}'.format(float(times[0]))}, "
        f"user time {'{:.3f}'.format(float(times[1]))}, "
        f"sys time {'{:.3f}'.format(float(times[2]))}, "
        f"total cpu time {'{:.3f}'.format(float(times[3]))}"
    )

Num samples 1
peak memory: 280.56 MiB, increment: 16.58 MiB
Wall time 0.008, user time 0.038, sys time 0.001, total cpu time 0.039
Num samples 16
peak memory: 290.91 MiB, increment: 11.72 MiB
Wall time 0.008, user time 0.036, sys time 0.000, total cpu time 0.036
Num samples 64
peak memory: 300.08 MiB, increment: 9.16 MiB
Wall time 0.008, user time 0.038, sys time 0.001, total cpu time 0.039
Num samples 256
peak memory: 310.19 MiB, increment: 11.34 MiB
Wall time 0.008, user time 0.045, sys time 0.001, total cpu time 0.046
Num samples 1024
peak memory: 323.76 MiB, increment: 14.73 MiB
Wall time 0.010, user time 0.052, sys time 0.002, total cpu time 0.054
Num samples 16384
peak memory: 396.96 MiB, increment: 73.20 MiB
Wall time 0.041, user time 0.109, sys time 0.003, total cpu time 0.112


These timings are a little suspicious. I'd expect the difference to be larger.

So, let's bisect the method and measure purely the sampling time.

In [3]:
# generate the tensor of arm-context pairs
arm_context_pairs = torch.cat(
    [
        arm_set.unsqueeze(-2).expand(-1, context_map.shape[0], -1),
        context_map.expand(arm_set.shape[0], -1, -1),
    ],
    dim=-1,
).reshape(num_arms * num_contexts, -1)


def estimate_w_samples_only(
    num_samples: int, replications: int,
):
    wall_times = torch.zeros(replications, **ckwargs)
    user_cts = torch.zeros(replications, **ckwargs)
    sys_cts = torch.zeros(replications, **ckwargs)
    for seed in range(replications):
        # clean caches
        model.train()
        model.eval()
        torch.manual_seed(seed)
        # start times
        wt_start = time.time()
        ct_start = clock2()
        # run
        posterior = model.posterior(arm_context_pairs)
        y_samples = posterior.rsample(
            sample_shape=torch.Size([num_samples])
        )
        # record times
        ct_end = clock2()
        wt_end = time.time()
        wall_times[seed] = wt_end - wt_start
        user_cts[seed] = ct_end[0] - ct_start[0]
        sys_cts[seed] = ct_end[1] - ct_start[1]
    return (
        wall_times.mean(),
        user_cts.mean(),
        sys_cts.mean(),
        user_cts.mean() + sys_cts.mean()
    )

In [4]:
for num_samples in [1, 16, 64, 256, 1024, 2**14 ]:
    print(f"Num samples {num_samples}")
    %memit times = estimate_w_samples_only(num_samples, 1000)
    print(
        f"Wall time {'{:.3f}'.format(float(times[0]))}, "
        f"user time {'{:.3f}'.format(float(times[1]))}, "
        f"sys time {'{:.3f}'.format(float(times[2]))}, "
        f"total cpu time {'{:.3f}'.format(float(times[3]))}"
    )

Num samples 1
peak memory: 390.88 MiB, increment: 7.62 MiB
Wall time 0.007, user time 0.034, sys time 0.001, total cpu time 0.034
Num samples 16
peak memory: 399.79 MiB, increment: 8.77 MiB
Wall time 0.007, user time 0.035, sys time 0.001, total cpu time 0.035
Num samples 64
peak memory: 408.81 MiB, increment: 8.93 MiB
Wall time 0.007, user time 0.034, sys time 0.001, total cpu time 0.034
Num samples 256
peak memory: 417.70 MiB, increment: 8.89 MiB
Wall time 0.008, user time 0.042, sys time 0.001, total cpu time 0.043
Num samples 1024
peak memory: 426.74 MiB, increment: 8.98 MiB
Wall time 0.009, user time 0.048, sys time 0.001, total cpu time 0.048
Num samples 16384
peak memory: 500.28 MiB, increment: 73.54 MiB
Wall time 0.038, user time 0.099, sys time 0.003, total cpu time 0.103
