In [None]:
import math
import torch
import matplotlib.pyplot as plt

from botorch.models import SingleTaskGP
from botorch.models.transforms.input import Normalize
from botorch.models.transforms.outcome import Standardize
from botorch.fit import fit_gpytorch_mll
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.acquisition import ExpectedImprovement
from botorch.optim import optimize_acqf

from custom_mean import CustomMean, LinearCalibration

# Ground Truth Functions

In [None]:
def sinusoidal(x):
    """
    Simple sinusoidal function which can be used as ground truth.
    
    Args:
        x (torch.Tensor): A tensor of shape (n_batch, n_samples, n_dim).
    """
    y = torch.sin(2 * torch.pi * x)
    return y


def quadratic(x, max_pos=0.5, max_val=0.5):
    """
    Simple quadratic function which can be used as ground truth.
    
    Args:
        x (torch.Tensor): A tensor of shape (n_batch, n_samples, n_dim).
        max_pos (float): The x-coordinate at the maximum.
        max_val (float): The y-value at the maximum.
    """
    return max_val - 20 * (x - max_pos) ** 2

In [None]:
# define the domain and plot ground truth functions
x_lim = (0.0, 1.0)
test_x = torch.linspace(*x_lim, 100)

fig, ax = plt.subplots(1, 2, figsize=(10, 4))
ax[0].plot(test_x, sinusoidal(test_x), "C0-")
ax[0].set_title("sinusoidal")
ax[1].plot(test_x, quadratic(test_x), "C0-")
ax[1].set_title("quadratic")
for i in range(2):
    ax[i].set_xlim(x_lim)
    ax[i].set_xlabel("x")
    ax[i].set_ylabel("f")
fig.tight_layout()

# Use Mismatched Ground Truth as Prior Mean

In [None]:
class MismatchedGT(torch.nn.Module):
    def __init__(self, ground_truth, **kwargs):
        """
        Prediction of the ground truth function with optional linear mismatches in x and y.
        
        Args:
            ground_truth (Callable): Ground truth function.
        
        Keyword Args:
            x_dim (int): The input dimension. Defaults to 1.
            x_shift (torch.Tensor): A tensor of shape (x_dim). Defaults to zeros.
            x_scale (torch.Tensor): A tensor of shape (x_dim). Defaults to ones.
            y_shift (torch.Tensor): A tensor of shape (1). Defaults to zero.
            y_scale (torch.Tensor): A tensor of shape (1). Defaults to one.
        """
        super(MismatchedGT, self).__init__()
        self.ground_truth = ground_truth
        assert callable(self.ground_truth), f"Expected ground_truth to be callable"
        
        # parameters for mismatch in x
        x_dim = kwargs.get("x_dim", 1)
        x_shape = torch.Size([x_dim])
        self.x_shift = kwargs.get("x_shift", torch.zeros(x_shape))
        assert self.x_shift.shape == x_shape, f"Expected tensor with {x_shape}, but got: {self.x_shift.shape}"
        self.x_scale = kwargs.get("x_scale", torch.ones(x_shape))
        assert self.x_scale.shape == x_shape, f"Expected tensor with {x_shape}, but got: {self.x_scale.shape}"
        
        # parameters for mismatch in y
        y_shape = torch.Size([1])
        self.y_shift = kwargs.get("y_shift", torch.zeros(y_shape))
        assert self.y_shift.shape == y_shape, f"Expected tensor with {y_shape}, but got: {self.y_shift.shape}"
        self.y_scale = kwargs.get("y_scale", torch.ones(y_shape))
        assert self.y_scale.shape == y_shape, f"Expected tensor with {y_shape}, but got: {self.y_scale.shape}"
    
    def forward(self, x):
        """Expects tensor of shape (n_batch, n_samples, n_dim)"""
        mismatched_x = self.x_scale * x + self.x_shift
        y = self.y_scale * self.ground_truth(mismatched_x).squeeze(-1) + self.y_shift
        return y

In [None]:
mismatched_gt = MismatchedGT(
    ground_truth=sinusoidal, 
    x_shift=torch.tensor([0.1]),
    x_scale=torch.tensor([1.1]),
    y_shift=torch.tensor([0.15]),
    y_scale=torch.tensor([1.05]),
)

fig, ax = plt.subplots(1, 1, figsize=(5, 4))
ax.plot(test_x, mismatched_gt(test_x), "C0-", label="prior mean")
ax.plot(test_x, mismatched_gt.ground_truth(test_x), "C3--", label="ground truth")
ax.set_xlim(x_lim)
ax.set_xlabel("x")
ax.set_ylabel("f")
ax.legend()
fig.tight_layout()

# Definition of Custom Means with and without Hyperparameters

In [None]:
# no additional learnable parameters
cm = CustomMean(mismatched_gt, Normalize(1), Standardize(1))

# learnable linear transformations
pcm = LinearCalibration(
    mismatched_gt, 
    Normalize(1, bounds=torch.FloatTensor(x_lim).double().reshape(2, 1)),
    Standardize(1), 
    x_dim=1,
)

custom_means = [cm, pcm]

In [None]:
def create_gp(train_x, train_y, mean_module):
    gp = SingleTaskGP(
        train_x, 
        train_y,
        mean_module=mean_module,
        input_transform=mean_module.input_transformer,
        outcome_transform=mean_module.outcome_transformer,
    )
    gp.likelihood.noise = torch.tensor(1e-4)
    gp.likelihood.noise.requires_grad = False
    return gp

In [None]:
# number of steps for BO
n_steps = 5

# create plotting axes
fig,ax = plt.subplots(2, n_steps, sharex="all", sharey="all")
fig.set_size_inches(12, 5)

# generate intial data
train_x = torch.tensor([0.6]).repeat(2, 1).double()
train_y = mismatched_gt.ground_truth(train_x)

for i in range(n_steps):
    # create GP models
    gps = []
    for j, custom_mean in enumerate(custom_means):
        gp = create_gp(train_x[j].unsqueeze(-1), train_y[j].unsqueeze(-1), custom_mean)
        gps.append(gp)
    
    # maximum likelihood fits
    for gp in gps:
        mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
        fit_gpytorch_mll(mll)
        
    candidates = []
    for j, gp in enumerate(gps):
        # create and optimize acquisition function using the GP
        acq = ExpectedImprovement(gp, train_y[j].max())
        candidate, _ = optimize_acqf(
            acq_function=acq,
            bounds=torch.tensor([0.0, 1.0]).reshape(2, 1),
            q=1,
            num_restarts=5,
            raw_samples=20
        )
        candidates.append(candidate)
        
        # posterior and acquisition function
        with torch.no_grad():
            # get GP posterior
            post = gp.posterior(test_x.unsqueeze(-1))

            # get posterior means and confidence regions
            mean = post.mean.flatten()
            l, u = post.mvn.confidence_region()

            # get acquisition function values
            acq_val = acq(test_x.reshape(-1, 1, 1))

        # plot GP model, data and ground truth
        ax[j, i].plot(test_x, mean, label="posterior mean")
        ax[j, i].fill_between(test_x, l.squeeze(), u.squeeze(), alpha=0.25, label="confidence region")
        ax[j, i].plot(train_x[j], train_y[j], "oC2", zorder=10)
        ax[j, i].plot(test_x, mismatched_gt.ground_truth(test_x), "C3--", zorder=1, label="ground truth")
        ax[j, i].set_xlim(x_lim)
        ax[j, i].set_xticks([0.0, 0.5, 1.0])
        if j == 0:
            ax[j, i].set_title(f"step {i+1}")
        if j == len(gps) - 1:
            ax[j, i].set_xlabel(f"x")
        
        # plot the acquisition function
        ax2 = ax[j, i].twinx()
        ax2.plot(test_x, acq_val, "C1")
        ax2.fill_between(test_x, torch.zeros_like(acq_val), acq_val, alpha=0.5, fc="C1")
        ax2.set_yticklabels([])
        ax2.set_ylim(0, 10.0 * acq_val.max())

        # plot the maximum point of the acquisition function
        ax2.plot(test_x[torch.argmax(acq_val)], acq_val[torch.argmax(acq_val)] * 1.7,
                 marker="v", ms=10)
        
    # add maximization points and resulting observations to training data
    train_x = torch.cat([train_x, torch.tensor(candidates).unsqueeze(-1)], dim=-1)
    train_y = mismatched_gt.ground_truth(train_x)
    
for i in range(2):    
    ax[i, 0].set_ylabel("f")

fig.tight_layout()

# Learned Hyperparameters

In [None]:
if hasattr(custom_mean, "x_shift"):
    print("learned x_shift: {:.2f} ({:.2f})".format(*(-custom_mean.x_shift), *mismatched_gt.x_shift))
if hasattr(custom_mean, "x_scale"):
    print("learned x_scale: {:.2f} ({:.2f})".format(*(1 / custom_mean.x_scale), *mismatched_gt.x_scale))
if hasattr(custom_mean, "y_shift"):
    print("learned y_shift: {:.2f} ({:.2f})".format(
        *(-custom_mean.y_shift), *mismatched_gt.y_shift))
if hasattr(custom_mean, "y_scale"):
    print("learned y_scale: {:.2f} ({:.2f})".format(
        *(1 / custom_mean.y_scale), *mismatched_gt.y_scale))