## Compiling result for the few-shot regression paper

In [1]:
import os
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob
import json
import pickle
import itertools
import matplotlib as mpl
from collections import defaultdict
mpl.rcParams['font.family'] = 'Arial'

In [2]:
def make_hash(depth=None, type=None):
    """Utility method to make a multilevel dict"""
    if (depth, type) == (None, None):
        return defaultdict(makehash)
    elif depth == 0:
        return defaultdict(type)
    else:
        return defaultdict(partial(makehash, depth - 1, type))

In [3]:
from metalearn.datasets.loaders import load_dataset
from metalearn.models.factory import ModelFactory
from metalearn.utils.metric import mse, vse, r2, pcc

f = "/home/prtos/.invivo/invivoai-sagemaker-artifacts/iscb-expts4/iscb-2019-01-27-17-28-00-450/output/model/6d3a0b6fb82df5661be94b8c740ff86e_params.json"
def unflatten(dictionary):
    resultDict = dict()
    for key, value in dictionary.items():
        parts = key.split(".")
        d = resultDict
        for part in parts[:-1]:
            if part not in d:
                d[part] = dict()
            d = d[part]
        d[parts[-1]] = value
    return resultDict
    
def load_model(folder):
    param_file = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith('_params.json')]
    model_file = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith('_ckp.ckp')]
    if len(param_file) == 0 or len(model_file) == 0:
        return 
    param_file = param_file[0]
    model_file = model_file[0]
    
    with open(param_file) as fd:
        params = json.load(fd)
        
    model_name, model_params = params['model_name'], params['model_params']
    dataset_name, dataset_params = params['dataset_name'], params['dataset_params']
    model = ModelFactory()(model_name, **model_params)
    _, _, meta_test = load_dataset(dataset_name, **dataset_params)
    
    model.load(ckp_fname)
    
print(unflatten(dd))

ModuleNotFoundError: No module named 'metalearn'

In [3]:
CACHE_DIR = os.path.join(os.getenv('INVIVO_CACHE_ROOT', '~/.invivo'), 'invivoai-sagemaker-artifacts')
KEY1 = 'iscb-expts4'
PATH_TEMPLATE = "output/model/"

In [5]:
key1_path = os.path.join(CACHE_DIR, KEY1, '*', PATH_TEMPLATE)
key1_path

'~/.invivo/invivoai-sagemaker-artifacts/iscb-expts4/*/output/model/'

In [40]:
res_per_dt =  make_hash(3)

def get_params(filename):
    with open(filename) as JSON:
        params = json.load(JSON)
        model_name = params.get('model_name', '')
        algo_name = params.get("model_params.algo", '')
        if algo_name:
            model_name = algo_name
        fold = str(params.get('dataset_params.fold', 0))
        dataset = params.get('dataset_name')
        feat_extract = params.get('model_params.feature_extractor_params.arch', params.get('model_params.fp', ''))
        min_ep = params.get('dataset_params.max_examples_per_episode', 10)
        kernel = params.get('model_params.kernel', '')
        cross_valid = params.get('model_params.do_cv')
        fixed_hps = params.get('model_params.fixe_hps', '')
        
        if cross_valid is not None:
            cross_valid = 'cv:{}'.format(int(cross_valid))
        else:
            cross_valid = ''
        if fixed_hps != '':
            fixed_hps = 'fL:{}'.format(int(fixed_hps))
            
        memory_shape = "x".join(map(str, params.get('model_params.memory_shape', [])))
        controller = str(params.get('model_params.controller_size', ''))
        cnn_size = "x".join(map(str, params.get('model_params.feature_extractor_params.cnn_sizes', [])))
        hyperparams = "|".join([kernel, fixed_hps, str(min_ep), cross_valid, memory_shape, cnn_size, controller, fold])
        return {'dataset':dataset, 'name':model_name+fixed_hps+cross_valid, 'kernel':kernel, 'cv':cross_valid, 'fhps':fixed_hps, 'ep': min_ep, 'feat':feat_extract, 'fold':fold, 'hp':hyperparams}

def get_valid_performance(filename):
    valid = pd.read_csv(filename, header=0, skip_blank_lines=True, delim_whitespace=True)
    return valid['val_loss'].min()

def get_name(x, dataset=''):
    path, _ = os.path.splitext(x)
    dirs, name = os.path.split(path)
    last_dir = os.path.basename(os.path.normpath(dirs))
    if 'pubchem' in path or 'pubchem' in dataset:
        name = os.path.join(last_dir, name)
    if 'pXC50' in name:
        name = name.split('_')[1]
    return name
    
for name in glob.glob(key1_path):
    if not (os.path.exists(os.path.join(name, 'failure')) or not glob.glob(os.path.join(name, '*csv'))):
        param_file = glob.glob(os.path.join(name, '*json'))[0]
        res_file = glob.glob(os.path.join(name, '*csv'))[0]
        param = get_params(param_file)
        model = param['name'] + "_" + param['feat'] + "_{}".format(param['ep'])
        fold = param['fold']
        valid_file = glob.glob(os.path.join(name, '*log'))
        cur_val = res_per_dt[param['dataset']][model][fold].get('val', np.inf)
        best_valid = None
        if len(valid_file) > 0:
              best_valid = get_valid_performance(valid_file[0])
        if (best_valid is None) or (best_valid <= cur_val):
            res_per_dt[param['dataset']][model][fold]['best_hp'] = param['hp'] 
            res_per_dt[param['dataset']][model][fold]['val'] = best_valid or np.inf
            val = pd.read_csv(res_file, delim_whitespace=True, skip_blank_lines=True, header=0)
            val['name'] = [get_name(x) for x in val['name']]
            res_per_dt[param['dataset']][model][fold]['res'] = val


In [41]:
lod = {'metaqsar': 'chembl', 'pubchem': 'pubchemtox'}
def patch_low_data():
    all_pkl = glob.glob(os.path.join(key2_path, '*pkl'))
    for f in all_pkl:
        res = os.path.basename(f).strip('.pkl')
        _, lstm, cur_dataset, hp = res.split('_', 3)
        res_dict = {}
        with open(f, 'rb') as IN:
            lstm_dt = pickle.load(IN)
            score = lstm_dt['score']
            std = lstm_dt['std']
            # cannot use std as it was computed on the wrong format
            res_dict['msemean'] = score.get('mse')#, dict((x, y**2 if not np.isnan(y) else y) for x,y in score['rms'].items()))
            res_dict['r2mean'] = score['r2']
            res_dict['vsemean'] = score['vse']
            res_dict['pccmean'] = score.get('pcc')#, dict((x, np.sqrt(y) if not np.isnan(y) else y) for x,y in score['r2'].items()))
            res_dict['msestd'] = std.get('mse')#, std['rms'])
            res_dict['pccstd'] = std.get('pcc')#, std['r2'])
            res_dict['r2std'] = std.get('r2')#, std['r2'])
            res_dict['vsestd'] = std.get('vse')#, std['r2'])
            cur_df = pd.DataFrame(res_dict)
            cur_df.index.name = 'name'
            cur_df = cur_df.reset_index()
            cur_df['name'] = [get_name(x, lod[cur_dataset]) for x in cur_df['name']]
            res_per_dt[lod[cur_dataset]]['ldt_'+hp]['0']['res'] = cur_df
            res_per_dt[lod[cur_dataset]]['ldt_'+hp]['0']['best_hp'] = hp

patch_low_data()


In [42]:
def compile_fold(dt_val, dataset, old_format=False):
    if dataset == 'mhc' and old_format:
        all_fold = []
        for x in range(len(dt_val.keys())):
            h = dt_val[str(x)]['res']
            h['name'] = [x+1] 
            all_fold.append(h)
        all_fold = pd.concat(all_fold)

    else:
        all_fold = pd.concat([dt_val[str(x)]['res'] for x in range(len(dt_val.keys()))])
    return all_fold
    

def merge_dict_of_pd(pd_dict, keyname='method'):
    return pd.concat(pd_dict, keys=pd_dict.keys(), sort=True).reset_index(level=0).reset_index(drop=True).rename(columns = {'level_0':keyname})
    
def reformat_dt(dt):
    return dt.set_index(['name', 'method', 'size']).stack().reset_index().rename(columns={'level_3': 'metric', 0: 'value'})

def select_best_met(df):
    methods = df.method.unique()
    mean_val = df.groupby("method").mean().reset_index()
    cur_to_met = dict((x, x.split('_')[0]) for x in methods)
    best_list = {}
    best_method = {}
    for k,v in cur_to_met.items():
        score = mean_val.loc[mean_val.method==k]["value"].values[0]
        if score > best_list.get(v, -np.inf):
            best_list[v] = score
            best_method[v] = k
    r = df.loc[df['method'].isin(list(best_method.values()))]
    return r


In [43]:
res_per_dt.keys()

dict_keys(['mhc', 'pubchemtox', 'chembl'])

In [44]:
algo_per_dt = {}
for dtname, dt in res_per_dt.items():
    dframes = {}
    for algo in dt.keys():
        dframes[algo] = compile_fold(dt[algo], dtname)
    algo_per_dt[dtname] = merge_dict_of_pd(dframes)

In [45]:
from ivbase.utils.memorize import memorize, hash_dict
import ipywidgets
from IPython.display import display
sns.set(rc={'figure.figsize':(10,6)})
sns.set_style('ticks')
sns.set_context("paper", font_scale=1.5, rc={"lines.linewidth": 2})
metrics = ['msemean', 'r2mean', 'pccmean']

w_metrics = ipywidgets.Dropdown(
    options=metrics,
    value="pccmean",
    description='Metrics',
    disabled=False
)

w_plot = ipywidgets.RadioButtons(
    options=['violin', 'box'],
    description='Type of plot',
    disabled=False,
)

w_dataset = ipywidgets.SelectMultiple(
    options=['mhc', 'chembl', 'pubchemtox'],
    value=['mhc'],
    description='Dataset',
    disabled=False
)

w_best_KRR = ipywidgets.ToggleButton(
    value=True,
    description='Best MetaKRR',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Use Best Algo on Test',
    icon='check'
)

metric_to_name = dict(
    pccmean='Pearson\'s r',
    msemean= 'Mean square error',
    r2mean= 'R2'
)


In [46]:
def get_plotter(metric, plot_type="violin", dtset=['chembl'], best_mkrr=False):
    hash_val = hash_dict(metric=metric, plot_type=plot_type, dtset=dtset, best_mkrr=best_mkrr)
    output = plot_methods(hash_val, metric=metric, plot_type=plot_type, dtset=dtset, best_mkrr=best_mkrr)
    return display(output)

@memorize
def plot_methods(metric, plot_type, dtset, best_mkrr):
    if len(dtset) > 0:
        df = merge_dict_of_pd(dict((x, reformat_dt(algo_per_dt[x])) for x in dtset), 'dataset')
        if best_mkrr:
            r = select_best_met(df.loc[(df.metric == metric)])
        else:
            r =  df.loc[(df.metric == metric)]
        method_list = sorted(r.method.unique())
        
        #method_list = [x for x in r.method.unique() if 'meta' in x]
        #r = r.loc[r['method'].isin(list(method_list))]
        #colors =  sns.color_palette("Reds", 5)#len(method_list))
        
        colors =  sns.color_palette("Set3", len(method_list))
        pal = dict((met, colors[i]) for i, met in enumerate(method_list))
        out = ipywidgets.Output()
        with out:
            if len(dtset) > 1:
                if plot_type == "violin":
                    ax = sns.violinplot(x='dataset', y="value", hue="method", data=r, palette=pal)
                else:
                    ax = sns.boxplot(x='dataset', y="value", hue="method", data=r, palette=pal)
                lgd = plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), fancybox=True, shadow=True)

            else:
                if plot_type == "violin":
                    ax = sns.violinplot(x='method', y="value", data=r, palette=pal)
                else:
                    ax = sns.boxplot(x='method', y="value", data=r, palette=pal)
                    
            ax.set(ylabel=metric_to_name[metric])
            ax.set_xticklabels(ax.get_xticklabels(), rotation=-80)
            plt.savefig('{metric}_{dataset}.svg'.format(metric=metric, dataset="_".join(dtset)), dpi=600)
            plt.title("Performance of selected methods on {} ({})".format(metric, dtset))
            #plt.savefig(os.path.join(res_dir, "{}}{}.pdf".format(plot_type, metric), bbox_extra_artists=(lgd,), bbox_inches='tight')
            #plt.autoscale(enable=True, axis='x', tight=True)

            plt.show()
        return out   

In [47]:
ipywidgets.interact(get_plotter, metric=w_metrics, plot_type=w_plot, dtset=w_dataset, best_mkrr=w_best_KRR)

interactive(children=(Dropdown(description='Metrics', index=2, options=('msemean', 'r2mean', 'pccmean'), value…

<function __main__.get_plotter(metric, plot_type='violin', dtset=['chembl'], best_mkrr=False)>

In [36]:
def get_formated(df, gp='name'):
    return df.round(3).groupby(gp).apply(lambda x: x.astype(str).apply('±'.join, 1)).reset_index(level=0)[0]

def mhc_to_latex(mhc_dt, metric='mse'):
    met_dt_list = {}
    for method in mhc_dt.method.unique():
        met_dt = mhc_dt.loc[mhc_dt.method==method]
        met_dt = met_dt.reset_index().set_index('name')
        met_dt = met_dt[[metric+'mean', metric+'std']]
        met_dt = get_formated(met_dt)
        met_dt_list[method] = met_dt
    df = pd.concat(met_dt_list, axis=1, sort=False)
    return df
df= mhc_to_latex(algo_per_dt['mhc'], 'pcc')
print(df)
df.to_latex()


              metakrr_skfL:1cv:0_cnn_10 metakrr_skfL:1cv:0_cnn_20  \
name                                                                
HLA-DRB1*0101               0.275±0.048               0.224±0.173   
HLA-DRB1*0301               0.423±0.035               0.283±0.194   
HLA-DRB1*0401               0.339±0.053               0.262±0.184   
HLA-DRB1*0404               0.401±0.038               0.321±0.122   

              metakrr_skfL:1cv:0_cnn_40 metakrr_skfL:1cv:0_cnn_5  
name                                                              
HLA-DRB1*0101               0.181±0.122              0.218±0.095  
HLA-DRB1*0301               0.362±0.055              0.418±0.024  
HLA-DRB1*0401               0.314±0.148               0.36±0.021  
HLA-DRB1*0404               0.385±0.069              0.406±0.064  


'\\begin{tabular}{lllll}\n\\toprule\n{} & metakrr\\_skfL:1cv:0\\_cnn\\_10 & metakrr\\_skfL:1cv:0\\_cnn\\_20 & metakrr\\_skfL:1cv:0\\_cnn\\_40 & metakrr\\_skfL:1cv:0\\_cnn\\_5 \\\\\nname          &                           &                           &                           &                          \\\\\n\\midrule\nHLA-DRB1*0101 &               0.275±0.048 &               0.224±0.173 &               0.181±0.122 &              0.218±0.095 \\\\\nHLA-DRB1*0301 &               0.423±0.035 &               0.283±0.194 &               0.362±0.055 &              0.418±0.024 \\\\\nHLA-DRB1*0401 &               0.339±0.053 &               0.262±0.184 &               0.314±0.148 &               0.36±0.021 \\\\\nHLA-DRB1*0404 &               0.401±0.038 &               0.321±0.122 &               0.385±0.069 &              0.406±0.064 \\\\\n\\bottomrule\n\\end{tabular}\n'

In [51]:
print(algo_per_dt['mhc'].groupby('method', as_index=False).agg({'pccmean':['mean', 'std']}).to_latex())

\begin{tabular}{llrr}
\toprule
{} &                     method & \multicolumn{2}{l}{pccmean} \\
{} &      mean &       std \\
\midrule
0  &                mann\_cnn\_10 &  0.458083 &  0.071901 \\
1  &      metakrr\_skcv:0\_cnn\_10 &  0.261707 &       NaN \\
2  &      metakrr\_skcv:1\_cnn\_10 &  0.160249 &  0.021339 \\
3  &  metakrr\_skfL:0cv:0\_cnn\_10 &  0.366902 &  0.062001 \\
4  &  metakrr\_skfL:0cv:0\_cnn\_20 &  0.245782 &  0.055418 \\
5  &  metakrr\_skfL:0cv:0\_cnn\_30 &  0.297835 &  0.139706 \\
6  &  metakrr\_skfL:0cv:0\_cnn\_40 &  0.272624 &  0.094266 \\
7  &   metakrr\_skfL:0cv:0\_cnn\_5 &  0.357735 &  0.092841 \\
8  &  metakrr\_skfL:0cv:0\_cnn\_50 &  0.275130 &  0.119495 \\
9  &  metakrr\_skfL:0cv:1\_cnn\_10 &  0.211991 &  0.087992 \\
10 &  metakrr\_skfL:0cv:1\_cnn\_20 &  0.238331 &  0.091323 \\
11 &  metakrr\_skfL:0cv:1\_cnn\_30 &  0.344085 &  0.063694 \\
12 &  metakrr\_skfL:0cv:1\_cnn\_40 &  0.264911 &  0.175041 \\
13 &   metakrr\_skfL:0cv:1\_cnn\_5 &  0.342913 &  0.055792 \

In [None]:
df= mhc_to_latex('pcc', 'pubchemtox')
print(df)

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
sns.set_style('whitegrid')

def partial(func, *args, **keywords):
    def newfunc(*fargs, **fkeywords):
        newkeywords = keywords.copy()
        newkeywords.update(fkeywords)
        return func(*(args + fargs), **newkeywords)
    newfunc.func = func
    newfunc.args = args
    newfunc.keywords = keywords
    return newfunc


def get_axis_limit(metric):
    if 'pcc' in metric:
        return (-1, 1)
    return None

    
def plot_pairwise(df, a1, a2, metric='pccmean', with_kde=False, size_range=(5, 100), cmap='binary'):
    y_lim = x_lim = get_axis_limit(metric)
    min_color = 0.7
    a1_dt = df.loc[df.method==a1][['name', metric]].set_index('name')
    a2_dt = df.loc[df.method==a2][['name', metric]].set_index('name')
    size_dt = df.groupby('name')['size'].first()
    table = pd.concat({a1:a1_dt, a2: a2_dt, 'size':size_dt}, axis=1, sort=False)
    table = table.T.reset_index(level=1,drop=True).T.dropna()
    if not x_lim:
        cur_dist = table[[a1, a2]].values
        min_val = np.min(cur_dist)# - cur_dist.std()*5
        max_val = np.max(cur_dist)# + cur_dist.std()*5
        y_lim = x_lim = (min_val, max_val)
    size = table['size'].values.ravel()
    s = (size - size.min()) / (size.max() - size.min())*(size_range[1] - size_range[0]) + size_range[0]
    c = (size - size.min()) / (size.max() - size.min())*(1-min_color) + min_color

    g = sns.JointGrid(x=a1, y=a2, data=table, space=0, xlim=x_lim, ylim=y_lim)
    if with_kde:
        g = g.plot_joint(sns.kdeplot, cmap="Blues_d", shade_lowest=False)
    g = g.plot_joint(plt.scatter, s=s, c=c, edgecolors='#dddddd', cmap=cmap)#, alpha=0.8)
    g = g.plot_marginals(sns.kdeplot, shade=True, color='k')
    g.ax_joint.set_xticks(np.linspace(x_lim[0], x_lim[1], 5))
    g.ax_joint.set_yticks(np.linspace(y_lim[0], y_lim[1], 5))
    g.ax_joint.plot(np.linspace(x_lim[0], x_lim[1]), np.linspace(y_lim[0], y_lim[1]), c='k')
    g.set_axis_labels(a1, a2)

    divider = make_axes_locatable(g.ax_marg_y)
    cax = divider.append_axes('right', size='25%', pad=0.15)
    #g.ax_marg_y.set_yscale('log')
    cbar = plt.colorbar(cax=cax, ticks=[min_color, 1], orientation='vertical')
    cbar.ax.set_yticklabels(['Small', 'Large'])  # horizontal colorbar
    cbar.set_label('Dataset size' , fontsize='large')
    out = ipywidgets.Output()
    with out:
        plt.show()
    plt.close()
    return out
    
def all_pairwise(dataset, **kwargs):
    df = algo_per_dt[dataset]
    algos = sorted(df.method.unique())
    for a1, a2 in itertools.combinations(algos, 2):
        plot_pairwise(df, a1, a2, **kwargs)

        
def for_interactive(dataset='pubchemtox'):
    df = algo_per_dt[dataset]
    algos = sorted(df.method.unique())

    metric = ipywidgets.Dropdown(
        options=['msemean', 'r2mean', 'pccmean'],
        value="pccmean",
        description='Metric',
        disabled=False
    )

    with_kde = ipywidgets.ToggleButton(
        value=False,
        description='KDE',
        disabled=False,
        button_style='', # 'success', 'info', 'warning', 'danger' or ''
        tooltip='Use KDE',
        icon='check'
    )

    cmap = ipywidgets.Text(
        value='binary',
        placeholder='Enter plt color map',
        description='cmap:',
        disabled=False,
        continuous_update=False
    )

    size_range = ipywidgets.IntRangeSlider(
        value=[5, 100],
        min=0,
        max=500,
        step=1,
        description='Size:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='d',
    )
    
    return df, algos, metric, with_kde, size_range, cmap
    

df, algos, wmetric, with_kde, size_range, cmap = for_interactive('pubchemtox')
a1 = ipywidgets.Select(
    options=algos,
    value=algos[0],
    description='Method 1',
    disabled=False
)

a2 = ipywidgets.Select(
    options=algos[1:],
    value=algos[1],
    description='Method 2',
    disabled=False
)

def valid_algo(*args):
    allowed = [x for x in algos if x!= a1.value]
    with a2.hold_trait_notifications():
        a2.value = set(a2.options).intersection(allowed).pop()
        a2.options=allowed


a1.observe(valid_algo, 'value')
ipywidgets.interact(partial(plot_pairwise, df=df), a1=a1, a2=a2, metric=wmetric, with_kde=with_kde, size_range=size_range, cmap=cmap)
  
#all_pairwise('pubchemtox', metric='pccmean', with_kde=False, cmap='hot')

In [None]:
sns.set_style('ticks')
def plot_with_reference(df, metric='pccmean', ref=None):
    all_algos = set(df.method.unique())
    if ref and ref in all_algos:
        all_algos.remove(ref)
        ref = df.loc[df.method==ref][['name', metric]].set_index('name')
    else:
        ref = None
    for algo in all_algos:
        dt = df.loc[df.method==algo][['name', metric]].set_index('name')
        if ref is not None:
            dt = dt - ref
        dt = dt.dropna()
        sns.kdeplot(dt[metric], shade=True, label=algo)
    plt.legend(frameon=True, fontsize='medium', fancybox=True)
    sns.despine(top=True, right=True, offset=5)
    plt.autoscale(enable=True, axis='x', tight=True)
    #plt.savefig("xxx.pdf")
    plt.show()

In [None]:
plot_with_reference(algo_per_dt['pubchemtox'], ref=None)

In [None]:
algo_per_dt['pubchemtox'].loc[algo_per_dt['pubchemtox'].method=='metakrr_sk_cnn'].head()