In [None]:
%load_ext autoreload
%autoreload 2
import torch
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration

torch.set_grad_enabled(False)  # avoid blowing up mem
device = "cuda"
model_id = "google/paligemma2-3b-pt-224"
model = (
    PaliGemmaForConditionalGeneration.from_pretrained(
        model_id, torch_dtype=torch.bfloat16
    )
    .to(device)
    .eval()
)
processor = PaliGemmaProcessor.from_pretrained(model_id)

In [None]:
from getAttentionLib import (
    compute_mult_attn_sums_over_noisy_vqa,
    compute_mult_attn_sums_over_vqa,
)

n_img_tokens = 256
n_vqa_samples = 20 # 1000
layers = list(range(26))
vqa, _, vqa_responses = compute_mult_attn_sums_over_vqa(
    model, processor, n_vqa_samples, layers, n_img_tokens
)
nvqa = compute_mult_attn_sums_over_noisy_vqa(
    model, processor, n_vqa_samples, layers, n_img_tokens
)

In [None]:
vqa.shape, nvqa.shape

In [None]:
n_examples, n_layers = vqa.shape[:2]
kls = torch.zeros(n_examples, n_layers)
for i in range(n_examples):
    for j in range(n_layers):
        vqa_distr = vqa[i, j].flatten()
        nvqa_distr = nvqa[i, j].flatten()
        kl = torch.nn.functional.kl_div(input=torch.log(nvqa_distr), target=vqa_distr)
        kls[i, j] = kl
kls.shape

In [None]:
import matplotlib.pyplot as plt

from getAttentionLib import plot_metric_with_std_over_layers

fig = plot_metric_with_std_over_layers(metric=kls, ylabel="KL(vqa || GN vqa)")
fig.savefig("imgs/gn_vs_str_comparison/kls_over_layers.png")

In [None]:
vqa.std(dim=0).shape

In [None]:
from getAttentionLib import (
    compute_attn_sums,
    compute_mult_attn_sums,
    plot_mult_attn_sums,
)


layers = [0, 1, 25]
vqa_means = vqa.mean(dim=0)
nvqa_means = nvqa.mean(dim=0)
diffs = (vqa - nvqa).abs()
figsize = (8, 4)
kwargs = {"n_img_tokens": n_img_tokens, "figsize": figsize, "layers": layers}
plot_mult_attn_sums(
    None, None, mult_attn_sums=vqa_means[layers], stds=vqa.std(dim=0)[layers], **kwargs
).show()
plot_mult_attn_sums(
    None,
    None,
    mult_attn_sums=nvqa_means[layers],
    stds=nvqa.std(dim=0)[layers],
    **kwargs,
).show()
plot_mult_attn_sums(
    None,
    None,
    mult_attn_sums=diffs.mean(dim=0)[layers],
    stds=diffs.std(dim=0)[layers],
    **kwargs,
    vmax=diffs.max(),
    color_threshold=0.15,
    cmap="Reds",
).show()