In [1]:
import gc
import torch
from torch import nn
from torch.profiler import record_function, profile, ProfilerActivity, schedule
from einops import rearrange
from lrn_sparseatt.masks import BooleanMask, boolean_mask_to_jagged_indices
from lrn_sparseatt.ops import sparse_attn
from contextlib import contextmanager

In [2]:
N_RUNS = 50

@contextmanager
def profile_and_report(name):
    gc.collect()
    sched = schedule(skip_first=10, wait=5, warmup=5, active=N_RUNS - 20)
    try:
        with profile(activities=[ProfilerActivity.CPU], schedule=sched) as prof:
            with record_function(name):
                yield prof
    finally:
        print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=15))
        print(f"Total per run (us): {prof.key_averages().self_cpu_time_total / (N_RUNS-20):.2f}")

# Initialize scenario

First define the sizes of the inputs and of the model

In [3]:
seq_len = 128
d_model = 64
n_heads = 1
head_dim = d_model // n_heads

Get an attention mask with some sparsity level:

In [4]:
mask = BooleanMask.random(seq_len, 0.7).as_tensor(seq_len)
mask

tensor([[ True, False,  True,  ..., False, False,  True],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ...,  True, False,  True],
        ...,
        [ True, False,  True,  ...,  True, False, False],
        [False, False,  True,  ..., False, False, False],
        [False,  True, False,  ...,  True,  True, False]])

Initialize random `q, k, v`:

In [5]:
q = torch.randn((n_heads, seq_len, d_model // n_heads))
k = torch.randn((n_heads, seq_len, d_model // n_heads))
v = torch.randn((n_heads, seq_len, d_model // n_heads))

# Masked MHSA

In [6]:
def compute_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    attn_mask: torch.Tensor,
) -> torch.Tensor:
    # q, k, v have shape [H, T, D]
    # attn_mask has shape [T, T] or [T, T]
    head_dim = q.size(2)

    attn_weights: torch.Tensor = torch.matmul(q, k.transpose(-2, -1)) / (
        head_dim**0.5
    )
    # attn_weights has shape [H, T, T]

    # attn_mask shape should be broadcastable to attn_weights shape
    attn_mask = attn_mask.unsqueeze(0)  # shape [1, T, T]
    attn_weights = attn_weights.masked_fill(~attn_mask, float("-inf"))
    attn_weights = torch.softmax(attn_weights, dim=-1)
    return torch.matmul(attn_weights, v)

In [7]:
with profile_and_report("dense_attention") as p:
    for _ in range(N_RUNS):
        res_dense = compute_attention(q, k, v, mask)
        p.step()

------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
           ProfilerStep*        18.50%     383.222us       100.00%       2.071ms      69.047us            30  
       aten::masked_fill         0.87%      18.004us        33.81%     700.441us      23.348us            30  
      aten::masked_fill_        29.28%     606.479us        29.28%     606.479us      20.216us            30  
            aten::matmul         3.24%      67.095us        23.87%     494.472us       8.241us            60  
           aten::softmax         0.42%       8.625us        14.90%     308.552us      10.285us            30  
               aten::bmm        11.85%     245.466us        14.83%     307.217us       5.120us            60  
 

  _warn_once(


# Sparse attention

First, lets get the indices of attention weights to compute:

In [8]:
indices = BooleanMask(mask).to_indices()
values, offsets = boolean_mask_to_jagged_indices(mask)

# Sparse matmul with values and offsets

In [9]:
with profile_and_report("cpp_attn") as p:
    for _ in range(N_RUNS):
        res_sparse = sparse_attn(q.squeeze(0), k.squeeze(0), v.squeeze(0), values, offsets, head_dim**0.5)
        p.step()

------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                          Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                 ProfilerStep*         5.62%     188.126us       100.00%       3.350ms     111.676us            30  
    extension_cpp::sparse_attn        91.92%       3.080ms        93.18%       3.122ms     104.055us            30  
                 aten::squeeze         1.04%      34.705us         1.21%      40.498us       0.450us            90  
               aten::new_zeros         0.28%       9.288us         0.63%      21.205us       0.707us            30  
                   aten::empty         0.56%      18.876us         0.56%      18.876us       0.210us            90  
                   aten::zero_         0.22%       7.333us      

In [10]:
torch.testing.assert_close(res_dense, res_sparse.unsqueeze(0))