In [None]:
import math
import sys
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from datasets import load_from_disk


from transformers import LlamaForCausalLM, LlamaConfig
from transformers.models.llama.modeling_llama import repeat_kv

import transformers.models.llama.modeling_llama as modeling_llama

def noop_apply_rotary_pos_emb(q, k, *args, **kwargs):
    return q, k


modeling_llama.apply_rotary_pos_emb = noop_apply_rotary_pos_emb


nope_model = LlamaForCausalLM.from_pretrained(
    "{NOPE_MODEL_PATH}",
    device_map="auto",
    torch_dtype=torch.bfloat16,
    attn_implementation="eager",
)


dset = load_from_disk("../data")

input_ids = torch.stack([torch.LongTensor(i[:50]) for i in dset["input_ids"]]).to(
    nope_model.device
) 

with torch.no_grad():
    nope_out = nope_model(
        input_ids=input_ids,
        output_hidden_states=True,
    )

nope_hiddens = nope_out.hidden_states

In [None]:
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb

plt.rcParams["axes.titlesize"] = 20
plt.rcParams["axes.labelsize"] = 16
plt.rcParams["xtick.labelsize"] = 14
plt.rcParams["ytick.labelsize"] = 14


def plot_matrix(mat, title):
    mat = mat.detach().float().cpu().numpy()
    if mat.ndim == 3:
        mat = mat.mean(0)
    elif mat.ndim == 4:
        mat = mat.mean((0, 1))
    mat[np.triu_indices(mat.shape[0], k=1)] = np.nan
    plt.imshow(mat, cmap="rocket")
    plt.title(title)
    plt.colorbar()
    plt.xlabel("Key Position")
    plt.ylabel("Query Position")


def plot_activations(layer_idx):
    x = nope_hiddens[layer_idx]
    next_hidden = nope_hiddens[layer_idx + 1]
    model = nope_model

    ln = model.model.layers[layer_idx].input_layernorm(x)
    B, S, E = x.shape

    q = (
        model.model.layers[layer_idx]
        .self_attn.q_proj(ln)
        .reshape(B, S, -1, 64)
        .transpose(1, 2)
    )
    k = (
        model.model.layers[layer_idx]
        .self_attn.k_proj(ln)
        .reshape(B, S, -1, 64)
        .transpose(1, 2)
    )
    v = (
        model.model.layers[layer_idx]
        .self_attn.v_proj(ln)
        .reshape(B, S, -1, 64)
        .transpose(1, 2)
    )

    k = repeat_kv(k, 4)
    v = repeat_kv(v, 4)

    attn_score = q @ k.mT / math.sqrt(64)
    attn_score = attn_score + torch.triu(
        torch.ones((S, S), device=attn_score.device)
        * torch.finfo(attn_score.dtype).min,
        1,
    )

    attn_prob = attn_score.softmax(dim=-1, dtype=torch.float32).to(q.dtype)
    qkv = (attn_prob @ v).transpose(1, 2).reshape(B, S, -1)
    qkvo = model.model.layers[layer_idx].self_attn.o_proj(qkv)
    qkvox = qkvo + x

    ln_f = model.model.layers[layer_idx].post_attention_layernorm(qkvox)
    f = model.model.layers[layer_idx].mlp(ln_f)
    fx = f + qkvox

    next_ln = model.model.layers[layer_idx + 1].input_layernorm(fx)
    next_q = (
        model.model.layers[layer_idx + 1]
        .self_attn.q_proj(next_ln)
        .reshape(B, S, -1, 64)
        .transpose(1, 2)
    )
    next_k = (
        model.model.layers[layer_idx + 1]
        .self_attn.k_proj(next_ln)
        .reshape(B, S, -1, 64)
        .transpose(1, 2)
    )
    next_k = repeat_kv(next_k, 4)

    next_qk = next_q @ next_k.mT / math.sqrt(64)

    assert torch.allclose(fx, next_hidden, atol=1e-5)

    plt.figure(figsize=(20, 20))


    plt.subplot(3, 3, 1)
    plot_matrix(
        ln @ ln.mT,
        rf"(a) $\mathbf{{Y^{{({layer_idx+1})}}Y^{{({layer_idx+1})\intercal}}}}$"
        + "\n$(Y=LN(X))$",
    )

    plt.subplot(3, 3, 2)
    plot_matrix(
        attn_score,
        rf"(b) $\mathbf{{Q^{{({layer_idx+1})}}K^{{({layer_idx+1})\intercal}}/\sqrt{{d}}}}$"
        + "\n$(Q=W_QY, K=W_KY)$",
    )

    plt.subplot(3, 3, 3)
    plot_matrix(
        qkv @ qkv.mT,
        rf"(c) $\mathbf{{(A^{{({layer_idx+1})}}V^{{({layer_idx+1})}})(A^{{({layer_idx+1})}}V^{{({layer_idx+1})}})^\intercal}}$"
        + "\n$(A=Softmax(Causal(QK^\intercal)))$",
    )

    plt.subplot(3, 3, 4)
    plot_matrix(
        qkvo @ qkvo.mT,
        "(d)\n"
        + rf"$\mathbf{{(A^{{({layer_idx+1})}}V^{{({layer_idx+1})}}W_O^{{({layer_idx+1})}})(A^{{({layer_idx+1})}}V^{{({layer_idx+1})}}W_O^{{({layer_idx+1})}})^\intercal}}$",
    )

    plt.subplot(3, 3, 5)
    plot_matrix(
        qkvox @ qkvox.mT,
        rf"(e) $\mathbf{{O^{{({layer_idx+1})}}O^{{({layer_idx+1})\intercal}}}}$"
        + "\n$(O=AVW_O+X)$",
    )

    plt.subplot(3, 3, 6)
    plot_matrix(
        ln_f @ ln_f.mT,
        rf"(f) $\mathbf{{LN(O^{{({layer_idx+1})}})LN(O^{{({layer_idx+1})}})^\intercal}}$",
    )

    plt.subplot(3, 3, 7)
    plot_matrix(
        fx @ fx.mT,
        f"(g) $\mathbf{{X^{{({layer_idx+1})}}X^{{({layer_idx+1})\intercal}}}}$"
        + "\n$(X=FFN(LN(O))+O)$",
    )

    plt.subplot(3, 3, 8)
    plot_matrix(
        next_ln @ next_ln.mT,
        f"(h) $\mathbf{{Y^{{({layer_idx+2})}}Y^{{({layer_idx+2})\intercal}}}}$",
    )

    plt.subplot(3, 3, 9)
    plot_matrix(
        next_qk, 
        rf"(b) $\mathbf{{Q^{{({layer_idx+1})}}K^{{({layer_idx+1})\intercal}}/\sqrt{{d}}}}$"
    )

    plt.tight_layout()
    plt.savefig("../figures/nope_params.pdf")

plot_activations(0)

In [None]:
plt.figure(figsize=(20, 40))


def normalize_lower_diagonals(attn: torch.Tensor):
    n, _ = attn.shape
    for i in range(n):
        diag = attn.diag(-i)
        diag = diag - diag.mean()
        attn[torch.arange(i, n), torch.arange(n - i)] = diag
    return attn


rope = False
for l in range(22):
    plt.subplot(8, 4, l + 1)

    x = nope_hiddens[l]
    next_hidden = nope_hiddens[l + 1]
    model = nope_model

    ln = model.model.layers[l].input_layernorm(x)
    B, S, E = x.shape

    q = model.model.layers[l].self_attn.q_proj(ln).reshape(B, S, -1, 64).transpose(1, 2)
    k = model.model.layers[l].self_attn.k_proj(ln).reshape(B, S, -1, 64).transpose(1, 2)


    k = repeat_kv(k, 4)

    attn_score = q @ k.mT / math.sqrt(64)
    attn_score = attn_score + torch.triu(
        torch.ones((S, S), device=attn_score.device)
        * torch.finfo(attn_score.dtype).min,
        1,
    )

    plt.imshow(
        normalize_lower_diagonals(
            attn_score.mean(dim=(0, 1)).detach().cpu().float()
        ).numpy(),
        cmap="rocket",
    )
    plt.colorbar()
    plt.title(f"Layer {l+1}")

plt.tight_layout()
plt.savefig("../figures/nope_attns_all.pdf", bbox_inches="tight")