In [None]:
import matplotlib.pyplot as plt
import numpy as np
import multifidelityfunctions as mff
from matplotlib import colors
from scipy.spatial import distance
from high_v_low_experiment import multi_fidelity_doe
from itertools import product

In [None]:
def plot_distances(ndim):
    max_high = 51
    max_low = 126
    nreps = 15
    
    minmax_dists = np.full((max_high, max_low, nreps, 4), np.nan)
    
    for h, l in [(h, l) for h, l in product(range(2, max_high), range(3, max_low)) if l > h]:
        for i in range(nreps):
            high_x, low_x = multi_fidelity_doe(ndim, h, l)
            h_dists = distance.pdist(high_x)
            l_dists = distance.pdist(low_x)

            minmax_dists[h, l, i] = [np.max(np.min(h_dists, axis=0)),
                                     np.min(np.max(h_dists, axis=0)), 
                                     np.max(np.min(l_dists, axis=0)), 
                                     np.min(np.max(h_dists, axis=0))]
    
    means = np.mean(minmax_dists, axis=2)
    norm = colors.Normalize(vmin=np.nanmin(means), vmax=np.nanmax(means))
    
    fig, axes = plt.subplots(2,2,figsize=(16,9))
    axes = axes.flatten()
    img0 = axes[0].imshow(means[:,:,0], norm=norm, origin='lower')
    axes[0].set_title('minimum distance: high_x')
    img1 = axes[1].imshow(means[:,:,1], norm=norm, origin='lower')
    axes[1].set_title('maximum distance: high_x')
    img2 = axes[2].imshow(means[:,:,2], norm=norm, origin='lower')
    axes[2].set_title('minimum distance: low_x')
    img3 = axes[3].imshow(means[:,:,3], norm=norm, origin='lower')
    axes[3].set_title('maximum distance: low_x')
    plt.colorbar(img0, ax=axes[0], shrink=.6)
    plt.colorbar(img1, ax=axes[1], shrink=.6)
    plt.colorbar(img2, ax=axes[2], shrink=.6)
    plt.colorbar(img3, ax=axes[3], shrink=.6)
    plt.suptitle(f'{ndim}D')
    plt.tight_layout()
    plt.show()

In [None]:
for d in range(1, 9):
    plot_distances(d)