In [135]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines

In [136]:
model_ls = [
    'DeepSTABp',
    'base1cd5-no_graphs',
    'base3cd5_s1n64-all_graphs',
    'base3d5_s1n64-all_graphs',
    'base3d5_s1n16-all_graphs-bs64',
]

metric_ls = {
    'mae': [0,7],
    # 'mse': [0,100],
    'rmse': [0,7],
    'pcc': [0.8,1],
    'r2': [0.5,1]
}

In [137]:
fig, ax = plt.subplots(
    2,2, sharex=True, sharey=False, figsize=(16,8), layout='constrained'
)
ax = ax.flatten()
# fig.suptitle()

i = 0
for metric_idx, metric in enumerate(metric_ls):
    df = pd.read_csv(f'{metric}.csv')
    for model_idx, model in enumerate(model_ls):
        if model == 'DeepSTABp':
            c = 'black'
        else:
            c = f'C{model_idx}'
        
        ax[i].plot(
            df['epoch'], df[f'{model} - train.{metric}'], '--',
            c=c, linewidth=1, alpha=0.8,
            #label=f'{model[5:]}'
        )
        ax[i].plot(
            df['epoch'], df[f'{model} - valid.{metric}'], '-',
            c=c, linewidth=1, alpha=0.8,
            #label=f'{model[5:]}: valid'
        )

    ax[i].set_xlabel('epoch')
    ax[i].set_ylabel(metric)
    ax[i].tick_params(labelbottom=True)
    
    ax[i].set_xlim(df['epoch'].min(), df['epoch'].max())
    ax[i].set_ylim(metric_ls[metric])

    ax[i].grid()

    # handles, labels = ax[i].get_legend_handles_labels()
    
    i += 1

handles = []
for model_idx, model in enumerate(model_ls):
    if model == 'DeepSTABp':
        c = 'black'
        label = 'DeepSTABp'
    else:
        c = f'C{model_idx}'
        label = f'{model[5:]}'

    handles.append(
        mpatches.Patch(color=c, label=label)
    )
fig.legend(
    handles=handles, bbox_to_anchor=(0.725,1.05), #loc='outside upper left',
    ncols=5, fancybox=True, shadow=True
)
handles = [
    mlines.Line2D([], [], color='black', linestyle='--', label='train'),
    mlines.Line2D([], [], color='black', linestyle='-', label='valid')
]
fig.legend(
    handles=handles, bbox_to_anchor=(0.875,1.05), #loc='outside upper right',
    ncols=2, fancybox=True, shadow=True
)

# plt.show()
plt.savefig('metrics.png', dpi=300, bbox_inches='tight')
plt.close()