In [None]:
import sys
import os

import matplotlib.pyplot as plt
import numpy as np

cwd = os.getcwd()
sys.path.append(os.path.abspath(os.path.join(cwd, "..")))

from evaluation.prediction_reader import score_result

In [None]:
def multi_boxplot(data, xlabel, ylabel, xticks, labels, location, save_figure=None, ylim=None):
    
    # plt.rcParams.update({"font.size": 12, "font.family": "serif", "font.serif": ["NewComputerModern10"]})
    # plt.rcParams.update({'font.size': 12, 'font.family': 'sans-serif'})

    plt.figure(figsize=(6, 3))

    # colour_scheme = ['lightblue', 'orange', 'mediumseagreen', 'orchid']
    colour_scheme = ['#52B297', '#8E4A93', '#009EE7', '#EF766E']

    # Set the width of each box
    box_width = 0.2


    for i in range(len(data)):
        boxprops = dict(linestyle='-', linewidth=1, color='black', facecolor=colour_scheme[i])
        medianprops = dict(linestyle='-', linewidth=1, color='black')
        meanprops = dict(linestyle='-', linewidth=1, color='none')

        plt.boxplot(data[i], positions=np.arange(len(data[i])) + i * box_width, widths=box_width, patch_artist=True,
                    boxprops=boxprops, showmeans=True, meanline=True, medianprops=medianprops, meanprops=meanprops, label=labels[i])
        

    plt.xlabel(xlabel)
    plt.ylabel(ylabel)

    central_positions = np.arange(len(data[0])) + (len(data) * box_width / 2) - (box_width / 2)
    plt.xticks(central_positions, xticks)
    
    if ylim is not None:
        plt.ylim(-0.005, ylim)
    
    plt.grid(axis='y', linestyle='--', alpha=0.7)

    if location is not None:
        plt.legend(loc=location)
        # plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=len(data), frameon=False)

    if save_figure is not None:
        os.makedirs('figures', exist_ok=True)
        plt.savefig(f'figures/{save_figure}.pdf', format='pdf', bbox_inches='tight')

    
    plt.show()

In [None]:
# Best Models
bs = 4

recall_base, precision_base, score_one_base, score_all_base = score_result('../evaluation/alohomora/result_metatrans.csv', bs, False)
# recall_random, precision_random, score_one_random, score_all_random = score_result('../evaluation/alohomora/result_comb_metatrans_random_5_per_model.csv', bs, False)
# recall_base, precision_base, score_one_base, score_all_base = score_result('../evaluation/alohomora/result_base_rand.csv', bs, False)
# recall_random, precision_random, score_one_random, score_all_random = score_result('../evaluation/alohomora/result_comb_random_split_5_per_model.csv', bs, False)
recall_initial, precision_initial, score_one_initial, score_all_initial = score_result('../evaluation/alohomora/result_chemf_base_metabolic.csv', bs, False)

base = [precision_base[-2], recall_base[-2], score_one_base[-2], score_all_base[-2]]
initial = [precision_initial[-2], recall_initial[-2], score_one_initial[-2], score_all_initial[-2]]

xlabel = None
xticks = ['Precision', 'Recall', 'At Least One Met.', 'All Met.']
labels = ['ChemVA-Met Fine-Tuned Rand', 'Chemformer Fine-Tuned']

ylabel = 'Proportion'
data = [base, initial]
multi_boxplot(data, xlabel, ylabel, xticks, labels, 'lower right', 'best_model', 1)