In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math
import seaborn as sns


class SimpleResModel(nn.Module):
    def __init__(self, dim=64):
        super().__init__()
        self.ln = nn.LayerNorm(dim, elementwise_affine=False, bias=False, eps=1e-12)

    def forward(self, x, layernorm=False, ln_scaling=False):
        B, S, E = x.size()
        hidden = x
        if layernorm:
            hidden = self.ln(hidden)
        else:
            hidden = hidden / (hidden.norm(dim=-1, keepdim=True) + 1e-12)
        attn = hidden @ hidden.transpose(-2, -1)
        if layernorm:
            if ln_scaling:
                attn = attn / E
            else:
                attn = attn / (E**0.5)

        mask = torch.zeros(S, S, device=x.device)
        mask = torch.masked_fill(
            mask, torch.ones_like(mask).triu(1).bool(), float("-inf")
        )
        attn = attn + mask.unsqueeze(0)

        _out = F.softmax(attn, dim=-1)  # softmax to get attention weights

        out = _out @ hidden  # B, S, E

        out += x
        return out, attn, _out


import math

# adjust plt fontsize
plt.rcParams["axes.titlesize"] = 20
plt.rcParams["axes.labelsize"] = 16
plt.rcParams["xtick.labelsize"] = 14
plt.rcParams["ytick.labelsize"] = 14


def plot_theory(
    alpha, row, layernorm, total_rows, ln_scaling=False, num_layers=4
):
    B, S, E = 100000, 50, 64

    layer = SimpleResModel(E)
    m = 1
    if alpha == 1:
        l = 1
        m = 0
    elif alpha == 0:
        l = 0
        m = 1
    else:
        l = math.sqrt(alpha / (1 - alpha))
    attn_outs, attn_scores, attn_probs = [], [], []

    def _inner_plot():
        out = torch.nn.functional.normalize(
            m * torch.randn(B, S, E) + l * torch.randn(B, 1, E), dim=-1
        )

        for i in range(num_layers):
            out, attn, _attn = layer(
                out, layernorm=layernorm, ln_scaling=ln_scaling
            )
            plt.subplot(total_rows, num_layers, row * num_layers + i + 1)
            inner = attn.mean(0)
            plt.imshow(inner.detach().cpu().numpy(), cmap="rocket")
            if not layernorm:
                plt.clim(vmin=0, vmax=1)
            else:
                if ln_scaling:
                    plt.clim(vmin=0, vmax=1)
                else:
                    plt.clim(vmin=0, vmax=8)
            plt.colorbar()

            title = rf"Layer {i+1}, $\alpha$={alpha}"
            if layernorm:
                if ln_scaling:
                    title += " (Scaling=$d$)"
                else:
                    title += " (Scaling=$\sqrt{d}$)"
            plt.title(title)
            attn_outs.append(out)
            attn_scores.append(attn)
            attn_probs.append(_attn)

    _inner_plot()


plt.figure(figsize=(20, 10))
plot_theory(0, 0, False, 2)
plot_theory(0.2, 1, False, 2)
plt.subplots_adjust(left=0.05, right=0.95, top=0.9, bottom=0.1)
plt.tight_layout(pad=0.9)
plt.savefig("../figures/l2norm.pdf")

In [None]:
plt.figure(figsize=(20, 5))
plot_theory(0, 0, False, 1)
plot_theory(0.2, 1, False, 2)
plt.subplots_adjust(left=0.05, right=0.95, top=0.9, bottom=0.1)
plt.tight_layout(pad=0.9)
plt.savefig("../figures/l2norm.pdf")

In [None]:
plt.figure(figsize=(20, 15))
plot_theory(0.4, 0, False, 3)
plot_theory(0.6, 1, False, 3)
plot_theory(0.8, 2, False, 3)
plt.subplots_adjust(left=0.05, right=0.95, top=0.9, bottom=0.1)
plt.tight_layout(pad=0.9)
plt.savefig("../figures/l2norm_alpha.pdf")

In [None]:
plt.figure(figsize=(20, 20))
plot_theory(0.0, 0, True, 4)
plot_theory(0.2, 1, True, 4)
plot_theory(0.0, 2, True, 4, ln_scaling=True)
plot_theory(0.2, 3, True, 4, ln_scaling=True)
plt.subplots_adjust(left=0.05, right=0.95, top=0.9, bottom=0.1)
plt.tight_layout(pad=0.9)
plt.savefig("../figures/layernorm.pdf")