In [1]:
import gc
import torch
from torch import nn
from torch.profiler import record_function, profile, ProfilerActivity, schedule
from einops import rearrange
from lrn_sparseatt.masks import BooleanMask
from lrn_sparseatt.ops import sparse_matmul
from contextlib import contextmanager

In [2]:
N_RUNS = 50

@contextmanager
def profile_and_report(name):
    gc.collect()
    sched = schedule(skip_first=10, wait=5, warmup=5, active=N_RUNS - 20)
    try:
        with profile(activities=[ProfilerActivity.CPU], schedule=sched) as prof:
            with record_function(name):
                yield prof
    finally:
        print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=15))
        print(f"Total per run (us): {prof.key_averages().self_cpu_time_total / (N_RUNS-20):.2f}")

# Initialize scenario

First define the sizes of the inputs and of the model

In [3]:
seq_len = 128
d_model = 64
n_heads = 1
head_dim = d_model // n_heads

Get an attention mask with some sparsity level:

In [4]:
mask = BooleanMask.random(seq_len, 0.7).as_tensor(seq_len)
mask

tensor([[False, False, False,  ..., False,  True, False],
        [ True,  True, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False,  True,  True],
        [ True, False, False,  ...,  True, False, False]])

Initialize random `q, k, v`:

In [5]:
q = torch.randn((n_heads, seq_len, d_model // n_heads))
k = torch.randn((n_heads, seq_len, d_model // n_heads))
v = torch.randn((n_heads, seq_len, d_model // n_heads))

# Masked MHSA

In [6]:
def compute_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    attn_mask: torch.Tensor,
) -> torch.Tensor:
    # q, k, v have shape [H, T, D]
    # attn_mask has shape [T, T] or [T, T]
    head_dim = q.size(2)

    attn_weights: torch.Tensor = torch.matmul(q, k.transpose(-2, -1)) / (
        head_dim**0.5
    )
    # attn_weights has shape [H, T, T]

    # attn_mask shape should be broadcastable to attn_weights shape
    attn_mask = attn_mask.unsqueeze(0)  # shape [1, T, T]
    attn_weights = attn_weights.masked_fill(~attn_mask, float("-inf"))
    attn_weights = torch.softmax(attn_weights, dim=-1)
    return torch.matmul(attn_weights, v)

In [7]:
with profile_and_report("dense_attention") as p:
    for _ in range(N_RUNS):
        res_dense = compute_attention(q, k, v, mask)
        p.step()

------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
           ProfilerStep*        25.74%     775.288us       100.00%       3.012ms     100.388us            30  
       aten::masked_fill         4.90%     147.465us        32.21%     970.185us      32.340us            30  
      aten::masked_fill_        20.13%     606.369us        20.13%     606.369us      20.212us            30  
            aten::matmul         2.66%      80.088us        18.64%     561.454us       9.358us            60  
           aten::softmax         4.52%     136.178us        16.16%     486.740us      16.225us            30  
          aten::_softmax        11.64%     350.562us        11.64%     350.562us      11.685us            30  
 

  _warn_once(


# Sparse attention

First, lets get the indices of attention weights to compute:

In [8]:
indices = BooleanMask(mask).to_indices()
indices

tensor([[  0,   6],
        [  0,   7],
        [  0,  11],
        ...,
        [127, 114],
        [127, 116],
        [127, 125]])

In [12]:
q_indices = indices[:, 0]
kv_indices = indices[:, 1]
with profile_and_report("indsel_qk") as p:
    for _ in range(N_RUNS):
        qs_indsel = q.index_select(1, q_indices.flatten()).view(n_heads, -1, head_dim)
        ks_indsel = k.index_select(1, kv_indices.flatten()).view(n_heads, -1, head_dim)
        attn_weights = (qs_indsel * ks_indsel).sum(dim=-1) #/ (head_dim**0.5)
        p.step()

----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
         ProfilerStep*        16.75%     578.709us       100.00%       3.455ms     115.168us            30  
    aten::index_select        35.52%       1.227ms        35.78%       1.236ms      20.601us            60  
             aten::mul        28.09%     970.680us        28.09%     970.680us      32.356us            30  
             aten::sum        17.81%     615.208us        18.43%     636.879us      21.229us            30  
            aten::view         0.81%      27.874us         0.81%      27.874us       0.465us            60  
           aten::fill_         0.50%      17.128us         0.50%      17.128us       0.571us            30  
           aten::em

In [13]:
with profile_and_report("cpp") as p:
    for _ in range(N_RUNS):
        attn_weights_2 = sparse_matmul(q.squeeze(0), k.squeeze(0), indices) #/ (head_dim**0.5)
        p.step()

--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                            Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                   ProfilerStep*         7.55%     162.807us       100.00%       2.156ms      71.853us            30  
    extension_cpp::sparse_matmul        66.94%       1.443ms        91.09%       1.963ms      65.448us            30  
                aten::contiguous         0.33%       7.074us        23.89%     514.885us       5.721us            90  
                     aten::clone         0.42%       9.042us        23.56%     507.811us      16.927us            30  
                     aten::copy_        22.52%     485.440us        22.71%     489.439us      16.315us            30  
                   aten::squeeze         1.17%  

In [11]:
torch.testing.assert_close(attn_weights, attn_weights_2.unsqueeze(0))