In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
from getAttentionLib import load_pg2_model_and_processor

model, processor = load_pg2_model_and_processor(
    compile=False, attn_implementation="eager"
)

In [None]:
from getAttentionLib import (
    compute_mult_attn_sums_over_noisy_vqa,
    compute_mult_attn_sums_over_vqa,
)

n_img_tokens = 256
n_vqa_samples = 100  # 1000
layers = list(range(26))
vqa, _, _ = compute_mult_attn_sums_over_vqa(
    model, processor, n_vqa_samples, layers, n_img_tokens, get_responses=False
)

In [None]:
sigma = 0.0045
alphas = [3, 10, 20, 30]
nvqa_list = []
for alpha in alphas:
    nvqa = compute_mult_attn_sums_over_noisy_vqa(
        model, processor, n_vqa_samples, layers, n_img_tokens, std=sigma * alpha
    )
    nvqa_list.append(nvqa)

In [None]:
vqa.shape, nvqa_list[0].shape

In [None]:
import torch

def compute_kl(vqa: torch.Tensor, nvqa: torch.Tensor):
    n_examples, n_layers = vqa.shape[:2]
    kls = torch.zeros(n_examples, n_layers)
    for i in range(n_examples):
        for j in range(n_layers):
            vqa_distr = vqa[i, j].flatten()
            nvqa_distr = nvqa[i, j].flatten()
            kl = torch.nn.functional.kl_div(input=torch.log(nvqa_distr), target=vqa_distr)
            kls[i, j] = kl
    return kls

kls_list = [compute_kl(vqa, nvqa) for nvqa in nvqa_list]
kls_list[0].shape

In [None]:
import matplotlib.pyplot as plt

from getAttentionLib import plot_metric_with_std_over_layers

fig = plt.figure(figsize=(5, 3))
ax = fig.add_subplot(1, 1, 1)
markers = ["o", "s", "D", "P"]
for kls, alpha, marker in zip(kls_list, alphas, markers):
    fig = plot_metric_with_std_over_layers(metric=kls, ylabel="KL(vqa || GN vqa)", ax=ax, label=f"$\\alpha = {alpha}$", marker=marker)
plt.grid(alpha=0.5)
fig.show()
    # fig.savefig("imgs/gn_vs_str_comparison/kls_over_layers.png")

In [None]:
[kls[:, -1].mean() for kls in kls_list]

In [None]:
from getAttentionLib import (
    compute_attn_sums,
    compute_mult_attn_sums,
    plot_mult_attn_sums,
)


layers = [2, 15, 25]
nvqa = nvqa_list[1]
vqa_means = vqa.mean(dim=0)
nvqa_means = nvqa.mean(dim=0)
diffs = (vqa - nvqa).abs()
figsize = (8, 4)
kwargs = {"n_img_tokens": n_img_tokens, "figsize": figsize, "layers": layers}
plot_mult_attn_sums(
    None, None, mult_attn_sums=vqa_means[layers], stds=vqa.std(dim=0)[layers], **kwargs
).show()
plot_mult_attn_sums(
    None,
    None,
    mult_attn_sums=nvqa_means[layers],
    stds=nvqa.std(dim=0)[layers],
    **kwargs,
).show()
plot_mult_attn_sums(
    None,
    None,
    mult_attn_sums=diffs.mean(dim=0)[layers],
    stds=diffs.std(dim=0)[layers],
    **kwargs,
    vmax=diffs.max(),
    color_threshold=0.15,
    cmap="Reds",
).show()