In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math
import seaborn # For rocket colormap

plt.rcParams["axes.titlesize"] = 20
plt.rcParams["axes.labelsize"] = 16
plt.rcParams["xtick.labelsize"] = 14
plt.rcParams["ytick.labelsize"] = 14


class SimpleResModel(nn.Module):
    def __init__(self, dim=256):
        super().__init__()
        self.ln = nn.LayerNorm(dim, elementwise_affine=False, bias=False, eps=1e-12)

    def forward(self, x, layernorm=False, residual=False):
        B, S, E = x.size()
        hidden = x
        if layernorm:
            hidden = self.ln(hidden)

        attn = hidden @ hidden.mT / (E**0.5)

        mask = torch.zeros(S, S, device=x.device)
        mask = torch.masked_fill(
            mask, torch.ones_like(mask).triu(1).bool(), float("-inf")
        )
        attn = attn + mask.unsqueeze(0)

        _out = F.softmax(attn, dim=-1)

        out = _out @ hidden
        if residual:
           out += x
        return out, attn, _out



def plot_theory(
    alpha, row, layernorm, total_rows, residual=False, num_layers=4
):
    B, S, E = 100000, 10, 64

    layer = SimpleResModel(E)
    m = 1
    if alpha == 1:
        l = 1
        m = 0
    elif alpha == 0:
        l = 0
        m = 1
    else:
        l = math.sqrt(alpha / (1 - alpha))
    attn_outs, attn_scores, attn_probs = [], [], []

    def _inner_plot():
        out = m * torch.randn(B, S, E) + l * torch.randn(B, 1, E)
        out = out / math.sqrt(E)

        for i in range(num_layers):
            out, attn, _attn = layer(
                out, layernorm=layernorm, residual=residual
            )
            plt.subplot(total_rows, num_layers, row * num_layers + i + 1)
            inner = attn.mean(0)
            plt.imshow(inner.detach().cpu().numpy(), cmap="rocket")
            plt.colorbar()

            title = rf"Layer {i+1}, $\alpha$={alpha}"
            if layernorm:
                title += ", LN"
            else:
                title += ", noLN"
            
            if residual:
                title += ", Res"
            else:
                title += ", noRes"

            plt.title(title)
            attn_outs.append(out)
            attn_scores.append(attn)
            attn_probs.append(_attn)

    _inner_plot()

In [None]:
plt.figure(figsize=(20, 20))
plot_theory(0.0, 0, False, 4) # noln
plot_theory(0.0, 1, True, 4) # ln
plot_theory(0.5, 2, True, 4) # w/ alpha
plot_theory(0.5, 3, True, 4, residual=True) # w/ alpha + res
plt.subplots_adjust(left=0.05, right=0.95, top=0.9, bottom=0.1)
plt.tight_layout(pad=0.9)
plt.savefig("../figures/simul_main.pdf")

In [None]:
plt.figure(figsize=(20, 20))
plot_theory(0.2, 0, True, 4)
plot_theory(0.4, 1, True, 4)
plot_theory(0.6, 2, True, 4)
plot_theory(0.8, 3, True, 4)
plt.subplots_adjust(left=0.05, right=0.95, top=0.9, bottom=0.1)
plt.tight_layout(pad=0.9)
plt.savefig("../figures/simul_alpha.pdf")

In [None]:
plt.figure(figsize=(20, 25))
plot_theory(0.0, 0, True, 5, residual=True)
plot_theory(0.2, 1, True, 5, residual=True)
plot_theory(0.4, 2, True, 5, residual=True)
plot_theory(0.6, 3, True, 5, residual=True)
plot_theory(0.8, 4, True, 5, residual=True)
plt.subplots_adjust(left=0.05, right=0.95, top=0.9, bottom=0.1)
plt.tight_layout(pad=0.9)
plt.savefig("../figures/simul_alpha_residual.pdf")