In [None]:
import torch.nn as nn
from typing import Optional
import transformers
import torch
import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams["axes.titlesize"] = 20
plt.rcParams["axes.labelsize"] = 16
plt.rcParams["xtick.labelsize"] = 14
plt.rcParams["ytick.labelsize"] = 14


# identical for phi-3, llama3, qwen3
def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    key_states = transformers.models.llama.modeling_llama.repeat_kv(
        key, module.num_key_value_groups
    )
    value_states = transformers.models.llama.modeling_llama.repeat_kv(
        value, module.num_key_value_groups
    )

    _attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        _attn_weights = _attn_weights + causal_mask

    attn_weights = nn.functional.softmax(_attn_weights, dim=-1, dtype=torch.float32).to(
        query.dtype
    )
    attn_weights = nn.functional.dropout(
        attn_weights, p=dropout, training=module.training
    )
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, _attn_weights.mean(dim=(0, 1))


transformers.models.llama.modeling_llama.eager_attention_forward = (
    eager_attention_forward
)
transformers.models.phi3.modeling_phi3.eager_attention_forward = eager_attention_forward
transformers.models.qwen3.modeling_qwen3.eager_attention_forward = (
    eager_attention_forward
)

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_from_disk
import numpy as np
import torch


def normalize_lower_diagonals(attn: torch.Tensor):
    n, _ = attn.shape
    for i in range(n):
        diag = attn.diag(-i)
        diag = diag - diag.mean()
        attn[torch.arange(i, n), torch.arange(n - i)] = diag

    mask = torch.triu(torch.ones_like(attn), diagonal=1).bool()
    attn = attn.masked_fill(mask, float("nan"))

    return attn


def plot_normalized_attention(model_name, normalize=False):
    dset = load_from_disk("../data")
    tok = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        torch_dtype=torch.bfloat16,
        attn_implementation="eager",
    )
    model.eval()

    dset = dset.map(lambda x: {"input_ids": tok.encode(x["text"])[:1024]}, num_proc=24)
    input_ids = (
        torch.from_numpy(np.asarray([i for i in dset["input_ids"]])).long().cuda()
    )

    attns = [torch.zeros(1024, 1024) for _ in range(len(model.model.layers))]
    with torch.no_grad():
        for i in range(0, 20, 20):
            inputs = input_ids[i : i + 20].to("cuda")
            out = model(input_ids=inputs, output_attentions=True)
            for j in range(len(model.model.layers)):
                attns[j] += out.attentions[j].cpu().float()
    attns = [i / 50.0 for i in attns]

    plt.figure(figsize=(20, 40))
    num_layers = len(attns)
    for l in range(num_layers):
        plt.subplot(num_layers // 4, 4, l + 1)
        attn = attns[l]
        if normalize:
            attn = normalize_lower_diagonals(attn.clone())

        notna = attn[~torch.isnan(attn)]
        notinf = notna[~torch.isinf(notna)]
        vmin = torch.quantile(notinf, 0.01).item()
        vmax = torch.quantile(notinf, 0.99).item()
        plt.imshow(
            attn.numpy(),
            cmap="rocket",
            vmin=vmin,
            vmax=vmax,
        )
        plt.colorbar()
        plt.title(f"Layer {l+1}")


In [None]:
import gc

plot_normalized_attention("meta-llama/Llama-3.1-8B", False)
plt.tight_layout()
plt.savefig("../figures/llama3_rope_orig.pdf")
gc.collect()
torch.cuda.empty_cache()

plot_normalized_attention("microsoft/phi-4", False)
plt.tight_layout()
plt.savefig("../figures/phi4_rope_orig.pdf")
gc.collect()
torch.cuda.empty_cache()

plot_normalized_attention("Qwen/Qwen3-8B", False)
plt.tight_layout()
plt.savefig("../figures/qwen3_rope_orig.pdf")
gc.collect()
torch.cuda.empty_cache()


plot_normalized_attention("meta-llama/Llama-3.1-8B", True)
plt.tight_layout()
plt.savefig("../figures/llama3_rope_norm.pdf")
gc.collect()
torch.cuda.empty_cache()

plot_normalized_attention("microsoft/phi-4", True)
plt.tight_layout()
plt.savefig("../figures/phi4_rope_norm.pdf")
gc.collect()
torch.cuda.empty_cache()

plot_normalized_attention("Qwen/Qwen3-8B", True)
plt.tight_layout()
plt.savefig("../figures/qwen3_rope_norm.pdf")
gc.collect()
torch.cuda.empty_cache()

In [None]:
def plot_all_models(models, aliases):
    plt.figure(figsize=(20, 12))

    for idx, model_name in enumerate(models):
        dset = load_from_disk("../data")
        tok = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype=torch.bfloat16,
            attn_implementation="eager",
        )
        model.eval()

        dset = dset.map(
            lambda x: {"input_ids": tok.encode(x["text"])[:1024]}, num_proc=24
        )
        input_ids = (
            torch.from_numpy(np.asarray([i for i in dset["input_ids"]])).long().cuda()
        )

        attns = [torch.zeros(1024, 1024) for _ in range(len(model.model.layers))]
        with torch.no_grad():
            for i in range(0, 1000, 20):
                inputs = input_ids[i : i + 20].to("cuda")
                out = model(input_ids=inputs, output_attentions=True)
                for j in range(len(model.model.layers)):
                    attns[j] += out.attentions[j].cpu().float()
        attns = [i / 50.0 for i in attns]

        num_layers = len(attns)
        for l in range(4):
            plt.subplot(len(models), 4, idx * 4 + l + 1)
            normalized = normalize_lower_diagonals(attns[l])
            notna = normalized[~torch.isnan(normalized)]
            notinf = notna[~torch.isinf(notna)]
            vmin = torch.quantile(notinf, 0.01).item()
            vmax = torch.quantile(notinf, 0.99).item()
            vmin = torch.quantile(notinf, 0.01).item()
            vmax = torch.quantile(notinf, 0.99).item()

            plt.imshow(
                normalize_lower_diagonals(attns[l]), cmap="rocket", vmin=vmin, vmax=vmax
            )
            plt.colorbar()
            plt.title(f"{aliases[idx]}, Layer {l+1}")


models = ["meta-llama/Llama-3.1-8B", "microsoft/phi-4", "Qwen/Qwen3-8B"]
alias = ["Llama-3", "Phi-4", "Qwen3"]

plot_all_models(models, alias)
plt.tight_layout()
plt.savefig("../figures/all_rope.pdf")