In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import botorch

sns.set(style='whitegrid', font_scale=1.75)

DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
import torchsort

import sys
sys.path.append('..')
from experiments.std_bayesopt.helpers import *

In [None]:
obj_fn = lambda x: np.maximum(-0.125 * x ** 2 + 16 * np.sin(x), 0)

In [None]:
fig = plt.figure(figsize=(8, 5))

noise_scale = 1.
x_bounds = torch.tensor((-16., 16.)).view(-1, 1)

x = np.linspace(*x_bounds, 64)
f = obj_fn(x)
y = f + noise_scale * np.random.randn(*f.shape)

plt.scatter(x, y, edgecolors='black', facecolors='none', label='observations', s=64, zorder=3)
plt.plot(x, f, color='black', linestyle='--', label='ground truth', linewidth=2, zorder=2)

plt.ylabel('y')
plt.xlabel('x')
plt.ylim((-4, 24))
plt.legend(loc='upper left', ncol=2)

In [None]:
import botorch
from botorch.fit import fit_gpytorch_model
from botorch.acquisition.analytic import UpperConfidenceBound, ExpectedImprovement
from botorch.optim.optimize import optimize_acqf
from botorch.sampling import IIDNormalSampler

from gpytorch.mlls import ExactMarginalLogLikelihood

In [None]:
from lambo.utils import DataSplit, update_splits
from lambo.optimizers.pymoo import Normalizer

cutoff = np.max(np.where(x < -8))
x_min, y_min = x.min(0), y.min(0)
x_range, y_range = x.max(0) - x_min, y.max(0) - y_min

x_norm = Normalizer(
    loc=x_min + 0.5 * x_range,
    scale=x_range / 2.,
)
y_norm = Normalizer(
    loc=y_min + 0.5 * y_range,
    scale=y_range / 2.,
)

train_x = x[:cutoff]
train_y = y[:cutoff]

all_inputs = torch.tensor(x_norm(x), device=DEVICE).view(-1, 1)
all_targets = torch.tensor(y_norm(y), device=DEVICE).view(-1, 1)
target_dim = all_targets.shape[-1]

new_split = DataSplit(
    all_inputs[:cutoff].cpu().numpy(), all_targets[:cutoff].cpu().numpy()
)
train_split, val_split, test_split = update_splits(
    train_split=DataSplit(),
    val_split=DataSplit(),
    test_split=DataSplit(),
    new_split=new_split,
    holdout_ratio=0.2
)

cls_train_split, cls_val_split, cls_test_split = update_splits(
    train_split=DataSplit(),
    val_split=DataSplit(),
    test_split=DataSplit(),
    new_split=DataSplit(),
    holdout_ratio=0.2
)

input_bounds = torch.tensor([-1., 1.], device=DEVICE).view(-1, 1)

In [None]:
def run_datashift_opt(splits, cls_splits=None, acqf="ei"):
    def draw_plot(ax):
        # plot p(f | x, D)
        ax.plot(all_inputs.cpu(), f_hat_mean, color='blue', linewidth=2, zorder=4, label='p(f | x, D)')
        ax.fill_between(all_inputs.view(-1).cpu(), f_hat_mean - 1.96 * f_hat_std, f_hat_mean + 1.96 * f_hat_std,
                        color='blue', alpha=0.25)

        # plot a(x)
        ax.plot(all_inputs.cpu(), acq_vals, color='green', zorder=5, linewidth=2, label='a(x)')
        ax.scatter(input_query.cpu(), target_query.cpu(), marker='x', color='red', label='x*', zorder=5,
                   s=32, linewidth=2)

        # plot observed
        ax.scatter(train_inputs.cpu(), train_targets.cpu(), edgecolors='black', facecolors='black',
                   label='D', s=32, zorder=3)

        # plot true function
        ax.plot(all_inputs.cpu(), y_norm(f), color='black', linestyle='--', label='f')

        ax.set_ylim((-2., 2.))

        ax.set_xlabel('x')
        ax.set_ylabel('y')
        return ax

    train_split, val_split, test_split = splits
    
    optimize_callback = None
    rx_estimator = None
    if acqf == 'ei':
        acqf_init = lambda gp, best_f: qExpectedImprovement(
            gp,
            best_f=train_targets.max(0)[0],
            sampler=IIDNormalSampler(64)
        )
    elif acqf == "nei":
        acqf_init = lambda gp, best_f: qNoisyExpectedImprovement(gp, X_baseline=train_inputs, sampler=IIDNormalSampler(64))
    elif acqf == 'conformal_ei':
        def acqf_init(gp, best_f):
            gp.conformal()
            return qConformalExpectedImprovement(
                gp,
                best_f=train_targets.max(0)[0],
                sampler=PassSampler(32),
                cache_root=False
            )

        ######### Ratio Estimator ###########
        cls_train_split, cls_val_split, cls_test_split = cls_splits

        ## Remains uniform, when untrained.
        classifier = nn.Sequential(
            nn.Linear(train_split.inputs.shape[-1], 1),
        ).to(DEVICE)
        for p in classifier.parameters():
            p.data.fill_(0)
        optim = torch.optim.Adam(classifier.parameters(), lr=1e-3)
        criterion = torch.nn.BCEWithLogitsLoss()

        class _RatioEstimator(nn.Module):
            def forward(self, inputs):
                _p = classifier(inputs).squeeze(-1).sigmoid()
                ## FIXME: adjust for class priors?
                return _p / (1 - _p + 1e-6)
        rx_estimator = _RatioEstimator()

        class RatioDataset(torch.utils.data.Dataset):
            def __len__(self):
                return len(cls_train_split.inputs)

            def __getitem__(self, index):
                return cls_train_split.inputs[index], cls_train_split.targets[index]
        rx_dataset = RatioDataset()

        def optimize_callback(xk):
            global cls_train_split, cls_val_split, cls_test_split

            ## FIXME: handle too many -ve samples.
            xk = torch.from_numpy(xk).reshape(-1, 1)
            xk.add_(.1 * torch.randn_like(xk))
            yk = torch.zeros(len(xk), 1)
            cls_train_split, cls_val_split, cls_test_split = update_splits(
                train_split=cls_train_split,
                val_split=cls_val_split,
                test_split=cls_test_split,
                new_split=DataSplit(xk, yk),
                holdout_ratio=0.2
            )

            ## One stochastic gradient step of the classifier.
            if len(rx_dataset):
                loader = torch.utils.data.DataLoader(rx_dataset, shuffle=True, batch_size=64)
                X, y = next(iter(loader))
                X, y = X.to(DEVICE).float(), y.to(DEVICE)
                optim.zero_grad()
                loss = criterion(classifier(X), y)
                loss.backward()
                optim.step()

        ######### Ratio Estimator ###########
    elif acqf == "conformal_nei":
        def acqf_init(gp, best_f):
            gp.conformal()
            return qConformalNoisyExpectedImprovement(
                gp,
                X_baseline=train_inputs,
                sampler=PassSampler(32),
                cache_root=False
            )
        
    num_rounds = 32
    plot_interval = 8
    
    queried_targets = []
    for round_idx in range(num_rounds):
        train_inputs = torch.tensor(train_split[0], device=DEVICE)
        train_targets = torch.tensor(train_split[1], device=DEVICE)

        matern_gp = ConformalSingleTaskGP(
            train_X=train_inputs,
            train_Y=train_targets,
            alpha=0.1,
            tgt_grid_res=32,
            conformal_bounds=torch.tensor([[-2., 2.]]).t(),
            ratio_estimator=rx_estimator,
        ).to(DEVICE)
        mll = ExactMarginalLogLikelihood(matern_gp.likelihood, matern_gp)
        fit_gpytorch_model(mll)
        # acq_fn = UpperConfidenceBound(matern_gp, beta=8.)
        # HERE WE USE EI
        # acq_fn = ExpectedImprovement(matern_gp, best_f=train_targets.max())
        acq_fn = acqf_init(matern_gp, train_targets.max())

        matern_gp.requires_grad_(False)
        matern_gp.eval()
        with torch.no_grad():
            f_hat_dist = matern_gp(all_inputs)
            y_hat_dist = matern_gp.likelihood(f_hat_dist)
            f_hat_mean = f_hat_dist.mean.cpu()
            f_hat_std = f_hat_dist.variance.sqrt().cpu()
            y_hat_mean = f_hat_mean.cpu()
            y_hat_std = y_hat_dist.variance.sqrt().cpu()
            acq_vals = acq_fn(all_inputs[:, None]).cpu()

        input_query = optimize_acqf(acq_fn, input_bounds, 1, num_restarts=4, raw_samples=16,
                                    options=dict(callback=optimize_callback))[0]
        x_query = x_norm.inv_transform(input_query.cpu().numpy())
        f_query = obj_fn(x_query)
        y_query = f_query + noise_scale * np.random.randn(*f_query.shape)
        target_query = torch.tensor(y_norm(y_query), device=DEVICE)

        new_split = DataSplit(
            input_query.reshape(-1, 1).cpu(),
            target_query.reshape(-1, 1).cpu(),
        )
        train_split, val_split, test_split = update_splits(
            train_split, val_split, test_split, new_split, holdout_ratio=0.2
        )
        if acqf.startswith('conformal'):
            cls_train_split, cls_val_split, cls_test_split = update_splits(
                train_split=cls_train_split,
                val_split=cls_val_split,
                test_split=cls_test_split,
                new_split=DataSplit(
                    input_query.reshape(-1, 1).cpu(),
                    torch.ones(1, 1),
                ),
                holdout_ratio=0.2
            )

        queried_targets.append(target_query)

        if round_idx % plot_interval == 0:
            print(f'{train_split[0].shape[0]} train, {val_split[0].shape[0]} val, {test_split[0].shape[0]} test')
            if acqf.startswith('conformal'):
                print(f'[Iterates] {cls_train_split[0].shape[0]} train, {cls_val_split[0].shape[0]} val, {cls_test_split[0].shape[0]} test')
            fig = plt.figure(figsize=(8, 5))
            ax = fig.add_subplot(1, 1, 1)
            draw_plot(ax)

    plt.show()
    return torch.tensor(queried_targets)

In [None]:
conformal_pts = [
    run_datashift_opt([train_split, val_split, test_split],
                      [cls_train_split, cls_val_split, cls_test_split],
                      acqf="conformal_ei") for _ in range(2)
]

In [None]:
std_pts = [run_datashift_opt([train_split, val_split, test_split], acqf="ei") for _ in range(2)]

In [None]:
conformal = torch.stack(conformal_pts).cummax(1)[0]
std = torch.stack(std_pts).cummax(1)[0]

In [None]:
plt.plot(conformal.mean(0), label = "conformal qEI")
plt.fill_between(torch.arange(32), 
                 conformal.mean(0) - 2. / 5**0.5 * conformal.std(0), 
                 conformal.mean(0) + 2. / 5**0.5 * conformal.std(0), 
                 alpha = 0.3)
plt.plot(std.mean(0), label = "qEI")
plt.fill_between(torch.arange(32), 
                 std.mean(0) - 2. / 5**0.5 * std.std(0), 
                 std.mean(0) + 2. / 5**0.5 * std.std(0), 
                 alpha = 0.3)
plt.legend()
plt.xlabel("Function Evaluations")
plt.ylabel("Best Achieved")