In [None]:
%matplotlib inline
from matplotlib import colors
from mpl_toolkits.axes_grid1 import make_axes_locatable   
from pyDOE import lhs
import matplotlib.pyplot as plt
import numpy as np
import time
import os
import sys
import warnings

module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from itertools import product
from more_itertools import chunked
from functools import partial
import multifidelityfunctions as mff
import multiLevelCoSurrogates as mlcs
from sklearn.gaussian_process import GaussianProcessRegressor, kernels
from sklearn.ensemble import RandomForestRegressor

np.random.seed(20160501)  # Setting seed for reproducibility
OD = mff.oneDimensional

from IPython.core.display import clear_output
from pprint import pprint
np.set_printoptions(linewidth=200, edgeitems=10, precision=4, suppress=True)
plot_dir = '../../plots/'
data_dir = '../../files/'

# Trade-off heatmap: number of high- vs. low-fidelity points

This section covers an experiment about the influence of low-fidelity points is in a co-surrogate setup.

Let $n_L$ be the number of low-fidelity points and $n_H$ the number of high-fidelity points. Create a sample $x_L$ of $n_L$ points using some initial sampling method (random, LHS, grid, etc), and take from that a subsample $x_H \subset x_L$ through some heuristic (maximal distance, random, etc). Then we train a number of models:
 - direct low-fidelity model using $x_L, f_L(x_L)$ only
 - direct high-fidelity model using $x_H, f_H(x_H)$ only
 - hierarchical high-fidelity model using both $x_L, f_L(x_L)$ and $x_H, f_H(x_H)$
 
Independently, a function-dependent sample $x_{mse}$ of size 1000 is also created. This sample is used to calculate a Mean Squared Error (MSE) value for the state of a model after training.

For the experiments, we examine all combinations for $n_L \in 3, \ldots, 100$ and $n_H \in 2, \ldots, 40$, with the restriction that $n_L > n_H$. Each combination is repeated 30 times.

In [None]:
max_high = 40
max_low = 100
num_reps = 30

In [None]:
# %%writefile -a function_defs.py

def low_random_sample(ndim, nlow):
    return np.random.rand(nlow, ndim)

def low_lhs_sample(ndim, nlow):
    if ndim == 1:
        return np.linspace(0,1,nlow).reshape(-1,1)
    elif ndim > 1:
        return lhs(ndim, nlow)

In [None]:
# %%writefile -a function_defs.py

def create_mse_tracking(func, sample_generator,
                        max_high=40, max_low=100, num_reps=30,
                        min_high=2, min_low=3):
    ndim = func.ndim
    mse_tracking = np.empty((max_high+1, max_low+1, num_reps, 3))
    mse_tracking[:] = np.nan
    cases = list(product(range(min_high, max_high+1), range(min_low, max_low+1), range(num_reps)))

    for idx, case in enumerate(cases):
        num_high, num_low, rep = case

        if num_high >= num_low:
            continue
        if idx % 100 == 0:
            clear_output()
            print(f'{idx}/{len(cases)}')

        low_x = sample_generator(ndim, num_low)
        high_x = low_x[np.random.choice(num_low, num_high, replace=False)]
        
        archive = mlcs.CandidateArchive(ndim=ndim, fidelities=['high', 'low', 'high-low'])
        archive.addcandidates(low_x, func.low(low_x), fidelity='low')
        archive.addcandidates(high_x, func.high(high_x), fidelity='high')

        mfbo = mlcs.MultiFidelityBO(func, archive, output_range=(-10, 16))
        mse_tracking[num_high, num_low, rep] = mfbo.getMSE()

    clear_output()
    print(f'{len(cases)}/{len(cases)}')
    return mse_tracking

In [None]:
# %%writefile -a function_defs.py

def plot_high_vs_low_num_samples(data, name, vmin=.5, vmax=100, save_as=None):
    norm = colors.LogNorm(vmin=vmin, vmax=vmax, clip=True)
    fig, ax = plt.subplots(figsize=(9,3.5))
    
    ax.set_aspect('equal')
    data = np.nanmedian(data, axis=2)
    
    plt.title('Median MSE for high (hierarchical) model')
    img = ax.imshow(data[:,:,0], cmap='viridis_r', norm=norm)
    
    divider = make_axes_locatable(ax)
    axx = divider.append_axes("bottom", size=.2, pad=0.05, sharex=ax)
    axy = divider.append_axes("left", size=.2, pad=0.05, sharey=ax)
    
    ax.xaxis.set_tick_params(labelbottom=False)
    ax.yaxis.set_tick_params(labelleft=False)
    axy.xaxis.set_tick_params(labelbottom=False)
    axx.yaxis.set_tick_params(labelleft=False)
    
    img = axy.imshow(np.nanmean(data[:,:,1], axis=1).reshape(-1,1), cmap='viridis_r', norm=norm)
    img = axx.imshow(np.nanmean(data[:,:,2], axis=0).reshape(1,-1), cmap='viridis_r', norm=norm)
    
    fig.colorbar(img, ax=ax, orientation='vertical')
    axy.set_ylabel('#High-fid samples')
    axx.set_xlabel('#Low-fid samples')
    
    plt.tight_layout()
    if save_as:
        plt.savefig(save_as)
    plt.show()

In [None]:
# %%writefile -a function_defs.py

def plot_high_vs_low_num_samples_diff(data, name, vmin=.5, vmax=100, save_as=None):
    to_plot = np.nanmedian(data[:,:,:,1] - data[:,:,:,0], axis=2)
    print(np.nanmin(to_plot), np.nanmax(to_plot))

    max_diff = 2*min(abs(np.nanmin(to_plot)), np.nanmax(to_plot))
    norm = colors.Normalize(vmin=-max_diff, vmax=max_diff, clip=True)
    
    fig, ax = plt.subplots(figsize=(9,3.5))
    img = ax.imshow(to_plot, cmap='RdYlGn', norm=norm)
    fig.colorbar(img, ax=ax, orientation='vertical')
    ax.set_ylabel('#High-fid samples')
    ax.set_xlabel('#Low-fid samples')
    
    plt.title('Median of paired (high (hierarchical) - high (direct)) MSE')
    plt.tight_layout()
    if save_as:
        plt.savefig(save_as)
    plt.show()

## Random Sample generation

In [None]:
if '1d_mse_tracking.npy' in os.listdir(data_dir):
    mse_tracking = np.load(f'{data_dir}1d_mse_tracking.npy')
else:
    mse_tracking = create_mse_tracking(OD, low_random_sample)
    np.save(f'{data_dir}1d_mse_tracking.npy', mse_tracking)

### Test sample inspection

In [None]:
sample = np.load(f'{data_dir}1d_test_sample.npy')
img = plt.hist(sample)
plt.show()

### Error distribution

In [None]:
errors = np.load(f'{data_dir}1d_error_tracking.npy')
mean_errors = np.mean(errors, axis=(0,1,2))

In [None]:
plt.scatter(x=sample.flatten(), y=mean_errors[0], s=2)
plt.title('Mean error - high fidelity (hierarchical) model')
plt.show()

In [None]:
plt.scatter(x=sample.flatten(), y=mean_errors[1], s=2)
plt.title('Mean error - high fidelity (direct) model')
plt.show()

In [None]:
plt.scatter(x=sample.flatten(), y=mean_errors[2], s=2)
plt.title('Mean error - low fidelity (direct) model')
plt.show()

### Global MSE inspection

In [None]:
print('median')
pprint([(f'{95+i}%-ile', np.percentile(np.nanmedian(mse_tracking, axis=2).flatten(), 95+i)) for i in range(6)])

In [None]:
name = 'high-low-samples-random'
plot_high_vs_low_num_samples(mse_tracking, name, save_as=f'{plot_dir}{name}.pdf')

In [None]:
name = 'high-low-samples-random'
plot_high_vs_low_num_samples_diff(mse_tracking, name, save_as=f'{plot_dir}{name}_diff.pdf')

## Linspace, random subsample generation

In [None]:
if '1d_lin_mse_tracking.npy' in os.listdir(data_dir):
    lin_mse_tracking = np.load(f'{data_dir}1d_lin_mse_tracking.npy')
else:
    lin_mse_tracking = create_mse_tracking(OD, low_lhs_sample)
    np.save(f'{data_dir}1d_lin_mse_tracking.npy', lin_mse_tracking)

In [None]:
print('median')
pprint([(f'{95+i}%-ile', np.percentile(np.nanmedian(lin_mse_tracking, axis=2).flatten(), 95+i)) for i in range(6)])

In [None]:
name = 'high-low-samples-linear'
plot_high_vs_low_num_samples(lin_mse_tracking, name, save_as=f'{plot_dir}{name}.pdf')

In [None]:
name = 'high-low-samples-linear'
plot_high_vs_low_num_samples_diff(lin_mse_tracking, name, save_as=f'{plot_dir}{name}_diff.pdf')

## Difference in error between linear and random sample

In [None]:
# %%writefile -a function_defs.py

def plot_inter_method_diff(data_A, data_B, name, save_as=None):
    fig, ax = plt.subplots(figsize=(9,3.5))

    plt.title(f'high (hierarchical) MSE: {name}')
    to_plot = np.nanmedian(data_A[:,:,:,0] - data_B[:,:,:,0], axis=2)
    
    print(np.nanmin(to_plot), np.nanmax(to_plot))
    max_diff = .05*min(abs(np.nanmin(to_plot)), np.nanmax(to_plot))
    norm = colors.Normalize(vmin=-max_diff, vmax=max_diff, clip=True)

    img = ax.imshow(to_plot, cmap='RdYlGn', norm=norm)

    fig.colorbar(img, ax=ax, orientation='vertical')
    ax.set_ylabel('#High-fid samples')
    ax.set_xlabel('#Low-fid samples')

    plt.tight_layout()
    if save_as:
        plt.savefig(save_as)
    plt.show()

In [None]:
name = "1D, random - LHS"
plot_inter_method_diff(mse_tracking, lin_mse_tracking, name, save_as=f'{plot_dir}{name}.pdf')