In [1]:
import pandas as pd
import sys
sys.path.append("..")
from src.Eval import eval_llm

import os
import io
from contextlib import redirect_stdout


In [2]:
import sys
sys.path.append("..")
from src.Eval import eval_llm

import os
import io
from contextlib import redirect_stdout

def clean_dict(d):
    """
    Recursively remove keys from a dictionary if they map to empty lists.
    """
    if isinstance(d, dict):
        return {k: clean_dict(v) for k, v in d.items() if clean_dict(v) != {} and v != []}
    elif isinstance(d, list):
        return [clean_dict(item) for item in d if clean_dict(item) != {} and item != []]
    else:
        return d

def parse_filename(filename):
    parts = filename.replace(".jsonl", "").split("_")
    dataset = parts[0]
    k_shot = parts[1].split(".")[0]
    k = int(k_shot.replace("shot", ""))
    if k and k != 0 and len(parts) == 3:
        sampling_method = parts[2]
    else:
        sampling_method = "rand"
    return dataset, k, sampling_method

def silent_eval_llm(fpath):
    with io.StringIO() as buf, redirect_stdout(buf):
        # Call the actual eval_llm function
        return eval_llm(fpath)
    
def aggregated_results_llm(
    results_dir,
    ent_types=["overall", "Artist", "WoA"],
    eval_schemas=["strict", "exact", "ent_type"],
    metrics=["f1", "f1_macro", "f1_micro", "precision", "precision_macro", "recall", "recall_macro", "missed", "spurious", "incorrect"],
    datasets=["dataset1", "dataset2", "dataset3", "dataset4", "dataset5"],
    sampling_methods=["rand", "tfidf", ""],
    ks=[0,5,15,25,35,45]
):
    results = {}
    for schema in eval_schemas:
        results[schema] = {}
        for ent_type in ent_types:
            results[schema][ent_type] = {}
            for metric in metrics:
                results[schema][ent_type][metric] = {}
                for model in [m for m in os.listdir(results_dir) if m != "archive"]:
                    results[schema][ent_type][metric][model] = {}
                    for sampling_method in sampling_methods:
                        results[schema][ent_type][metric][model][sampling_method] = {}
                        for k in ks:
                            results[schema][ent_type][metric][model][sampling_method][k] = []

    model_dirs = [
        os.path.join(results_dir, d)
        for d in os.listdir(results_dir)
        if os.path.isdir(os.path.join(results_dir, d)) and d != "archive"
    ]

    for model_dir in model_dirs:
        model = model_dir.split(os.sep)[-1]
        files = [
            os.path.join(model_dir, f)
            for f in os.listdir(model_dir)
            if os.path.isfile(os.path.join(model_dir, f))
        ]

        for fpath in files:
            print(fpath)
            filename = os.path.basename(fpath)
            dataset, k, sampling_method = parse_filename(filename)

            if dataset in datasets and k in ks:
                predictions = silent_eval_llm(fpath)

                for ent_type in ent_types:
                    for schema in eval_schemas:
                        for metric in metrics:
                            key = f"{ent_type}_{schema}_{metric}"
                            if key in predictions.keys():
                                results[schema][ent_type][metric][model][sampling_method][k].append(predictions[key])
    return clean_dict(results)

results = aggregated_results_llm("../output/reddit+shsyt/")
#results_tfidf = aggregated_results_llm("../output/tfidf_sampling")



2024-09-07 17:32:15 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/llama3.1/dataset4_0shot.jsonl


2024-09-07 17:32:15 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/llama3.1/dataset5_0shot.jsonl
../output/reddit+shsyt/llama3.1/dataset3_0shot.jsonl


2024-09-07 17:32:15 root INFO: Imported 658 predictions for 658 true examples
2024-09-07 17:32:16 root INFO: Imported 660 predictions for 660 true examples


../output/reddit+shsyt/llama3.1/dataset2_0shot.jsonl


2024-09-07 17:32:16 root INFO: Imported 660 predictions for 660 true examples


../output/reddit+shsyt/llama3.1/dataset1_5shot_rand.jsonl


2024-09-07 17:32:16 root INFO: Imported 193 predictions for 193 true examples


../output/reddit+shsyt/llama3.1/dataset1_25shot_rand.jsonl
../output/reddit+shsyt/llama3.1:70b/dataset2_5shot_tfidf.jsonl


2024-09-07 17:32:16 root INFO: Imported 660 predictions for 660 true examples
2024-09-07 17:32:17 root INFO: Imported 660 predictions for 660 true examples


../output/reddit+shsyt/llama3.1:70b/dataset1_25shot_tfidf.jsonl


2024-09-07 17:32:17 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/llama3.1:70b/dataset3_5shot_tfidf.jsonl


2024-09-07 17:32:17 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/llama3.1:70b/dataset5_5shot_tfidf.jsonl


2024-09-07 17:32:17 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/llama3.1:70b/dataset4_25shot_tfidf.jsonl


2024-09-07 17:32:17 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/llama3.1:70b/dataset5_5shot_rand.jsonl


2024-09-07 17:32:18 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/llama3.1:70b/dataset4_0shot.jsonl


2024-09-07 17:32:18 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/llama3.1:70b/dataset5_0shot.jsonl


2024-09-07 17:32:18 root INFO: Imported 660 predictions for 660 true examples


../output/reddit+shsyt/llama3.1:70b/dataset2_25shot_rand.jsonl


2024-09-07 17:32:18 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/llama3.1:70b/dataset3_25shot_tfidf.jsonl


2024-09-07 17:32:19 root INFO: Imported 660 predictions for 660 true examples


../output/reddit+shsyt/llama3.1:70b/dataset2_5shot_rand.jsonl


2024-09-07 17:32:19 root INFO: Imported 660 predictions for 660 true examples


../output/reddit+shsyt/llama3.1:70b/dataset1_5shot_tfidf.jsonl


2024-09-07 17:32:19 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/llama3.1:70b/dataset4_25shot_rand.jsonl


2024-09-07 17:32:19 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/llama3.1:70b/dataset3_0shot.jsonl


2024-09-07 17:32:19 root INFO: Imported 660 predictions for 660 true examples


../output/reddit+shsyt/llama3.1:70b/dataset2_0shot.jsonl


2024-09-07 17:32:20 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/llama3.1:70b/dataset5_25shot_rand.jsonl


2024-09-07 17:32:20 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/llama3.1:70b/dataset3_5shot_rand.jsonl


2024-09-07 17:32:20 root INFO: Imported 660 predictions for 660 true examples


../output/reddit+shsyt/llama3.1:70b/dataset1_5shot_rand.jsonl


2024-09-07 17:32:20 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/llama3.1:70b/dataset3_25shot_rand.jsonl


2024-09-07 17:32:21 root INFO: Imported 660 predictions for 660 true examples


../output/reddit+shsyt/llama3.1:70b/dataset1_25shot_rand.jsonl


2024-09-07 17:32:21 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/llama3.1:70b/dataset4_5shot_tfidf.jsonl


2024-09-07 17:32:21 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/llama3.1:70b/dataset4_5shot_rand.jsonl


2024-09-07 17:32:21 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/llama3.1:70b/dataset5_25shot_tfidf.jsonl


2024-09-07 17:32:21 root INFO: Imported 660 predictions for 660 true examples


../output/reddit+shsyt/llama3.1:70b/dataset2_25shot_tfidf.jsonl


2024-09-07 17:32:22 root INFO: Imported 660 predictions for 660 true examples


../output/reddit+shsyt/firefunction-v2/dataset2_5shot_tfidf.jsonl


2024-09-07 17:32:22 root INFO: Imported 660 predictions for 660 true examples


../output/reddit+shsyt/firefunction-v2/dataset1_25shot_tfidf.jsonl


2024-09-07 17:32:22 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/firefunction-v2/dataset3_5shot_tfidf.jsonl


2024-09-07 17:32:22 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/firefunction-v2/dataset5_5shot_tfidf.jsonl


2024-09-07 17:32:23 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/firefunction-v2/dataset4_25shot_tfidf.jsonl


2024-09-07 17:32:23 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/firefunction-v2/dataset5_5shot_rand.jsonl


2024-09-07 17:32:23 root INFO: Imported 660 predictions for 660 true examples


../output/reddit+shsyt/firefunction-v2/dataset2_25shot_rand.jsonl


2024-09-07 17:32:23 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/firefunction-v2/dataset3_25shot_tfidf.jsonl
../output/reddit+shsyt/firefunction-v2/dataset2_5shot_rand.jsonl


2024-09-07 17:32:24 root INFO: Imported 660 predictions for 660 true examples
2024-09-07 17:32:24 root INFO: Imported 660 predictions for 660 true examples


../output/reddit+shsyt/firefunction-v2/dataset1_5shot_tfidf.jsonl


2024-09-07 17:32:24 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/firefunction-v2/dataset4_25shot_rand.jsonl


2024-09-07 17:32:24 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/firefunction-v2/dataset5_25shot_rand.jsonl


2024-09-07 17:32:24 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/firefunction-v2/dataset3_5shot_rand.jsonl
../output/reddit+shsyt/firefunction-v2/dataset1_5shot_rand.jsonl


2024-09-07 17:32:25 root INFO: Imported 660 predictions for 660 true examples
2024-09-07 17:32:25 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/firefunction-v2/dataset3_25shot_rand.jsonl
../output/reddit+shsyt/firefunction-v2/dataset1_25shot_rand.jsonl


2024-09-07 17:32:25 root INFO: Imported 660 predictions for 660 true examples
2024-09-07 17:32:25 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/firefunction-v2/dataset4_5shot_tfidf.jsonl


2024-09-07 17:32:25 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/firefunction-v2/dataset4_5shot_rand.jsonl


2024-09-07 17:32:26 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/firefunction-v2/dataset5_25shot_tfidf.jsonl


2024-09-07 17:32:26 root INFO: Imported 660 predictions for 660 true examples


../output/reddit+shsyt/firefunction-v2/dataset2_25shot_tfidf.jsonl


2024-09-07 17:32:26 root INFO: Imported 660 predictions for 660 true examples


../output/reddit+shsyt/mixtral:8x22b/dataset2_5shot_tfidf.jsonl


2024-09-07 17:32:26 root INFO: Imported 660 predictions for 660 true examples


../output/reddit+shsyt/mixtral:8x22b/dataset1_25shot_tfidf.jsonl


2024-09-07 17:32:27 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/mixtral:8x22b/dataset3_5shot_tfidf.jsonl


2024-09-07 17:32:27 root INFO: Imported 498 predictions for 498 true examples


../output/reddit+shsyt/mixtral:8x22b/dataset4_25shot_tfidf.jsonl
../output/reddit+shsyt/mixtral:8x22b/dataset2_25shot_rand.jsonl


2024-09-07 17:32:27 root INFO: Imported 660 predictions for 660 true examples
2024-09-07 17:32:27 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/mixtral:8x22b/dataset3_25shot_tfidf.jsonl


2024-09-07 17:32:27 root INFO: Imported 660 predictions for 660 true examples


../output/reddit+shsyt/mixtral:8x22b/dataset2_5shot_rand.jsonl


2024-09-07 17:32:28 root INFO: Imported 660 predictions for 660 true examples


../output/reddit+shsyt/mixtral:8x22b/dataset1_5shot_tfidf.jsonl


2024-09-07 17:32:28 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/mixtral:8x22b/dataset4_25shot_rand.jsonl


2024-09-07 17:32:28 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/mixtral:8x22b/dataset3_5shot_rand.jsonl


2024-09-07 17:32:28 root INFO: Imported 660 predictions for 660 true examples


../output/reddit+shsyt/mixtral:8x22b/dataset1_5shot_rand.jsonl


2024-09-07 17:32:29 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/mixtral:8x22b/dataset3_25shot_rand.jsonl


2024-09-07 17:32:29 root INFO: Imported 660 predictions for 660 true examples


../output/reddit+shsyt/mixtral:8x22b/dataset1_25shot_rand.jsonl


2024-09-07 17:32:29 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/mixtral:8x22b/dataset4_5shot_tfidf.jsonl


2024-09-07 17:32:29 root INFO: Imported 658 predictions for 658 true examples


../output/reddit+shsyt/mixtral:8x22b/dataset4_5shot_rand.jsonl


2024-09-07 17:32:29 root INFO: Imported 660 predictions for 660 true examples


../output/reddit+shsyt/mixtral:8x22b/dataset2_25shot_tfidf.jsonl


In [3]:
import pandas as pd
import numpy as np

def results_to_dataframe(aggregated_results, agg_func='mean'):
    data = []

    for schema, schema_dict in aggregated_results.items():
        for ent_type, ent_type_dict in schema_dict.items():
            for metric, metric_dict in ent_type_dict.items():
                for model, model_dict in metric_dict.items():
                    for sampling, sampling_dict in model_dict.items():
                        for k_shot, values in sampling_dict.items():
                            if values:
                                if agg_func == 'mean':
                                    agg_value = np.mean(values)
                                elif agg_func == 'sum':
                                    agg_value = np.sum(values)
                                else:
                                    raise ValueError("Invalid aggregation function. Use 'mean' or 'sum'.")
                                
                                data.append([schema, ent_type, metric, model, sampling, k_shot, agg_value])

    df = pd.DataFrame(data, columns=['Schema', 'Entity Type', 'Metric', 'Model', 'Sampling', 'k', 'Value'])
    df.Model = df.Model.str.replace("mistral", "Mistral-7B").str.replace("mixtral", "Mixtral-8x22B").str.replace("gpt-3.5-turbo-0125", "GPT-3.5-Turbo")
    return df.set_index(['Schema', 'Entity Type', 'Metric', 'Model', 'Sampling', 'k']).unstack(['Schema', 'Entity Type', 'Metric'])

results_to_dataframe(results)


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value
Unnamed: 0_level_1,Unnamed: 1_level_1,Schema,strict,strict,strict,strict,strict,strict,strict,strict,strict,strict,...,ent_type,ent_type,ent_type,ent_type,ent_type,ent_type,ent_type,ent_type,ent_type,ent_type
Unnamed: 0_level_2,Unnamed: 1_level_2,Entity Type,overall,overall,overall,overall,Artist,Artist,Artist,Artist,Artist,Artist,...,Artist,Artist,Artist,Artist,WoA,WoA,WoA,WoA,WoA,WoA
Unnamed: 0_level_3,Unnamed: 1_level_3,Metric,f1_macro,f1_micro,precision_macro,recall_macro,f1,precision,recall,missed,spurious,incorrect,...,recall,missed,spurious,incorrect,f1,precision,recall,missed,spurious,incorrect
Model,Sampling,k,Unnamed: 3_level_4,Unnamed: 4_level_4,Unnamed: 5_level_4,Unnamed: 6_level_4,Unnamed: 7_level_4,Unnamed: 8_level_4,Unnamed: 9_level_4,Unnamed: 10_level_4,Unnamed: 11_level_4,Unnamed: 12_level_4,Unnamed: 13_level_4,Unnamed: 14_level_4,Unnamed: 15_level_4,Unnamed: 16_level_4,Unnamed: 17_level_4,Unnamed: 18_level_4,Unnamed: 19_level_4,Unnamed: 20_level_4,Unnamed: 21_level_4,Unnamed: 22_level_4,Unnamed: 23_level_4
Mixtral-8x22B:8x22b,rand,5,0.747285,0.749061,0.717799,0.780716,0.786231,0.772794,0.800187,54.75,67.0,14.5,...,0.830722,54.75,67.0,4.0,0.782668,0.732515,0.840913,45.25,90.0,3.25
Mixtral-8x22B:8x22b,rand,25,0.769663,0.770885,0.736416,0.80764,0.800028,0.783301,0.817891,48.25,63.25,15.25,...,0.846985,48.25,63.25,5.25,0.801765,0.74809,0.864362,37.0,84.0,4.25
Mixtral-8x22B:8x22b,tfidf,5,0.757742,0.759988,0.722869,0.796984,0.792092,0.762174,0.824698,46.5,74.75,14.5,...,0.852179,46.5,74.75,5.0,0.791548,0.748319,0.841272,44.25,82.0,4.0
Mixtral-8x22B:8x22b,tfidf,25,0.78083,0.782625,0.747317,0.818234,0.810681,0.783039,0.840586,39.25,62.5,14.0,...,0.870796,39.25,62.5,4.25,0.810807,0.768576,0.858916,36.25,71.0,4.5
firefunction-v2,rand,5,0.752028,0.754567,0.761079,0.743824,0.777905,0.77588,0.780395,56.6,59.2,20.4,...,0.801034,56.6,59.2,13.2,0.783173,0.804999,0.762685,70.0,53.8,2.2
firefunction-v2,rand,25,0.764807,0.765908,0.778371,0.752175,0.777896,0.786833,0.769788,67.8,60.2,13.6,...,0.790731,67.8,60.2,6.2,0.808418,0.828142,0.789822,59.8,45.6,4.0
firefunction-v2,tfidf,5,0.758773,0.761246,0.768487,0.749964,0.782906,0.782945,0.783378,53.0,53.8,23.0,...,0.805271,53.0,53.8,15.4,0.787018,0.807983,0.767456,67.6,52.2,3.0
firefunction-v2,tfidf,25,0.776149,0.77807,0.781329,0.771571,0.796028,0.791832,0.800731,53.6,58.0,16.4,...,0.824112,53.6,58.0,8.2,0.80344,0.819015,0.788617,60.6,49.2,3.6
llama3.1,rand,0,0.730728,0.733185,0.727095,0.734732,0.76043,0.76195,0.759039,67.5,66.25,15.75,...,0.786751,67.5,66.25,6.25,0.763855,0.754432,0.773931,64.75,72.5,3.0
llama3.1,rand,5,0.739356,0.742094,0.725537,0.757064,0.803235,0.820937,0.78628,62.0,46.0,19.0,...,0.796834,62.0,46.0,15.0,0.76652,0.715068,0.825949,51.0,100.0,4.0


In [4]:
results

{'strict': {'overall': {'f1_macro': {'llama3.1': {'rand': {0: [0.7541778460519967,
       0.7161413110210668,
       0.7239345509893454,
       0.7286579278457901],
      5: [0.7393558703508001],
      25: [0.5873144157088122]}},
    'llama3.1:70b': {'rand': {0: [0.7456525363678046,
       0.7208214980394578,
       0.7166172106824926,
       0.7248613037447988],
      5: [0.7492697751746438,
       0.7360364215385755,
       0.7357053880823501,
       0.7359636283232738,
       0.7358553272676918],
      25: [0.777991070372942,
       0.7638679403385286,
       0.7585194626924209,
       0.7529138521503851,
       0.752461467678859]},
     'tfidf': {5: [0.7523381446779775,
       0.7359286775631501,
       0.7513060613466944,
       0.7583700111898546,
       0.761890157982859],
      25: [0.7714950717099187,
       0.792104628034521,
       0.729653832116364,
       0.74589320697542,
       0.7803715159697182]}},
    'firefunction-v2': {'rand': {5: [0.7409455944891575,
       0.76042

# Precision

In [5]:
_results = results_to_dataframe(results, "precision", "mean").reset_index()
_results[(_results.Schema == "strict") & (_results.k == 0)].drop(["Schema", "k"], axis=1)


TypeError: results_to_dataframe() takes from 1 to 2 positional arguments but 3 were given

# Recall

In [None]:
_results = results_to_dataframe(results, "recall", "mean").reset_index()
_results[(_results.Schema == "strict") & (_results.k == 0)].drop(["Schema", "k"], axis=1)


Unnamed: 0,Entity Type,Model,Value
0,Artist,llama3.1,0.756232
5,Artist,llama3.1-70b,0.779171
6,WoA,llama3.1,0.716194
11,WoA,llama3.1-70b,0.775184


# F1

In [None]:
results_to_dataframe(results, "f1_macro", "mean")


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Value
Schema,Entity Type,Model,k,Unnamed: 4_level_1
strict,overall,llama3.1,0,0.736789
strict,overall,llama3.1,5,0.736789
strict,overall,llama3.1,15,0.736789
strict,overall,llama3.1,25,0.736789
strict,overall,llama3.1,35,0.736789
strict,overall,llama3.1-70b,0,0.725325
exact,overall,llama3.1,0,0.751727
exact,overall,llama3.1,5,0.751727
exact,overall,llama3.1,15,0.751727
exact,overall,llama3.1,25,0.751727


In [None]:
results_to_dataframe(results, "f1", "mean")


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Value
Schema,Entity Type,Model,k,Unnamed: 4_level_1
strict,Artist,llama3.1-70b,0,0.771081
strict,Artist,llama3.1-8b,0,0.763753
strict,WoA,llama3.1-70b,0,0.67957
strict,WoA,llama3.1-8b,0,0.709825
exact,Artist,llama3.1-70b,0,0.833947
exact,Artist,llama3.1-8b,0,0.784851
exact,WoA,llama3.1-70b,0,0.683969
exact,WoA,llama3.1-8b,0,0.718603
ent_type,Artist,llama3.1-70b,0,0.793239
ent_type,Artist,llama3.1-8b,0,0.788424
