In [None]:
import argparse
import os
import math
import time
import numpy as np
import matplotlib.pyplot as plt

from dataclasses import dataclass

import torch
from botorch.acquisition import qExpectedImprovement
from botorch.fit import fit_gpytorch_model
from botorch.optim.fit import fit_gpytorch_torch
from botorch.generation import MaxPosteriorSampling
from botorch.models import FixedNoiseGP, SingleTaskGP
from botorch.optim import optimize_acqf
from botorch.utils.transforms import unnormalize
from torch.quasirandom import SobolEngine

import gpytorch
from gpytorch.constraints import Interval
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.priors import HorseshoePrior

In [None]:
import sys
sys.path.append("../std_bayesopt/")
from utils import initialize_model_unconstrained as initialize_model

sys.path.append("./")
from trbo import TurboState, update_state, generate_batch

In [None]:
torch.random.manual_seed(0)

In [None]:
dim = 50
device = torch.device("cuda:0")
dtype = torch.float32

In [None]:
from rover_function import create_large_domain
def l2cost(x, point):
    return 10 * np.linalg.norm(x - point, 1)

domain = create_large_domain(
    force_start=False,
    force_goal=False,
    start_miss_cost=l2cost,
    goal_miss_cost=l2cost,
    n_points=dim, 
)
n_points = domain.traj.npoints

raw_x_range = np.repeat(domain.s_range, n_points, axis=1)
# 5 is the max achievable, this is the offset
bounded_fn_callable = lambda X: torch.stack([torch.tensor(domain(x.cpu().numpy()) + 5.) for x in X]).to(X)
# map into bounds, thanks..
# we need to map this from [0, 1]^d -> [-0.1, 1.1]^d
fn_callable = lambda X: bounded_fn_callable(X * 1.2 - 0.1)
bounds = torch.tensor(raw_x_range, dtype=dtype, device=device)
bounds = torch.zeros(2, raw_x_range.shape[-1], dtype=dtype, device=device)
bounds[1] = 1.
dim = bounds.shape[-1]

In [None]:
n_candidates = 5000

In [None]:
next_x = torch.rand(800, bounds.shape[-1], device=device, dtype=dtype) * \
    (bounds[1] - bounds[0]) + bounds[0]
next_obj = fn_callable(next_x).unsqueeze(-1)

In [None]:
            # generate a new model
            mll_gibbon, model = initialize_model(
                next_x, 
                next_obj, 
                None,
                method="exact",
                use_input_transform=False,
                use_outcome_transform=True,
                num_inducing=500,
                loss="pll",
            )
            # fit the models
            optimizer_kwargs = {"maxiter": 1000}
            fit_gpytorch_torch(mll_gibbon, options=optimizer_kwargs)

In [None]:
            # generate a new model
            mll_gibbon, svgp = initialize_model(
                next_x, 
                next_obj, 
                None,
                method="variational",
                use_input_transform=False,
                use_outcome_transform=True,
                num_inducing=500,
                loss="pll",
            )
            # fit the models
            optimizer_kwargs = {"maxiter": 1000}
            fit_gpytorch_torch(mll_gibbon, options=optimizer_kwargs);

In [None]:
from trbo import generate_candidates

In [None]:
def rollout_path(candidate_sets, model, tree_depth, orig_topk):
    post_samples = model.posterior(candidate_sets[0]).rsample().squeeze(0)
    samples, inds = torch.topk(post_samples, dim=-2, k=orig_topk)
    
    if candidate_sets[0].ndim < 3:
        cs = candidate_sets[0][inds]
    else:
        cs = torch.stack([candidate_sets[0][i, inds[i]] for i in range(inds.shape[0])])

    fantasy = model.condition_on_observations(
        cs,
        samples.unsqueeze(-1)
    )
    pred_root_list = []
    pred_root_list.append(fantasy.prediction_strategy.lik_train_train_covar.root_decomposition())
    for depth in range(tree_depth - 1):
        post_samples = fantasy.posterior(candidate_sets[depth+1]).rsample().squeeze(0)
        samples, inds = torch.topk(post_samples, dim=-2, k=1)
        fantasy = fantasy.condition_on_observations(
            candidate_sets[depth+1][inds].squeeze(-2),
            samples,
        )
        pred_root_list.append(fantasy.prediction_strategy.lik_train_train_covar.root_decomposition())
        
    targets = fantasy.train_targets
    return fantasy.train_inputs[0][..., -tree_depth:, :], targets[..., -tree_depth:], pred_root_list

In [None]:
orig_topk = 10
tree_depth = 16
length = 0.2

In [None]:
# Scale the TR to be proportional to the lengthscales
x_center = next_x[next_obj.argmax(), :].clone()
weights = model.covar_module.base_kernel.lengthscale.squeeze().detach()
weights = weights / weights.mean()
weights = weights / torch.prod(weights.pow(1.0 / len(weights)))
tr_lb = torch.clamp(x_center - weights * length / 2.0, 0.0, 1.0)
tr_ub = torch.clamp(x_center + weights * length / 2.0, 0.0, 1.0)

dim = next_x.shape[-1]

In [None]:
rollout_candidates = 1000
candidates = [generate_candidates(
    x_center, dim, rollout_candidates, torch.stack([tr_lb, tr_ub]), device=device, dtype=dtype
)] * tree_depth

In [None]:
with torch.no_grad():
    inputs, targets, roots = rollout_path(candidates, model, tree_depth, orig_topk)
    # inputs, targets = inputs.cpu(), targets.cpu()

In [None]:
with torch.no_grad():
    exact_train_evals = model.covar_module(inputs.reshape(-1, dim)).symeig()[0]

In [None]:
evals = [x.evaluate().double().symeig()[0].cpu().detach().numpy().T for x in roots]

In [None]:
torch.cuda.empty_cache()
# del model

In [None]:
with torch.no_grad():
    var_inputs, var_targets, var_roots = rollout_path(candidates, svgp, tree_depth, orig_topk)

In [None]:
with torch.no_grad():
    var_train_evals = svgp.covar_module(inputs.reshape(-1, 60)).symeig()[0]

In [None]:
var_evals = [x.evaluate().double().symeig()[0].cpu().detach().numpy().T for x in var_roots]

In [None]:
exact_cond = np.stack(x[-1] / x[0] for x in evals)
var_cond = np.stack(x[-1] / x[0] for x in var_evals)

In [None]:
exact_mean = np.mean(exact_cond, 1)
exact_std = np.std(exact_cond, 1)

var_mean = np.mean(var_cond, 1)
var_std = np.std(var_cond, 1)

sem = 2 / (10 ** 0.5)

In [None]:
plt.plot(np.arange(16), exact_mean, color = "orange", linewidth = 4, marker = "x", markersize=10, label = "Exact")
plt.fill_between(np.arange(16), exact_mean + exact_std * sem, exact_mean - exact_std * sem, color = "orange",
                alpha = 0.1)
plt.plot(np.arange(16), var_mean, color = "purple", linewidth = 4, marker = "x", markersize=10, label = "OVC")
plt.fill_between(np.arange(16), var_mean + var_std * sem, var_mean - var_std * sem, color = "purple",
                alpha = 0.1)

# plt.plot([np.mean(x[-1] / x[0]) for x in evals], marker = "x", label = "Exact")
# plt.plot([np.mean(x[-1] / x[0]) for x in var_evals], marker = "x", label = "OVC")

plt.legend()
plt.grid()
plt.xlabel("Rollout Depth")
plt.ylabel("Condition Number")

In [None]:
plt.semilogy([x[0] for x in evals], color = "blue", label = "Exact", marker = "x")
plt.semilogy([x[0] for x in var_evals], color = "orange", label = "OVC", marker = "x")
# plt.legend()
plt.grid()
plt.ylabel("Smallest Eigenvalue")
plt.xlabel("Rollout Depth")

In [None]:
plt.semilogy(exact_train_evals.cpu(), label = "Exact")
plt.semilogy(var_train_evals.cpu(), label = "SVGP")
plt.legend()
plt.ylabel("Eigenvalue")
plt.grid()

In [None]:
torch.save({
    "evals": evals, "var_evals": var_evals, 
    "exact_evals_train": exact_train_evals.cpu(), "var_evals_train": var_train_evals.cpu(),
},
    "conditioning_experiment_100.pt"
)