In [None]:
from collections import namedtuple

import numpy as np
import matplotlib.pyplot as plt
from pyDOE import lhs
from scipy.spatial import distance
from pyprojroot import here


BiFidelityDoE = namedtuple("BiFidelityDoE", "high low")


def low_lhs_sample(ndim, nlow):
    if ndim == 1:
        return np.linspace(0,1,nlow).reshape(-1,1)
    elif ndim > 1:
        return lhs(ndim, nlow)


def bi_fidelity_doe(ndim, num_high, num_low):
    """Create a Design of Experiments (DoE) for two fidelities in `ndim`
    dimensions. The high-fidelity samples are guaranteed to be a subset
    of the low-fidelity samples.

    :returns high-fidelity samples, low-fidelity samples
    """
    high_x = low_lhs_sample(ndim, num_high)
    low_x = low_lhs_sample(ndim, num_low)

    dists = distance.cdist(high_x, low_x)

    #TODO: this is the naive method, potentially speed up?
    highs_to_match = set(range(num_high))
    while highs_to_match:
        min_dist = np.min(dists)
        high_idx, low_idx = np.argwhere(dists == min_dist)[0]

        low_x[low_idx] = high_x[high_idx]
        # make sure just selected samples are not re-selectable
        dists[high_idx,:] = np.inf
        dists[:,low_idx] = np.inf
        highs_to_match.remove(high_idx)
    return BiFidelityDoE(high_x, low_x)

In [None]:
np.random.seed(20160501)

bfd = bi_fidelity_doe(2, 10, 20)

plt.scatter(*bfd.low.T, s=48, label='low')
plt.scatter(*bfd.high.T, s=12, label='high')
plt.legend(loc=0)
plt.show()

In [None]:
def illustrated_bi_fidelity_doe(ndim, num_high, num_low, save_dir=None):
    """Create a Design of Experiments (DoE) for two fidelities in `ndim`
    dimensions. The high-fidelity samples are guaranteed to be a subset
    of the low-fidelity samples.

    :returns high-fidelity samples, low-fidelity samples
    """
    high_x = low_lhs_sample(ndim, num_high)
    low_x = low_lhs_sample(ndim, num_low)

    dists = distance.cdist(high_x, low_x)
    fig_size = (4, 4)
    plt.rcParams.update({'font.size': 16})
    plt.rc('axes', labelsize=20)

    #TODO: this is the naive method, potentially speed up?
    highs_to_match = set(range(num_high))
    while highs_to_match:

        min_dist = np.min(dists)
        high_idx, low_idx = np.argwhere(dists == min_dist)[0]

        plt.figure(figsize=fig_size, constrained_layout=True)
        plt.scatter(*low_x.T, s=48, label='low')
        plt.scatter(*high_x.T, s=12, label='high')
        plt.arrow(*high_x[high_idx], *(low_x[low_idx] - high_x[high_idx]))
        plt.xticks([])
        plt.yticks([])
        plt.xlabel('$x_1$')
        plt.ylabel('$x_2$')
#         plt.legend(loc=0)
        plt.title(f'step {num_high-len(highs_to_match)}/{num_high}')
#         plt.tight_layout()
        if save_dir:
            plt.savefig(save_dir / f'illustrated-bi-fid-doe-{num_high-len(highs_to_match)}.pdf')
        plt.show()
        plt.close()

        low_x[low_idx] = high_x[high_idx]
        # make sure just selected samples are not re-selectable
        dists[high_idx,:] = np.inf
        dists[:,low_idx] = np.inf
        highs_to_match.remove(high_idx)
    
    
    plt.figure(figsize=fig_size, constrained_layout=True)
    plt.scatter(*low_x.T, s=48, label='low')
    plt.scatter(*high_x.T, s=12, label='high')
    plt.xticks([])
    plt.yticks([])
    plt.xlabel('$x_1$')
    plt.ylabel('$x_2$')
#     plt.legend(loc=0)
    plt.title(f'step {num_high-len(highs_to_match)}/{num_high}')
#     plt.tight_layout()
    if save_dir:
        plt.savefig(save_dir / f'illustrated-bi-fid-doe-{num_high-len(highs_to_match)}.pdf')
    plt.show()

    return BiFidelityDoE(high_x, low_x)

In [None]:
np.random.seed(20160501)

plot_dir = here('plots') / 'illustrated-doe'
plot_dir.mkdir(exist_ok=True, parents=True)

bfd = illustrated_bi_fidelity_doe(2, 10, 20, plot_dir)

In [None]:
from functools import partial

def generator(n):
    for i in range(n):
        yield i*i

def animator(i, gen):
    return next(gen)

animate = partial(animator, gen=generator(10))
for i in range(10):
    print(animate(i))


In [None]:
###TODO: http://louistiao.me/posts/notebooks/save-matplotlib-animations-as-gifs/