In [1]:
%load_ext autoreload
%autoreload 2
import torch
from transformers import (
    PaliGemmaProcessor,
    PaliGemmaForConditionalGeneration,
    Gemma2Model,
    Gemma2ForCausalLM,
)
from PIL import Image
import matplotlib.pyplot as plt

torch.set_grad_enabled(False)  # avoid blowing up mem
device = "cuda"

In [None]:
model_id = "google/paligemma2-3b-pt-224"
model = (
    PaliGemmaForConditionalGeneration.from_pretrained(
        model_id, torch_dtype=torch.bfloat16
    )
    .to(device)
    .eval()
)
processor = PaliGemmaProcessor.from_pretrained(model_id)

In [None]:
img_path = "imgs/frisbee.jpg"
image = Image.open(img_path)
plt.axis("off")
_ = plt.imshow(image)

In [None]:
from getAttentionLib import get_response


text = "<image>Answer en what is the frisbee's color?"
inputs_tokens, response = get_response(model, processor, text, image)
inputs_tokens[-1] = "\\n"  # to print it nicely
print("len(inputs_tokens) =", len(inputs_tokens))
print(inputs_tokens)
print(response)

In [None]:
inputs = processor(text=text, images=image, return_tensors="pt").to(model.device)
inputs.keys()

In [None]:
from getAttentionLib import (
    Hook,
    State,
    get_activations,
    paligemma_merge_text_and_image,
    gaussian_noising,
)

healthy_embeds = paligemma_merge_text_and_image(model, inputs)
healthy_activations, healthy_outputs = get_activations(model, healthy_embeds)
print(processor.decode(healthy_outputs.logits[0, -1, :].argmax()))

In [None]:
unhealthy_embeds = gaussian_noising(healthy_embeds, num_img_tokens=256)
jammed_activations, jammed_outputs = get_activations(model, unhealthy_embeds)
print(processor.decode(jammed_outputs.logits[0, -1, :].argmax()))


In [8]:
from getAttentionLib import tok_prob

purple_token = 34999
assert processor.tokenizer.decode(purple_token) == "purple"
assert jammed_activations.shape == healthy_activations.shape
assert (
    tok_prob(healthy_outputs, purple_token)
    == model(**inputs).logits[0, -1, :].softmax(dim=-1).max()
)

In [9]:
# from getAttentionLib import loop_over_restore_all_activations


# purple_probs = loop_over_restore_all_activations(
#     model, healthy_activations, unhealthy_embeds, healthy_response_tok_idx=purple_token
# )
# torch.save(purple_probs, "purple_probs_of_noisy_frisbee_img.pt")

In [10]:
purple_probs = torch.load("purple_probs_of_noisy_frisbee_img.pt")

# Does the `<bos>` token get attention?
As we can see below, the noisy input embeddings have a significantly different attention pattern than the healthy image input. 
Almost all attention goes in the image tokens. Almost no attention is allocated to the the `<bos>` token, which explains why it cannot influence the output, even when restored.
This suggests that noise input data disrupts the learned attention patterns. To preserve the attention patterns, we need a different approach to corrput the input image.
Using a different, unrelated image might work better, because it will preserve the attention patterns, as shown in VQA example.

In [None]:
from getAttentionLib import maxpool_img_tokens, plot_pooled_probs_plotly


plot_pooled_probs_plotly(maxpool_img_tokens(purple_probs), inputs_tokens, healthy_response_tok_name="purple").show()

In [12]:
from getAttentionLib import (
    compute_mult_attn_sums,
    plot_region_attn_progression,
    plot_mult_attn_sums,
)

mult_attn_sums = compute_mult_attn_sums(
    model,
    {"inputs_embeds": unhealthy_embeds},
    layers=list(range(len(model.language_model.model.layers))),
    n_img_tokens=256,
)

In [None]:
# plot_mult_attn_sums(
#     None, None, layers=[0, 15, 25], mult_attn_sums=mult_attn_sums
# ).show()
# plot_region_attn_progression(None, None, mult_attn_sums=mult_attn_sums).show()
plot_mult_attn_sums(None, None, layers=[0, 15, 25], mult_attn_sums=[mult_attn_sums[e] for e in [0, 15, 25]]).show()

# Use different image as corrupted input

In [None]:
frisbee2_img = Image.open("imgs/frisbee2.png")
plt.axis("off")
_ = plt.imshow(frisbee2_img)

In [None]:
# Tomas: I thought the frisbee was white, but LLMs disagree.
# ChatGPT: "The frisbee in the image is light blue."
# Gemini: "The frisbee's color is light blue."
# Claude: "The frisbee is light blue."
print(get_response(model, processor, text, frisbee2_img)[1])

In [16]:
frisbee2_inputs = processor(text=text, images=frisbee2_img, return_tensors="pt").to(
    model.device
)
frisbee2_embeds = paligemma_merge_text_and_image(model, frisbee2_inputs)

In [None]:
from getAttentionLib import plot_mult_attn_sums

plot_mult_attn_sums(model, {"inputs_embeds": frisbee2_embeds}, layers=[0, 15, 25], n_img_tokens=256).show()

In [18]:
# from getAttentionLib import loop_over_restore_all_activations


# purple_probs = loop_over_restore_all_activations(
#     model, healthy_activations, unhealthy_embeds=frisbee2_embeds, healthy_response_tok_idx=purple_token
# )
# torch.save(purple_probs, "purple_probs_of_frisbee2_img.pt")

In [None]:
from getAttentionLib import plot_pooled_probs_plotly
from getAttentionLib import plot_pooled_probs_plt
from getAttentionLib import plot_and_browse_img_token_in_probs

frisbee2_purple_probs = torch.load("purple_probs_of_frisbee2_img.pt")
frisbee2_pooled_purple_probs = maxpool_img_tokens(frisbee2_purple_probs)
# plot_and_browse_img_token_in_purple_probs(purple_probs, inputs_tokens)
# plot_pooled_purple_probs_plt(dino_pooled_purple_probs, inputs_tokens).show()
plot_pooled_probs_plotly(frisbee2_pooled_purple_probs, inputs_tokens, healthy_response_tok_name="purple").show()

# Does the `<bos>` token have no outflowing information?

In [None]:
text = "<image>Answer en what color is the frisbee?"
inputs = processor(text=text, images=image, return_tensors="pt").to(model.device)
# add to transformers/models/gemma2/modeling_gemma2.py line 402 (forward right before attention)
# ```python
# from pathlib import Path
# for i in range(1, 30):
#     fname = f"value_states/layer{i}.pt"
#     if not Path(fname).exists():
#         torch.save(value_states, fname)
#         print(f"Saved {fname}")
#         break
# ```
# model(**inputs)
inputs_tokens = processor.tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
inputs_tokens[-1] = "\\n"
n_img_tokens = 256
print(inputs_tokens[n_img_tokens:])

In [54]:
vsl1 = torch.load("value_states/layer1.pt")
vsl26 = torch.load("value_states/layer26.pt")

In [55]:
assert torch.allclose(vsl1, vsl1)
assert torch.allclose(vsl26, vsl26)
assert not torch.allclose(vsl1, vsl26)

In [None]:
norms = torch.norm(vsl1[0,0,:,:], dim=1).cpu().float() # layer1, batch example 0, head 0, all tokens, all features
img_norms = norms[:n_img_tokens]
img_norms.mean(), img_norms.std()

In [None]:
txt_norms = norms[n_img_tokens:]
txt_norms.mean(), txt_norms.std()

In [59]:
vsls = [torch.load(f"value_states/layer{i}.pt") for i in range(1, 27)]
assert len(vsls) == 26

In [None]:
# remove the batch dimension, compute norms in each head
all_norms = torch.norm(torch.stack(vsls).squeeze(1), dim=-1)
all_norms = all_norms.cpu().float()
all_norms.shape

In [None]:
all_text_norms = all_norms[:, :, n_img_tokens:].mean(dim=1) # average over all heads
print(all_text_norms.shape)
plt.imshow(all_text_norms.T, cmap="Blues", vmin=0, vmax=all_text_norms.max()) # only text tokens accross all layers
plt.colorbar()
plt.yticks(ticks=range(len(txt_norms)), labels=inputs_tokens[n_img_tokens:])
plt.xlabel("Layer")
plt.tight_layout()
plt.show()