In [None]:
# Plotting for principles across datasets

import inverse_cai.app.loader
import inverse_cai.app.metrics
import pathlib
import pandas as pd
import numpy as np
from loguru import logger

# Load data from results directory
RESULTS_PATH = pathlib.Path("../exp/outputs/2024-11-17_19-22-42_large_dataset/results")
votes_df = inverse_cai.app.loader.create_votes_df(RESULTS_PATH)

# Compute metrics for all principles
metrics = {
    "overall": inverse_cai.app.metrics.compute_metrics(votes_df),
    "by_dataset": {},
}

for dataset in votes_df["dataset"].unique():
    metrics["by_dataset"][dataset] = inverse_cai.app.metrics.compute_metrics(votes_df[votes_df["dataset"] == dataset])

In [None]:
EXCLUDE_OVERALL = True
INCLUDE_BF = False

# Create DataFrame with principle metrics
datasets_names = votes_df["dataset"].unique()
dataset_size = {
    dataset_name: len(votes_df[votes_df["dataset"] == dataset_name]["comparison_id"].unique())
    for dataset_name in datasets_names
}
data = []
for principle in metrics["overall"]["principles"]:
    row = {
        "Principle": principle.replace("Select the response that ", "").replace(".", "")[:50],
        "Overall Performance": metrics["overall"]["metrics"]["perf"]["by_principle"][principle],
        "Overall Accuracy": metrics["overall"]["metrics"]["acc"]["by_principle"][principle],
        "Overall Relevance": metrics["overall"]["metrics"]["relevance"]["by_principle"][principle],
    }

    # Add per-dataset metrics
    for dataset in metrics["by_dataset"].keys():
        row.update({
            #f"{dataset} Perf": metrics["by_dataset"][dataset]["metrics"]["perf"]["by_principle"][principle],
            f"{dataset} Accuracy": metrics["by_dataset"][dataset]["metrics"]["acc"]["by_principle"][principle],
            f"{dataset} Relevance": metrics["by_dataset"][dataset]["metrics"]["relevance"]["by_principle"][principle],
        })

    data.append(row)

principles_df = pd.DataFrame(data)

# remove rows that have no relevance of above 0.05 for any dataset (use or to allow for multiple datasets)
principles_df = principles_df[
    (principles_df["alpacaeval Relevance"] > 0.05) |
    (principles_df["chatbotarena Relevance"] > 0.05) |
    (principles_df["prism Relevance"] > 0.05)
]

# Sort by overall performance
principles_df = principles_df.sort_values("Overall Performance", ascending=False)

# Format metrics as percentages with 2 decimal places
metric_columns = [col for col in principles_df.columns if col != "Principle"]

if EXCLUDE_OVERALL:
    metric_columns = [col for col in metric_columns if "Overall" not in col]
    overall_col_format = ""
    # remove all columns that start with "Overall"
    allowed_columns = [col for col in principles_df.columns if not col.startswith("Overall")]
    principles_df = principles_df[allowed_columns]

else:
    overall_col_format = "rrr"

# replace above formatting with per row formatting
for i, row in principles_df.iterrows():
    for col in metric_columns:
        if "Overall" in col:
            principles_df.at[i, col] = f"{row[col]*100:.1f}"
        else:
            dataset_name, metric_type = col.split(" ")
            max_val = max([row[f"{name} {metric_type}"] for name in datasets_names])
            relevance_val = row[f"{dataset_name} Relevance"]
            size_dataset = dataset_size[dataset_name]
            principles_df.at[i, col] = f"\\textbf{{{row[col]*100:.1f}}}" if row[col] == max_val and INCLUDE_BF else f"{row[col]*100:.1f}"

            if relevance_val * size_dataset < 50:
                # make text color grey
                principles_df.at[i, col] = f"\\textcolor{{lightgray}}{{{principles_df.at[i, col]}}}"
# Create Styler object for LaTeX
styler = principles_df.style.set_caption("Performance of principles across datasets").hide(axis="index")



# Generate LaTeX table
latex_table = styler.to_latex(
    column_format="l" + overall_col_format + "|rr" * (len(metrics["by_dataset"])),
    caption="\\textbf{Reconstruction results of principles across three datasets}: \\emph{AlpacaEval} ($648$ preferences), \\emph{ChatbotArena} ($5,115$), and \\emph{PRISM} ($7,490$). Metrics shown are accuracy (\\emph{Acc}) and relevance (\\emph{Rel}) scores. All principles are generated by ICAI based on a separate training set of $1000$ preferences from \\emph{PRISM} and \\emph{ChatbotArena}. Sorted by overall performance, where performance combines accuracy and relevance scores. Greyed-out values indicate that the principle was relevant for less than $50$ preferences on the respective dataset.",
    label="tab:principles_performance",
    position="H",
    position_float="centering",
    hrules=True,
)

# Create two-row header
datasets = list(metrics["by_dataset"].keys())
if not EXCLUDE_OVERALL:
    datasets = ["Overall"] + datasets
top_row = ["\\multicolumn{1}{c}{\\textbf{Principle}} & "]  # Empty cell for Principle column
bottom_row = ["\\multicolumn{1}{c}{\emph{Select the response that...}} & "]

DATASET_NAMES = {
    "Overall": "Overall",
    "alpacaeval": "AlpacaEv.",
    "chatbotarena": "ChatbotAr.",
    "prism": "PRISM",
}

for i, dataset in enumerate(datasets):
    num_cols = 3 if dataset == "Overall" else 2
    top_row.append(f"\\multicolumn{{{num_cols}}}{{c}}{{\\textbf{{{DATASET_NAMES[dataset]}}}}} & ")
    if dataset == "Overall":
        bottom_row.append("\\textbf{Perf} & \\textbf{Acc} & \\textbf{Rel} & ")
    else:
        bottom_row.append("\\textbf{Acc} & \\textbf{Rel} & ")

# Remove trailing " & " from last entries
top_row[-1] = top_row[-1].rstrip(" & ")
bottom_row[-1] = bottom_row[-1].rstrip(" & ")

# Combine rows
new_header = (
    "".join(top_row) + r" \\" + "\n" +
    "".join(bottom_row) + r" \\"
)

# Insert the new header and remove original header
latex_table = latex_table.replace(r"\toprule", r"\toprule" + "\n" + new_header)
latex_table_lines = latex_table.split('\n')
latex_table = '\n'.join([line for line in latex_table_lines if not line.startswith('Principle &')])


# Save to file
output_path = pathlib.Path("appendix/numerical_results/principles_performance.tex")
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
    f.write(latex_table)

print("Generated LaTeX table:")
print(latex_table)

In [None]:
principles_df