# 05.5 - Evaluate Auto Tuned Prompts

In this notebook, we will evaluate the performance of the prompts tuned by the test models. For this evaluation, we will evlauate inter-model agrement, as well as look at the performance of prompts tuned by a each model against the other models.

In [None]:
from genscai.prompts import PromptCatalog
from genscai import paths

catalog = PromptCatalog(paths.data / "prompt_catalog.db")
model_ids = catalog.retrieve_model_ids()
model_ids

In [None]:
import textwrap

for model_id in model_ids:
    print(f"model: {model_id}")

    prompt = catalog.retrieve_last(model_id)
    print(f"prompt revision: {prompt.version}")
    print(f"prompt metrics {prompt.metrics}")
    print(textwrap.fill(f"prompt: {prompt.prompt}", 100))
    print()

## Delete poorly performing prompts from catalog

If there are prompts that poor performing and they should't be use for future analysis, they can be removed from the catalog.

In [None]:
# catalog.delete_all("deepseek-ai/DeepSeek-R1-Distill-Llama-8B")

## Classify the entire corpus using each of the local models.

The following script classifies all of the papers in the corpus using each of the local models. This takes a long time, since the full corpus is large. The script outputs four files (data/modeling_papers_N.json). Unless you'd like to re-run the classifications, ignore the following cell and continue with the analyses.

In [None]:
# !python $paths.root/scripts/06_classification_all_models.py

## Evaluate inter-model agreement

In [None]:
from genscai.models import HuggingFaceClient
import pandas as pd

MODEL_IDS = [
    HuggingFaceClient.MODEL_LLAMA_3_1_8B,
    HuggingFaceClient.MODEL_GEMMA_2_9B,
    HuggingFaceClient.MODEL_PHI_4_14B,
    HuggingFaceClient.MODEL_MISTRAL_NEMO_12B,
]

data = []
for i, model_id in enumerate(MODEL_IDS):
    df = pd.read_json(paths.data / f"modeling_papers_{i}.json", orient="records", lines=True)
    data.append(df)

In [None]:
from genscai.data import load_midas_data

df1, df2, df3 = load_midas_data()
df_all = pd.concat([df1, df2, df3])

In [None]:
rates = [[0] * len(data) for _ in [0] * len(data)]

print(f"total: {len(df_all)}")

for i in range(len(data)):
    for j in range(len(data)):
        df_intersection = pd.merge(data[i], data[j], how="inner", on=["id"])
        df_union = pd.merge(data[i], data[j], how="outer", on=["id"])

        rates[i][j] = (len(df_all) - len(df_union) + len(df_intersection)) / len(df_all)

        print(
            f"{MODEL_IDS[i]}, {MODEL_IDS[j]}: intersection={len(df_intersection)}, union={len(df_union)}, agreement_rate={rates[i][j]:.3f}"
        )

df = pd.DataFrame(data=rates, index=MODEL_IDS, columns=MODEL_IDS)

In [None]:
import seaborn as sns

sns.heatmap(df, annot=True, fmt=".3f")

In [None]:
import genscai.paths as paths
import pandas as pd
import json

dfs = {"precision": pd.DataFrame(), "recall": pd.DataFrame(), "accuracy": pd.DataFrame()}

with open(paths.data / "inter_model_results.txt", "r") as fin:
    for line in fin:
        if line.startswith("prompt"):
            line = line.strip()
            parts = line.split(", ")

            parts[0] = parts[0][len("prompt: ") :]
            parts[1] = parts[1][len("model: ") :]
            parts[2] = parts[2][len("metrics: ") :]

            metrics = ", ".join(parts[2:])
            metrics = metrics.replace("'", '"')
            data = json.loads(metrics)

            dfs["precision"].loc[parts[0], parts[1]] = data["precision"]
            dfs["recall"].loc[parts[0], parts[1]] = data["recall"]
            dfs["accuracy"].loc[parts[0], parts[1]] = data["accuracy"]

for df in dfs.values():
    df.rename_axis("Model", axis="columns", inplace=True)
    df.rename_axis("Prompt", axis="rows", inplace=True)
    df["mean"] = df.mean(axis=1)
    df.sort_values(by="mean", ascending=False, inplace=True)

dfs["accuracy"]

In [None]:
import matplotlib.pyplot as plt

for key, df in dfs.items():
    ax = plt.axes()
    ax.set_title(key)
    p = sns.heatmap(df, ax=ax, annot=True, fmt=".3f")
    ax.set_xticklabels(p.get_xticklabels(), rotation=35, horizontalalignment="right")
    plt.show()