In [None]:
import math
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

from botorch.models import SingleTaskGP
from botorch.models.transforms.input import Normalize
from botorch.models.transforms.outcome import Standardize
from botorch.fit import fit_gpytorch_mll
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.acquisition import ExpectedImprovement
from botorch.optim import optimize_acqf

from custom_mean import LinearCalibration

# Ground Truth Functions

In [None]:
def unimodal(x, max_pos=(0.5, 0.5), max_val=0.5):
    """
    2d quadratic function which can be used as unimodal ground truth.
    
    Args:
        x (torch.Tensor): A tensor of shape (n_batch, n_samples, n_dim).
        max_pos (tuple): The x-coordinates at the maximum.
        max_val (float): The y-value at the maximum.
    """
    y_0 = math.sqrt(max_val) - (x[..., 0] - max_pos[0]) ** 2
    y_1 = math.sqrt(max_val) - (x[..., 1] - max_pos[1]) ** 2
    return y_0 * y_1


def multimodal(x, f=1.6):
    """
    2d sinusoidal function which can be used as multimodal ground truth.
    
    Args:
        x (torch.Tensor): A tensor of shape (n_batch, n_samples, n_dim).
        f (float): The frequency of oscillation in both dimensions.
    """
    y_0 = x[..., 0] * torch.sin(2 * math.pi * f * x[..., 0])
    y_1 = x[..., 1] * torch.sin(2 * math.pi * f * x[..., 1])
    return y_0 * y_1

In [None]:
# define plotting function
mode = "3d"
cbar_size = 5  # in percent
cbar_pad = 0.15  # in inches


def plot_function(fig, ax, x, y, **kwargs):
    
    # keyword arguments
    x_lim = kwargs.get("x_lim", ((0.0, 1.0), (0.0, 1.0)))
    y_lim = kwargs.get("y_lim", (torch.min(y), torch.max(y)))
    surface_color = kwargs.get("surface_color", "C0")
    title = kwargs.get("title", "")
    
    # reshape data
    mesh_dim = int(math.sqrt(x.shape[0]))
    x_mesh = [x[:, 0].reshape(mesh_dim, mesh_dim),
              x[:, 1].reshape(mesh_dim, mesh_dim)]
    y_mesh = y.reshape(mesh_dim, mesh_dim)
    
    # plot data
    if mode == "2d":
        im = ax.pcolormesh(x_mesh[0], x_mesh[1], y_mesh, vmin=y_lim[0], vmax=y_lim[1],
                           shading='auto')
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size=f"{cbar_size}%", pad=cbar_pad)
        plt.colorbar(im, cax=cax)
    elif mode == "3d":
        ax.plot_surface(*x_mesh, y_mesh, color=surface_color)
        ax.set_zlim(y_lim)
    else:
        raise ValueError(f"Unknown plotting mode: {mode}.")
        
    # axis limits and labels
    ax.set_xlim(x_lim[0])
    ax.set_ylim(x_lim[1])
    ax.set_title(title)

In [None]:
# define the domain and plot exemplary ground truth functions
x_lim = ((0.0, 1.0), (0.0, 1.0))
x_0 = torch.linspace(*x_lim[0], 100)
x_1 = torch.linspace(*x_lim[1], 100)
test_x = torch.cartesian_prod(x_0, x_1)

if mode == "2d":
    ax_size = (5 * (1.0 + cbar_size / 100) + cbar_pad, 5)
    fig, ax = plt.subplots(1, 2, sharex="all", sharey="all", 
                           figsize=(2 * ax_size[0], ax_size[1]))
    plot_function(fig, ax[0], test_x, unimodal(test_x), title="unimodal")
    plot_function(fig, ax[1], test_x, multimodal(test_x), title="multimodal")
elif mode == "3d":
    ax_size = (5, 5)
    fig = plt.figure(figsize=(2 * ax_size[0], ax_size[1]))
    ax = fig.add_subplot(121, projection="3d")
    plot_function(fig, ax, test_x, unimodal(test_x), title="unimodal")
    ax = fig.add_subplot(122, projection="3d")
    plot_function(fig, ax, test_x, multimodal(test_x), title="multimodal")

# Use Mismatched Ground Truth as Prior Mean

In [None]:
class MismatchedGT(torch.nn.Module):
    def __init__(self, ground_truth, **kwargs):
        """
        Prediction of the ground truth function with optional linear mismatches in x and y.
        
        Args:
            ground_truth (Callable): Ground truth function.
        
        Keyword Args:
            x_dim (int): The input dimension. Defaults to 1.
            x_shift (torch.Tensor): A tensor of shape (x_dim). Defaults to zeros.
            x_scale (torch.Tensor): A tensor of shape (x_dim). Defaults to ones.
            y_shift (torch.Tensor): A tensor of shape (1). Defaults to zero.
            y_scale (torch.Tensor): A tensor of shape (1). Defaults to one.
        """
        super(MismatchedGT, self).__init__()
        self.ground_truth = ground_truth
        assert callable(self.ground_truth), f"Expected ground_truth to be callable"
        
        # parameters for mismatch in x
        x_dim = kwargs.get("x_dim", 2)
        x_shape = torch.Size([x_dim])
        self.x_shift = kwargs.get("x_shift", torch.zeros(x_shape))
        assert self.x_shift.shape == x_shape, f"Expected tensor with {x_shape}, but got: {self.x_shift.shape}"
        self.x_scale = kwargs.get("x_scale", torch.ones(x_shape))
        assert self.x_scale.shape == x_shape, f"Expected tensor with {x_shape}, but got: {self.x_scale.shape}"
        
        # parameters for mismatch in y
        y_shape = torch.Size([1])
        self.y_shift = kwargs.get("y_shift", torch.zeros(y_shape))
        assert self.y_shift.shape == y_shape, f"Expected tensor with {y_shape}, but got: {self.y_shift.shape}"
        self.y_scale = kwargs.get("y_scale", torch.ones(y_shape))
        assert self.y_scale.shape == y_shape, f"Expected tensor with {y_shape}, but got: {self.y_scale.shape}"
    
    def forward(self, x):
        """Expects tensor of shape (n_batch, n_samples, n_dim)"""
        mismatched_x = self.x_scale * x + self.x_shift
        y = self.y_scale * self.ground_truth(mismatched_x) + self.y_shift
        return y

In [None]:
mismatched_gt = MismatchedGT(
    ground_truth=lambda x: unimodal(x, max_pos=(0.4, 0.6)),
    x_shift=torch.tensor([0.05, 0.05]),
    x_scale=torch.tensor([0.95, 1.05]),
    y_shift=torch.tensor([-0.05]),
    y_scale=torch.tensor([1.1]),
)

if mode == "2d":
    ax_size = (5 * (1.0 + cbar_size / 100) + cbar_pad, 5)
    fig, ax = plt.subplots(1, 2, sharex="all", sharey="all", 
                           figsize=(2 * ax_size[0], ax_size[1]))
    plot_function(fig, ax[0], test_x, mismatched_gt.ground_truth(test_x), title="ground truth")
    plot_function(fig, ax[1], test_x, mismatched_gt(test_x), title="prior mean")
elif mode == "3d":
    ax_size = (5, 5)
    fig = plt.figure(figsize=(2 * ax_size[0], ax_size[1]))
    ax = fig.add_subplot(121, projection="3d")
    plot_function(fig, ax, test_x, mismatched_gt.ground_truth(test_x), title="ground truth")
    ax = fig.add_subplot(122, projection="3d")
    plot_function(fig, ax, test_x, mismatched_gt(test_x), title="prior mean")

# Definition of Custom Mean

In [None]:
custom_mean = LinearCalibration(
    mismatched_gt, 
    Normalize(2, bounds=torch.FloatTensor(x_lim).double().T), 
    Standardize(1), 
    x_dim=2,
    y_dim=1,
)

In [None]:
def create_gp(train_x, train_y, mean_module):
    gp = SingleTaskGP(
        train_x, train_y,
        mean_module=mean_module,
        input_transform=mean_module.input_transformer,
        outcome_transform=mean_module.outcome_transformer,
    )
    gp.likelihood.noise = torch.tensor(1e-4)
    gp.likelihood.noise.requires_grad = False
    return gp

In [None]:
# number of steps for BO
n_steps = 10

# create plotting figure
if mode == "2d":
    ax_size = (5 * (1.0 + cbar_size / 100) + cbar_pad, 5)
    fig, ax = plt.subplots(n_steps, 3, sharex="all", sharey="all", figsize=(3 * ax_size[0], n_steps * ax_size[1]))
elif mode == "3d":
    ax_size = (5, 5)
    fig = plt.figure(figsize=(3 * ax_size[0], n_steps * ax_size[1]))

# generate intial data
train_x = torch.tensor([[0.55, 0.6], [0.6, 0.55]]).double().reshape(2, 2)
train_y = mismatched_gt.ground_truth(train_x).reshape(2, 1)

for i in range(n_steps):
    # create GP model
    gp = create_gp(train_x, train_y, custom_mean)
    
    # maximum likelihood fits
    mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
    fit_gpytorch_mll(mll)
        
    # create and optimize acquisition function using the GP
    acq = ExpectedImprovement(gp, train_y.max())
    candidate, _ = optimize_acqf(
        acq_function=acq,
        bounds=torch.FloatTensor(x_lim).double().T,
        q=1,
        num_restarts=5,
        raw_samples=20
    )
        
    # posterior and acquisition function
    with torch.no_grad():
        # get GP posterior
        post = gp.posterior(test_x)

        # get posterior means and confidence regions
        mean = post.mean.flatten()
        l, u = post.mvn.confidence_region()

        # get acquisition function values
        acq_val = acq(test_x.reshape(-1, 1, 2))
        norm_acq = acq_val / torch.max(acq_val)

        # plot GP model, data and ground truth
        if mode == "2d":
            plot_function(fig, ax[i, 0], test_x, mismatched_gt.ground_truth(test_x), title="ground truth")
            plot_function(fig, ax[i, 1], test_x, mean, title="prior mean")
            ax[i, 1].plot(train_x[:, 0].numpy(), train_x[:, 1].numpy(), "C3x", markersize=10, markeredgewidth=3.0)
            plot_function(fig, ax[i, 2], test_x, norm_acq, title="acquisition function")
            ax[i, 2].plot(*candidate[0], "C3x", markersize=10, markeredgewidth=3.0)
        elif mode == "3d":
            ax = fig.add_subplot(n_steps, 3, 3 * i + 1, projection="3d")
            plot_function(fig, ax, test_x, mismatched_gt.ground_truth(test_x), surface_color="C3", 
                          title="ground truth")
            ax = fig.add_subplot(n_steps, 3, 3 * i + 2, projection="3d")
            plot_function(fig, ax, test_x, mean, surface_color="C0", title="prior mean")
            ax.plot(train_x[:, 0].numpy(), train_x[:, 1].numpy(), train_y[:, 0].numpy(), 
                    "C1v", markeredgecolor="k", markersize=10, zorder=10)
            ax = fig.add_subplot(n_steps, 3, 3 * i + 3, projection="3d")
            plot_function(fig, ax, test_x, norm_acq, surface_color="C1", title="acquisition function")
            ax.plot(*candidate[0], 1.0, "C0v", markeredgecolor="k", markersize=10, zorder=10)

    # add maximization points and resulting observations to training data
    train_x = torch.cat([train_x, candidate], dim=0)
    train_y = mismatched_gt.ground_truth(train_x).unsqueeze(dim=-1)

fig.tight_layout()

# Learned Hyperparameters

In [None]:
if hasattr(custom_mean, "x_shift"):
    print("learned x_shift: [{:.2f}, {:.2f}] ([{:.2f}, {:.2f}])".format(
        *(-custom_mean.x_shift), *mismatched_gt.x_shift))
if hasattr(custom_mean, "x_scale"):
    print("learned x_scale: [{:.2f}, {:.2f}] ([{:.2f}, {:.2f}])".format(
        *(1 / custom_mean.x_scale), *mismatched_gt.x_scale))
if hasattr(custom_mean, "y_shift"):
    print("learned y_shift: {:.2f} ({:.2f})".format(
        *(-custom_mean.y_shift), *mismatched_gt.y_shift))
if hasattr(custom_mean, "y_scale"):
    print("learned y_scale: {:.2f} ({:.2f})".format(
        *(1 / custom_mean.y_scale), *mismatched_gt.y_scale))