In [1]:
import gc
import torch
from torch import nn
from torch.profiler import record_function, profile, ProfilerActivity, schedule
from einops import rearrange
from lrn_sparseatt.masks import BooleanMask, boolean_mask_to_nested_indices
from contextlib import contextmanager

In [2]:
N_RUNS = 50

@contextmanager
def profile_and_report(name):
    gc.collect()
    sched = schedule(skip_first=10, wait=5, warmup=5, active=N_RUNS - 20)
    try:
        with profile(activities=[ProfilerActivity.CPU], schedule=sched) as prof:
            with record_function(name):
                yield prof
    finally:
        print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=15))
        print(f"Total per run (us): {prof.key_averages().self_cpu_time_total / (N_RUNS-20):.2f}")

# Initialize scenario

First define the sizes of the inputs and of the model

In [3]:
seq_len = 128
d_model = 64
n_heads = 2
head_dim = d_model // n_heads

Get an attention mask with some sparsity level:

In [16]:
mask = BooleanMask.random(seq_len, 0.9).as_tensor(seq_len)
mask

tensor([[False, False, False,  ..., False, False, False],
        [False,  True, False,  ..., False, False, False],
        [False,  True,  True,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False,  True, False,  ..., False, False, False],
        [False, False, False,  ..., False,  True, False]])

Initialize random `q, k, v`:

In [17]:
q = torch.randn((n_heads, seq_len, d_model // n_heads))
k = torch.randn((n_heads, seq_len, d_model // n_heads))
v = torch.randn((n_heads, seq_len, d_model // n_heads))

# Masked MHSA

In [18]:
def compute_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    attn_mask: torch.Tensor,
) -> torch.Tensor:
    # q, k, v have shape [H, T, D]
    # attn_mask has shape [T, T] or [T, T]
    head_dim = q.size(2)

    attn_weights: torch.Tensor = torch.matmul(q, k.transpose(-2, -1)) / (
        head_dim**0.5
    )
    # attn_weights has shape [H, T, T]

    # attn_mask shape should be broadcastable to attn_weights shape
    attn_mask = attn_mask.unsqueeze(0)  # shape [1, T, T]
    attn_weights = attn_weights.masked_fill(~attn_mask, float("-inf"))
    attn_weights = torch.softmax(attn_weights, dim=-1)
    return torch.matmul(attn_weights, v)

def full_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
) -> torch.Tensor:
    # q, k, v have shape [H, T, D]
    # attn_mask has shape [T, T] or [T, T]
    head_dim = q.size(2)

    attn_weights: torch.Tensor = torch.matmul(q, k.transpose(-2, -1)) / (
        head_dim**0.5
    )
    attn_weights = torch.softmax(attn_weights, dim=-1)
    return torch.matmul(attn_weights, v)

In [19]:
with profile_and_report("dense_attention") as p:
    for _ in range(N_RUNS):
        res_dense = compute_attention(q, k, v, mask)
        p.step()

------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
           ProfilerStep*        13.68%     368.746us       100.00%       2.695ms      89.838us            30  
       aten::masked_fill         0.91%      24.422us        34.68%     934.542us      31.151us            30  
      aten::masked_fill_        28.75%     774.788us        28.75%     774.788us      25.826us            30  
            aten::matmul         2.54%      68.544us        24.26%     653.878us      10.898us            60  
               aten::bmm        12.72%     342.721us        17.09%     460.582us       7.676us            60  
           aten::softmax         0.25%       6.793us        16.87%     454.587us      15.153us            30  
 

# Sparse attention

First, lets get the indices of attention weights to compute:

In [20]:
indices = boolean_mask_to_nested_indices(BooleanMask(mask).as_tensor(seq_len))
indices

NestedTensor(size=(128, j2), offsets=tensor([   0,   12,   24,   36,   48,   60,   72,   84,   96,  108,  120,  132,
         144,  156,  168,  180,  192,  204,  216,  228,  240,  252,  264,  276,
         288,  300,  312,  324,  336,  348,  360,  372,  384,  396,  408,  420,
         432,  444,  456,  468,  480,  492,  504,  516,  528,  540,  552,  564,
         576,  588,  600,  612,  624,  636,  648,  660,  672,  684,  696,  708,
         720,  732,  744,  756,  768,  780,  792,  804,  816,  828,  840,  852,
         864,  876,  888,  900,  912,  924,  936,  948,  960,  972,  984,  996,
        1008, 1020, 1032, 1044, 1056, 1068, 1080, 1092, 1104, 1116, 1128, 1140,
        1152, 1164, 1176, 1188, 1200, 1212, 1224, 1236, 1248, 1260, 1272, 1284,
        1296, 1308, 1320, 1332, 1344, 1356, 1368, 1380, 1392, 1404, 1416, 1428,
        1440, 1452, 1464, 1476, 1488, 1500, 1512, 1524, 1536]), contiguous=True)

In [31]:
indices.values().shape

torch.Size([1536])

In [47]:
ks = k.index_select(1, indices.values()).view(n_heads, -1, head_dim)
vs = v.index_select(1, indices.values()).view(n_heads, -1, head_dim)

In [49]:
from torch.nested import nested_tensor_from_jagged
nk = nested_tensor_from_jagged(ks, indices.offsets(), jagged_dim=2)
nv = nested_tensor_from_jagged(vs, indices.offsets(), jagged_dim=2)
nv.shape

torch.Size([128, 2, j2, 32])

In [50]:
q.shape

torch.Size([2, 128, 32])

In [82]:
qk = (q.transpose(0, 1).unsqueeze(2) * nk).sum(dim=-1)
num = (qk - qk.max()).exp()
den = num.sum(dim=-1, keepdims=True)
attn_weights = num/den
print(attn_weights.shape, nv.shape)
(attn_weights.unsqueeze(-1) * nv).sum(dim=2).shape

torch.Size([128, 2, j2]) torch.Size([128, 2, j2, 32])


torch.Size([128, 2, 32])

In [85]:
def test(q, k, v, indices):
    ks = k.index_select(1, indices.values()).view(n_heads, -1, head_dim)
    vs = v.index_select(1, indices.values()).view(n_heads, -1, head_dim)
    nk = nested_tensor_from_jagged(ks, indices.offsets(), jagged_dim=2)
    nv = nested_tensor_from_jagged(vs, indices.offsets(), jagged_dim=2)
    qk = (q.transpose(0, 1).unsqueeze(2) * nk).sum(dim=-1)
    num = (qk - qk.max()).exp()
    den = num.sum(dim=-1, keepdims=True)
    attn_weights = num/den
    return (attn_weights.unsqueeze(-1) * nv).sum(dim=2).transpose(0, 1)

In [86]:
with profile_and_report("nested") as p:
    for _ in range(N_RUNS):
        out = test(q, k, v, indices)
        p.step()

torch.testing.assert_close(res_dense, out)

-----------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-----------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                           PythonSubclass        22.35%      29.743ms       180.48%     240.174ms     363.900us           660  
                            ProfilerStep*         2.91%       3.872ms       100.00%     133.075ms       4.436ms            30  
                                aten::mul         1.43%       1.906ms        66.41%      88.377ms     589.181us           150  
         aten::_nested_from_padded_tensor         0.03%      35.297us        66.36%      88.307ms       1.472ms            60  
                              aten::copy_        46.01%      61.228ms        46.01%      61.228ms       

AssertionError: Tensor-likes are not close!

Mismatched elements: 8192 / 8192 (100.0%)
Greatest absolute difference: 3.691958427429199 at index (0, 102, 13) (up to 1e-05 allowed)
Greatest relative difference: 1476.8680419921875 at index (0, 1, 7) (up to 1.3e-06 allowed)