In [1]:
import pandas as pd
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
from autorocks.viz import plots_setup
import torch
import gpytorch
import botorch.posteriors

class Forrester(botorch.test_functions.SyntheticTestFunction):

    dim = 1
    _bounds = [(0.0, 1.0)]
    _optimal_value = -6.0
    _optimizers = [(0.78)]

    def evaluate_true(self, X: torch.Tensor) -> torch.Tensor:
        return torch.pow(6 * X - 2, 2) * torch.sin(12 * X - 4) 
        
    
problem = Forrester(negate = True)



In [3]:
from botorch import models
from botorch.optim import fit
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import numpy as np
import pandas as pd

plt.style.use("ggplot")
sns.set_theme(style="ticks", rc={"axes.spines.right": False, "axes.spines.top": False})
sns.set_context("paper")  # , font_scale=1.5, rc={"lines.linewidth": 1.5})
plt.rcParams["svg.fonttype"] = "none"
plt.rcParams["font.family"] = "Arial"
plt.rc("text", usetex=False)
plt.rc("xtick", labelsize="large")
plt.rc("ytick", labelsize="large")
plt.rc("axes", labelsize="large")
plt.rc("pdf", use14corefonts=True)



class ParametricMean(gpytorch.means.Mean):
    def __init__(self, input_size, batch_shape=torch.Size()):
        super().__init__()
        self.register_parameter(name="weight_pow", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size, 1)))
        self.register_parameter(name="weight_sin", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size, 1)))
        self.register_parameter(name="bias_pow", parameter=torch.nn.Parameter(torch.randn(*batch_shape, 1)))
        self.register_parameter(name="bias_sin", parameter=torch.nn.Parameter(torch.randn(*batch_shape, 1)))

    def forward(self, x):
        return (torch.pow(x * self.weight_pow + self.bias_pow, 2) * torch.sin(x * self.weight_sin + self.bias_sin)).squeeze(-1) 

mean_functions = {
    "Zero": gpytorch.means.ZeroMean(),
    "Constant": gpytorch.means.ConstantMean(),
    "Linear": gpytorch.means.LinearMean(1),
    "Parametric": ParametricMean(1)
}

res = {}

num_observations = 10
train_x_full, train_y_full = plots_setup.generate_data(num_observations, problem)


In [8]:
# func = "Constant"
predicted = []
for func in mean_functions.keys():
    for i in [1, 3, 7, 9]:
        train_x = train_x_full[:i]
        train_y = train_y_full[:i]
        torch.cuda.empty_cache()
        model = models.FixedNoiseGP(
                train_X=train_x,
                train_Y=train_y,
                train_Yvar=torch.zeros_like(train_y),
                mean_module = mean_functions[func]
        )
        mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)
        fit.fit_gpytorch_scipy(mll)
    
        test_x = torch.linspace(0, 1)
        test_y = problem(test_x)
        cr = plots_setup.predict(model = model, test_x = test_x, observation_noise = False)
    
        lower, upper, mean = cr.lower, cr.upper, cr.mean
        sorted_x = np.argsort(test_x)
        for x_iter, y_iter, lower_iter, upper_iter in zip(test_x[sorted_x].squeeze().cpu().numpy().tolist(),
                                                          mean[sorted_x].squeeze().cpu().numpy().tolist(),
                                                          lower[sorted_x].squeeze().cpu().numpy().tolist(),
                                                          upper[sorted_x].squeeze().cpu().numpy().tolist()): 
            predicted.append({"x":x_iter , 
                              "y": y_iter,
                              "Obs":i,
                              "Mean": func})
            predicted.append({"x":x_iter , 
                              "y": lower_iter,
                              "Obs":i,
                              "Mean": func
                              })
            predicted.append({"x":x_iter , 
                              "y": upper_iter,
                              "Obs":i,
                              "Mean": func
                              })
        
        for (observed_x, observed_y) in zip(train_x.cpu().numpy().tolist(),
                                            train_y.cpu().numpy().tolist()):
            predicted.append({"x_train": observed_x[0],
                              "y_train": observed_y[0],
                              "Obs":i,
                              "Mean": func
                              })
        for (x_truth, y_truth) in zip(train_x_full.squeeze().cpu().numpy().tolist(), 
                                      train_y_full.squeeze().cpu().numpy().tolist()):
            truth_sorted = np.argsort(test_x)
            predicted.append({"x_truth": x_truth,
                              "y_truth": y_truth,
                              "Obs":i,
                              "Mean": func
                              })

import pandas as pd 
df = pd.DataFrame.from_dict(predicted)
df

In [9]:
import numpy as np 
DPI = 300  # default dpi for most printers
plt.style.use("ggplot")
sns.set_theme(style="ticks", rc={"axes.spines.right": False, "axes.spines.top": False})
sns.set_context("paper", font_scale=1.5, rc={"lines.linewidth": 1.5})
plt.rcParams["svg.fonttype"] = "none"
plt.rcParams["font.family"] = "Arial"
plt.rc("text", usetex=False)
plt.rc("xtick", labelsize="small")
plt.rc("ytick", labelsize="small")
plt.rc("axes", labelsize="medium")
plt.rc("pdf", use14corefonts=True)

grid = sns.FacetGrid(df, row="Mean", col="Obs", hue="Mean", palette="hls")
# https://seaborn.pydata.org/examples/many_facets.html
grid = grid.map(sns.lineplot, "x_truth", "y_truth", label="True Function", alpha=0.4, color = "black", linestyle='--')
grid = grid.map(sns.scatterplot, "x_train", "y_train", color='black', marker="*", label="Observed", s=48)
observed_func_legend = grid._legend_data.copy()
grid = grid.map(sns.lineplot, "x", "y", label="prediction", alpha=0.7)
grid.set(xlim=(0, 1), ylim=(-6, 7))
grid.add_legend()

# To save fig use https://stackoverflow.com/questions/10101700/moving-matplotlib-legend-outside-of-the-axis-makes-it-cutoff-by-the-figure-box 

In [10]:
output_location = "./"
output_format = "svg"

grid.savefig(f"{output_location}/mean_funcs.{output_format}", bbox_inches="tight", format=f"{output_format}", dpi=300)