In [1]:
import gc
import torch
from torch import nn
from torch.profiler import record_function, profile, ProfilerActivity, schedule
from einops import rearrange
from lrn_sparseatt.masks import BooleanMask
from contextlib import contextmanager

  cpu = _conversion_method_template(device=torch.device("cpu"))


In [2]:
N_RUNS = 50

@contextmanager
def profile_and_report(name):
    gc.collect()
    sched = schedule(skip_first=10, wait=5, warmup=5, active=N_RUNS - 20)
    try:
        with profile(activities=[ProfilerActivity.CPU], schedule=sched) as prof:
            with record_function(name):
                yield prof
    finally:
        print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=15))
        print(f"Total per run (us): {prof.key_averages().self_cpu_time_total / (N_RUNS-20):.2f}")

# Initialize scenario

First define the sizes of the inputs and of the model

In [3]:
seq_len = 128
d_model = 64
n_heads = 2
head_dim = d_model // n_heads

Get an attention mask with some sparsity level:

In [4]:
mask = BooleanMask.random(seq_len, 0.7).as_tensor(seq_len)
mask

tensor([[False, False, False,  ...,  True, False, False],
        [False, False, False,  ..., False, False, False],
        [False,  True, False,  ...,  True,  True, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False,  True, False,  ...,  True,  True, False],
        [ True, False, False,  ..., False, False,  True]])

Initialize random `q, k, v`:

In [5]:
q = torch.randn((n_heads, seq_len, d_model // n_heads))
k = torch.randn((n_heads, seq_len, d_model // n_heads))
v = torch.randn((n_heads, seq_len, d_model // n_heads))

# Masked MHSA

In [6]:
def compute_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    attn_mask: torch.Tensor,
) -> torch.Tensor:
    # q, k, v have shape [H, T, D]
    # attn_mask has shape [T, T] or [T, T]
    head_dim = q.size(2)

    attn_weights: torch.Tensor = torch.matmul(q, k.transpose(-2, -1)) / (
        head_dim**0.5
    )
    # attn_weights has shape [H, T, T]

    # attn_mask shape should be broadcastable to attn_weights shape
    attn_mask = attn_mask.unsqueeze(0)  # shape [1, T, T]
    attn_weights = attn_weights.masked_fill(~attn_mask, float("-inf"))
    attn_weights = torch.softmax(attn_weights, dim=-1)
    return torch.matmul(attn_weights, v)

In [7]:
with profile_and_report("dense_attention") as p:
    for _ in range(N_RUNS):
        res_dense = compute_attention(q, k, v, mask)
        p.step()

------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
           ProfilerStep*        12.82%     388.098us       100.00%       3.028ms     100.936us            30  
       aten::masked_fill         1.03%      31.336us        42.82%       1.297ms      43.221us            30  
      aten::masked_fill_        37.11%       1.124ms        37.11%       1.124ms      37.458us            30  
            aten::matmul         2.25%      68.129us        21.49%     650.610us      10.844us            60  
               aten::bmm        11.41%     345.630us        15.18%     459.757us       7.663us            60  
           aten::softmax         0.33%      10.001us        14.18%     429.353us      14.312us            30  
 

  _warn_once(


# Sparse attention

First, lets get the indices of attention weights to compute:

In [8]:
indices = BooleanMask(mask).to_indices()
indices

tensor([[  0,   5],
        [  0,  10],
        [  0,  13],
        ...,
        [127, 119],
        [127, 121],
        [127, 127]])

In [9]:
q_indices = indices[:, 0]
kv_indices = indices[:, 1]

## Step 1: Collect q and k

The idea is to only compute the dot products $q_i^T k_j$ for non-masked indices $(i, j)$.

**Option 1**: Gather all the queries and keys

In [10]:
with profile_and_report("gather_qk") as p:
    for _ in range(N_RUNS):
        qs_gather = q.gather(1, q_indices.unsqueeze(0).unsqueeze(2).expand((n_heads, -1, head_dim)))
        ks_gather = k.gather(1, kv_indices.unsqueeze(0).unsqueeze(2).expand((n_heads, -1, head_dim)))
        p.step()

--------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
--------------------  ------------  ------------  ------------  ------------  ------------  ------------  
       ProfilerStep*         5.98%     321.308us       100.00%       5.370ms     178.987us            30  
        aten::gather        92.27%       4.954ms        92.30%       4.956ms      82.604us            60  
     aten::unsqueeze         0.98%      52.590us         1.19%      63.749us       0.531us           120  
        aten::expand         0.42%      22.582us         0.53%      28.293us       0.472us            60  
    aten::as_strided         0.31%      16.870us         0.31%      16.870us       0.094us           180  
            aten::to         0.03%       1.832us         0.03%       1.832us       0.031us            60  
--------------------  ------------  -

**Option 2**: Use `index_select` instead of `gather`:

In [11]:
with profile_and_report("indsel_qk") as p:
    for _ in range(N_RUNS):
        qs_indsel = q.index_select(1, q_indices.flatten()).view(n_heads, -1, head_dim)
        ks_indsel = k.index_select(1, kv_indices.flatten()).view(n_heads, -1, head_dim)
        p.step()

----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
         ProfilerStep*        12.91%     147.498us       100.00%       1.142ms      38.079us            30  
    aten::index_select        84.96%     970.540us        85.37%     975.204us      16.253us            60  
            aten::view         1.42%      16.249us         1.42%      16.249us       0.271us            60  
           aten::empty         0.41%       4.664us         0.41%       4.664us       0.078us            60  
         aten::flatten         0.30%       3.413us         0.30%       3.413us       0.057us            60  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total

## Step 2: compute the attention scores

i.e. only the interesting dot products

**Option 1**: elementwise multiplication and sum

In [12]:
with profile_and_report("mul_qk_indsel") as p:
    for _ in range(N_RUNS):
        attn_weights = (qs_indsel * ks_indsel).sum(dim=-1) / (head_dim**0.5)
        p.step()

-----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
          ProfilerStep*        21.51%     326.207us       100.00%       1.517ms      50.561us            30  
              aten::mul        38.01%     576.578us        38.01%     576.578us      19.219us            30  
              aten::sum        33.82%     513.004us        35.19%     533.840us      17.795us            30  
              aten::div         3.20%      48.488us         5.29%      80.192us       2.673us            30  
               aten::to         0.27%       4.167us         2.09%      31.704us       1.057us            30  
         aten::_to_copy         0.74%      11.167us         1.82%      27.537us       0.918us            30  
          

In [13]:
with profile_and_report("mul_qk_gather") as p:
    for _ in range(N_RUNS):
        attn_weights = (qs_gather * ks_gather).sum(dim=-1) / (head_dim**0.5)
        p.step()

-----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
          ProfilerStep*        16.73%     292.496us       100.00%       1.748ms      58.277us            30  
              aten::sum        40.19%     702.569us        41.48%     725.242us      24.175us            30  
              aten::mul        37.43%     654.372us        37.43%     654.372us      21.812us            30  
              aten::div         2.71%      47.410us         4.36%      76.202us       2.540us            30  
               aten::to         0.25%       4.376us         1.65%      28.792us       0.960us            30  
         aten::_to_copy         0.55%       9.669us         1.40%      24.416us       0.814us            30  
          

**Option 2**: use einsum

In [14]:
with profile_and_report("mul_qk_indsel") as p:
    for _ in range(N_RUNS):
        attn_weights = torch.einsum("hmd,hmd->hm", qs_indsel, ks_indsel) / (head_dim**0.5)
        p.step()

-----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
          ProfilerStep*        19.67%     316.414us       100.00%       1.608ms      53.614us            30  
           aten::einsum         7.54%     121.334us        74.52%       1.199ms      39.953us            30  
              aten::bmm        59.06%     949.998us        59.17%     951.749us      31.725us            30  
              aten::div         3.55%      57.167us         5.81%      93.418us       3.114us            30  
          aten::permute         3.85%      61.871us         4.74%      76.202us       0.508us           150  
             aten::view         2.60%      41.832us         2.60%      41.832us       0.349us           120  
          

## Step 3: softmax on weights

In [15]:
with profile_and_report("softmax") as p:
    for _ in range(N_RUNS):
        num = (attn_weights - attn_weights.max()).exp()
        den = torch.index_add(torch.zeros((n_heads, seq_len)), 1, q_indices.flatten(), num)
        den = den.index_select(1, q_indices.flatten())
        attn_weights = num / den
        p.step()

----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
         ProfilerStep*        14.68%     266.957us       100.00%       1.819ms      60.631us            30  
       aten::index_add         1.29%      23.461us        55.71%       1.013ms      33.775us            30  
    aten::scatter_add_        53.62%     975.244us        53.65%     975.871us      32.529us            30  
             aten::exp        12.95%     235.464us        12.95%     235.464us       7.849us            30  
    aten::index_select         8.08%     147.012us         8.23%     149.677us       4.989us            30  
             aten::max         2.99%      54.411us         3.41%      62.044us       2.068us            30  
             aten::

## Step 4: multiply weights and values

In [16]:
with profile_and_report("finalize") as p:
    for _ in range(N_RUNS):
        vs_indsel = v.index_select(1, kv_indices.flatten()).view(n_heads, -1, head_dim)
        weighted_vs = attn_weights.unsqueeze(-1) * vs_indsel
        out = torch.zeros((n_heads, seq_len, head_dim))
        out.index_add_(1, q_indices.flatten(), weighted_vs)
        p.step()

----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
         ProfilerStep*         5.53%     360.924us       100.00%       6.527ms     217.553us            30  
      aten::index_add_        72.92%       4.759ms        73.25%       4.781ms     159.363us            30  
    aten::index_select        10.84%     707.590us        10.90%     711.382us      23.713us            30  
             aten::mul         9.31%     607.629us         9.31%     607.629us      20.254us            30  
           aten::zeros         0.19%      12.504us         0.47%      30.918us       1.031us            30  
          aten::select         0.23%      14.790us         0.34%      21.916us       0.365us            60  
       aten::unsque

## Put all together

In [17]:
def sparse_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    indices: torch.Tensor,
) -> torch.Tensor:
    q_indices = indices[:, 0].flatten()
    kv_indices = indices[:, 1].flatten()

    qs_indsel = q.index_select(1, q_indices).view(n_heads, -1, head_dim)
    ks_indsel = k.index_select(1, kv_indices).view(n_heads, -1, head_dim)

    attn_weights = (qs_indsel * ks_indsel).sum(dim=-1) / (head_dim**0.5)

    num = (attn_weights - attn_weights.max()).exp()
    den = torch.index_add(torch.zeros((n_heads, seq_len)), 1, q_indices, num)
    den = den.index_select(1, q_indices)
    attn_weights = num / den

    vs_indsel = v.index_select(1, kv_indices).view(n_heads, -1, head_dim)
    weighted_vs = attn_weights.unsqueeze(-1) * vs_indsel
    out = torch.zeros((n_heads, seq_len, head_dim))
    out.index_add_(1, q_indices, weighted_vs)

    return out

In [18]:
with profile_and_report("sparse_attn") as p:
    for _ in range(N_RUNS):
        res_sparse = sparse_attention(q, k, v, indices)
        p.step()

-----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
          ProfilerStep*         9.65%       1.397ms       100.00%      14.482ms     482.736us            30  
       aten::index_add_        31.51%       4.563ms        31.67%       4.586ms     152.867us            30  
     aten::index_select        24.86%       3.600ms        24.93%       3.611ms      30.088us           120  
              aten::mul        14.77%       2.138ms        14.77%       2.138ms      35.639us            60  
        aten::index_add         0.18%      25.835us         7.17%       1.039ms      34.635us            30  
     aten::scatter_add_         6.87%     994.261us         6.87%     995.385us      33.180us            30  
          

In [19]:
torch.testing.assert_close(res_dense, res_sparse)

## Index select axis

In [36]:
with profile_and_report("indsel_qk_1") as p:
    for _ in range(N_RUNS):
        qs_indsel_1 = q.index_select(1, q_indices.flatten()).view(n_heads, -1, head_dim)
        ks_indsel_1 = k.index_select(1, kv_indices.flatten()).view(n_heads, -1, head_dim)
        p.step()

----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
         ProfilerStep*        14.30%     176.954us       100.00%       1.238ms      41.250us            30  
    aten::index_select        83.39%       1.032ms        83.87%       1.038ms      17.299us            60  
            aten::view         1.53%      18.920us         1.53%      18.920us       0.315us            60  
           aten::empty         0.48%       5.953us         0.48%       5.953us       0.099us            60  
         aten::flatten         0.30%       3.708us         0.30%       3.708us       0.062us            60  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total

In [37]:
qt = q.transpose(0, 1).contiguous()
kt = k.transpose(0, 1).contiguous()

with profile_and_report("indsel_qk_0") as p:
    for _ in range(N_RUNS):
        qs_indsel_0 = qt.index_select(0, q_indices.flatten()).view(-1, n_heads, head_dim)
        ks_indsel_0 = kt.index_select(0, kv_indices.flatten()).view(-1, n_heads, head_dim)
        p.step()

----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
         ProfilerStep*        16.34%     279.425us       100.00%       1.710ms      56.987us            30  
    aten::index_select        79.33%       1.356ms        81.92%       1.400ms      23.341us            60  
          aten::select         1.53%      26.172us         2.15%      36.835us       0.307us           120  
            aten::view         1.48%      25.372us         1.48%      25.372us       0.423us            60  
      aten::as_strided         0.62%      10.663us         0.62%      10.663us       0.089us           120  
           aten::empty         0.43%       7.329us         0.43%       7.329us       0.122us            60  
         aten::flat

In [38]:
qt = q.transpose(1, 2).contiguous()
kt = k.transpose(1, 2).contiguous()

with profile_and_report("indsel_qk_2") as p:
    for _ in range(N_RUNS):
        qs_indsel_0 = qt.index_select(2, q_indices.flatten()).view(n_heads, head_dim, -1)
        ks_indsel_0 = kt.index_select(2, kv_indices.flatten()).view(n_heads, head_dim, -1)
        p.step()

----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
         ProfilerStep*         8.30%     378.663us       100.00%       4.562ms     152.067us            30  
    aten::index_select        89.88%       4.100ms        90.87%       4.145ms      69.091us            60  
          aten::select         0.57%      25.832us         0.81%      37.087us       0.309us           120  
            aten::view         0.72%      33.002us         0.72%      33.002us       0.550us            60  
      aten::as_strided         0.25%      11.255us         0.25%      11.255us       0.094us           120  
           aten::empty         0.18%       8.086us         0.18%       8.086us       0.135us            60  
         aten::flat