In [12]:
from utils_generate_edits import prep_sample_indices
import json, os

with open("all_finegrained_clean.json", "r") as f:
    data = json.load(f)

data = [d for d in data if d["split"] == "test"]

for d in data:
    for edit in d["fine_grained_edits"]:
        edit["categorization"] = edit["categorization"].replace("/ ", "/").replace(" (Unnecessary ornamental and overly verbose)", "")

id2sample = {d["id"]: d for d in data}

for anno_fn in os.listdir("data/detection_preds/"):
    model, prompt_id = anno_fn.replace(".jsonl", "").split("_")
    with open("data/detection_preds/"+anno_fn, "r") as f:
        for line in f:
            d = json.loads(line)
            if d["id"] in id2sample:
                id2sample[d["id"]][f"pred_{model}_{prompt_id}"] = d["detection"]

for sample in data:
    prep_sample_indices(sample)

# categories = sorted(set([span["categorization"] for sample in data for span in sample["fine_grained_edits"]]))
all_cats = list(data[0]["gold_indices"].keys())
pred_keys = list(set([k for d in data for k in d if k.startswith("pred_")]))



In [13]:
from utils_misc import display_results
import numpy as np, pandas as pd

results_ps, results_rs, results_f1s = {}, {}, {}
for pred_key in pred_keys:
    _, model, prompt = pred_key.split("_")
    results_ps[pred_key] = {"Model": model, "Prompt": prompt}
    results_rs[pred_key] = {"Model": model, "Prompt": prompt}
    results_f1s[pred_key] = {"Model": model, "Prompt": prompt}
    
    precs = {cat: [] for cat in all_cats}
    recs = {cat: [] for cat in all_cats}
    f1s = {cat: [] for cat in all_cats}

    for sample in data:
        if pred_key not in sample:
            continue
        idx_key = pred_key.replace("pred_", "idx_")
        sample["f1_" + pred_key] = {}
        for cat in all_cats:
            if len(sample["gold_indices"][cat]) == 0:
                continue
            
            gold = set(sample["gold_indices"][cat])
            pred = set(sample[idx_key][cat])
            tp, fp, fn = len(gold & pred), len(pred - gold), len(gold - pred)
            if tp == 0:
                precision, recall, f1 = 0, 0, 0
            else:
                precision = tp / (tp + fp)
                recall = tp / (tp + fn)
                f1 = 2 * precision * recall / (precision + recall)
            precs[cat].append(precision)
            recs[cat].append(recall)
            f1s[cat].append(f1)

    results_ps[pred_key]["N"] = len(precs["all"])
    for cat in all_cats:
        results_ps[pred_key][cat] = np.mean(precs[cat])
        results_rs[pred_key][cat] = np.mean(recs[cat])
        results_f1s[pred_key][cat] = np.mean(f1s[cat])
        
# sort by F1
pred_keys = sorted(pred_keys, key=lambda x: results_f1s[x]["all"], reverse=True)
results_ps = sorted(results_ps.values(), key=lambda x: pred_keys.index(f"pred_{x['Model']}_{x['Prompt']}"))
results_rs = sorted(results_rs.values(), key=lambda x: pred_keys.index(f"pred_{x['Model']}_{x['Prompt']}"))
results_f1s = sorted(results_f1s.values(), key=lambda x: pred_keys.index(f"pred_{x['Model']}_{x['Prompt']}"))

# display_results(results_ps, results_rs, results_f1s)

# global results focused on "all", put promps on the same row
all_prompts = sorted(set([r["Prompt"] for r in results_f1s]))
results_ps_all, results_rs_all, results_f1s_all = [], [], []
all_models = [r["Model"] for r in results_f1s]
models = sorted(set(all_models), key=lambda x: all_models.index(x))
for model in models:
    results_ps_all.append({"Model": model, "N": 0})
    results_rs_all.append({"Model": model})
    results_f1s_all.append({"Model": model})
    for prompt in all_prompts:
        for results, results_all in [(results_ps, results_ps_all), (results_rs, results_rs_all), (results_f1s, results_f1s_all)]:
            for r in results:
                if r["Model"] == model and r["Prompt"] == prompt:
                    if results == results_ps:
                        results_all[-1]["N"] += r["N"]
                    results_all[-1][prompt] = r["all"]
display_results(results_ps_all, results_rs_all, results_f1s_all)

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Unnamed: 0,Model,N,v2-fs25,v2-fs5
0,gemini-1.5-pro,896,0.418,0.423
1,llama3.1-70b,2,0.412,
2,gemini-1.5-flash,1445,0.403,0.391
3,mistral-large2,1460,0.431,0.417
4,claude3.5-sonnet,1460,0.446,0.434
5,gpt-4o,1460,0.46,0.451
6,claude3-haiku,1460,0.418,0.382
7,gpt-4o-mini,1460,0.452,0.434

Unnamed: 0,Model,v2-fs25,v2-fs5
0,gemini-1.5-pro,0.803,0.858
1,llama3.1-70b,0.679,
2,gemini-1.5-flash,0.817,0.896
3,mistral-large2,0.705,0.774
4,claude3.5-sonnet,0.645,0.715
5,gpt-4o,0.534,0.604
6,claude3-haiku,0.546,0.693
7,gpt-4o-mini,0.494,0.54

Unnamed: 0,Model,v2-fs25,v2-fs5
0,gemini-1.5-pro,0.514,0.532
1,llama3.1-70b,0.513,
2,gemini-1.5-flash,0.504,0.513
3,mistral-large2,0.494,0.505
4,claude3.5-sonnet,0.486,0.501
5,gpt-4o,0.456,0.477
6,claude3-haiku,0.424,0.451
7,gpt-4o-mini,0.432,0.439
