In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math
from transformers.models.llama.modeling_llama import (
    LlamaRotaryEmbedding,
    apply_rotary_pos_emb,
)
from dataclasses import dataclass
import seaborn as sns


def normalize_lower_diagonals(attn: torch.Tensor):
    n, _ = attn.shape
    for i in range(n):
        diag = attn.diag(-i)
        diag = diag - diag.mean()
        attn[torch.arange(i, n), torch.arange(n - i)] = diag

    for i in range(n):
        diag = attn.diag(i)
        diag = diag - diag.mean()
        attn[torch.arange(n - i), torch.arange(i, n)] = diag

    return attn


@dataclass
class Config:
    max_position_embeddings: int = 100
    rope_theta: int = 10000
    head_dim: int = 64
    hidden_size: int = 64
    num_attention_heads: int = 1


class SimpleResModel(nn.Module):
    def __init__(self, dim=64):
        super().__init__()
        self.ln = nn.LayerNorm(dim)
        config = Config(head_dim=dim)
        self.rope = LlamaRotaryEmbedding(config)

    def forward(self, x, causal=True):
        B, S, E = x.size()
        hidden = x
        hidden = hidden / (hidden.norm(dim=-1, keepdim=True) + 1e-12)
        pos = torch.arange(S, device=x.device).unsqueeze(0)
        cos, sin = self.rope(hidden, pos)
        q, k = apply_rotary_pos_emb(hidden.unsqueeze(1), hidden.unsqueeze(1), cos, sin)
        q, k = q.squeeze(1), k.squeeze(1)

        attn = q @ k.transpose(-2, -1)

        mask = torch.zeros(S, S, device=x.device)
        mask = torch.masked_fill(
            mask, torch.ones_like(mask).triu(1).bool(), float("-inf")
        )
        if causal:
            attn = attn + mask.unsqueeze(0)

        _out = F.softmax(attn, dim=-1)  # softmax to get attention weights

        out = _out @ hidden  # B, S, E

        out += x
        return out, attn, _out


import math

# adjust plt fontsize
plt.rcParams["axes.titlesize"] = 20
plt.rcParams["axes.labelsize"] = 16
plt.rcParams["xtick.labelsize"] = 14
plt.rcParams["ytick.labelsize"] = 14

import torch


def plot_theory(alpha, causal):
    plt.figure(figsize=(20, 8))

    B, S, E = 100000, 50, 64

    layer = SimpleResModel(E)
    m = 1
    if alpha == 1:
        l = 1
        m = 0
    elif alpha == 0:
        l = 0
        m = 1
    else:
        l = math.sqrt(alpha / (1 - alpha))
    print(l)

    attn_outs, attn_scores, attn_probs = [], [], []

    def _inner_plot(causal, row):
        out = torch.nn.functional.normalize(
            m * torch.randn(B, S, E) + l * torch.randn(B, 1, E), dim=-1
        )

        for i in range(4):
            out, attn, _attn = layer(out, causal=causal)
            plt.subplot(2, 4, row * 8 + i + 1)
            inner = attn.mean(0)

            plt.imshow(
                (inner.detach().cpu()).numpy(),
                cmap="rocket",
            )

            plt.colorbar()
            plt.clim(0, 1)
            plt.title(f"Layer {i+1}")
            attn_outs.append(out)
            attn_scores.append(attn)
            attn_probs.append(_attn)

            plt.subplot(2, 4, row * 4 + i + 5)
            inner = normalize_lower_diagonals(inner)
            plt.imshow(
                (inner.detach().cpu()).numpy(),
                cmap="rocket",
            )
            plt.colorbar()

            plt.title(f"Layer {i+1} (Normalized)")

    _inner_plot(causal, 0)

    return attn_outs, attn_scores, attn_probs

In [None]:
plot_theory(0.0, causal=True)
plt.tight_layout()
plt.tight_layout(pad=0.9)
plt.savefig("../figures/rope_causal_0.pdf")

In [None]:
plot_theory(0.5, causal=True)
plt.tight_layout(pad=0.9)
plt.savefig("../figures/rope_causal_05.pdf")

In [None]:
plot_theory(0.0, causal=False)
plt.tight_layout(pad=0.9)
plt.savefig("../figures/rope_noncausal_0.pdf")