In [1]:
import pickle
import time
import numpy as np
import math
from scipy.special import binom
from itertools import combinations, permutations, product
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from olympus.surfaces import CatCamel, CatMichalewicz, CatSlope, CatDejong

In [2]:
#------------------
# HELPER FUNCTIONS
#------------------
def stirling_sum(Ns):
    """ ...
    """
    stirling = lambda n,k: int(1./math.factorial(k) * np.sum([(-1.)**i * binom(k,i)*(k-i)**n for i in range(k)]))
    return np.sum([stirling(Ns, k) for k in range(Ns+1)])

def partition(S):
    """ ...
    """
    if len(S) == 1:
        yield [S]
        return 

    first = S[0]
    for smaller in partition(S[1:]):
        for n, subset in enumerate(smaller):
            yield smaller[:n]+[[first] + subset]+smaller[n+1:]
        yield [[first]]+smaller 
    
def gen_partitions(S):
    """
    generate all possible partitions of Ns-element set S
    
    Args: 
        S (list): list of non-functional parameters S
    """
    return [p for _, p in enumerate(partition(S),1)]


def gen_permutations(X_funcs, Ng):
    """ generate all possible functional parameter permutations
    given number of non-functional parameter subsets Ng
    
    Args: 
        X_funcs (np.ndarray): numpy array with all functional 
            possile functional parameters
        Ng (int): number of non-functional parameter subsets
        
    Returns
        (np.ndarray): array of parameter permutations of
            shape (# perms, Ng, # params)
    """
    
    return np.array(list(permutations(X_funcs, Ng)))

def measure_objective(xgs, G, surf_map):
    """ ... 
    """
    f_x = 0.
    for g_ix, Sg in enumerate(G):
        f_xg = 0.
        for si in Sg:
            f_xg += measure_single_obj(xgs[g_ix], si, surf_map)
        f_x += f_xg / len(Sg)

    return f_x



def record_merits(S, surf_map, X_func_truncate=20):
    
    # list of dictionaries to store G, X_func, f_x
    f_xs = [] 
    
    start_time = time.time()
    
    # generate all the partitions of non-functional parameters
    Gs = gen_partitions(S)
    print('total non-functional partitions : ', len(Gs))
    
    # generate all the possible values of functional parametres
    param_opts = [f'x{i}' for i in range(21)]
    cart_product = list(product(*param_opts))
    X_func = np.array([list(elem) for elem in cart_product])
    
    if isinstance(X_func_truncate,int):
        X_funcs = X_funcs[:X_func_truncate, :]
    print('cardnality of functional params : ', X_funcs.shape[0])
    
    for G_ix, G in enumerate(Gs): 
        if G_ix % 1 == 0:
            print(f'[INFO] Evaluating partition {G_ix+1}/{len(Gs)+1}')
        Ng = len(G)
        # generate permutations of functional params
        X_func_perms = gen_permutations(X_funcs, Ng)
        
        for X_func in X_func_perms:
            # measure objective 
            f_x = measure_objective(X_func, G, surf_map)
            # store values
            f_xs.append({'G': G, 'X_func': X_func,'f_x': f_x,})
    total_time = round(time.time()-start_time,2)
    print(f'[INFO] Done in {total_time} s')
    
    return f_xs

In [3]:
#-------------
# TOY PROBLEM
#-------------
#S = [0, 1, 2, 3] # four non-functional parameter options 
S = [0, 1, 2]
surf_map = {
    0: CatCamel(param_dim=2, num_opts=21),
    1: CatDejong(param_dim=2, num_opts=21),
    2: CatMichalewicz(param_dim=2, num_opts=21),
    3: CatSlope(param_dim=2, num_opts=21),
}

In [None]:
def measure_single_obj(X_func, si, surf_map):
    return surf_map[si].run(X_func)[0][0]

In [None]:
X_func = ['x10', 'x20']
si = 1

In [None]:
f_xs = record_merits(S, surf_map, X_func_truncate=20)

In [4]:
gen_partitions(S)

[[[0, 1, 2]], [[0], [1, 2]], [[0, 1], [2]], [[1], [0, 2]], [[0], [1], [2]]]

In [9]:
param_opts = [f'x{i}' for i in range(21)]
params_opts = [param_opts, param_opts]
cart_product = list(product(*params_opts))
cart_product = np.array([list(elem) for elem in cart_product])
cart_product.shape

(441, 2)

['x0',
 'x1',
 'x2',
 'x3',
 'x4',
 'x5',
 'x6',
 'x7',
 'x8',
 'x9',
 'x10',
 'x11',
 'x12',
 'x13',
 'x14',
 'x15',
 'x16',
 'x17',
 'x18',
 'x19',
 'x20']