In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(font_scale=2.0)

label_fs = 40

# sns.set_style('whitegrid')
sns.set_palette("mako")

In [None]:
import time
import torch

from botorch.acquisition.objective import ConstrainedMCObjective
from botorch.acquisition.monte_carlo import qExpectedImprovement, qNoisyExpectedImprovement
from botorch.acquisition.max_value_entropy_search import qLowerBoundMaxValueEntropy
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
from botorch.optim.fit import fit_gpytorch_torch
from botorch.sampling.samplers import SobolQMCNormalSampler

from utils import (
    generate_initial_data,
    parse,
    optimize_acqf_and_get_observation,
    update_random_observations,
    get_var_model,
    get_exact_model,
)

In [None]:
from volatilitygp.mlls import PatchedVariationalELBO as VariationalELBO
from gpytorch.mlls import ExactMarginalLogLikelihood

In [None]:
import matplotlib.pyplot as plt
from botorch.test_functions import Beale

In [None]:
torch.random.manual_seed(20)

In [None]:
neg_hartmann6 = Beale(negate=True)

bounds = neg_hartmann6.bounds

dtype = torch.double # if dtype=="double" else torch.float
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")

In [None]:
!nvidia-smi

In [None]:
def generate_initial_data(n, fn, NOISE_SE, device, dtype, is_poisson=False):
    # generate training data
    train_x = torch.rand(n, fn.dim, device=device, dtype=dtype)
    exact_obj = fn(train_x).unsqueeze(-1)  # add output dimension
    train_obj = exact_obj + NOISE_SE * torch.randn_like(exact_obj)
    if is_poisson:
        train_obj = Poisson(train_obj.exp()).sample()

    best_observed_value = train_obj.max().item()
    return train_x, train_obj, best_observed_value

In [None]:
    # call helper functions to generate initial training data and initialize model
    train_x_ei, train_obj_ei, best_observed_value_ei = generate_initial_data(
        n=50, fn=neg_hartmann6, NOISE_SE=0.2, device=device, dtype=dtype, is_poisson=False
    )
    train_yvar = 0.2 * torch.ones(1)

In [None]:
svgp = get_var_model(train_x_ei, train_obj_ei, train_yvar, is_poisson=False)
exact = get_exact_model(train_x_ei, train_obj_ei, train_yvar)

In [None]:
svgp_mll = VariationalELBO(svgp.likelihood, svgp, num_data=50)
exact_mll = ExactMarginalLogLikelihood(exact.likelihood, exact)

In [None]:
fit_gpytorch_torch(svgp_mll);
fit_gpytorch_torch(exact_mll);

In [None]:
bounds  = bounds.to(device)

In [None]:
from botorch.acquisition import qMaxValueEntropy, qKnowledgeGradient

In [None]:
candidate_set = torch.rand(250, bounds.size(1), device=device, dtype=dtype)
candidate_set = bounds[0] + (bounds[1] - bounds[0]) * candidate_set

In [None]:
x_grid = torch.meshgrid(torch.linspace(-4.5, 4.5, 15), torch.linspace(-4.5, 4.5, 15))
x_grid_vals = torch.cat((x_grid[0].reshape(-1, 1), x_grid[1].reshape(-1, 1)),dim=-1).to(device)

In [None]:
kg_svgp = qKnowledgeGradient(svgp, current_value=best_observed_value_ei, num_fantasies=8)
kg_exact = qKnowledgeGradient(exact, current_value=best_observed_value_ei, num_fantasies=8)

In [None]:
expanded_candidates = candidate_set.unsqueeze(0).repeat(225, 1, 1)

In [None]:
catted_x = torch.cat((x_grid_vals.unsqueeze(-2), expanded_candidates), dim=1)

In [None]:
with torch.no_grad():
    kg_exact_val = kg_exact(catted_x).cpu()

In [None]:
%pdb

In [None]:
with torch.no_grad():
    kg_svgp_val = kg_svgp(catted_x)

In [None]:
plt.figure(figsize = (7,6))
f = plt.contourf(*x_grid, kg_exact_val.reshape(15, 15).cpu(), cmap="mako")
plt.colorbar(f)
plt.xlabel("x1")
plt.ylabel("x2")
plt.savefig("./kg_exact.pdf", bbox_inches="tight")
# plt.title("KG - Exact")

In [None]:
plt.figure(figsize = (7,6))
f = plt.contourf(*x_grid, kg_svgp_val.reshape(15, 15).cpu(), cmap="mako")
plt.colorbar(f)

plt.xlabel("x1")
plt.ylabel("x2")
plt.savefig("./kg_svgp_exact.pdf", bbox_inches="tight")
# plt.title("KG - SVGP")

In [None]:
svgp._memoize_cache = {}
svgp.condition_into_exact = False

In [None]:
kg_svgp = qKnowledgeGradient(svgp, current_value=best_observed_value_ei, num_fantasies=8)

In [None]:
with torch.no_grad():
    kg_svgp_val = kg_svgp(catted_x)

In [None]:
plt.figure(figsize = (7,6))
f = plt.contourf(*x_grid, kg_svgp_val.reshape(15, 15).cpu(), cmap="mako")
plt.colorbar(f)

plt.xlabel("x1")
plt.ylabel("x2")
plt.savefig("./kg_svgp_sgpr.pdf", bbox_inches="tight")