In [1]:
import pandas as pd
import sys
sys.path.append("..")
from src.Eval import eval_llm

import os
import io
from contextlib import redirect_stdout


In [17]:
import sys
sys.path.append("..")
from src.Eval import eval_llm

import os
import io
from contextlib import redirect_stdout

def clean_dict(d):
    """
    Recursively remove keys from a dictionary if they map to empty lists.
    """
    if isinstance(d, dict):
        return {k: clean_dict(v) for k, v in d.items() if clean_dict(v) != {} and v != []}
    elif isinstance(d, list):
        return [clean_dict(item) for item in d if clean_dict(item) != {} and item != []]
    else:
        return d

def parse_filename(filename):
    parts = filename.replace(".jsonl", "").split("_")
    dataset = parts[0]
    k_shot = parts[1].split(".")[0]
    k = int(k_shot.replace("shot", ""))
    if k and k != 0 and len(parts) == 3:
        sampling_method = parts[2]
    else:
        sampling_method = "rand"
    return dataset, k, sampling_method

def silent_eval_llm(fpath):
    with io.StringIO() as buf, redirect_stdout(buf):
        # Call the actual eval_llm function
        return eval_llm(fpath)
    
def aggregated_results_llm(
    results_dir,
    ent_types=["overall", "Artist", "WoA"],
    eval_schemas=["strict", "exact", "ent_type"],
    metrics=["f1", "f1_macro", "f1_micro", "precision", "precision_macro", "recall", "recall_macro", "missed", "spurious", "incorrect"],
    datasets=["dataset1", "dataset2", "dataset3", "dataset4"],
    sampling_methods=["rand", "tfidf", ""],
    ks=[0,5,15,25,35,45]
):
    results = {}
    for schema in eval_schemas:
        results[schema] = {}
        for ent_type in ent_types:
            results[schema][ent_type] = {}
            for metric in metrics:
                results[schema][ent_type][metric] = {}
                for model in [m for m in os.listdir(results_dir) if m != "archive"]:
                    results[schema][ent_type][metric][model] = {}
                    for sampling_method in sampling_methods:
                        results[schema][ent_type][metric][model][sampling_method] = {}
                        for k in ks:
                            results[schema][ent_type][metric][model][sampling_method][k] = []

    model_dirs = [
        os.path.join(results_dir, d)
        for d in os.listdir(results_dir)
        if os.path.isdir(os.path.join(results_dir, d)) and d != "archive"
    ]

    for model_dir in model_dirs:
        model = model_dir.split(os.sep)[-1]
        files = [
            os.path.join(model_dir, f)
            for f in os.listdir(model_dir)
            if os.path.isfile(os.path.join(model_dir, f))
        ]

        for fpath in files:
            filename = os.path.basename(fpath)
            dataset, k, sampling_method = parse_filename(filename)

            if dataset in datasets and k in ks:
                predictions = silent_eval_llm(fpath)

                for ent_type in ent_types:
                    for schema in eval_schemas:
                        for metric in metrics:
                            key = f"{ent_type}_{schema}_{metric}"
                            if key in predictions.keys():
                                results[schema][ent_type][metric][model][sampling_method][k].append(predictions[key])
    return clean_dict(results)

results = aggregated_results_llm("../output/reddit+shsyt/")
#results_tfidf = aggregated_results_llm("../output/tfidf_sampling")



2024-09-04 07:21:04 root INFO: Imported 644 predictions for 644 true examples
2024-09-04 07:21:05 root INFO: Imported 367 predictions for 367 true examples
2024-09-04 07:21:05 root INFO: Imported 644 predictions for 644 true examples
2024-09-04 07:21:05 root INFO: Imported 660 predictions for 660 true examples
2024-09-04 07:21:05 root INFO: Imported 658 predictions for 658 true examples
2024-09-04 07:21:05 root INFO: Imported 644 predictions for 644 true examples
2024-09-04 07:21:05 root INFO: Imported 367 predictions for 367 true examples
2024-09-04 07:21:06 root INFO: Imported 660 predictions for 660 true examples
2024-09-04 07:21:06 root INFO: Imported 367 predictions for 367 true examples
2024-09-04 07:21:06 root INFO: Imported 660 predictions for 660 true examples
2024-09-04 07:21:06 root INFO: Imported 658 predictions for 658 true examples
2024-09-04 07:21:06 root INFO: Imported 644 predictions for 644 true examples
2024-09-04 07:21:07 root INFO: Imported 658 predictions for 658 

In [32]:
import pandas as pd
import numpy as np

def results_to_dataframe(aggregated_results, agg_func='mean'):
    data = []

    for schema, schema_dict in aggregated_results.items():
        for ent_type, ent_type_dict in schema_dict.items():
            for metric, metric_dict in ent_type_dict.items():
                for model, model_dict in metric_dict.items():
                    for sampling, sampling_dict in model_dict.items():
                        for k_shot, values in sampling_dict.items():
                            if values:
                                if agg_func == 'mean':
                                    agg_value = np.mean(values)
                                elif agg_func == 'sum':
                                    agg_value = np.sum(values)
                                else:
                                    raise ValueError("Invalid aggregation function. Use 'mean' or 'sum'.")
                                
                                data.append([schema, ent_type, metric, model, sampling, k_shot, agg_value])

    df = pd.DataFrame(data, columns=['Schema', 'Entity Type', 'Metric', 'Model', 'Sampling', 'k', 'Value'])
    df.Model = df.Model.str.replace("mistral", "Mistral-7B").str.replace("mixtral", "Mixtral-8x22B").str.replace("gpt-3.5-turbo-0125", "GPT-3.5-Turbo")
    return df.set_index(['Schema', 'Entity Type', 'Metric', 'Model', 'Sampling', 'k']).unstack(['Schema', 'Entity Type', 'Metric'])



In [33]:
results_to_dataframe(results)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value,Value
Unnamed: 0_level_1,Unnamed: 1_level_1,Schema,strict,strict,strict,strict,strict,strict,strict,strict,strict,strict,...,ent_type,ent_type,ent_type,ent_type,ent_type,ent_type,ent_type,ent_type,ent_type,ent_type
Unnamed: 0_level_2,Unnamed: 1_level_2,Entity Type,overall,overall,overall,overall,Artist,Artist,Artist,Artist,Artist,Artist,...,Artist,Artist,Artist,Artist,WoA,WoA,WoA,WoA,WoA,WoA
Unnamed: 0_level_3,Unnamed: 1_level_3,Metric,f1_macro,f1_micro,precision_macro,recall_macro,f1,precision,recall,missed,spurious,incorrect,...,recall,missed,spurious,incorrect,f1,precision,recall,missed,spurious,incorrect
Model,Sampling,k,Unnamed: 3_level_4,Unnamed: 4_level_4,Unnamed: 5_level_4,Unnamed: 6_level_4,Unnamed: 7_level_4,Unnamed: 8_level_4,Unnamed: 9_level_4,Unnamed: 10_level_4,Unnamed: 11_level_4,Unnamed: 12_level_4,Unnamed: 13_level_4,Unnamed: 14_level_4,Unnamed: 15_level_4,Unnamed: 16_level_4,Unnamed: 17_level_4,Unnamed: 18_level_4,Unnamed: 19_level_4,Unnamed: 20_level_4,Unnamed: 21_level_4,Unnamed: 22_level_4,Unnamed: 23_level_4
llama3.1,rand,0,0.736789,0.739164,0.7379,0.736213,0.763753,0.771562,0.756232,62.0,54.75,14.0,...,0.780763,62.0,54.75,7.0,0.770921,0.765083,0.777583,59.25,65.25,2.5
llama3.1,rand,5,0.736789,0.739164,0.7379,0.736213,0.763753,0.771562,0.756232,62.0,54.75,14.0,...,0.780763,62.0,54.75,7.0,0.770921,0.765083,0.777583,59.25,65.25,2.5
llama3.1,rand,15,0.736789,0.739164,0.7379,0.736213,0.763753,0.771562,0.756232,62.0,54.75,14.0,...,0.780763,62.0,54.75,7.0,0.770921,0.765083,0.777583,59.25,65.25,2.5
llama3.1,rand,25,0.736789,0.739164,0.7379,0.736213,0.763753,0.771562,0.756232,62.0,54.75,14.0,...,0.780763,62.0,54.75,7.0,0.770921,0.765083,0.777583,59.25,65.25,2.5
llama3.1,rand,35,0.736789,0.739164,0.7379,0.736213,0.763753,0.771562,0.756232,62.0,54.75,14.0,...,0.780763,62.0,54.75,7.0,0.770921,0.765083,0.777583,59.25,65.25,2.5
llama3.1,tfidf,5,0.736789,0.739164,0.7379,0.736213,0.763753,0.771562,0.756232,62.0,54.75,14.0,...,0.780763,62.0,54.75,7.0,0.770921,0.765083,0.777583,59.25,65.25,2.5
llama3.1,tfidf,15,0.736789,0.739164,0.7379,0.736213,0.763753,0.771562,0.756232,62.0,54.75,14.0,...,0.780763,62.0,54.75,7.0,0.770921,0.765083,0.777583,59.25,65.25,2.5
llama3.1,tfidf,25,0.736789,0.739164,0.7379,0.736213,0.763753,0.771562,0.756232,62.0,54.75,14.0,...,0.780763,62.0,54.75,7.0,0.770921,0.765083,0.777583,59.25,65.25,2.5
llama3.1,tfidf,35,0.736789,0.739164,0.7379,0.736213,0.763753,0.771562,0.756232,62.0,54.75,14.0,...,0.780763,62.0,54.75,7.0,0.770921,0.765083,0.777583,59.25,65.25,2.5
llama3.1-70b,rand,0,0.725325,0.726048,0.684567,0.777178,0.771081,0.763302,0.779171,46.75,54.0,30.0,...,0.801547,46.75,54.0,22.25,0.744471,0.664101,0.848518,44.75,129.0,1.5


# Precision

In [12]:
_results = results_to_dataframe(results, "precision", "mean").reset_index()
_results[(_results.Schema == "strict") & (_results.k == 0)].drop(["Schema", "k"], axis=1)


Unnamed: 0,Entity Type,Model,Value
0,Artist,llama3.1,0.771562
5,Artist,llama3.1-70b,0.763302
6,WoA,llama3.1,0.704237
11,WoA,llama3.1-70b,0.605833


# Recall

In [13]:
_results = results_to_dataframe(results, "recall", "mean").reset_index()
_results[(_results.Schema == "strict") & (_results.k == 0)].drop(["Schema", "k"], axis=1)


Unnamed: 0,Entity Type,Model,Value
0,Artist,llama3.1,0.756232
5,Artist,llama3.1-70b,0.779171
6,WoA,llama3.1,0.716194
11,WoA,llama3.1-70b,0.775184


# F1

In [14]:
results_to_dataframe(results, "f1_macro", "mean")


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Value
Schema,Entity Type,Model,k,Unnamed: 4_level_1
strict,overall,llama3.1,0,0.736789
strict,overall,llama3.1,5,0.736789
strict,overall,llama3.1,15,0.736789
strict,overall,llama3.1,25,0.736789
strict,overall,llama3.1,35,0.736789
strict,overall,llama3.1-70b,0,0.725325
exact,overall,llama3.1,0,0.751727
exact,overall,llama3.1,5,0.751727
exact,overall,llama3.1,15,0.751727
exact,overall,llama3.1,25,0.751727


In [8]:
results_to_dataframe(results, "f1", "mean")


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Value
Schema,Entity Type,Model,k,Unnamed: 4_level_1
strict,Artist,llama3.1-70b,0,0.771081
strict,Artist,llama3.1-8b,0,0.763753
strict,WoA,llama3.1-70b,0,0.67957
strict,WoA,llama3.1-8b,0,0.709825
exact,Artist,llama3.1-70b,0,0.833947
exact,Artist,llama3.1-8b,0,0.784851
exact,WoA,llama3.1-70b,0,0.683969
exact,WoA,llama3.1-8b,0,0.718603
ent_type,Artist,llama3.1-70b,0,0.793239
ent_type,Artist,llama3.1-8b,0,0.788424
