In [None]:
import torch
import warnings
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
import utilities as util

from botorch.models.transforms.input import Normalize
from botorch.models.transforms.outcome import Standardize
from custom_mean import LinearCalibration
from optimizer import BayesOpt

In [None]:
# load surrogate and define objective
surrogate = util.Surrogate()
Objective = util.NegativeTransverseBeamSize
ground_truth = Objective(surrogate.model)

In [None]:
# define custom mean
x_shift = torch.zeros(surrogate.x_dim)
x_scale = torch.ones(surrogate.x_dim)
y_shift = torch.tensor([0.2])
y_scale = torch.tensor([1.0])

mismatched_gt = util.MismatchedGroundTruth(
    x_dim=surrogate.x_dim,
    ground_truth=ground_truth,
    x_shift=x_shift,
    x_scale=x_scale,
    y_shift=y_shift,
    y_scale=y_scale,
)

custom_mean = LinearCalibration(
    mismatched_gt,
    Normalize(surrogate.x_dim, bounds=surrogate.x_lim.T),
    Standardize(1),
    x_dim=surrogate.x_dim,
    y_dim=1,
)

In [None]:
# run Bayesian optimization
bo = BayesOpt(surrogate, ground_truth, n_init=5, n_step=25)
bo.run(custom_mean)

In [None]:
# plot optmization sequence
bo.plot_running_max();

In [None]:
# plot sample distribution
bo.plot_sample_distribution();

In [None]:
# evaluate input calibration
if hasattr(custom_mean, "x_shift") and hasattr(custom_mean, "x_scale"):
    print("{:<5s} {:>10s} {:>10s} {:>10s} {:>10s}".format("x_dim", "x_shift", "learned", "x_scale", "learned"))
    for i in range(custom_mean.x_dim):
        x_shifts = (mismatched_gt.x_shift.detach()[i], -custom_mean.x_shift.detach()[i])
        x_scales = (mismatched_gt.x_scale.detach()[i], 1 / custom_mean.x_scale.detach()[i])
        print("{:<5d} {:10.2f} {:10.2f} {:10.2f} {:10.2f}".format(i, *x_shifts, *x_scales))
        if i == custom_mean.x_dim - 1: print()

# evaluate output calibration
if hasattr(custom_mean, "y_shift") and hasattr(custom_mean, "y_scale"):
    print("{:<5s} {:>10s} {:>10s} {:>10s} {:>10s}".format("y_dim", "y_shift", "learned", "y_scale", "learned"))
    for i in range(custom_mean.y_dim):
        y_shifts = (mismatched_gt.y_shift.detach()[i], -custom_mean.y_shift.detach()[i])
        y_scales = (mismatched_gt.y_scale.detach()[i], 1 / custom_mean.y_scale.detach()[i])
        print("{:<5d} {:10.2f} {:10.2f} {:10.2f} {:10.2f}".format(i, *y_shifts, *y_scales))