In [1]:
import itertools
import os
import numpy as np
import pandas as pd

In [2]:
d_experiment_logs = './logs/'

In [3]:
test_logs = [f"{d_experiment_logs}/{x}" for x in sorted(os.listdir(d_experiment_logs)) if x.startswith('test')]
train_logs = [f"{d_experiment_logs}/{x}" for x in sorted(os.listdir(d_experiment_logs)) if x.startswith('train')]
test_logs, train_logs

(['./logs//test_1.log',
  './logs//test_2.log',
  './logs//test_3.log',
  './logs//test_4.log',
  './logs//test_5.log',
  './logs//test_6.log',
  './logs//test_7.log',
  './logs//test_8.log'],
 ['./logs//train_1.log',
  './logs//train_2.log',
  './logs//train_3.log',
  './logs//train_4.log',
  './logs//train_5.log',
  './logs//train_6.log',
  './logs//train_7.log',
  './logs//train_8.log'])

In [4]:
def add_metric_val(metrics: dict, dataset: str, method: str, metric_name: str, metric_val: float) -> None:
    method = method.split('-')[0]
    if dataset not in metrics:
        metrics[dataset] = {}
    if method not in metrics[dataset]:
        metrics[dataset][method] = {}
    if metric_name not in metrics[dataset][method]:
        metrics[dataset][method][metric_name] = []
    metrics[dataset][method][metric_name].append(metric_val)


total_dfs = dict()

for train_log, test_log in zip(train_logs, test_logs):
    train_params = None
    metrics = {}
    with open(train_log) as train_in, open(test_log) as test_in:
        for line in test_in:
            if line.startswith('Namespace('):
                line = line.replace('Namespace', 'dict')
                train_params = eval(line)
            else:
                if line.startswith(' * Acc@1'):
                    acc = float(line.split()[-1])
                    add_metric_val(metrics, train_params['data_name'], train_params['log'], 'accuracy', acc)
                elif line.startswith(' * Acc1'):
                    acc = float(line.split()[2])
                    add_metric_val(metrics, train_params['data_name'], train_params['log'], 'accuracy', acc)
                elif line.startswith('PR AUC') or line.startswith('F1 PR AUC'):
                    pr_auc = float(line.split()[-1])
                    add_metric_val(metrics, train_params['data_name'], train_params['log'], 'pr_auc', pr_auc)
                elif line.startswith('ROC AUC') or line.startswith('F1 ROC AUC'):
                    roc_auc = float(line.split()[-1])
                    add_metric_val(metrics, train_params['data_name'], train_params['log'], 'roc_auc', roc_auc)
    
        for ds_name, ds in metrics.items():
            cur_metrics = None
            result = []
            for method_name, method_data in ds.items():
                if cur_metrics is None:
                    cur_metrics = list(method_data.keys())
                row = dict(zip(['method', *cur_metrics],
                               [method_name] +
                               [f'{np.mean(method_data[metric]):.3f}' for metric in cur_metrics]))
                if method_name == 'src_only':
                    result.insert(0, row)
                else:
                    result.append(row)
            df = pd.DataFrame(result).set_index('method').astype(float)

            if ds_name not in total_dfs.keys():
                total_dfs[ds_name] = df.sort_index()
            else:
                total_dfs[ds_name] = pd.concat([total_dfs[ds_name], df]).sort_index()

import functools
resulting_df = functools.reduce(lambda x, y: pd.concat([x, y], axis=1), total_dfs.values())
resulting_df.columns = pd.MultiIndex.from_product([[k for k in total_dfs.keys()], list(total_dfs.values())[0].columns], names=['data', 'metric_type'])
display(resulting_df)

data,His.ALL.05.H3K79me1.AllCell.dm6.mm10,His.ALL.05.H3K79me1.AllCell.dm6.mm10,His.ALL.05.H3K79me1.AllCell.dm6.mm10,His.ALL.05.H3K79me1.AllCell.hg38.ce11,His.ALL.05.H3K79me1.AllCell.hg38.ce11,His.ALL.05.H3K79me1.AllCell.hg38.ce11,His.ALL.05.H3K79me1.AllCell.ce11.mm10,His.ALL.05.H3K79me1.AllCell.ce11.mm10,His.ALL.05.H3K79me1.AllCell.ce11.mm10,His.ALL.05.H3K79me1.AllCell.mm10.hg38,...,His.ALL.05.H3K79me1.AllCell.mm10.ce11,His.ALL.05.H3K79me1.AllCell.hg38.dm6,His.ALL.05.H3K79me1.AllCell.hg38.dm6,His.ALL.05.H3K79me1.AllCell.hg38.dm6,His.ALL.05.H3K79me1.AllCell.ce11.hg38,His.ALL.05.H3K79me1.AllCell.ce11.hg38,His.ALL.05.H3K79me1.AllCell.ce11.hg38,His.ALL.05.H3K79me1.AllCell.dm6.ce11,His.ALL.05.H3K79me1.AllCell.dm6.ce11,His.ALL.05.H3K79me1.AllCell.dm6.ce11
metric_type,accuracy,pr_auc,roc_auc,accuracy,pr_auc,roc_auc,accuracy,pr_auc,roc_auc,accuracy,...,roc_auc,accuracy,pr_auc,roc_auc,accuracy,pr_auc,roc_auc,accuracy,pr_auc,roc_auc
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
adda,72.064,80.279,80.571,51.136,61.861,55.846,57.091,62.195,64.091,53.693,...,57.921,62.253,64.639,67.116,47.443,53.278,49.99,63.494,62.002,66.662
afn,76.649,87.276,85.09,56.676,61.314,56.082,68.191,78.216,78.335,55.966,...,39.119,61.614,64.98,65.478,53.125,58.624,51.68,60.085,66.107,67.956
cdan,67.753,74.619,76.44,53.125,55.542,60.98,68.444,79.944,78.894,62.216,...,50.538,57.118,59.263,59.394,48.295,58.72,52.903,52.841,57.759,51.153
dan,70.62,78.989,81.306,43.466,48.993,40.648,64.796,73.241,75.901,57.67,...,45.33,61.457,65.206,67.608,49.716,51.831,45.855,54.972,55.677,44.358
dann,67.342,74.567,77.923,50.0,49.948,49.961,68.03,79.929,78.784,57.102,...,45.272,55.769,58.815,58.67,48.864,51.333,45.643,50.426,56.507,49.697
jan,66.458,71.571,75.149,46.449,50.905,46.63,66.413,80.02,78.362,52.841,...,46.013,58.817,63.734,62.241,49.716,54.156,49.025,48.011,51.013,44.03
mcc,66.32,74.205,74.552,47.953,50.376,42.016,69.432,80.009,79.269,63.889,...,44.294,59.457,63.817,64.796,49.444,54.256,46.346,58.48,58.471,54.263
mcd,78.057,87.813,86.89,50.426,55.938,44.054,62.351,73.019,76.963,55.966,...,42.054,60.562,62.703,65.939,52.841,57.227,51.395,56.25,73.088,72.658
mdd,66.148,74.36,72.414,57.244,57.201,59.631,65.287,73.219,76.922,57.955,...,44.193,62.131,64.817,66.202,52.841,57.772,54.612,51.989,56.934,57.816
src_only,75.684,84.573,84.069,59.801,62.41,55.869,66.015,71.292,75.143,61.08,...,46.656,60.994,64.756,66.204,50.852,59.018,52.002,60.938,65.223,59.565


In [6]:
resulting_df.loc[:, resulting_df.columns.get_level_values(0).str.contains('mm10.hg38|hg38.mm10')]

data,His.ALL.05.H3K79me1.AllCell.mm10.hg38,His.ALL.05.H3K79me1.AllCell.mm10.hg38,His.ALL.05.H3K79me1.AllCell.mm10.hg38,His.ALL.05.H3K79me1.AllCell.hg38.mm10,His.ALL.05.H3K79me1.AllCell.hg38.mm10,His.ALL.05.H3K79me1.AllCell.hg38.mm10
metric_type,accuracy,pr_auc,roc_auc,accuracy,pr_auc,roc_auc
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
adda,53.693,54.429,54.667,67.958,78.54,80.29
afn,55.966,54.956,52.15,72.523,79.769,78.661
cdan,62.216,60.335,60.802,63.674,69.564,69.36
dan,57.67,58.349,57.568,69.719,72.068,75.466
dann,57.102,55.563,55.531,62.009,65.869,66.634
jan,52.841,58.29,57.403,62.769,64.356,67.152
mcc,63.889,64.291,63.264,66.808,74.132,72.744
mcd,55.966,60.269,58.065,66.695,71.375,73.396
mdd,57.955,56.252,55.068,68.376,73.156,75.076
src_only,61.08,63.96,62.24,67.845,78.722,76.738


In [201]:
import plotly.express as px

plotly_df = pd.DataFrame()
for data_name, data_df in total_dfs.items():
    reference = data_df.loc["src_only", :].to_frame().T
    reference.index.name = "method"

    diff_df = (
        pd.DataFrame(
            np.subtract(data_df.values, reference.values),
            index=data_df.index,
            columns=data_df.columns,
        ).iloc[:-1,:]
        .reset_index()
        .assign(data_name="->".join(data_name.rsplit(".", maxsplit=2)[-2:]))
        .melt(
            id_vars=["data_name", "method"],
            value_vars=["accuracy", "pr_auc", "roc_auc"],
            var_name='metric',
            value_name='diff'
        )
    )

    plotly_df = pd.concat([plotly_df, diff_df])

plotly_df = plotly_df.sort_values(['data_name', 'method','metric'], ascending=[False, True, True])
plotly_df


Unnamed: 0,data_name,method,metric,diff
0,mm10->hg38,adda,accuracy,-7.387
9,mm10->hg38,adda,pr_auc,-9.531
18,mm10->hg38,adda,roc_auc,-7.573
1,mm10->hg38,afn,accuracy,-5.114
10,mm10->hg38,afn,pr_auc,-9.004
...,...,...,...,...
16,ce11->dm6,mcd,pr_auc,2.110
25,ce11->dm6,mcd,roc_auc,0.794
8,ce11->dm6,mdd,accuracy,3.350
17,ce11->dm6,mdd,pr_auc,3.702


In [220]:
fig = px.bar(
    plotly_df,
    facet_col="data_name",
    x="metric",
    y="diff",
    color="method",
    barmode="group",
)
fig.update_layout(
    title=f'{list(total_dfs.keys())[0].rsplit(".", maxsplit=2)[0]}',
    height=400,
    width=1800,
    margin=dict(l=10, r=10, t=50, b=10),
)

fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig.for_each_xaxis(lambda y: y.update(title=""))
fig.add_annotation(
    y=-0.25,
    x=0.5,
    text="metric",
    xref="paper",
    yref="paper",
    showarrow=False,
)

path_output = f'{dir_output}/{list(total_dfs.keys())[0].rsplit(".", maxsplit=2)[0]}'
fig.write_html(f"{path_output}.html")
fig.write_image(f"{path_output}.png")