In [72]:
import os
import pandas as pd
import json
from IPython.display import display, Markdown, Latex
from tqdm.auto import tqdm

from common.consts import RESULTS_DIR, EVAL_SIZE
from common.utils import filename_to_obj, remove_index

In [73]:
def results_as_pandas(filename):
    path = os.path.join(RESULTS_DIR, filename)
    with open(path, "r") as f:
        data = f.readlines()
    data = [json.loads(d) for d in data]
    data = pd.DataFrame(data)

    params = filename_to_obj(filename)
    for k, v in params.items():
        data[k] = v

    data = data.explode("evaluations")
    data = data.rename_axis("question_idx").reset_index()

    data = pd.concat([data, data["evaluations"].apply(pd.Series)], axis=1)
    evaluation_keys = data["evaluations"].apply(pd.Series).columns
    for col in evaluation_keys:
        data = pd.concat([data, data[col].apply(pd.Series).add_prefix(f"{col}/")], axis=1)
        data = data.drop(columns=col)
    data = data.drop(columns=["evaluations"])

    return data


files = os.listdir(RESULTS_DIR)
params_names = list(filename_to_obj(files[0]).keys())
all_results = pd.concat([results_as_pandas(f) for f in tqdm(files)])
all_results.head()

100%|██████████| 3/3 [00:00<00:00,  6.01it/s]


Unnamed: 0,question_idx,question_id,llm,prompt_id,temperature,nli,ellm,sim,citations/ais_recall,citations/ais_precision,...,citations/supported,citations/citations,citations/correct_citations,citations/out_of_range,correctness/answer_overlap,correctness/answer_entail,correctness/citations_recall,correctness/citations_precision,quality/answer_relevance,quality/new_question
0,0,5abab42e55429955dce3eed2,gpt-3.5-turbo-0125,1,0.1,t5_xxl_true_nli_mixture,Mistral-7B-Instruct-v0.2,all-MiniLM-L6-v2,0.0,1.0,...,"[0, 0]","[[], [3]]","[[], [True]]","[0, 0]",0.5,0.0,0.5,1.0,0.730736,What is the relationship between Jordan Subban...
1,0,5abab42e55429955dce3eed2,gpt-3.5-turbo-0125,1,0.1,t5_xxl_true_nli_mixture,Mistral-7B-Instruct-v0.2,all-MiniLM-L6-v2,0.0,1.0,...,"[0, 0]","[[], [3]]","[[], [True]]","[0, 0]",0.5,0.0,0.5,1.0,0.822013,What professional hockey player was drafted fi...
2,0,5abab42e55429955dce3eed2,gpt-3.5-turbo-0125,1,0.1,t5_xxl_true_nli_mixture,Mistral-7B-Instruct-v0.2,all-MiniLM-L6-v2,0.0,1.0,...,"[0, 0]","[[], [3]]","[[], [True]]","[0, 0]",0.5,0.0,0.5,1.0,0.730736,What is the relationship between Jordan Subban...
3,1,5a761900554299109176e648,gpt-3.5-turbo-0125,1,0.1,t5_xxl_true_nli_mixture,Mistral-7B-Instruct-v0.2,all-MiniLM-L6-v2,1.0,1.0,...,[1],[[1]],[[True]],[0],1.0,1.0,0.5,1.0,0.965509,What is the name of the lobbying group that wa...
4,1,5a761900554299109176e648,gpt-3.5-turbo-0125,1,0.1,t5_xxl_true_nli_mixture,Mistral-7B-Instruct-v0.2,all-MiniLM-L6-v2,1.0,1.0,...,[1],[[1]],[[True]],[0],1.0,1.0,0.5,1.0,0.974511,What is the name of the lobbying group that wa...


In [74]:
all_obj_cols = all_results.select_dtypes(include=["object"]).columns
drop_obj_cols = list(set(all_obj_cols) - set(params_names))
drop_obj_cols.remove("question_id")
print(f"Dropping columns: {drop_obj_cols}")
all_num_results = all_results.drop(columns=drop_obj_cols)

Dropping columns: ['quality/new_question', 'citations/citations', 'citations/correct_citations', 'citations/supported', 'citations/out_of_range', 'citations/sentences']


In [75]:
eval_split = all_num_results[all_num_results["question_idx"] < EVAL_SIZE]
train_split = all_num_results[all_num_results["question_idx"] >= EVAL_SIZE]

In [76]:
def aggregate(split):
    split = split.drop(columns=["question_idx"])
    results_with_std_for_each_question = split.groupby([*params_names, "question_id"]).agg(["mean", "std"])
    results_for_each_model = results_with_std_for_each_question.groupby(params_names)
    results = results_for_each_model.mean()
    results["n_questions"] = results_for_each_model.size()
    return results


eval_results = aggregate(eval_split)
train_results = aggregate(train_split)

if eval_results["n_questions"].nunique() != 1:
    print("Warning: not all rows in evaluation have the same number of examples")

In [77]:
display(Markdown("### Prompts comparison"))
parameter_results = eval_results[eval_results.index.get_level_values("llm") == "Mistral-7B-Instruct-v0.2"]
parameter_results[parameter_results.index.get_level_values("temperature") == "0.1"]

### Prompts comparison

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,citations/ais_recall,citations/ais_recall,citations/ais_precision,citations/ais_precision,citations/n_sentences,citations/n_sentences,citations/n_total_citations,citations/n_total_citations,citations/n_correct_citations,citations/n_correct_citations,...,correctness/answer_overlap,correctness/answer_entail,correctness/answer_entail,correctness/citations_recall,correctness/citations_recall,correctness/citations_precision,correctness/citations_precision,quality/answer_relevance,quality/answer_relevance,n_questions
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,mean,std,mean,std,mean,std,mean,std,mean,std,...,std,mean,std,mean,std,mean,std,mean,std,Unnamed: 26_level_1
llm,prompt_id,temperature,nli,ellm,sim,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2,Unnamed: 25_level_2,Unnamed: 26_level_2
Mistral-7B-Instruct-v0.2,1,0.1,t5_xxl_true_nli_mixture,Mistral-7B-Instruct-v0.2,all-MiniLM-L6-v2,0.575644,0.03656,0.871517,0.03835,2.573333,0.212475,3.016667,0.372709,2.676667,0.302403,...,0.014915,0.886667,0.017321,0.751667,0.025981,0.721619,0.050125,0.720431,0.02834,100


In [78]:
display(Markdown("### Temperature comparison"))
parameter_results[parameter_results.index.get_level_values("prompt_id") == "1"]

### Temperature comparison

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,citations/ais_recall,citations/ais_recall,citations/ais_precision,citations/ais_precision,citations/n_sentences,citations/n_sentences,citations/n_total_citations,citations/n_total_citations,citations/n_correct_citations,citations/n_correct_citations,...,correctness/answer_overlap,correctness/answer_entail,correctness/answer_entail,correctness/citations_recall,correctness/citations_recall,correctness/citations_precision,correctness/citations_precision,quality/answer_relevance,quality/answer_relevance,n_questions
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,mean,std,mean,std,mean,std,mean,std,mean,std,...,std,mean,std,mean,std,mean,std,mean,std,Unnamed: 26_level_1
llm,prompt_id,temperature,nli,ellm,sim,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2,Unnamed: 25_level_2,Unnamed: 26_level_2
Mistral-7B-Instruct-v0.2,1,0.1,t5_xxl_true_nli_mixture,Mistral-7B-Instruct-v0.2,all-MiniLM-L6-v2,0.575644,0.03656,0.871517,0.03835,2.573333,0.212475,3.016667,0.372709,2.676667,0.302403,...,0.014915,0.886667,0.017321,0.751667,0.025981,0.721619,0.050125,0.720431,0.02834,100
Mistral-7B-Instruct-v0.2,1,0.101,t5_xxl_true_nli_mixture,Mistral-7B-Instruct-v0.2,all-MiniLM-L6-v2,0.567611,0.061196,0.885047,0.048794,2.713333,0.237846,3.22,0.444561,2.7,0.327998,...,0.000769,0.87,0.011547,0.786667,0.034641,0.743667,0.045828,0.738541,0.033567,100


In [79]:
display(Markdown("### Evaluation results"))
eval_display = eval_results[eval_results.index.get_level_values("prompt_id") == "1"]
eval_display = eval_display[eval_display.index.get_level_values("temperature") == "0.1"]
eval_display = remove_index(eval_display, "prompt_id")
eval_display = eval_display.sort_values(by=("correctness/citations_recall", "mean"), ascending=False)
eval_display

### Evaluation results

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,prompt_id,citations/ais_recall,citations/ais_recall,citations/ais_precision,citations/ais_precision,citations/n_sentences,citations/n_sentences,citations/n_total_citations,citations/n_total_citations,citations/n_correct_citations,...,correctness/answer_overlap,correctness/answer_entail,correctness/answer_entail,correctness/citations_recall,correctness/citations_recall,correctness/citations_precision,correctness/citations_precision,quality/answer_relevance,quality/answer_relevance,n_questions
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,mean,std,mean,std,mean,std,mean,std,mean,...,std,mean,std,mean,std,mean,std,mean,std,Unnamed: 25_level_1
llm,temperature,nli,ellm,sim,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2,Unnamed: 25_level_2
gpt-3.5-turbo-0125,0.1,t5_xxl_true_nli_mixture,Mistral-7B-Instruct-v0.2,all-MiniLM-L6-v2,1,0.737556,0.072604,0.995833,0.002887,2.02,0.214022,2.143333,0.207846,2.13,...,0.01299,0.91,0.011547,0.853333,0.023094,0.941667,0.023671,0.751901,0.041098,100
Mistral-7B-Instruct-v0.2,0.1,t5_xxl_true_nli_mixture,Mistral-7B-Instruct-v0.2,all-MiniLM-L6-v2,1,0.575644,0.03656,0.871517,0.03835,2.573333,0.212475,3.016667,0.372709,2.676667,...,0.014915,0.886667,0.017321,0.751667,0.025981,0.721619,0.050125,0.720431,0.02834,100


In [80]:
def show_cleaned_results(short_eval_display):
    short_eval_display = remove_index(short_eval_display, "temperature")
    short_eval_display = remove_index(short_eval_display, "nli")
    short_eval_display = remove_index(short_eval_display, "ellm")
    short_eval_display = remove_index(short_eval_display, "sim")
    important_columns = ["citations/ais_recall", "citations/ais_precision", "correctness/answer_overlap", "correctness/answer_entail", "correctness/citations_recall", "correctness/citations_precision", "quality/answer_relevance"]
    short_eval_display = short_eval_display[important_columns]
    return short_eval_display


show_cleaned_results(eval_display)

Unnamed: 0_level_0,citations/ais_recall,citations/ais_recall,citations/ais_precision,citations/ais_precision,correctness/answer_overlap,correctness/answer_overlap,correctness/answer_entail,correctness/answer_entail,correctness/citations_recall,correctness/citations_recall,correctness/citations_precision,correctness/citations_precision,quality/answer_relevance,quality/answer_relevance
Unnamed: 0_level_1,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std
llm,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2
gpt-3.5-turbo-0125,0.737556,0.072604,0.995833,0.002887,0.884244,0.01299,0.91,0.011547,0.853333,0.023094,0.941667,0.023671,0.751901,0.041098
Mistral-7B-Instruct-v0.2,0.575644,0.03656,0.871517,0.03835,0.872812,0.014915,0.886667,0.017321,0.751667,0.025981,0.721619,0.050125,0.720431,0.02834


In [81]:
display(Markdown("### Training results"))
train_display = remove_index(train_results, "prompt_id")
train_display

### Training results

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,prompt_id,citations/ais_recall,citations/ais_recall,citations/ais_precision,citations/ais_precision,citations/n_sentences,citations/n_sentences,citations/n_total_citations,citations/n_total_citations,citations/n_correct_citations,...,correctness/answer_overlap,correctness/answer_entail,correctness/answer_entail,correctness/citations_recall,correctness/citations_recall,correctness/citations_precision,correctness/citations_precision,quality/answer_relevance,quality/answer_relevance,n_questions
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,mean,std,mean,std,mean,std,mean,std,mean,...,std,mean,std,mean,std,mean,std,mean,std,Unnamed: 25_level_1
llm,temperature,nli,ellm,sim,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2,Unnamed: 25_level_2
