In [1]:
%load_ext autoreload
%autoreload 2
import torch
from transformers import (
    PaliGemmaProcessor,
    PaliGemmaForConditionalGeneration,
    Gemma2Model,
    Gemma2ForCausalLM,
)
from PIL import Image
import matplotlib.pyplot as plt

torch.set_grad_enabled(False)  # avoid blowing up mem
device = "cuda"

In [None]:
model_id = "google/paligemma2-3b-pt-224"
model = (
    PaliGemmaForConditionalGeneration.from_pretrained(
        model_id, torch_dtype=torch.bfloat16
    )
    .to(device)
    .eval()
)
processor = PaliGemmaProcessor.from_pretrained(model_id)

In [3]:
img_path = "imgs/frisbee.jpg"
image = Image.open(img_path)
text = "<image>Answer en what is the frisbee's color?"
inputs = processor(text=text, images=image, return_tensors="pt").to(model.device)

In [None]:
output = model(**inputs, output_attentions=True)

In [5]:
# compute unnormalized importances for all tokens in layer 25
all_importances = []
all_fx_norms = []
for layer in range(26):
    layer_fx_norms = torch.load(f"fx_norms/layer{layer}.pt")
    all_fx_norms.append(layer_fx_norms)
    layer_attns = output.attentions[layer][0, :, :, :]  # all heads, all tokens
    layer_imps = layer_fx_norms.unsqueeze(1) * layer_attns
    all_importances.append(layer_imps)

all_fx_norms = torch.stack(all_fx_norms).float()
all_importances = torch.stack(all_importances).float()
assert all_importances.shape == (26, n_heads := 8, 269, 269)


In [None]:
normalized_imp = all_importances / all_importances.sum(dim=3, keepdim=True)
print(normalized_imp.shape)
assert torch.allclose(
    normalized_imp[0, 0, 10, :].sum(), torch.tensor(1.0)
)  # 0th batch, 0th head,10th token (should be true for any token)

In [7]:
from getAttentionLib import dump_attn
from getAttentionLib import get_img_grid_sizes


token_strings = processor.tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
assert token_strings[-1] == "\n"
token_strings[-1] = "\\n"
_, grid_side_len = get_img_grid_sizes(model, inputs)

In [None]:
layers = [0, 15, 25]
for layer in layers:
    imp = all_importances[layer]
    dump_attn(
        attn_weights=imp[None, :],
        layer_idx=layer,
        name="PaliGemma2_transformedInputs",
        tokens=token_strings,
        img_path=img_path,
        grid_side_len=grid_side_len,
    )

In [None]:
from matplotlib.pylab import multi_dot
from getAttentionLib import (
    compute_attn_sums,
    compute_mult_attn_sums,
    plot_attn_sums,
    plot_mult_attn_sums,
    plot_region_attn_progression,
)

n_img_tokens = grid_side_len**2
imp_sums = torch.stack(
    [
        compute_attn_sums(imps, n_img_tokens=n_img_tokens)
        for imps in normalized_imp[[0, 15, 25]]
    ]
)
plot_mult_attn_sums(
    None, None, layers=layers, mult_attn_sums=imp_sums, n_img_tokens=n_img_tokens
).show()

In [None]:
attn_sums = []
for layer_attns in torch.stack(output.attentions)[torch.tensor([0, 15, 25]), 0]:
    attn_sums.append(compute_attn_sums(layer_attns, n_img_tokens=n_img_tokens).float())
attn_sums = torch.stack(attn_sums)
plot_mult_attn_sums(
    None, None, layers=layers, mult_attn_sums=attn_sums, n_img_tokens=n_img_tokens
).show()


In [None]:
diffs = imp_sums - attn_sums
plot_mult_attn_sums(
    None,
    None,
    layers=layers,
    mult_attn_sums=diffs,
    n_img_tokens=n_img_tokens,
    cmap="bwr",
    vmin=-0.4,
    vmax=0.4,
).show()

# Sanity Checks

In [None]:
# The final token in the final layer attends on average (over all heads) 0.336 to the <bos> token
assert output.attentions[-1][0, :, -1, n_img_tokens].mean() == 0.336  # success
# the final token in the final layers attens on average (over all heads) 0.175 to itself
assert output.attentions[-1][0, :, -1, -1].mean() == 0.175  # success

# The fx norm of the <bos> token in the final layer and that of the final token should be within 50% of each other
# because their importances are very similar to their attention values
l25_fx_norms = torch.load(f"fx_norms/layer25.pt")
l25_fx_norms[:, n_img_tokens], l25_fx_norms[:, -1]  # does not hold

In [None]:
# compute unnormalized importances in the final layer with destination token=final token
print(l25_fx_norms.shape)
l25_dest_last_tkn_attns = output.attentions[-1][
    0, :, -1, :
]  # last layer, 0th batch, all heads, dest=last token, src=all tokens
print(l25_dest_last_tkn_attns.shape)
assert (l25_dest_last_tkn_attns.sum(dim=1) == 1.0).all()  # equals 1.0 for all heads
l25_imps = (l25_fx_norms * l25_dest_last_tkn_attns).sum(dim=0)
assert len(l25_imps) == 269
l25_imps[n_img_tokens], l25_imps[-1]  # sanity check

In [None]:
# now they are close
print(l25_imps[n_img_tokens], all_importances[-1, :, -1, n_img_tokens].sum())
print(l25_imps[-1], all_importances[-1, :, -1, -1].sum())

# Do the f(x) norms develop over the layers?

In [None]:
from getAttentionLib import aggregate_layer_norms, plot_fx_norms_progressions


max_norms, avg_norms = aggregate_layer_norms(all_fx_norms, n_img_tokens)
plot_fx_norms_progressions(max_norms, avg_norms, sharey=False).show()

# How different are f(x) norms for image tokens?

In [None]:
# Compute dispersion coefficient (coefficient of variation) for each layer
# This measures the relative variability of fx norms within each layer
layer_means = all_fx_norms.sum(dim=1).mean(dim=1)
layer_stds = all_fx_norms.sum(dim=1).std(dim=1)
dispersion_coefficients = layer_stds / layer_means
dispersion_coefficients.min(), dispersion_coefficients.max()

# Are Max-Norm Img Tokens also the Most Important Causal Ones?

The answer seems to be no. There most important causal image tokens are centered on the frisbee's text.
The Max-Norm img tokens follow no clear pattern. The importance of image tokens (f(x) norm * attention) seem to be slightly correlated with the causally important ones on the frisbee.

In [17]:
token_strings = processor.tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
assert token_strings[-1] == "\n"
token_strings[-1] = "\\n"

In [None]:
import torch
from getAttentionLib import (
    maxpool_img_tokens,
    avgpool_img_tokens,
    plot_img_and_text_probs_side_by_side,
    plot_img_probs,
    plot_pooled_probs_plotly,
)
from getAttentionLib import plot_pooled_probs_plt
from getAttentionLib import plot_and_browse_img_token_in_probs

frisbee2_purple_probs = torch.load("purple_probs_of_frisbee2_img.pt")
# frisbee2_pooled_purple_probs = avgpool_img_tokens(frisbee2_purple_probs)
# plot_pooled_probs_plotly(
#     frisbee2_pooled_purple_probs, token_strings, healthy_response_tok_name="purple"
# ).show()
plot_img_and_text_probs_side_by_side(frisbee2_purple_probs, n_img_tokens=256).show()

In [None]:
plot_and_browse_img_token_in_probs(frisbee2_purple_probs)

In [None]:
mean_fx_norms = all_fx_norms.mean(dim=1).cpu()
plot_and_browse_img_token_in_probs(probs=mean_fx_norms, cmax=20)

In [None]:
# how important are the img tokens for the final token?
imp_for_last_tok = normalized_imp.mean(dim=1)[:, -1, :].cpu()
plot_and_browse_img_token_in_probs(probs=imp_for_last_tok, cmax=0.05)
# all_importances.

In [None]:
# attention only "importances" for the last token
# sum over all heads, 0th batch, dest_token=last token, all src tokens
lastlayer_lasttok_attn = (
    torch.stack(output.attentions).float().sum(dim=2)[:, 0, -1, :].cpu()
)
plot_and_browse_img_token_in_probs(probs=lastlayer_lasttok_attn, cmax=0.1)

In [None]:
from correlation_analysis import (
    plot_correlation_in_midlayers,
    plot_correlation_progression,
)


_ = plot_correlation_progression(
    imp_for_last_tok, frisbee2_purple_probs, p_threshold=0.001
)
_ = plot_correlation_progression(
    mean_fx_norms, frisbee2_purple_probs, p_threshold=0.001
)
# plot_correlation_in_midlayers(mean_fx_norms, frisbee2_purple_probs, start_layer=8, end_layer=13)

In [24]:
# from correlation_analysis import plot_correlation_in_midlayers
# plot_correlation_in_midlayers(imp_for_last_tok, frisbee2_purple_probs)

In [None]:
_ = plot_correlation_progression(
    lastlayer_lasttok_attn, frisbee2_purple_probs, p_threshold=0.001
)

# Compute Log Prob Increase by Token
This metrics looks in each layer at the change in log prob for the correct answer.
The comparison is between the hidden state before and after the multi-head attention.

In [26]:
# Need a hook in each layer that has saves the input & the module
from getAttentionLib import paligemma_merge_text_and_image

inputs_embeds = paligemma_merge_text_and_image(model, inputs)
outputs = model(
    inputs_embeds=inputs_embeds, output_attentions=True, output_hidden_states=True
)

In [None]:
hidden_states = torch.stack(outputs.hidden_states)
hidden_states.shape

In [28]:
# compute base log probs for correct answer
purple_token = 34999
assert processor.tokenizer.decode(purple_token) == "purple"

hidden_states_T = hidden_states[:, 0, -1, :]
purple_log_probs = model.language_model.lm_head(hidden_states_T).softmax(dim=1)[
    :, purple_token
]

In [None]:
import matplotlib.pyplot as plt

plt.scatter(range(1, len(purple_log_probs) + 1), purple_log_probs.float())

In [None]:
model.language_model.model.layers[0]