In [1]:
import gc
import torch
from torch import nn
from torch.profiler import record_function, profile, ProfilerActivity, schedule
from einops import rearrange
from lrn_sparseatt.masks import BooleanMask, boolean_mask_to_jagged_indices
from lrn_sparseatt.ops import sparse_matmul, sparse_matmul_vo
from contextlib import contextmanager

In [2]:
N_RUNS = 50

@contextmanager
def profile_and_report(name):
    gc.collect()
    sched = schedule(skip_first=10, wait=5, warmup=5, active=N_RUNS - 20)
    try:
        with profile(activities=[ProfilerActivity.CPU], schedule=sched) as prof:
            with record_function(name):
                yield prof
    finally:
        print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=15))
        print(f"Total per run (us): {prof.key_averages().self_cpu_time_total / (N_RUNS-20):.2f}")

# Initialize scenario

First define the sizes of the inputs and of the model

In [3]:
seq_len = 128
d_model = 64
n_heads = 1
head_dim = d_model // n_heads

Get an attention mask with some sparsity level:

In [4]:
mask = BooleanMask.random(seq_len, 0.7).as_tensor(seq_len)
mask

tensor([[False, False,  True,  ...,  True,  True, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False,  True, False,  ..., False,  True,  True],
        [False, False, False,  ..., False,  True, False],
        [False, False, False,  ..., False,  True, False]])

Initialize random `q, k, v`:

In [5]:
q = torch.randn((n_heads, seq_len, d_model // n_heads))
k = torch.randn((n_heads, seq_len, d_model // n_heads))
v = torch.randn((n_heads, seq_len, d_model // n_heads))

# Masked MHSA

In [6]:
def compute_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    attn_mask: torch.Tensor,
) -> torch.Tensor:
    # q, k, v have shape [H, T, D]
    # attn_mask has shape [T, T] or [T, T]
    head_dim = q.size(2)

    attn_weights: torch.Tensor = torch.matmul(q, k.transpose(-2, -1)) / (
        head_dim**0.5
    )
    # attn_weights has shape [H, T, T]

    # attn_mask shape should be broadcastable to attn_weights shape
    attn_mask = attn_mask.unsqueeze(0)  # shape [1, T, T]
    attn_weights = attn_weights.masked_fill(~attn_mask, float("-inf"))
    attn_weights = torch.softmax(attn_weights, dim=-1)
    return torch.matmul(attn_weights, v)

In [7]:
with profile_and_report("dense_attention") as p:
    for _ in range(N_RUNS):
        res_dense = compute_attention(q, k, v, mask)
        p.step()

------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
           ProfilerStep*        18.30%     412.687us       100.00%       2.256ms      75.184us            30  
       aten::masked_fill         0.83%      18.666us        34.88%     786.668us      26.222us            30  
      aten::masked_fill_        30.65%     691.409us        30.65%     691.409us      23.047us            30  
            aten::matmul         3.19%      72.047us        22.33%     503.652us       8.394us            60  
           aten::softmax         0.51%      11.502us        15.73%     354.850us      11.828us            30  
          aten::_softmax        15.22%     343.348us        15.22%     343.348us      11.445us            30  
 

  _warn_once(


# Sparse attention

First, lets get the indices of attention weights to compute:

In [8]:
indices = BooleanMask(mask).to_indices()
indices

tensor([[  0,   2],
        [  0,   3],
        [  0,   7],
        ...,
        [127, 123],
        [127, 124],
        [127, 126]])

In [9]:
q_indices = indices[:, 0]
kv_indices = indices[:, 1]
with profile_and_report("indsel_qk") as p:
    for _ in range(N_RUNS):
        qs_indsel = q.index_select(1, q_indices.flatten()).view(n_heads, -1, head_dim)
        ks_indsel = k.index_select(1, kv_indices.flatten()).view(n_heads, -1, head_dim)
        attn_weights = (qs_indsel * ks_indsel).sum(dim=-1) #/ (head_dim**0.5)
        p.step()

----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
         ProfilerStep*        14.04%     324.161us       100.00%       2.309ms      76.971us            30  
    aten::index_select        44.79%       1.034ms        45.10%       1.041ms      17.355us            60  
             aten::mul        22.32%     515.473us        22.32%     515.473us      17.182us            30  
             aten::sum        16.83%     388.585us        17.45%     402.876us      13.429us            30  
            aten::view         0.91%      21.006us         0.91%      21.006us       0.350us            60  
           aten::fill_         0.50%      11.582us         0.50%      11.582us       0.386us            30  
           aten::em

In [16]:
with profile_and_report("cpp") as p:
    for _ in range(N_RUNS):
        attn_weights_2 = sparse_matmul(q.squeeze(0), k.squeeze(0), indices) #/ (head_dim**0.5)
        p.step()

--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                            Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                   ProfilerStep*         7.46%     171.589us       100.00%       2.299ms      76.646us            30  
    extension_cpp::sparse_matmul        67.30%       1.547ms        91.32%       2.100ms      69.993us            30  
                aten::contiguous         0.28%       6.508us        23.77%     546.641us       6.074us            90  
                     aten::clone         0.38%       8.624us        23.49%     540.133us      18.004us            30  
                     aten::copy_        22.50%     517.305us        22.70%     521.885us      17.396us            30  
                   aten::squeeze         1.03%  

In [11]:
torch.testing.assert_close(attn_weights, attn_weights_2.unsqueeze(0))

# Sparse matmul with values and offsets

In [12]:
values, offsets = boolean_mask_to_jagged_indices(mask)

In [18]:
with profile_and_report("cpp_2") as p:
    for _ in range(N_RUNS):
        attn_weights_3 = sparse_matmul_vo(q.squeeze(0), k.squeeze(0), values, offsets) #/ (head_dim**0.5)
        p.step()

-----------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                               Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-----------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                      ProfilerStep*         9.13%     156.958us       100.00%       1.719ms      57.310us            30  
    extension_cpp::sparse_matmul_vo        88.74%       1.526ms        89.25%       1.534ms      51.148us            30  
                      aten::squeeze         1.37%      23.624us         1.62%      27.920us       0.465us            60  
                        aten::empty         0.39%       6.788us         0.39%       6.788us       0.226us            30  
                   aten::as_strided         0.25%       4.296us         0.25%       4.296us       0.072us            60  
                   aten:

tensor([  38,   76,  114,  152,  190,  228,  266,  304,  342,  380,  418,  456,
         494,  532,  570,  608,  646,  684,  722,  760,  798,  836,  874,  912,
         950,  988, 1026, 1064, 1102, 1140, 1178, 1216, 1254, 1292, 1330, 1368,
        1406, 1444, 1482, 1520, 1558, 1596, 1634, 1672, 1710, 1748, 1786, 1824,
        1862, 1900, 1938, 1976, 2014, 2052, 2090, 2128, 2166, 2204, 2242, 2280,
        2318, 2356, 2394, 2432, 2470, 2508, 2546, 2584, 2622, 2660, 2698, 2736,
        2774, 2812, 2850, 2888, 2926, 2964, 3002, 3040, 3078, 3116, 3154, 3192,
        3230, 3268, 3306, 3344, 3382, 3420, 3458, 3496, 3534, 3572, 3610, 3648,
        3686, 3724, 3762, 3800, 3838, 3876, 3914, 3952, 3990, 4028, 4066, 4104,
        4142, 4180, 4218, 4256, 4294, 4332, 4370, 4408, 4446, 4484, 4522, 4560,
        4598, 4636, 4674, 4712, 4750, 4788, 4826, 4864])