In [290]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
from denn import *
import pandas as pd
import seaborn as sns
import io 

In [291]:
# Functions
nn_p = 3
nn_w = 5
nn_tw = 1
freqlimit = str(1)
labels_order = ['noNN_RI', 'NN_RI','noNN_HMu',  'NN_HMu', 'noNN_CwN', 'NN_CwN', 'noNN_No', 
                 'NN_No','noNN_Rst', 'NN_Rst']
path= Path(f'../../data/cluster_results')
col_by_freq = True
col_palette = 'Set3'

pat = re.compile('.*/(exp\d)/(\w*)/nonn/freq([0-9\.]+)div(\w+)/(\w+)_\w+.csv')
decode_keys = ['experiment','function','freq','div','method']

nn_pat = re.compile('.*/(exp\d)/(\w*)/nn/freq([0-9\.]+)nn_w(\d+)nn_p(\d+)\w+nn_tw(\d+)\w+div([A-Za-z]+)/(\w+)_(\w+)_\w+.csv')
nn_decode_keys = ['experiment','function','freq','nnw','nnp','nntw','div','method','replace_mech']#,

def get_files(m): return list(path.glob(f'**/nonn/**/*{m}.csv'))
def get_nn_files(m): return list(path.glob(f'**/nn/**/*{m}.csv'))

def read_csv(f,m):
    df = pd.read_csv(f)
    for k,v in zip(decode_keys,pat.search(str(f)).groups()): df[k] = v
    df['freq'] = df['freq'].astype(float)
    df['method'] = df['method'] + '_' + df['div']
#     df['method'] = df['method'].str.replace('noNNRestart', 'noNN_Rst')
    df['method'] = df['method'].str.replace('noNNRestart_No', 'noNN_Rst')
    df.drop('div', axis=1, inplace=True)
    df.rename({'0':m.upper(), m:m.upper()}, axis=1, inplace=True)
    return df

def read_nn_csv(f,m):
    df = pd.read_csv(f)
    for k,v in zip(nn_decode_keys,nn_pat.search(str(f)).groups()): df[k] = v
    df['freq'] = df['freq'].astype(float)
    df['method'] = df['method'] + '_' + df['replace_mech'] + '_' + df['div']
    df['method'] = df['method'].str.replace('NNnorm_Worst', 'NN')
    df['method'] = df['method'].str.replace('NNconv_Worst', 'NN')
    df.drop(['replace_mech','div'], axis=1, inplace=True)
    df.rename({'0':m.upper(), m:m.upper()}, axis=1, inplace=True)
    return df

def get_data(m, normalize=False):
    files = get_files(m)
    nn_files = get_nn_files(m)
#     data = pd.concat([read_csv(f,m) for f in files] + [read_nn_csv(f,m) for f in nn_files])
    nn_data = pd.concat([read_nn_csv(f,m) for f in nn_files])
    nn_data = nn_data[nn_data['nnw']==str(nn_w)]
    nn_data = nn_data[nn_data['nntw']==str(nn_tw)]
    nonn_data = pd.concat([read_csv(f,m) for f in files])
    data = pd.concat([nn_data , nonn_data])

    if normalize:
        data_norm = (data.groupby(['experiment','function','freq','method'])[m.upper()].mean().reset_index()
                         .groupby(['experiment','function'])[m.upper()].min().reset_index()
                         .rename({m.upper():m.upper()+'_norm'}, axis=1))
        data = data.merge(data_norm, 'left')
        data[m.upper()+'_norm'] = data[m.upper()] / data[m.upper()+'_norm']
    return data.reset_index(drop=True)
    return data

def plot_one(data, m, normalize=False, title='', title_size=14, col_by_freq=col_by_freq, col_palette=col_palette,
             legend=False, hide_x=True, hide_y=True, ax=None, do_lim=True, ll=0.2, ul=0.8):
    m = m.upper()
    if normalize: m = m + '_norm' 
    if ax is None: fig,ax = plt.subplots(1,1,figsize=(6,4))
    sns.boxplot('method' if col_by_freq else 'freq', m, hue='freq' if col_by_freq else 'method',
                data=data, palette=col_palette,
                # width=2.5,
                linewidth=0.5, fliersize=0, ax=ax)
    if not legend: ax.get_legend().remove()
    if do_lim:
        ax.set_ylim(-0.05,1.05)
    else:
        g = data.groupby(['method','freq'])[m]
        q1,q3 = g.quantile(ll),g.quantile(ul)
        iqr = q3-q1
        lower_lim = (q1 - 1.5*iqr).min()
        upper_lim = (q3 + 1.5*iqr).max()
        lower_lim = data[m][data[m]>=lower_lim].min()
        upper_lim = data[m][data[m]<=upper_lim].max()
        ax.set_ylim(lower_lim,upper_lim)
        
    ax.set_xlabel('')
    if hide_x: ax.set_xticklabels([])
    if hide_y: ax.set_ylabel('')
    ax.set_title(title, size=title_size)
    return ax

def plot_all_vertical(m, normalize=False, title_size=14, col_by_freq=col_by_freq, col_palette=col_palette, do_lim=True,
                      ll=0.2, ul=0.8):
    data = get_data(m, normalize=normalize)
#     fig,axss = plt.subplots(4, 3, figsize=(14,16), sharex=do_lim, sharey=do_lim)
    fig,axss = plt.subplots(4, 3, figsize=(26,18), sharex=False, sharey=False)
    for i,(axs,exp) in enumerate(zip(axss,['exp1','exp2','exp3','exp4'])):
        first_exp = i==0
        for j,(ax,func) in enumerate(zip(axs,['sphere', 'rosenbrock', 'rastrigin'])):
            first_func = j==0
            test = data.query(f'experiment=={exp!r} and function=={func!r}').sort_values('method', ascending=False)
            plot_one(test, m, normalize=normalize, title=f'{exp}-{func.title()}', title_size=title_size, hide_x=False,
                     hide_y=not first_func, col_by_freq=col_by_freq, col_palette=col_palette, ax=ax,
                     do_lim=do_lim, ll=ll, ul=ul)
        ax.legend(loc='upper left', ncol=3)
    plt.tight_layout()
    return fig,axss

def plot_all_horizontal(m, normalize=False, title_size=14, col_by_freq=col_by_freq, col_palette=col_palette, do_lim=True,
                        ll=0.2, ul=0.8):
    data = get_data(m, normalize=normalize)
#     fig,axss = plt.subplots(4, 3, figsize=(14,16), sharex=do_lim, sharey=do_lim)
    fig,axss = plt.subplots(3, 4, figsize=(28,20), sharex=False, sharey=False)#20,12
    for i,exp in enumerate(['exp1','exp2','exp3','exp4']):#
        first_exp = i==0
        axs = axss[:,i]
        for j,(ax,func) in enumerate(zip(axs,['sphere', 'rosenbrock', 'rastrigin'])):
            first_func = i==0
            test = data.query(f'experiment=={exp!r} and function=={func!r}').sort_values('method', ascending=False)
            plot_one(test, m, normalize=normalize, title=f'{exp}-{func.title()}', title_size=title_size, hide_x=False,
                     hide_y=not first_func, col_by_freq=col_by_freq, col_palette=col_palette, ax=ax,
                     do_lim=do_lim, ll=ll, ul=ul)
        
    for axs in axss: axs[-1].legend(loc='upper right', ncol=3)
    plt.tight_layout()
    return fig,axss

def get_heatmap_data(m):
    df = get_data(m)
    df = df[df.nnp.isna() | (df.nnp == str(nn_p))].drop('nnp', axis=1)
    df.function = df.function.str.title()
    m = m.upper()
    df_pivot = df.pivot_table(index=['experiment','function','freq'], columns=['method'], values=[m],
                              aggfunc='mean')[m]
    return df_pivot

# Plots

In [292]:
df_pivot = get_heatmap_data('mof')#[labels_order]
df_pivot_rank = df_pivot.rank(axis=1)
df_pivot = df_pivot.reset_index()
df_pivot.experiment = df_pivot.experiment.str.slice(3)
df_pivot.set_index(['experiment','function','freq'], inplace=True)

Sorting because non-concatenation axis is not aligned. A future version
of pandas will change to not sort by default.

To accept the future behavior, pass 'sort=False'.




In [293]:
df_pivot_rank = df_pivot_rank.query(f'freq=={freqlimit}')
df_pivot = df_pivot.query(f'freq=={freqlimit}')
df_pivot

Unnamed: 0_level_0,Unnamed: 1_level_0,method,NN_CwN,NN_HMu,NN_RI,noNN_CwN,noNN_HMu,noNN_No,noNN_RI,noNN_Rst
experiment,function,freq,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
1,Rastrigin,1.0,15.4336,4.860777,4.457794,10.15614,4.382079,22.49666,3.980478,8.104017
1,Rosenbrock,1.0,19426.82,3370.7508,3610.889,29398.45,2920.402,85839.4,2889.698,9686.117
1,Sphere,1.0,5.492393,1.008488,1.025177,8.186963,0.8859331,23.66202,0.8715007,2.857165
2,Rastrigin,1.0,3.670675,,3.510239,9.091213,3.385957,1.601281,2.518225,10.64801
2,Rosenbrock,1.0,405.753,,141.9824,3130.698,169.3449,150.2851,98.77811,2562.096
2,Sphere,1.0,0.5215018,0.2585,0.1566152,3.180963,0.200303,0.454492,0.117002,2.068908
3,Rastrigin,1.0,9.083339,,1.737545,10.5434,1.924866,22.17239,1.414138,9.600763
3,Rosenbrock,1.0,321.7864,,24.8787,25796.35,21.48184,69.36475,18.28413,25436.64
3,Sphere,1.0,0.4428295,,0.02559957,5.849812,0.01909811,0.01587667,0.01799961,4.392413
4,Rastrigin,1.0,113.1201,,32.0481,78.73978,31.64391,178.1399,34.10237,26.11627


In [294]:
df_pivot.style.background_gradient(cmap=plt.cm.Greens, axis=1)

Unnamed: 0_level_0,Unnamed: 1_level_0,method,NN_CwN,NN_HMu,NN_RI,noNN_CwN,noNN_HMu,noNN_No,noNN_RI,noNN_Rst
experiment,function,freq,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
1,Rastrigin,1.0,15.4336,4.86078,4.45779,10.1561,4.38208,22.4967,3.98048,8.10402
1,Rosenbrock,1.0,19426.8,3370.75,3610.89,29398.5,2920.4,85839.4,2889.7,9686.12
1,Sphere,1.0,5.49239,1.00849,1.02518,8.18696,0.885933,23.662,0.871501,2.85717
2,Rastrigin,1.0,3.67068,,3.51024,9.09121,3.38596,1.60128,2.51823,10.648
2,Rosenbrock,1.0,405.753,,141.982,3130.7,169.345,150.285,98.7781,2562.1
2,Sphere,1.0,0.521502,0.2585,0.156615,3.18096,0.200303,0.454492,0.117002,2.06891
3,Rastrigin,1.0,9.08334,,1.73755,10.5434,1.92487,22.1724,1.41414,9.60076
3,Rosenbrock,1.0,321.786,,24.8787,25796.4,21.4818,69.3648,18.2841,25436.6
3,Sphere,1.0,0.442829,,0.0255996,5.84981,0.0190981,0.0158767,0.0179996,4.39241
4,Rastrigin,1.0,113.12,,32.0481,78.7398,31.6439,178.14,34.1024,26.1163


In [295]:
df_pivot_rank.style.background_gradient(cmap=plt.cm.Greens, axis=1)

Unnamed: 0_level_0,Unnamed: 1_level_0,method,NN_CwN,NN_HMu,NN_RI,noNN_CwN,noNN_HMu,noNN_No,noNN_RI,noNN_Rst
experiment,function,freq,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
exp1,Rastrigin,1.0,7,4.0,3,6,2,8,1,5
exp1,Rosenbrock,1.0,6,3.0,4,7,2,8,1,5
exp1,Sphere,1.0,6,3.0,4,7,2,8,1,5
exp2,Rastrigin,1.0,5,,4,6,3,1,2,7
exp2,Rosenbrock,1.0,5,,2,7,4,3,1,6
exp2,Sphere,1.0,6,4.0,2,8,3,5,1,7
exp3,Rastrigin,1.0,4,,2,6,3,7,1,5
exp3,Rosenbrock,1.0,5,,3,7,2,4,1,6
exp3,Sphere,1.0,5,,4,7,3,1,2,6
exp4,Rastrigin,1.0,6,,3,5,2,7,4,1


In [296]:
summary = df_pivot.groupby(['experiment', 'function']).mean()
summary.round(2).style.background_gradient(cmap=plt.cm.Greens, axis=1)

Unnamed: 0_level_0,method,NN_CwN,NN_HMu,NN_RI,noNN_CwN,noNN_HMu,noNN_No,noNN_RI,noNN_Rst
experiment,function,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
1,Rastrigin,15.43,4.86,4.46,10.16,4.38,22.5,3.98,8.1
1,Rosenbrock,19426.8,3370.75,3610.89,29398.5,2920.4,85839.4,2889.7,9686.12
1,Sphere,5.49,1.01,1.03,8.19,0.89,23.66,0.87,2.86
2,Rastrigin,3.67,,3.51,9.09,3.39,1.6,2.52,10.65
2,Rosenbrock,405.75,,141.98,3130.7,169.34,150.29,98.78,2562.1
2,Sphere,0.52,0.26,0.16,3.18,0.2,0.45,0.12,2.07
3,Rastrigin,9.08,,1.74,10.54,1.92,22.17,1.41,9.6
3,Rosenbrock,321.79,,24.88,25796.3,21.48,69.36,18.28,25436.6
3,Sphere,0.44,,0.03,5.85,0.02,0.02,0.02,4.39
4,Rastrigin,113.12,,32.05,78.74,31.64,178.14,34.1,26.12


In [297]:
out = summary.round(2).style
for i in range(len(labels_order)//2): 
    out = out.highlight_min(axis=1, subset=labels_order[i*2:(i+1)*2])

out

Unnamed: 0_level_0,method,NN_CwN,NN_HMu,NN_RI,noNN_CwN,noNN_HMu,noNN_No,noNN_RI,noNN_Rst
experiment,function,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
1,Rastrigin,15.43,4.86,4.46,10.16,4.38,22.5,3.98,8.1
1,Rosenbrock,19426.8,3370.75,3610.89,29398.5,2920.4,85839.4,2889.7,9686.12
1,Sphere,5.49,1.01,1.03,8.19,0.89,23.66,0.87,2.86
2,Rastrigin,3.67,,3.51,9.09,3.39,1.6,2.52,10.65
2,Rosenbrock,405.75,,141.98,3130.7,169.34,150.29,98.78,2562.1
2,Sphere,0.52,0.26,0.16,3.18,0.2,0.45,0.12,2.07
3,Rastrigin,9.08,,1.74,10.54,1.92,22.17,1.41,9.6
3,Rosenbrock,321.79,,24.88,25796.3,21.48,69.36,18.28,25436.6
3,Sphere,0.44,,0.03,5.85,0.02,0.02,0.02,4.39
4,Rastrigin,113.12,,32.05,78.74,31.64,178.14,34.1,26.12


# Interactive

In [207]:
from ipywidgets import interact
def show_table(dm):
    out = summary[[f'noNN_{dm}', f'NN_{dm}']].round(2)
    return out.style.background_gradient(cmap=plt.cm.Blues_r, axis=1)
    
interact(show_table, dm=['RI', 'HMu','No','Rst'])

interactive(children=(Dropdown(description='dm', options=('RI', 'HMu', 'No', 'Rst'), value='RI'), Output()), _…

<function __main__.show_table(dm)>