# Imports

In [None]:
%matplotlib inline
from matplotlib import colors
from pyDOE import lhs
import matplotlib.pyplot as plt
import numpy as np
import time
import os
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from itertools import product
from multifidelityfunctions import oneDimensional as OD
from multifidelityfunctions import MultiFidelityFunction, row_vectorize
from multiLevelCoSurrogates import CandidateArchive, Surrogate, HierarchicalSurrogate, MultiFidelityBO, create_random_sample_set
from multiLevelCoSurrogates.bo import ScatterPoints
from multiLevelCoSurrogates.Utils import Surface, createsurface, plotsurfaces, plotsurfaceonaxis, plotcmaponaxis
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.ensemble import RandomForestRegressor

np.random.seed(20160501)  # Setting seed for reproducibility

Print settings/helpers

In [None]:
from IPython.core.display import clear_output
from pprint import pprint
np.set_printoptions(linewidth=200)
plot_dir = '../multiLevelCoSurrogates/plots/'

# Recreating the example plot in [Forrester2007 (Multi-fidelity optimization via surrogate modelling)](https://royalsocietypublishing.org/doi/full/10.1098/rspa.2007.1900)

<img src="https://royalsocietypublishing.org/cms/attachment/efa57e07-5384-4503-8b2b-ccbe632ffe87/3251fig1.jpg" alt="Forrester2007 example plot" width="400"/>

## Step by step construction

The function in question:

In [None]:
plot_x = np.linspace(start=0,stop=1,num=501).reshape(-1,1)

low_x = np.linspace(0,1,11).reshape(-1,1)
high_x = low_x[[0,4,6,10]]

In [None]:
plot_high = OD.high(plot_x)
plot_low = OD.low(plot_x)

plt.plot(plot_x, plot_high, label='high')
plt.plot(plot_x, plot_low, label='low')
plt.legend(loc=1)
plt.show()

Showing the datapoints selected by the paper.

In [None]:
high_y = OD.high(high_x)
low_y = OD.low(low_x)

line, = plt.plot(plot_x, plot_high, label='high')
plt.scatter(high_x, high_y, color=line.get_color())
line, = plt.plot(plot_x, plot_low, label='low')
plt.scatter(low_x, low_y, color=line.get_color())
plt.legend(loc=1)
plt.show()

Training Gaussian Process models for each fidelity exclusively. Low-fidelity is a good fit, high fidelity is not.

In [None]:
gp_direct = GaussianProcessRegressor()
gp_direct.fit(high_x, high_y)

gp_low = GaussianProcessRegressor()
gp_low.fit(low_x, low_y)

line, = plt.plot(plot_x, plot_high, label='high')
plt.scatter(high_x, high_y, color=line.get_color())
line, = plt.plot(plot_x, plot_low, label='low')
plt.scatter(low_x, low_y, color=line.get_color())
plt.plot(plot_x, gp_direct.predict(plot_x), label='high-fit GP')
plt.plot(plot_x, gp_low.predict(plot_x), label='low-fit GP')
plt.legend(loc=1)
plt.show()

Co-Kriging formulation is $\hat{f}_h(x) = \rho * f_l(x) + \delta(x)$. <br>
$\hat{f}_h(x)$ is the high-fidelity prediction at $x$<br>
$\rho$ is a scaling factor<br>
$f_l(x)$ is a low-fidelity information input (either actual or another model) at $x$<br>
$\delta(x)$ is a prediction for the difference between $f_h(x)$ and $\rho * f_l(x)$<br>

$\rho$ is calculated as $1 / (1/n)\Sigma_{i=1}^n f_h(x_i) / f_l(x_i)$, i.e. `1/mean(f_high(x_high) / f_low(x_high))` with `x_high` being all input for which we have high-fidelity outcomes.

Here we start by plotting just the parts of this equation.<br>
In this example, there is an explicit scaling factor of __2__ between high and low fidelity that is seen to be easily captured by the difference model $\delta(x)$, i.e. `gp_diff`

In [None]:
low_at_high = np.array(OD.low([x for x in high_x])).reshape(-1,1)
scale = 1/np.mean(high_y / low_at_high)

diff_x = high_x
diff_y = np.array([(OD.high(x) - scale*OD.low(x)) for x in diff_x])
gp_diff = GaussianProcessRegressor()
gp_diff.fit(diff_x, diff_y)

line, = plt.plot(plot_x, plot_high, label='high')
plt.scatter(high_x, high_y, color=line.get_color())
line, = plt.plot(plot_x, plot_low, label='low')
plt.scatter(low_x, low_y, color=line.get_color())
plt.plot(plot_x, gp_direct.predict(plot_x), label='high-fit GP')
plt.plot(plot_x, gp_low.predict(plot_x), label='low-fit GP')
plt.plot(plot_x, plot_high - plot_low, label='diff')
plt.plot(plot_x, gp_diff.predict(plot_x), label='scaled diff-fit GP')
plt.legend(loc=1)
plt.show()

In [None]:
scale

The `scale` parameter here is an estimate based on the datapoints we have. For this example with only four high-fidelity points, this is a reasonable, but not exact fit. The actual value according to the function definition should be 2, and the value stated by the paper to match best in the x-range [0,1] is 1.87.

And now with the actual co-kriging prediction plotted.

In [None]:
co_y = lambda x: scale*gp_low.predict(x) + gp_diff.predict(x)

line, = plt.plot(plot_x, plot_high, label='high')
plt.scatter(high_x, high_y, color=line.get_color())
line, = plt.plot(plot_x, plot_low, label='low')
plt.scatter(low_x, low_y, color=line.get_color())
plt.plot(plot_x, gp_direct.predict(plot_x), label='high-fit GP')
plt.plot(plot_x, gp_low.predict(plot_x), label='low-fit GP')
plt.plot(plot_x, co_y(plot_x), label='co-kriging')
plt.legend(loc=1)
plt.show()

## Direct construction with (Hierarchical)Surrogate

Recreating the same plot as above using our own (Hierarchical)Surrogate interface.

In [None]:
# Archive only has to be created once...
archive = CandidateArchive(ndim=1, fidelities=['high', 'low', 'high-low'])
archive.addcandidates(low_x, low_y, fidelity='low')
archive.addcandidates(high_x, high_y, fidelity='high')

### Without normalization by Surrogate

In [None]:
surr_high = Surrogate.fromname('Kriging', archive, fidelity='high', normalized=False)
surr_low = Surrogate.fromname('Kriging', archive, fidelity='low', normalized=False)
surr_hier = HierarchicalSurrogate('Kriging', surr_low, archive, ['high', 'low'], normalized=False)

surr_high.train()
surr_low.train()
surr_hier.train()

# Plotting
plt.plot(plot_x, OD.high(plot_x), label='high')
plt.plot(plot_x, OD.low(plot_x), label='low')
plt.plot(plot_x, surr_high.predict(plot_x), label='high-fit GP')
plt.plot(plot_x, surr_low.predict(plot_x), label='low-fit GP')
plt.plot(plot_x, surr_hier.predict(plot_x), label='co-kriging')
plt.legend(loc=0)
plt.tight_layout()
plt.show()

### With normalization by Surrogate

Just to show that the normalization is correctly implemented.<br>
Because of the values in this example, it's not really needed, but if the results at least don't get worse in this case, it's probably correct.

In [None]:
surr_high = Surrogate.fromname('Kriging', archive, fidelity='high', normalized=True)
surr_low = Surrogate.fromname('Kriging', archive, fidelity='low', normalized=True)
surr_hier = HierarchicalSurrogate('Kriging', surr_low, archive, ['high', 'low'], normalized=True)

surr_high.train()
surr_low.train()
surr_hier.train()

# Plotting
plt.plot(plot_x, OD.high(plot_x), label='high')
plt.plot(plot_x, OD.low(plot_x), label='low')
plt.plot(plot_x, surr_high.predict(plot_x), label='high-fit GP')
plt.plot(plot_x, surr_low.predict(plot_x), label='low-fit GP')
plt.plot(plot_x, surr_hier.predict(plot_x), label='co-kriging')
plt.legend(loc=0)
plt.tight_layout()
plt.show()

## Direct construction with MultiFidelityBO

Recreating the same plot again with the MultiFidelityBO (Bayesian Optimization) interface.<br>
This interface automatically creates a full set of hierarchical models for any number of fidelities.

In [None]:
mfbo = MultiFidelityBO(OD, archive, output_range=(-10, 16))

# Plotting
plt.plot(plot_x, OD.high(plot_x), label='high')
plt.plot(plot_x, OD.low(plot_x), label='low')
plt.plot(plot_x, mfbo.direct_models['high'].predict(plot_x), label='high-fit GP')
plt.plot(plot_x, mfbo.models['low'].predict(plot_x), label='low-fit GP')
plt.plot(plot_x, mfbo.models['high'].predict(plot_x), label='co-kriging')
plt.legend(loc=0)
plt.tight_layout()
plt.savefig(f'{plot_dir}forrester2007_recreated')
plt.show()

## Making the match exact

We make two changes to the procedure to really recreate the plot:
 1. Using $f_l(x)$ directly rather than model $\hat{f}_l(x)$
 2. Using better scaling values. `1.87` gives the match seen in the original picture, while `2` gives a perfect match 

The first change should actually be used too. If predicting some $\hat{f}_h(x)$ value for a completely new point $x$, then obviously the lower-fidelity models are the only available source of information. But when selecting which point to evaluate in higher fidelity, the exact lower fidelity information is usually available and can therefore be used.

The value `1.87` comes from taking the mean over the entire range rather than just the 4 common datapoints we have, while the value `2` is derived from the function definition.

In [None]:
1/np.mean(plot_high/plot_low)

In [None]:
gp_diff_20 = GaussianProcessRegressor().fit(diff_x, np.array([(OD.high(x) - 2*OD.low(x)) for x in diff_x]))
gp_diff_187 = GaussianProcessRegressor().fit(diff_x, np.array([(OD.high(x) - 1.87*OD.low(x)) for x in diff_x]))

cokriging_y_20 = lambda x: 2*OD.low(x) + gp_diff_20.predict(x)
cokriging_y_187 = lambda x: 1.87*OD.low(x) + gp_diff_187.predict(x)

line, = plt.plot(plot_x, plot_high, label='high')
plt.scatter(high_x, high_y, color=line.get_color())
line, = plt.plot(plot_x, plot_low, label='low')
plt.scatter(low_x, low_y, color=line.get_color())
plt.plot(plot_x, gp_direct.predict(plot_x), label='high-fit GP')
plt.plot(plot_x, gp_low.predict(plot_x), label='low-fit GP')
plt.plot(plot_x, cokriging_y_20(plot_x), label='co-kriging (2)')
plt.plot(plot_x, cokriging_y_187(plot_x), label='co-kriging (1.87)')
plt.legend(loc=0)
plt.tight_layout()
plt.savefig(f'{plot_dir}accurate_forrester2007.png')
plt.savefig(f'{plot_dir}accurate_forrester2007.pdf')
plt.show()

### Side by side comparison
<img src="https://royalsocietypublishing.org/cms/attachment/efa57e07-5384-4503-8b2b-ccbe632ffe87/3251fig1.jpg" alt="Forrester2007 example plot" width="362"/><img src="../multiLevelCoSurrogates/plots/accurate_forrester2007.png" alt="Recreated Forrester2007 example plot"/>

# Trade-off heatmap: number of high- vs. low-fidelity points

This section covers an experiment about the influence of low-fidelity points is in a co-surrogate setup.

Let $n_L$ be the number of low-fidelity points and $n_H$ the number of high-fidelity points. Create a sample $x_L$ of $n_L$ points using some initial sampling method (random, LHS, grid, etc), and take from that a subsample $x_H \subset x_L$ through some heuristic (maximal distance, random, etc). Then we train a number of models:
 - direct low-fidelity model using $x_L, f_L(x_L)$ only
 - direct high-fidelity model using $x_H, f_H(x_H)$ only
 - hierarchical high-fidelity model using both $x_L, f_L(x_L)$ and $x_H, f_H(x_H)$
 
Independently, a function-dependent sample $x_{mse}$ of size 1000 is also created. This sample is used to calculate a Mean Squared Error (MSE) value for the state of a model after training.

For the experiments, we examine all combinations for $n_L \in 3, \ldots, 100$ and $n_H \in 2, \ldots, 40$, with the restriction that $n_L > n_H$. Each combination is repeated 30 times.

In [None]:
max_high = 10
max_low = 10
num_reps = 1

In [None]:
def low_random_sample(ndim, nlow):
    return np.random.rand(nlow, ndim)

def low_lhs_sample(ndim, nlow):
    if ndim == 1:
        return np.linspace(0,1,nlow).reshape(-1,1)
    elif ndim > 1:
        return lhs(ndim, nlow)

In [None]:
def create_mse_tracking(func, sample_generator):
    ndim = func.ndim
    mse_tracking = np.empty((max_high+1, max_low+1, num_reps, 3))
    mse_tracking[:] = np.nan
    cases = list(product(range(2, max_high+1), range(3, max_low+1), range(num_reps)))

    for idx, case in enumerate(cases):
        num_high, num_low, rep = case

        if num_high >= num_low:
            continue
        if idx % 100 == 0:
            clear_output()
            print(f'{idx}/{len(cases)}')

        low_x = sample_generator(ndim, num_low)
        high_x = low_x[np.random.choice(num_low, num_high, replace=False)]
        
        archive = CandidateArchive(ndim=ndim, fidelities=['high', 'low', 'high-low'])
        archive.addcandidates(low_x, func.low(low_x), fidelity='low')
        archive.addcandidates(high_x, func.high(high_x), fidelity='high')

        mfbo = MultiFidelityBO(func, archive, output_range=(-10, 16))
        mse_tracking[num_high, num_low, rep] = mfbo.getMSE()

    clear_output()
    print(f'{len(cases)}/{len(cases)}')
    return mse_tracking

In [None]:
def plot_high_vs_low_num_samples(data, name, vmin=.5, vmax=100):
    norm = colors.LogNorm(vmin=vmin, vmax=vmax, clip=True)
    fig, axes = plt.subplots(figsize=(9,9), sharey=True)
    plt.subplot(311)
    plt.title('high (hierarchical)')
    img = plt.imshow(data[:,:,0], cmap='viridis_r')
    img.set_norm(norm)
    plt.subplot(312)
    plt.title('high (direct)')
    img = plt.imshow(data[:,:,1], cmap='viridis_r')
    img.set_norm(norm)
    plt.subplot(313)
    plt.title('low (direct)')
    img = plt.imshow(data[:,:,2], cmap='viridis_r')
    img.set_norm(norm)
    fig.text(0.14, 0.5, '#High-fid samples', ha='center', va='center', rotation='vertical')
    fig.colorbar(img, ax=axes, orientation='vertical')
    plt.xlabel('#Low-fid samples')
    plt.tight_layout()
    plt.savefig(f'{plot_dir}{name}.pdf')
    plt.show()

## Random Sample generation

In [None]:
if 'mse_tracking.npy' in os.listdir('.'):
    mse_tracking = np.load('mse_tracking.npy')
else:
    mse_tracking = create_mse_tracking(OD, low_random_sample)
    np.save('mse_tracking.npy', mse_tracking)

In [None]:
print('median')
pprint([(f'{95+i}%-ile', np.percentile(np.median(mse_tracking, axis=2).flatten(), 95+i)) for i in range(6)])

In [None]:
plot_high_vs_low_num_samples(np.median(mse_tracking, axis=2), 'high-low-samples-random')

## Linspace, random subsample generation

In [None]:
if 'lin_mse_tracking.npy' in os.listdir('.'):
    lin_mse_tracking = np.load('lin_mse_tracking.npy')
else:
    lin_mse_tracking = create_mse_tracking(OD, low_lhs_sample)
    np.save('lin_mse_tracking.npy', lin_mse_tracking)

In [None]:
print('median')
pprint([(f'{95+i}%-ile', np.percentile(np.median(lin_mse_tracking, axis=2).flatten(), 95+i)) for i in range(6)])

In [None]:
plot_high_vs_low_num_samples(np.median(lin_mse_tracking, axis=2), 'high-low-samples-linear')

The top-most plot here shows some strange vertical features that seem to occur when the number of low-fidelity sample points is a multiple of 8. We still have to look at what causes this...

### Todo
 - [ ] high (direct) and low (direct) plots have no reason to be a 2d heatmap: only one value has any impact anyway. Any other effect is simply through random sampling
 - [ ] For linear low-fidelity sampling: consider changing random to exhaustive sub-sample sets for small values of $n_L$ and $n_H$
 - [ ] Plot MSE difference between high (direct) and high (hierarchical) to show accuracy difference
 - [ ] Investigate factor-of-eight vertical lines of bad performance

# EGO - 1D function

First creating an inverted function as BO is currently hardcoded for maximization problems

In [None]:
inv_OD = MultiFidelityFunction(
    u_bound=np.array(OD.u_bound), l_bound=np.array(OD.l_bound),
    functions=[lambda x: -OD.high(x), lambda x: -OD.low(x)],
    fidelity_names=['high', 'low'],
)

In [None]:
low_x = np.linspace(0,1,6).reshape((-1,1))
high_x = low_x[[2,3]].reshape((-1,1))

archive = CandidateArchive(ndim=1, fidelities=['high', 'low', 'high-low'])
archive.addcandidates(low_x, inv_OD.low(low_x), fidelity='low')
archive.addcandidates(high_x, inv_OD.high(high_x), fidelity='high')

np.random.seed(20160501)
mfbo = MultiFidelityBO(inv_OD, archive, output_range=(-16, 10), schema=[1,1])

# Plotting
fig, axes = plt.subplots(3,3, figsize=(12,9))

for idx, ax in enumerate(axes.flatten()):    
    line_1, = ax.plot(plot_x, inv_OD.high(plot_x), label='high')
    line_2, = ax.plot(plot_x, inv_OD.low(plot_x), label='low')
    line_high, = ax.plot(plot_x, mfbo.direct_models['high'].predict(plot_x), label='high-fit GP')
    line_hier, = ax.plot(plot_x, mfbo.models['high'].predict(plot_x), label='co-kriging')
    scat_2 = ax.scatter(*archive.getcandidates(fidelity='low'), color=line_2.get_color())
    scat_1 = ax.scatter(*archive.getcandidates(fidelity='high'), color=line_1.get_color())
    ax.fill_between(plot_x.flatten(), plot_hier - 3*std_hier, plot_hier + 3*std_hier, alpha=.25, color=line_hier.get_color())
    
    ax2 = ax.twinx()
    line_acq, = ax2.plot(plot_x, mfbo.utility(plot_x, gp=mfbo.models['high'], y_max=archive.max['high']),
                         alpha=.5, label='acq', color='C4')
    line_std_low, = ax2.plot(plot_x, mfbo.models['low'].predict(plot_x, mode='std'),
                             alpha=.5, label='std low', color='C5', ls='--', )
    line_std_diff, = ax2.plot(plot_x, mfbo.models['high'].diff_model.predict(plot_x, mode='std'),
                              alpha=.5, label='std diff', color='C6', ls=':', )
    ax2.set_ylim(bottom=0)    
    
    lines = [line_1, line_2, line_high, line_hier, line_std_low, line_std_diff, line_acq]
    
    ax.set_title(f'Iteration {idx}')
    ax.set_xlim([0, 1])
    ax.set_ylim([-16, 10])
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax2.tick_params('y', colors='#555555')
    
    mfbo.iteration(idx)

plt.tight_layout()
plt.subplots_adjust(bottom=0.1)
ax = fig.add_axes([0,0, 1,.05])
ax.axis('off')
ax.legend(
    lines, [l.get_label() for l in lines], ncol=4,
    loc='upper center', bbox_to_anchor=(.5,1),
)
    
plt.savefig(f'{plot_dir}1D_BO.png')
plt.savefig(f'{plot_dir}1D_BO.pdf')
plt.show()

# Testing models on 2D functions

In [None]:
# defining some point styles
red_dot = {'marker': '.', 'color': 'red'}
blue_circle = {'marker': 'o', 'facecolors': 'none', 'color': 'blue'}

## 2D version of 1D function

### Creating the 2D function

In [None]:
@row_vectorize
def td_inv_high(xx):
    x1, x2 = xx
    return -(OD.high(x1) + OD.high(x2))

@row_vectorize
def td_inv_low(xx):
    x1, x2 = xx
    return -(OD.low(x1) + OD.low(x2))


TD_inv = MultiFidelityFunction(
    u_bound=np.array(OD.u_bound*2), l_bound=np.array(OD.l_bound*2),
    functions=[td_inv_high, td_inv_low],
    fidelity_names=['high', 'low'],
)

In [None]:
u_bound, l_bound = TD_inv.u_bound, TD_inv.l_bound
steps = [.025, .025]
surf_high = createsurface(TD_inv.high, u_bound=u_bound, l_bound=l_bound, step=steps)
surf_low = createsurface(TD_inv.low, u_bound=u_bound, l_bound=l_bound, step=steps)

In [None]:
plotsurfaces([surf_high, surf_low])

### Creating models

In [None]:
n_low = 16
n_high = 6

In [None]:
def create_models_and_compare(func, low, high):
    archive = CandidateArchive(ndim=2, fidelities=['high', 'low', 'high-low'])
    archive.addcandidates(low, func.low(low), fidelity='low')
    archive.addcandidates(high, func.high(high), fidelity='high')

    mfbo = MultiFidelityBO(func, archive, output_range=(-16, 10), schema=[1,1])

    surf_high_model = createsurface(mfbo.models['high'].predict, u_bound=u_bound, l_bound=l_bound, step=steps)
    surf_low_model = createsurface(mfbo.models['low'].predict, u_bound=u_bound, l_bound=l_bound, step=steps)

    points_high = [ScatterPoints(*archive.getcandidates(fidelity='high'), red_dot)]
    points_low = [ScatterPoints(*archive.getcandidates(fidelity='low'), blue_circle)]

    points = [
        points_high, points_low,
        points_high, points_low,
    ]

    plotsurfaces([surf_high, surf_low, surf_high_model, surf_low_model], shape=(2,2), titles=['high', 'low', 'high (model)', 'low (model)'], all_points=points)

As a first attempt, let's simply use all 2d combinations of the coordinates originally used for the example figure

In [None]:
high_xy = np.array(list(product(high_x.flatten(), repeat=2)))
low_xy =  np.array(list(product(low_x.flatten(), repeat=2)))

create_models_and_compare(TD_inv, low_xy, high_xy)

### With a random sample

In [None]:
low_xy = low_random_sample(ndim=2, nlow=n_low)
high_xy = low_xy[np.random.choice(n_low, n_high, replace=False)]

create_models_and_compare(TD_inv, low_xy, high_xy)

### With an LHS

In [None]:
low_xy = low_lhs_sample(ndim=2, nlow=n_low)
high_xy = low_xy[np.random.choice(n_low, n_high, replace=False)]

create_models_and_compare(TD_inv, low_xy, high_xy)

## MSE errors per sample size combination

### Random sampling

In [None]:
if '2d_mse_tracking.npy' in os.listdir('.'):
    mse_tracking = np.load('2d_mse_tracking.npy')
else:
    mse_tracking = create_mse_tracking(TD_inv, low_random_sample)
    np.save('2d_mse_tracking.npy', mse_tracking)

In [None]:
plot_data = np.median(mse_tracking, axis=2)

print('median')
pprint([(f'{95+i}%-ile', np.percentile(np.median(mse_tracking, axis=2).flatten(), 95+i)) for i in range(6)])

In [None]:
plot_high_vs_low_num_samples(plot_data, 'high-low-samples-random', vmax=10000)

### LHS

In [None]:
if '2d_lin_mse_tracking.npy' in os.listdir('.'):
    lin_mse_tracking = np.load('2d_lin_mse_tracking.npy')
else:
    lin_mse_tracking = create_mse_tracking(TD_inv, low_lhs_sample)
    np.save('2d_lin_mse_tracking.npy', lin_mse_tracking)

In [None]:
lin_plot_data = np.median(lin_mse_tracking, axis=2)

print('median')
pprint([(f'{95+i}%-ile', np.percentile(np.median(lin_mse_tracking, axis=2).flatten(), 95+i)) for i in range(6)])

In [None]:
plot_high_vs_low_num_samples(lin_plot_data, 'high-low-samples-linear', vmax=3000)

# EGO on 2D functions

In [None]:
high_xy = np.array(list(product(high_x.flatten(), repeat=2)))
low_xy =  np.array(list(product(low_x.flatten(), repeat=2)))

archive = CandidateArchive(ndim=2, fidelities=['high', 'low', 'high-low'])
archive.addcandidates(low_xy, TD_inv.low(low_xy), fidelity='low')
archive.addcandidates(high_xy, TD_inv.high(high_xy), fidelity='high')

mfbo = MultiFidelityBO(TD_inv, archive, output_range=(-16, 10), schema=[2,1])

fig, axes = plt.subplots(3,3, figsize=(18, 20))#, subplot_kw={'projection': '3d'})
for idx, ax in enumerate(axes.flatten()):
    surface = createsurface(mfbo.models['high'].predict, u_bound=u_bound, l_bound=l_bound, step=steps)
    title = f'high model - iteration {idx}'
    points = [
        ScatterPoints(*archive.getcandidates(fidelity='high'), red_dot),
        ScatterPoints(*archive.getcandidates(fidelity='low'), blue_circle),
    ]
    #plotsurfaceonaxis(ax, surface, title, points)
    plotcmaponaxis(ax, surface, title, points)
    mfbo.iteration(idx)
    mfbo.iteration(idx)
plt.tight_layout()
plt.savefig(f'{plot_dir}2D_BO.pdf')
plt.savefig(f'{plot_dir}2D_BO.png')
plt.show()

# Extension to 3 fidelities