In [1]:
import os
import pandas as pd
import matplotlib.pyplot as plt
from glob import glob
from functools import  reduce
from itertools import combinations
from IPython.display import display

In [2]:
def highlight_max(s):
    '''
    highlight the maximum in a Series yellow.
    '''
    s = pd.to_numeric(s, errors='coerce')
    is_max = s == s.max()
    return ['background-color: yellow' if v else '' for v in is_max]

In [23]:
def summarize_results(expts_folder, show_figures=True, select_best_on_valid=False):
    res_folders = [el for el in os.listdir(expts_folder) 
               if os.path.isdir(os.path.join(expts_folder, el))]
    metric = "pcc_mean"
    dfs = []
    for res_folder in res_folders:
        print(res_folder)
        if select_best_on_valid:
            template = os.path.join(expts_folder, res_folder) + '/log*'
            fnames = glob(template)
            best_val_loss = float('Inf')
            best_hp =None
            for fname in fnames:
                try:
                    temp = pd.read_csv(fname, delimiter='\t')
                except pd.errors.EmptyDataError:
                    continue
                best_val_loss = temp.val_loss.min()
                best_hp = fname
            fnames = [best_hp.replace('log', 'results')]
        else:            
            template = os.path.join(expts_folder, res_folder) + '/results*'
            fnames = glob(template)
        res_all_hps = []
        temp = None
        for fname in fnames:
            try:
                temp = pd.read_csv(fname, delimiter='\t')
            except pd.errors.EmptyDataError:
                continue
            if temp.shape[1] < 2:
                temp = pd.read_csv(fname, delimiter='\t', skiprows=1)
                temp = temp.dropna(how='any', axis=0)
                temp = temp.drop_duplicates(keep=False)
                # temp = temp.loc[[temp.pcc_median.argmax()]] 
                temp = temp[pd.to_numeric(temp[metric], errors='coerce').notnull()]
                temp = temp.loc[temp.groupby(["name"])[metric].idxmax()] 
            res_all_hps.append(temp)
        
        if len(res_all_hps) > 0:
            # print(res_all_hps)
            algo_name = '_'.join(res_folder.split('_')[-2:])
            res = pd.concat(res_all_hps, ignore_index=True)
            # print(res)
            # res['hps'] = ['_'.join(f.split('/')[-1].split('_')[5:]) for f in fnames]
            res[metric] = pd.to_numeric(res[metric])
            res['size'] = pd.to_numeric(res['size'])
            res = res.loc[res.groupby(["name"])[metric].idxmax()] 
            res.to_csv(os.path.join(expts_folder, 'recap_'+res_folder+'.txt'), 
                       index=False,  float_format='%.3f')
            if show_figures:
                res.plot.scatter('size', metric)
                plt.title(algo_name)
                plt.show()
            dfs.append(res[['name', metric]].set_index('name').rename(
                columns={metric: algo_name}))

    res = pd.concat(dfs, axis=1).apply(pd.to_numeric)
    if show_figures:
        for x, y in combinations(res.columns.tolist(), 2):
            print('% x > y', res[x].gt(res[y]).mean())
            ax = res.plot.scatter(x, y)
            ax.plot((-1, 1), (-1, 1), ls="-", c=".3")
            plt.show()
    else:
        # print(res.to_latex(float_format='%.3f'))
        print(res)
        # display(res.style.apply(highlight_max, axis=1).set_precision(3))

In [24]:
summarize_results(expts_folder='/home/prtos/workspace/code/few_shot_regression/expt_results/mhc_test', show_figures=False)
# summarize_results(expts_folder='/home/prtos/workspace/code/few_shot_regression/expt_results/expts_helios', show_figures=True)

results_mhcpan_mann_cnn
results_mhcpan_pretrain_cnn
results_mhcpan_maml_cnn
results_mhcpan_krr_cnn
               mann_cnn  pretrain_cnn  maml_cnn   krr_cnn
name                                                     
HLA-DRB1*0101  0.184831      0.041413  0.486823  0.426860
HLA-DRB1*0301  0.360999      0.063412  0.542506  0.485820
HLA-DRB1*0401  0.287210      0.124768  0.546736  0.532126
HLA-DRB1*0404  0.323568      0.037275  0.561135  0.578007
HLA-DRB1*0405  0.455905      0.172535  0.600145  0.641682
HLA-DRB1*0701  0.301558      0.198806  0.693124  0.662625
HLA-DRB1*0802  0.155813      0.045334  0.425447  0.316689
HLA-DRB1*0901  0.133562      0.076647  0.524341  0.505290
HLA-DRB1*1101  0.205826      0.072667  0.577498  0.623800
HLA-DRB1*1302  0.107683      0.110303  0.499297  0.426291
HLA-DRB1*1501  0.283141      0.202750  0.612747  0.604902
HLA-DRB3*0101  0.176131      0.018379  0.380906  0.291169
HLA-DRB4*0101  0.406634      0.073280  0.545620  0.504074
HLA-DRB5*0101  0.237611      0.