In [1]:
# Analysis of spuco_grid results
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

CSV_PATH = "logs_spuco_grid/results.csv"
OUT_DIR = "analysis_plots"
os.makedirs(OUT_DIR, exist_ok=True)

SATIRE_NAME = "SATIRE"  # the new metric name

# Load
df = pd.read_csv(CSV_PATH)
print("Loaded:", df.shape)
print(df.head())

# Normalize types
# difficulty is already a string like 'SpuriousFeatureDifficulty.MAGNITUDE_SMALL'
# ensure keep_ratio, strength are numeric
for col in ["keep_ratio", "strength", "worst_group_acc", "average_acc"]:
    df[col] = pd.to_numeric(df[col], errors="coerce")



Loaded: (385, 8)
   seed                                 difficulty  strength  keep_ratio  \
0  1234  SpuriousFeatureDifficulty.MAGNITUDE_SMALL       0.9         1.0   
1  1234  SpuriousFeatureDifficulty.MAGNITUDE_SMALL       0.9         0.1   
2  1234  SpuriousFeatureDifficulty.MAGNITUDE_SMALL       0.9         0.1   
3  1234  SpuriousFeatureDifficulty.MAGNITUDE_SMALL       0.9         0.1   
4  1234  SpuriousFeatureDifficulty.MAGNITUDE_SMALL       0.9         0.1   

   heuristic  worst_group_acc  average_acc     timestamp  
0  base_line        93.850267        98.19  1.765443e+09  
1     random        94.191919        96.92  1.765443e+09  
2       loss        60.201511        94.03  1.765443e+09  
3   gradnorm        73.551637        96.10  1.765443e+09  
4  confident        86.146096        96.54  1.765443e+09  


In [2]:
# Keep-ratio curves per (difficulty, strength)
sns.set(style="whitegrid")
plot_df = df.copy()

# Simple slug for filenames
import re
slug = lambda s: re.sub(r"[^A-Za-z0-9_.-]", "_", str(s))
palette = sns.color_palette("tab10")
satire_color = "red"

for diff in sorted(plot_df["difficulty"].unique()):
    for strength in sorted(plot_df["strength"].unique()):
        sub = plot_df[(plot_df["difficulty"] == diff) & (plot_df["strength"] == strength)]
        if sub.empty:
            continue
        fname_prefix = f"keep_ratio_{slug(diff)}_str_{strength}"

        # Separate baseline, satire, and heuristics
        baseline_row = sub[sub["heuristic"] == "base_line"].head(1)
        satire_row = sub[sub["heuristic"] == SATIRE_NAME].head(1)
        heur_sub = sub[(sub["heuristic"] != "base_line") & (sub["heuristic"] != SATIRE_NAME)]

        # Worst-group curves
        plt.figure(figsize=(8, 4))
        if not baseline_row.empty:
            b_val = baseline_row["worst_group_acc"].values[0]
            plt.axhline(b_val, color="black", linestyle="-", linewidth=1.5, label="baseline")
        if not satire_row.empty:
            s_val = satire_row["worst_group_acc"].values[0]
            plt.axhline(s_val, color=satire_color, linestyle="--", linewidth=1.5, label=SATIRE_NAME)
        if not heur_sub.empty:
            sns.lineplot(data=heur_sub, x="keep_ratio", y="worst_group_acc", hue="heuristic", marker="o", palette=palette)
        plt.title(f"Worst-group vs keep_ratio\n{diff}, strength={strength}")
        plt.ylabel("Worst group acc")
        plt.xlabel("keep_ratio")
        plt.legend(title=None, bbox_to_anchor=(1.02, 1), loc="upper left")
        plt.tight_layout()
        plt.savefig(os.path.join(OUT_DIR, fname_prefix + "_worst.png"), dpi=200, bbox_inches="tight")
        plt.close()

        # Average curves
        plt.figure(figsize=(8, 4))
        if not baseline_row.empty:
            b_val = baseline_row["average_acc"].values[0]
            plt.axhline(b_val, color="black", linestyle="-", linewidth=1.5, label="baseline")
        if not satire_row.empty:
            s_val = satire_row["average_acc"].values[0]
            plt.axhline(s_val, color=satire_color, linestyle="--", linewidth=1.5, label=SATIRE_NAME)
        if not heur_sub.empty:
            sns.lineplot(data=heur_sub, x="keep_ratio", y="average_acc", hue="heuristic", marker="o", palette=palette)
        plt.title(f"Average vs keep_ratio\n{diff}, strength={strength}")
        plt.ylabel("Average acc")
        plt.xlabel("keep_ratio")
        plt.legend(title=None, bbox_to_anchor=(1.02, 1), loc="upper left")
        plt.tight_layout()
        plt.savefig(os.path.join(OUT_DIR, fname_prefix + "_avg.png"), dpi=200, bbox_inches="tight")
        plt.close()

print("Saved per-config keep-ratio curves to", OUT_DIR)



  sns.lineplot(data=heur_sub, x="keep_ratio", y="worst_group_acc", hue="heuristic", marker="o", palette=palette)
  sns.lineplot(data=heur_sub, x="keep_ratio", y="average_acc", hue="heuristic", marker="o", palette=palette)
  sns.lineplot(data=heur_sub, x="keep_ratio", y="worst_group_acc", hue="heuristic", marker="o", palette=palette)
  sns.lineplot(data=heur_sub, x="keep_ratio", y="average_acc", hue="heuristic", marker="o", palette=palette)
  sns.lineplot(data=heur_sub, x="keep_ratio", y="worst_group_acc", hue="heuristic", marker="o", palette=palette)
  sns.lineplot(data=heur_sub, x="keep_ratio", y="average_acc", hue="heuristic", marker="o", palette=palette)
  sns.lineplot(data=heur_sub, x="keep_ratio", y="worst_group_acc", hue="heuristic", marker="o", palette=palette)
  sns.lineplot(data=heur_sub, x="keep_ratio", y="average_acc", hue="heuristic", marker="o", palette=palette)
  sns.lineplot(data=heur_sub, x="keep_ratio", y="worst_group_acc", hue="heuristic", marker="o", palette=palette)

Saved per-config keep-ratio curves to analysis_plots


In [3]:
# [Discarded]
# 2) Tradeoff scatter per (difficulty, strength)
# plot_df = df.copy()
#
# for diff in sorted(plot_df["difficulty"].unique()):
#     for strength in sorted(plot_df["strength"] == strength):
#         sub = plot_df[(plot_df["difficulty"] == diff) & (plot_df["strength"] == strength)]
#         if sub.empty:
#         	continue
#         fname = os.path.join(OUT_DIR, f"tradeoff_{slug(diff)}_str_{strength}.png")
#
#         # Separate baseline
#         baseline_row = sub[sub["heuristic"] == "base_line"].head(1)
#         heur_sub = sub[sub["heuristic"] != "base_line"]
#
#         plt.figure(figsize=(8, 4))
#         if not baseline_row.empty:
#             b_val = baseline_row["worst_group_acc"].values[0]
#             plt.axhline(b_val, color="black", linestyle="-", linewidth=1.5, label="baseline")
#         if not heur_sub.empty:
#             sns.scatterplot(data=heur_sub, x="average_acc", y="worst_group_acc",
#                             hue="heuristic",
#                             size="keep_ratio", sizes=(30, 120), palette=palette)
#         plt.title(f"Average vs Worst-group\n{diff}, strength={strength}")
#         plt.xlabel("Average acc")
#         plt.ylabel("Worst group acc")
#         plt.legend(title=None, bbox_to_anchor=(1.02, 1), loc="upper left", ncol=2)
#         plt.tight_layout()
#         plt.savefig(fname, dpi=200, bbox_inches="tight")
#         plt.close()