In [None]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path

import torch
from getAttentionLib import ActivationPatchingResult


def load_healthy_tok_logit_diffs(
    root: str | Path, is_gaussian_noising: bool
) -> torch.Tensor:
    root = Path(root)
    dirs = [root / d for d in os.listdir(root)]
    logit_diffs = torch.zeros(len(dirs))
    for i, dir in enumerate(dirs):
        pr = ActivationPatchingResult.load(dir)
        if is_gaussian_noising:
            logit_diffs[i] = pr.logit_diff_denominator_gn()
        else:
            logit_diffs[i] = pr.logit_diff_denominator_str()
    return logit_diffs

In [None]:
from matplotlib import pyplot as plt
import seaborn as sns

# folders = ["str", "gn_3o", "gn", "gn_20o", "gn_30o"]
folder2label = dict(
    str="Image Replacement",
    gn_03o="Image Noising 3$\sigma$",
    gn_10o="Image Noising 10$\sigma$",
    gn_20o="Image Noising 20$\sigma$",
    gn_30o="Image Noising 30$\sigma$",
)
fig, axes = plt.subplots(len(folder2label), 1, figsize=(4, 6), sharex=True)


# Second pass to create the actual plots with equal y-axis
for i, (folder, label) in enumerate(folder2label.items()):
    root = "./vqa_patching/" + folder + "/"
    ht_lds = load_healthy_tok_logit_diffs(root, is_gaussian_noising=folder != "str")
    print(len(ht_lds))
    sns.histplot(
        ht_lds, label=label, bins=10, alpha=0.6, ax=axes[i], kde=True, stat="density"
    )
    axes[i].legend()
    axes[i].set_ylim(0, 0.6)
    axes[i].axvline(x=ht_lds.mean(), color="navy", linestyle="--", alpha=0.7)
    axes[i].axvline(
        x=0, color="red", linestyle="--", alpha=0.7
    )  # Add vertical line at x=0
    axes[i].grid(alpha=0.5)
    axes[i].tick_params(
        axis="x", which="both", bottom=True, labelbottom=True
    )  # Show x-ticks on every plot

plt.tight_layout()
plt.legend()
plt.show()
# fig.savefig("logitdiffs_histograms.png", dpi=300)
