In [2]:
%load_ext autoreload
%autoreload 2

import torch

In [None]:
from getAttentionLib import load_pg2_model_and_processor

model, processor = load_pg2_model_and_processor(
    compile=False, attn_implementation="eager"
)

In [None]:
from getAttentionLib import (
    gaussian_noising,
    get_decoder_layer_outputs,
    get_vqa_balanced_pairs,
    paligemma_merge_text_and_image,
)
from tqdm import tqdm

n = 100
n_layers = 26
n_examples = n
n_img_tokens = 256
iter = get_vqa_balanced_pairs(n)
std = 0.0045

alphas = [3, 10, 20, 30]
kls = torch.zeros(1 + len(alphas), n_layers, n_examples)
pbar = tqdm(iter, total=n_examples)
for example_idx, (cl_row, sir_row) in enumerate(iter):
    text = f"<image>Answer en {cl_row['question']}"
    cl_inputs = processor(
        text=text, images=cl_row["image"].convert("RGB"), return_tensors="pt"
    ).to(model.device)
    sir_inputs = processor(
        text=text, images=sir_row["image"].convert("RGB"), return_tensors="pt"
    ).to(model.device)
    cl_attns = torch.stack(model(**cl_inputs, output_attentions=True).attentions)
    sir_attns = torch.stack(model(**sir_inputs, output_attentions=True).attentions)

    # noisy attns
    cl_embeds = paligemma_merge_text_and_image(model, cl_inputs)
    cl_activations, _ = get_decoder_layer_outputs(model, cl_embeds)

    noisy_attns = []
    for alpha in alphas:
        noisy_embeds = gaussian_noising(
            cl_embeds, num_img_tokens=n_img_tokens, std=std * alpha
        )
        noisy_attns.append(
            torch.stack(
                model(inputs_embeds=noisy_embeds, output_attentions=True).attentions
            )
        )

    n_tokens = cl_attns.shape[-1]
    for layer_idx in range(n_layers):
        cl_distrs = cl_attns[layer_idx].reshape(-1, n_tokens)
        sir_distrs = sir_attns[layer_idx].reshape(-1, n_tokens)
        kls[0, layer_idx, example_idx] = torch.nn.functional.kl_div(
            torch.log(cl_distrs), target=sir_distrs
        )
        for noise_idx, alpha in enumerate(alphas):
            noisy_distrs = noisy_attns[noise_idx][layer_idx].reshape(-1, n_tokens)
            kls[noise_idx + 1, layer_idx, example_idx] = torch.nn.functional.kl_div(
                torch.log(cl_distrs), target=noisy_distrs
            )
    pbar.update(1)
pbar.close()

In [None]:
import matplotlib.pyplot as plt

kl_means = kls.mean(dim=-1)
kl_stds = kls.std(dim=-1)

plt.figure(figsize=(5, 3))
plt.plot(kl_means[0], label="Symmetric Image Replacement", marker="o", alpha=0.8)
for noise_idx, alpha in enumerate(alphas):
    plt.plot(
        kl_means[noise_idx + 1],
        label=f"Image Noising $\\alpha={alpha}$",
        marker="x",
        alpha=0.8,
    )
    plt.fill_between(
        range(n_layers),
        kl_means[noise_idx + 1] - kl_stds[noise_idx + 1],
        kl_means[noise_idx + 1] + kl_stds[noise_idx + 1],
        alpha=0.2,
    )
plt.xlabel("Layer")
plt.ylabel("KL Divergence")
plt.legend()
plt.grid(alpha=0.5)
plt.tight_layout()
plt.show()