In [1]:
import pandas as pd
import matplotlib.pyplot as plt

In [ ]:
def get_single_model_results_(model, df_true_accuracy, df_cost):
    model_column = f'{model}'
    f_accuracy = df_true_accuracy[model_column].sum() / len(df_true_accuracy)
    f_cost = df_cost[model_column].sum()
    return [f_cost, f_accuracy]

In [ ]:
from prediction.prediction_model import data_preprocess

datasets = ['overruling', 'agnews', 'coqa', 'headlines', 'sciq'] #'overruling', 'agnews', 'coqa', 'headlines', 'sciq']
model_list = ['gptneox_20B', 'gptj_6B', 'fairseq_gpt_13B', 'text-davinci-002', 'text-curie-001', 'gpt-3.5-turbo',
                          'gpt-4', 'j1-jumbo', 'j1-grande', 'j1-large', 'xlarge', 'medium']
data_dir = f"datasets/text_classification"
test_data_size = 0.95
for dataset in datasets:
    pred_opt = pd.read_csv(f"output/text_classification/{dataset}_{test_data_size}")
    pred_ = pd.read_csv(f"output/text_classification/prediction/{dataset}_{test_data_size}")
    
    df_pre_accuracy, df_true_accuracy, df_cost = data_preprocess(data_dir, dataset, model_list, test_size=test_data_size)
    single_model_res = {}
    for i in range(len(model_list)):
        single_model_res[model_list[i]] = get_single_model_results_(model_list[i], df_true_accuracy, df_cost)
        
     fig = plt.figure(figsize=(12, 7))
        font = {  # 'family': 'serif',
            'color': 'black',
            'weight': 'normal',
            'size': 15,
        }
        for key, value in single_model_res.items():
            plt.scatter(value[0], value[1], alpha=1, label=key)

        plt.scatter(pred_opt['cost'], pred_opt["true_accuracy"], alpha=1, c="firebrick", label='pred_opt')
        plt.scatter(pred_['cost'], pred_["true_accuracy"], alpha=1, c="blue", label='pred')
        # plt.scatter(igs_res['cost'], igs_res["true_accuracy"], alpha=1, marker="^", c="blue",
        #             label='IGAP')
        plt.xlabel('Cost (USD)', fontdict=font)
        plt.ylabel('Accuracy', fontdict=font)
        plt.xticks(fontsize=14)
        plt.yticks(fontsize=14)
        plt.grid(zorder=0, linestyle='--', axis='y')
        plt.legend()
        lgd = plt.legend(loc='upper center', bbox_to_anchor=(0.47, -0.13), ncol=7, fontsize=12)
        plt.tight_layout()
        plt.show()
        fig.savefig(f"{save_dir}/Comparison_{dataset}.png", dpi=500,
                    bbox_extra_artists=(lgd,), bbox_inches='tight')