In [1]:
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import torch
from matplotlib import pyplot as plt

# From Puzzles to Real Code

All of these puzzles induces you to think about a smart way of using broadcast rules. But turns out that broadcasting is not only useful to solve "puzzles". To illustrate this better, here I'm going to show you two code snippets that I took from my recent research projects. For both problems, I've used broadcasting to write an optimized version of standard PyTorch functions. 

Since in real problems we usually have tensors with a batch dimension to leverage GPUs, these optimizations have to deal with the `batch` dimension. Note also that contrary to some problems, the ones covered here cannot be solved via reshaping since each sequence in the batch is an independent example.

# Aggregating word pieces

When we tokenize texts into word pieces (e.g., BPEs), we end up splitting not only one word from another, but also pieces of the word itself. For example:

```python
>>> wordpiece_tokenize("Welcome to the jungle")
["_Wel", "come", "_to", "_the", "_jungle"]
```

The symbol `_` represents the first piece of a tokenized word. Word piece tokenization has some advantages such as limiting the size of the vocabulary. However, consider a word labelling problem where each word is associated with a label, such as POS tagging or NER. A direct consequence of using word piece tokenization is that the number of "tokens" $m$ becomes larger than the actual number of words/labels $n$. Therefore, we need to **map** the tokenized pieces back to their actual words, such that $m = n$ again.

A simple way to solve this problem is following the strategy adopted by BERT: we only select the information from the first word piece. For example:

```python
>>> map_pieces_to_words(["_Wel", "come", "_to", "_the", "_jungle"])
["_Wel", "_to", "_the", "_jungle"]
```

And so `len(map_pieces_to_words(pieces)) == len(input.split())`.

## Setup

Consider a model that receives three input tensors:

- `input_ids`, a `torch.IntTensor` with a shape of `(batch_size, sequence_length)`
- `attention_mask`, a `torch.IntTensor` with a shape of `(batch_size, sequence_length)`
- `first_piece_mask`, a `torch.IntTensor` with a shape of `(batch_size, sequence length)`

`input_ids` contains indices of word pieces in the vocabulary. `attention_mask` contains boolean values denoting valid and padded positions. `first_piece_mask` contains boolean values denoting whether a token is the first word piece of that word or not. All tensors are properly padded to the right. For example:

```python
input_texts = ["Welcome to the jungle", "Hello darkness my old friend"]
input_pieces = [wordpiece_tokenize(text) for text in input_texts]
```

Let's say that the output of this code would be:
```python
>>> input_pieces
[
    ["_Wel", "come", "_to", "_the", "_jungle"],
    ["_He", "llo", "_dark", "ness", "_my", "_old", "_fri", "end"]
]
```

Creating our inputs:
```python
>>> input_ids = pad([pieces_to_ids(pieces) for pieces in input_pieces], pad_value=-1)
>>> input_ids
[
    [10, 11, 12, 13, 14, -1, -1, -1], 
    [15, 16, 17, 18, 19, 20, 21, 22]
]

>>> attention_mask = pad([[True]*len(pieces) for pieces in input_pieces], pad_value=0)
>>> attention_mask
[
    [1, 1, 1, 1, 1, 0, 0, 0], 
    [1, 1, 1, 1, 1, 1, 1, 1]
]

>>> first_piece_mask = pad([[p.startswith('_') for p in pieces] for pieces in input_pieces], pad_value=0)
>>> first_piece_mask
[
    [1, 0, 1, 1, 1, 0, 0, 0], 
    [1, 0, 1, 0, 1, 1, 1, 0]
]
```

Creating tensors:

```python
>>> input_ids = torch.as_tensor(input_ids)
>>> attention_mask = torch.as_tensor(attention_mask)
>>> first_piece_mask = torch.as_tensor(first_piece_mask)
>>> input_ids.shape  # batch_size = 2, sequence_length = 8
torch.Size([2, 8])
```

That is it. Our setup is done. In actual code:

In [None]:
input_texts = ["Welcome to the jungle", "Hello darkness my old friend"]
input_pieces = [["_Wel", "come", "_to", "_the", "_jungle"], 
                ["_He", "llo", "_dark", "ness", "_my", "_old", "_fri", "end"]]
input_ids = torch.as_tensor([[10, 11, 12, 13, 14, -1, -1, -1], [15, 16, 17, 18, 19, 20, 21, 22]])
attention_mask = torch.as_tensor([[1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1]])
first_piece_mask = torch.as_tensor([[1, 0, 1, 1, 1, 0, 0, 0], [1, 0, 1, 0, 1, 1, 1, 0]])
batch_size, seq_len = input_ids.shape

### First word-piece selection

Now, how can we efficiently select the input ids of the **first word piece** for each sequence in the batch?

Note that a simple binary-indexing strategy might gives us trouble for two reasons:
1. The number of 1s for each sequence in the batch differ.
2. Even if the number of 1s was equal for all sequences in the batch, binary indexing does not return a tensor with the same shape as the original tensor.

If you try that out, the result would be the following:
```python
>>> input_ids[first_piece_mask.bool()]
tensor([10, 12, 13, 14, 15, 17, 19, 20, 21])
```

Which is not what we want. To circumvent this behavior, we can resort to positional indexing. That is, we want to select the elements in the following positions:
```python
[
    [0, 2, 3, 4], 
    [0, 2, 4, 5, 6]
]
```

But here we also face the first issue, namely, that the number of "selected ids" is different for the two sequences in the batch. A simple fix to this problem is to _pad_ the first tensor with a dummy index value (e.g., `-1`).

Therefore, we are looking for a function that is a vectorized version of `torch.nonzero`. In simple words, we want a function that returns the indices of nonzero input elements as a padded tensor. Here are two ways of achieving this ,one using for loops + `torch.nonzero`, and another using broadcasting. 

In [None]:
from torch.nn.utils.rnn import pad_sequence

def pad(sequences, pad_value=0):
    return pad_sequence(sequences, batch_first=True, padding_value=pad_value)

def simple_padded_nonzero(mask: torch.LongTensor, pad_value: int = -1) -> torch.LongTensor:
    """
    Returns a (right-padded) tensor containing indices of nonzero elements from a binary mask tensor.
    
    Example:
        [[1, 0, 0, 1, 0, 1, 0, 1],
         [1, 1, 0, 0, 0, 1, 0, 0]]
        will be transformed to:
        [[0, 3, 5, 7],
         [0, 1, 5, -1]]
        where -1 indicates pad positions.
    
    Args:
        mask: torch.LongTensor with shape of (batch_size, sequence_length)
    
    Returns:
        torch.LongTensor with shape of (batch_size, original_sequence_length)
        where original_sequence_length = max(sum(mask, dim=-1))
    """
    batch_size, seq_len = mask.shape
    non_zero_tensors = [torch.nonzero(mask[i]).flatten() for i in range(batch_size)]
    return pad(non_zero_tensors, pad_value).to(mask.device)

def vectorized_padded_nonzero(mask: torch.LongTensor, pad_value: int = -1) -> torch.LongTensor:
    non_zeros = mask.nonzero()
    non_zero_rows = non_zeros[:, 0]
    non_zero_cols = non_zeros[:, 1]
    count_unique_ids = non_zero_rows.bincount().cumsum(dim=0).cpu()
    non_zero_tensors = non_zero_cols.tensor_split(count_unique_ids[:-1])
    return pad(non_zero_tensors, pad_value).to(mask.device)

In [None]:
first_piece_idxs = simple_padded_nonzero(first_piece_mask)
first_piece_idxs

In [None]:
first_piece_idxs = vectorized_padded_nonzero(first_piece_mask)
first_piece_idxs

Great! Both versions return the same output. Later on we will compare everything in terms of running time to see the impact of vectorizing. Now that we have the indices of the first pieces, we can simply do an index selection as follows:

In [None]:
ar = torch.arange(batch_size)
input_ids[ar.unsqueeze(-1), first_piece_idxs]

In [None]:
attention_mask[ar.unsqueeze(-1), first_piece_idxs]

In [None]:
x = torch.randint(0, 2, size=(128, 512))
%timeit simple_padded_nonzero(x.cpu())
%timeit vectorized_padded_nonzero(x.cpu())

In [None]:
%timeit simple_padded_nonzero(x.cuda())
%timeit vectorized_padded_nonzero(x.cuda())

### Summing and averaging pieces

Instead of selecting only the first piece of each word, we can think of other aggregation strategies, such as summing or averaging all piece vectors.

In [None]:
print(first_piece_mask)
print(first_piece_idxs)

In [None]:
first_piece_mask_aug = first_piece_mask + (1 - attention_mask)
last_piece_mask = ((first_piece_mask_aug  - first_piece_mask_aug.roll(-1)) <= 0).long() * attention_mask
last_piece_idxs = vectorized_padded_nonzero(last_piece_mask)
print(last_piece_mask)
print(last_piece_idxs)

In [None]:
cumsummed_pieces = input_ids.cumsum(dim=1)
cumsummed_pieces

In [None]:
a = cumsummed_pieces[ar.unsqueeze(-1), last_piece_idxs]
shifted_cumsummed_pieces = torch.cat((torch.zeros(batch_size, 1), cumsummed_pieces[:, :-1]), dim=1)
summed_pieces = a - shifted_cumsummed_pieces[ar.unsqueeze(-1), first_piece_idxs]
summed_pieces

In [None]:
lengths = last_piece_idxs - first_piece_idxs + 1
summed_pieces / lengths

---

# Clustered attention

Self-attention in transformers work with 3 tensors:

- `queries` with a shape of `(batch_size, num_heads, sequence_length, hidden_size)`
- `keys` with a shape of `(batch_size, num_heads, sequence_length, hidden_size)`
- `values` with a shape of `(batch_size, num_heads, sequence_length, hidden_size)`

Attention is computed as follows:

```python
# 1. compute logits in O(n^2 * d)
logits = queries @ keys.transpose(-1, -2) / math.sqrt(hidden_size)
# 2. mask out padding positions
logits = torch.masked_fill(attention_mask == 0, -9999999.)
# 3. map logits to probabilities
probas = torch.softmax(logits, dim=-1)
# 4. compute a weighted sum of value vectors
output = probas @ values
```

Let's say we want to improve the self-attention performance in transformers by working with clusters. The idea is that instead of compution all $n \times n$ dot-products, we can map queries, keys, and values to some clusters, and then compute dot-product only inside those clusters. Concretely, if we have **balanced** $c$ clusters, we can reduce the self-attention cost to:

$$
O \left(c \times \frac{n}{c} \times \frac{n}{c} \times d \right) = O\left(\frac{n^2}{c} \times d\right)
$$

If we set $c = \sqrt{n}$, which is a reasonable choice for the number of clusters for most applications, we get $O(n\sqrt{n} \times d)$, which is better than the quadractic cost $O(n^2 \times d)$.


In pratical terms, consider that you are given $c$ clusters as represented by their centroids for each head:

- `centroids` with a shape of `(num_heads, num_centroids, hidden_size)`

How can we compute attention efficiently?

In [3]:
def clustered_attention_sorted(q, k, v, centroids, mask):
    # get sequence lengths for q and k (might be different for seq2seq problems)
    q_seq_len = q.shape[-2]
    k_seq_len = k.shape[-2]
    batch_size = q.shape[0]
    num_heads = q.shape[1]
    num_centroids = centroids.shape[-2]
    
    # add `batch` dimension
    # (batch_size, num_heads, 1, num_centroids, num_projections)
    expanded_centroids = centroids[None, :, None, :, :].expand(batch_size, -1, 1, -1, -1)

    # add `middle` dimension
    # (batch_size, num_heads, 1, q_seq_len, num_projections)
    expanded_q = q.unsqueeze(2)
    # (batch_size, num_heads, 1, k_seq_len, num_projections)
    expanded_k = k.unsqueeze(2)

    # q_dists.shape is (batch, num_heads, 1, q_seq_len, num_centroids)
    q_dists = torch.cdist(expanded_q, expanded_centroids, p=2)
    # k_dists.shape is (batch, num_heads, 1, k_seq_len, num_centroids)
    k_dists = torch.cdist(expanded_k, expanded_centroids, p=2)

    # q_clustered.shape is (batch, num_heads, 1, q_seq_len)
    q_clustered = torch.argmin(q_dists, dim=-1)
    # k_clustered.shape is (batch, num_heads, 1, k_seq_len)
    k_clustered = torch.argmin(k_dists, dim=-1)

    # transpose to get `1` as different hashing rounds
    # q_clustered.shape is (batch, num_heads, q_seq_len, 1)
    q_clustered = q_clustered.transpose(2, 3)
    # k_clustered.shape is (batch, num_heads, k_seq_len, 1)
    k_clustered = k_clustered.transpose(2, 3)
    
    # deal with mask later, but we can also
    # set cluster id for padding positions as `num_centroids` (ids start with 0)
    # q_clustered = q_clustered.masked_fill(~mask.view(batch_size, 1, q_seq_len, 1), num_centroids)
    # k_clustered = k_clustered.masked_fill(~mask.view(batch_size, 1, k_seq_len, 1), num_centroids)

    # we need to divide q_clustered into (similarly for k_clustered)
    # (batch, num_heads, num_centroids, max_cluster_size_q_for_all_batch_and_heads, 1)

    # q_clustered_bin.shape is (batch, num_heads, q_seq_len, num_centroids)
    q_clustered_bin = q_clustered == torch.arange(num_centroids, device=device)
    # k_clustered_bin.shape is (batch, num_heads, k_seq_len, num_centroids)
    k_clustered_bin = k_clustered == torch.arange(num_centroids, device=device)

    # q_clustered_bin.shape is (batch, num_heads, num_centroids, q_seq_len)
    q_clustered_bin = q_clustered_bin.transpose(-1, -2).int()
    # k_clustered_bin.shape is (batch, num_heads, num_centroids, k_seq_len)
    k_clustered_bin = k_clustered_bin.transpose(-1, -2).int()

    # get the max cluster size across all batches and heads
    max_cluster_size_q = q_clustered_bin.sum(-1).max().item()
    max_cluster_size_k = k_clustered_bin.sum(-1).max().item()

    # utopically, max_cluster_size_q = q_seq_len / num_centroids
    # but in this implementation I'm ignoring this assumption
    # `q_clustered_vals` contains only 0 or 1 ints (due to one hot binarization)
    q_clustered_vals, q_clustered_idxs = q_clustered_bin.sort(dim=-1, descending=True, stable=True)
    k_clustered_vals, k_clustered_idxs = k_clustered_bin.sort(dim=-1, descending=True, stable=True)
    # values that are 0 correspond to padding positions, so we mask them with q_seq_len - 1 (last token)
    q_clustered_idxs[~q_clustered_vals.bool()] = q_seq_len - 1
    k_clustered_idxs[~k_clustered_vals.bool()] = k_seq_len - 1
    # get 0 and 1s as masks
    mask_clustered_q = q_clustered_vals.bool()
    mask_clustered_k = k_clustered_vals.bool()

    # deal with padding
    lenghts = mask.sum(-1)[:, None, None, None]
    pad_mask_bucketed_q = q_clustered_idxs < lenghts
    pad_mask_bucketed_k = k_clustered_idxs < lenghts
    
    # combine masks
    full_mask_bucketed_q = mask_clustered_q & pad_mask_bucketed_q
    full_mask_bucketed_k = mask_clustered_k & pad_mask_bucketed_k

    # q_bucketed.shape is (batch, num_heads, num_centroids, max_cluster_size_q)
    q_bucketed = q_clustered_idxs[:, :, :, :max_cluster_size_q]
    # k_bucketed.shape is (batch, num_heads, num_centroids, max_cluster_size_k)
    k_bucketed = k_clustered_idxs[:, :, :, :max_cluster_size_k]
    # same shape as above
    mask_bucketed_q = mask_clustered_q[:, :, :, :max_cluster_size_q]
    mask_bucketed_k = mask_clustered_k[:, :, :, :max_cluster_size_k]
    full_mask_bucketed_q = full_mask_bucketed_q[:, :, :, :max_cluster_size_q]
    full_mask_bucketed_k = full_mask_bucketed_k[:, :, :, :max_cluster_size_k]
    # create pairwise mask with shape (batch, num_heads, num_centroids, max_cluster_size_q, max_cluster_size_k)
    mask_bucketed = full_mask_bucketed_q.unsqueeze(-1) & full_mask_bucketed_k.unsqueeze(-2)

    # (batch, num_heads, num_clusters * max_cluster_size)
    squished_inds_q = q_bucketed.reshape(batch_size, num_heads, -1)
    squished_inds_k = k_bucketed.reshape(batch_size, num_heads, -1)

    # keys and values are bucketed with the same buckets
    # the bucketed tensors are (batch, num_heads, num_clusters * max_cluster_size, head_size)
    bucketed_q = q.gather(2, squished_inds_q.unsqueeze(-1).expand(-1, -1, -1, head_size))
    bucketed_k = k.gather(2, squished_inds_k.unsqueeze(-1).expand(-1, -1, -1, head_size))
    bucketed_v = v.gather(2, squished_inds_k.unsqueeze(-1).expand(-1, -1, -1, head_size))

    # we now expand the squished dim into (num_centroids, max_cluster_size)
    bucketed_q = bucketed_q.view(batch_size, num_heads, num_centroids, -1, head_size)
    bucketed_k = bucketed_k.view(batch_size, num_heads, num_centroids, -1, head_size)
    bucketed_v = bucketed_v.view(batch_size, num_heads, num_centroids, -1, head_size)

    # dots are (batch, num_heads, num_centroids, max_cluster_size_q, max_cluster_size_k)
    sqrt_d = head_size ** 0.5
    dots = bucketed_q @ bucketed_k.transpose(-1, -2) / sqrt_d

    # mask the dots past key length; add `max_cluster_size_q` dim for broadcasting
    neg_inf = -9999999.0
    dots = dots.masked_fill(~mask_bucketed, neg_inf)  # float('-inf') will generate nans in softmax

    # att_dist is (batch, num_heads, num_centroids, max_cluster_size_q, max_cluster_size_k)
    att_dist = torch.softmax(dots, dim=-1)

    # fix the uniform numbers for padding positions
    att_dist = att_dist * mask_bucketed.float()

    # output is (batch, num_heads, num_centroids, max_cluster_size_q, head_size)
    output = torch.matmul(att_dist, bucketed_v)

    # make sure squashed indices for pad positions are higher than last valid token id
    squished_mask_q = mask_bucketed_q.reshape(batch_size, num_heads, -1)
    # squished_mask_k = mask_bucketed_k.reshape(batch_size, num_heads, -1)
    fixed_squished_inds_q = squished_inds_q.masked_fill(~squished_mask_q, q_seq_len + 1)
    # fixed_squished_inds_k = squished_inds_q.masked_fill(~squished_mask_k, k_seq_len + 1)

    # get indices of valid contextualized query vectors
    _, rev_inds_q = fixed_squished_inds_q.sort(dim=-1, stable=True)
    # truncate to get only the first q_seq_len vectors -> the valid ones
    rev_inds_q = rev_inds_q[:, :, :q_seq_len]
    # fix order
    rev_inds_q, _ = rev_inds_q.sort(dim=-1)

    # squish output and gather correct vectors
    squished_output = output.view(batch_size, num_heads, -1, head_size)
    # output.shape is (batch, num_heads, q_seq_len, head_size)
    output = squished_output.gather(2, rev_inds_q.unsqueeze(-1).expand(-1, -1, -1, head_size))

    # concat heads back
    output = output.transpose(1, 2).reshape(batch_size, -1, num_heads * head_size)
    
    return output

In [4]:
from torch.nn.utils.rnn import pad_sequence

def pad(sequences, pad_value=0):
    return pad_sequence(sequences, batch_first=True, padding_value=pad_value)

def clustered_attention_vectorized(q, k, v, centroids, mask):
    # get sequence lengths for q and k (might be different for seq2seq problems)
    q_seq_len = q.shape[-2]
    k_seq_len = k.shape[-2]
    batch_size = q.shape[0]
    num_heads = q.shape[1]
    num_centroids = centroids.shape[-2]
    
    # add `batch` dimension
    # (batch_size, num_heads, 1, num_centroids, num_projections)
    expanded_centroids = centroids[None, :, None, :, :].expand(batch_size, -1, 1, -1, -1)

    # add `middle` dimension
    # (batch_size, num_heads, 1, q_seq_len, num_projections)
    expanded_q = q.unsqueeze(2)
    # (batch_size, num_heads, 1, k_seq_len, num_projections)
    expanded_k = k.unsqueeze(2)

    # q_dists.shape is (batch, num_heads, 1, q_seq_len, num_centroids)
    q_dists = torch.cdist(expanded_q, expanded_centroids, p=2)
    # k_dists.shape is (batch, num_heads, 1, k_seq_len, num_centroids)
    k_dists = torch.cdist(expanded_k, expanded_centroids, p=2)

    # q_clustered.shape is (batch, num_heads, 1, q_seq_len)
    q_clustered = torch.argmin(q_dists, dim=-1)
    # k_clustered.shape is (batch, num_heads, 1, k_seq_len)
    k_clustered = torch.argmin(k_dists, dim=-1)

    # transpose to get `1` as different hashing rounds
    # q_clustered.shape is (batch, num_heads, q_seq_len, 1)
    q_clustered = q_clustered.transpose(2, 3)
    # k_clustered.shape is (batch, num_heads, k_seq_len, 1)
    k_clustered = k_clustered.transpose(2, 3)
    
    # we will deal with masking later, but we could
    # set cluster id for padding positions as `num_centroids` (ids start with 0)
    # q_clustered = q_clustered.masked_fill(~mask.view(batch_size, 1, q_seq_len, 1), num_centroids)
    # k_clustered = k_clustered.masked_fill(~mask.view(batch_size, 1, k_seq_len, 1), num_centroids)

    # we need to divide q_clustered into (similarly for k_clustered)
    # (batch, num_heads, num_centroids, max_cluster_size_q_for_all_batch_and_heads, 1)

    # q_clustered_bin.shape is (batch, num_heads, q_seq_len, num_centroids)
    q_clustered_bin = q_clustered == torch.arange(num_centroids, device=device)
    # k_clustered_bin.shape is (batch, num_heads, k_seq_len, num_centroids)
    k_clustered_bin = k_clustered == torch.arange(num_centroids, device=device)

    # q_clustered_bin.shape is (batch, num_heads, num_centroids, q_seq_len)
    q_clustered_bin = q_clustered_bin.transpose(-1, -2).int()
    # k_clustered_bin.shape is (batch, num_heads, num_centroids, k_seq_len)
    k_clustered_bin = k_clustered_bin.transpose(-1, -2).int()
    
    # arange tensors for queries and keys
    q_ar = 1 + torch.arange(q_seq_len, device=device).view(1, 1, 1, -1).expand_as(q_clustered_bin)
    k_ar = 1 + torch.arange(k_seq_len, device=device).view(1, 1, 1, -1).expand_as(k_clustered_bin)
    
    # q_nz.shape is (num_now_zero_entries, 4)
    # where each column contains the nonzero ids for each original dimension,
    # namely: batch, num_heads, num_centroids, q_seq_len
    q_nz = (q_ar * q_clustered_bin).nonzero()
    k_nz = (k_ar * k_clustered_bin).nonzero()
    
    # convert the first three columns into a single column
    q_rows = q_nz[:, 0] * (num_heads * num_centroids) + q_nz[:, 1] * num_centroids + q_nz[:, 2]
    k_rows = k_nz[:, 0] * (num_heads * num_centroids) + k_nz[:, 1] * num_centroids + k_nz[:, 2]
    
    # the last column is the sequence dimension (the one we care about) 
    q_cols = q_nz[:, -1]
    k_cols = k_nz[:, -1]
    
    # count the number of unique row ids and cumsum them to create continuous slices 
    q_split_slices = q_rows.bincount().cumsum(dim=0)[:-1].cpu().tolist()
    k_split_slices = k_rows.bincount().cumsum(dim=0)[:-1].cpu().tolist()
    
    # pad for missing slices since the last head of the last batch might be empty
    num_total_centroids = batch_size * num_heads * num_centroids
    q_num_missing_centroids = num_total_centroids - len(q_split_slices)
    k_num_missing_centroids = num_total_centroids - len(k_split_slices)
    q_split_slices.extend([q_split_slices[-1]] * (q_num_missing_centroids - 1))
    k_split_slices.extend([k_split_slices[-1]] * (k_num_missing_centroids - 1))
    
    # merge the sequence ids in tensors following the slices
    q_splited = q_cols.tensor_split(q_split_slices)
    k_splited = k_cols.tensor_split(k_split_slices)
    
    # pad the smaller tensors with -1 and reshape back to 
    # (batch_size, num_heads, num_centroids, max_cluster_size_q_for_all_batch_and_heads)
    q_bucketed_idxs = pad(q_splited, -1).view(batch_size, num_heads, num_centroids, -1)
    k_bucketed_idxs = pad(k_splited, -1).view(batch_size, num_heads, num_centroids, -1)
    
    # get the max cluster size across all batches and heads
    # utopically, max_cluster_size_q = q_seq_len / num_centroids
    # but in this implementation I'm ignoring this assumption
    max_cluster_size_q = q_bucketed_idxs.shape[-1]
    max_cluster_size_k = k_bucketed_idxs.shape[-1]
    mask_bucketed_q = q_bucketed_idxs != -1
    mask_bucketed_k = k_bucketed_idxs != -1
    
    # deal with padding
    lenghts = mask.sum(-1)[:, None, None, None]
    pad_mask_bucketed_q = q_bucketed_idxs < lenghts
    pad_mask_bucketed_k = k_bucketed_idxs < lenghts

    # combine masks
    full_mask_bucketed_q = mask_bucketed_q & pad_mask_bucketed_q
    full_mask_bucketed_k = mask_bucketed_k & pad_mask_bucketed_k
    
    # create pairwise mask with shape 
    # (batch, num_heads, num_centroids, max_cluster_size_q, max_cluster_size_k)
    # this is where having balanced clusters with num_centroids = sqrt(n)
    # leads to performance improvements
    mask_bucketed = full_mask_bucketed_q.unsqueeze(-1) & full_mask_bucketed_k.unsqueeze(-2)

    # (batch, num_heads, num_clusters * max_cluster_size)
    q_bucketed = q_bucketed_idxs.masked_fill(~mask_bucketed_q, q_seq_len - 1)
    k_bucketed = k_bucketed_idxs.masked_fill(~mask_bucketed_k, k_seq_len - 1)
    squished_inds_q = q_bucketed.reshape(batch_size, num_heads, -1)
    squished_inds_k = k_bucketed.reshape(batch_size, num_heads, -1)

    # keys and values are bucketed with the same ids
    # the bucketed tensors are (batch, num_heads, num_clusters * max_cluster_size, head_size)
    bucketed_q = q.gather(2, squished_inds_q.unsqueeze(-1).expand(-1, -1, -1, head_size))
    bucketed_k = k.gather(2, squished_inds_k.unsqueeze(-1).expand(-1, -1, -1, head_size))
    bucketed_v = v.gather(2, squished_inds_k.unsqueeze(-1).expand(-1, -1, -1, head_size))

    # we now expand the squished dim into (num_centroids, max_cluster_size)
    bucketed_q = bucketed_q.view(batch_size, num_heads, num_centroids, -1, head_size)
    bucketed_k = bucketed_k.view(batch_size, num_heads, num_centroids, -1, head_size)
    bucketed_v = bucketed_v.view(batch_size, num_heads, num_centroids, -1, head_size)

    # dots are (batch, num_heads, num_centroids, max_cluster_size_q, max_cluster_size_k)
    sqrt_d = head_size ** 0.5
    dots = bucketed_q @ bucketed_k.transpose(-1, -2) / sqrt_d
    # mask the dots past key length; add `max_cluster_size_q` dim for broadcasting
    # float('-inf') will generate nans in softmax, so we use a very small value
    # instead. This happens because some clusters might be empty
    neg_inf = -9999999.0
    dots = dots.masked_fill(~mask_bucketed, neg_inf)

    # att_dist is (batch, num_heads, num_centroids, max_cluster_size_q, max_cluster_size_k)
    att_dist = torch.softmax(dots, dim=-1)

    # fix the uniform numbers for padding positions
    att_dist = att_dist * mask_bucketed.float()

    # output is (batch, num_heads, num_centroids, max_cluster_size_q, head_size)
    att_output = torch.matmul(att_dist, bucketed_v)

    # squish output and mask
    squished_output = att_output.view(batch_size, num_heads, -1, head_size)
    squished_mask_q = mask_bucketed_q.view(batch_size, num_heads, -1)
    
    # get indices of valid contextualized query vectors
    ar = torch.arange(num_centroids * max_cluster_size_q, device=device)
    ar = ar.view(1, 1, -1).expand(batch_size, num_heads, -1)
    squished_idxs_q = ar[squished_mask_q].view(batch_size, num_heads, -1)
    
    # output.shape is (batch, num_heads, q_seq_len, head_size)
    output = squished_output.gather(2, squished_idxs_q.unsqueeze(-1).expand(-1, -1, -1, head_size))
    
    # concat heads back
    # output.shape is (batch, q_seq_len, hidden_size)
    output = output.transpose(1, 2).reshape(batch_size, -1, num_heads * head_size)
    
    return output

In [None]:
batch_size = 2
sequence_length = 2048
num_heads = 4
head_size = 4
hidden_size = num_heads * head_size
num_centroids = 3
device = 'cuda' if torch.cuda.is_available() else 'cpu'

centroids = torch.randn(num_heads, num_centroids, hidden_size).to(device)
q = torch.randn(batch_size, num_heads, sequence_length, hidden_size).to(device)
k = torch.randn(batch_size, num_heads, sequence_length, hidden_size).to(device)
v = torch.randn(batch_size, num_heads, sequence_length, hidden_size).to(device)
mask = torch.ones(batch_size, sequence_length).bool().to(device)
# mask = torch.tensor([5, 8]).unsqueeze(-1) >= torch.arange(sequence_length).unsqueeze(0).expand(batch_size, -1)
# mask = mask.to(device)

In [None]:
with torch.no_grad():
    %timeit clustered_attention_sorted(q, k, v, centroids, mask).shape

In [None]:
with torch.no_grad():
    %timeit clustered_attention_vectorized(q, k, v, centroids, mask).shape

In [None]:
del q, k, v, mask, centroids
torch.cuda.empty_cache()

---

## Timing

In [6]:
import torch.utils.benchmark as benchmark
from itertools import product

batch_size = 1
num_heads = 1
head_size = 1
hidden_size = num_heads * head_size
sequence_lengths = [64, 128, 512, 1024, 2048, 4096, 8192, 8192*2]
num_centroids = [2, 3, 5, 9, 11]
device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_threads = 1
results = []

for num_c, seq_len in product(num_centroids, sequence_lengths):
    label = 'clustered attention'
    sub_label = f'[{num_c}, {seq_len}]'
    q = torch.randn(batch_size, num_heads, seq_len, hidden_size).to(device)
    k = torch.randn(batch_size, num_heads, seq_len, hidden_size).to(device)
    v = torch.randn(batch_size, num_heads, seq_len, hidden_size).to(device)
    centroids = torch.randn(num_heads, num_c, hidden_size).to(device)
    mask = torch.ones(batch_size, seq_len).bool().to(device)
    results.append(benchmark.Timer(
        stmt='clustered_attention_sorted(q, k, v, centroids, mask)',
        setup='from __main__ import clustered_attention_sorted',
        globals={'q': q, 'k': k, 'v': v, 'centroids': centroids, 'mask': mask},
        num_threads=num_threads,
        label=label,
        sub_label=sub_label,
        description='sorted',
    ).blocked_autorange(min_run_time=1))
    results.append(benchmark.Timer(
        stmt='clustered_attention_vectorized(q, k, v, centroids, mask)',
        setup='from __main__ import clustered_attention_vectorized',
        globals={'q': q, 'k': k, 'v': v, 'centroids': centroids, 'mask': mask},
        num_threads=num_threads,
        label=label,
        sub_label=sub_label,
        description='vectorized',
    ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.print()

[---------- clustered attention ----------]
                   |  sorted  |  vectorized
1 threads: --------------------------------
      [2, 64]      |    1.0   |      1.4   
      [2, 128]     |    1.1   |      1.4   
      [2, 512]     |    1.1   |      1.6   
      [2, 1024]    |    1.1   |      1.7   
      [2, 2048]    |    1.1   |      1.8   
      [2, 4096]    |    2.6   |      2.8   
      [2, 8192]    |   10.6   |     10.3   
      [2, 16384]   |   33.2   |     31.5   
      [3, 64]      |    1.0   |      1.4   
      [3, 128]     |    1.1   |      1.4   
      [3, 512]     |    1.1   |      1.5   
      [3, 1024]    |    1.2   |      1.8   
      [3, 2048]    |    1.8   |      2.3   
      [3, 4096]    |    2.3   |      2.5   
      [3, 8192]    |    8.0   |      7.4   
      [3, 16384]   |   63.3   |     55.8   
      [5, 64]      |    1.0   |      1.4   
      [5, 128]     |    1.1   |      1.4   
      [5, 512]     |    1.2   |      1.5   
      [5, 1024]    |    1.2   | 

--- 
## Profiling

In [7]:
import torch.autograd.profiler as profiler

batch_size = 2
sequence_length = 1024
num_heads = 4
head_size = 4
hidden_size = num_heads * head_size
num_centroids = 3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
centroids = torch.randn(num_heads, num_centroids, hidden_size).to(device)
q = torch.randn(batch_size, num_heads, sequence_length, hidden_size).to(device)
k = torch.randn(batch_size, num_heads, sequence_length, hidden_size).to(device)
v = torch.randn(batch_size, num_heads, sequence_length, hidden_size).to(device)
mask = torch.ones(batch_size, sequence_length).bool().to(device)

In [11]:
with profiler.profile(profile_memory=True, use_cuda=True, with_flops=True) as prof:
    out = clustered_attention_sorted(q, k, v, centroids, mask)

print(prof.key_averages().table(sort_by="self_cuda_time_total"))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  Total MFLOPs  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                             aten::sort         2.32%     655.000us         4.78%       1.349ms     337.250us       1.826ms        19.32%       2.006ms     501.500us           0 

In [10]:
with profiler.profile(profile_memory=True, use_cuda=True, with_flops=True) as prof:
    out = clustered_attention_vectorized(q, k, v, centroids, mask)

print(prof.key_averages().table(sort_by="self_cuda_time_total"))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  Total KFLOPs  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            aten::copy_         1.66%     562.000us         5.81%       1.972ms      30.812us       1.478ms         9.66%       1.478ms      23.094us           0 

## Going the extra mile: blockfied attention

BigBird's self-attention relies on two simplifications to get speed improvements over vanilla self-attention:

1. **Local + global + random connections:** Only attend to pre-determined elements, leading to a fixed pattern in the attention matrix
2. **Blocks:** Group contiguous tokens into chunks, leading to a blockfied pattern in the attention matrix

While the first point is important for attending to relevant elements, the second point is crucial for optimizing runtime. How can we introduce blocks in our implementation? Fortunetally, we can rely (again) on broadcasting to implement a routine that blockfy inputs and another routine to unblockfy the output. 

In [22]:
q.shape, q.shape[:-2]

(torch.Size([2, 4, 1024, 16]), torch.Size([2, 4]))

In [58]:
import math

def blockfy(input, block_size=2):
    seq_len = input.shape[-2]
    n_blocks = math.ceil(seq_len / float(block_size))
    # pad so that seq_len becomes divisible by block_size
    # (batch, heads, seq_len, hdim) -> (batch, heads, seq_len + seq_len % block_size, hdim)
    input_pad = torch.nn.functional.pad(input, (0, 0, 0, block_size - seq_len % block_size))
    # reshape to get contiguous chunks
    # (batch, heads, n, hdim) -> (batch, heads, n_blocks, block_size * hdim)
    return input_pad.view(*input_pad.shape[:-2], n_blocks, -1)


def unblockfy(output, seq_len, block_size=2):
    n_blocks = output.shape[-2]
    # (batch, heads, n_blocks, block_size * hdim) -> (batch, heads, n_blocks * block_size, hdim)
    output_pad = output.view(batch_size, num_heads, n_blocks * block_size, -1)
    # cut pad out
    return output_pad[:, :, :seq_len]


def unblockfy_attn(att_dist, seq_len, block_size=2, pad_mask=None, causal_mask=None):
    # (batch, heads, n_blocks, n_blocks) -> (batch, heads, n_blocks * block_size, n_blocks * block_size)
    att = att_dist.repeat_interleave(block_size, dim=-1).repeat_interleave(block_size, dim=-2)
    # mask out padding and "future" positions
    if pad_mask is not None:
        # (batch, seq_len) -> (batch, n_blocks * block_size, n_blocks * block_size)
        pairwise_mask = pad_mask.unsqueeze(-1) & pad_mask.unsqueeze(1)
        # add head dimension
        pairwise_mask = pairwise_mask.unsqueeze(1)
        if causal_mask is not None:
            # add elements of the triu to the mask
            pairwise_mask = pairwise_mask & causal_mask.unsqueeze(0).unsqueeze(1)
        # mask out 
        att = att.masked_fill(~pairwise_mask, 0)
    # note that att is not a distribution anymore
    return att[..., :seq_len, :seq_len]

In [61]:
print(q.shape)
print(blockfy(q, block_size=3).shape)
print(unblockfy(blockfy(q, block_size=3), seq_len=q.shape[-2], block_size=3).shape)

torch.Size([2, 4, 1024, 16])
torch.Size([2, 4, 342, 48])
torch.Size([2, 4, 1024, 16])


In [64]:
att_dist = torch.randn(2, 4, 342, 342)
unblockfy_attn(att_dist, seq_len=q.shape[-2], block_size=3).shape

torch.Size([2, 4, 1024, 1024])

That is! We can simply call our attention module with:

```python
q_block = blockfy(q)
k_block = blockfy(k)
v_block = blockfy(v)
mask_block = blockfy(mask.unsqueeze(-1), block_size=3).any(-1)
```

And then in the end we can reshape the output back to the original sequence length:
```python
output = unblockfy(output, seq_len=seq_len, block_size=3)
# note that mask and causal_mask should be padded to
# have a length of n_blocks * block_size
att_dist = unblockfy_attn(att_dist, mask, causal_mask)  
```

---

# Computing attention statistics

Given that we are working with batches and heads, how can we compute independent statistics for attention maps? For example, we could be interested in computing the ammount of sparsity, or the recall when compared with a gold attention pattern. The traditional (and safe) way of doing this would envolve flattening tensors and calling a standard method from a well-tested library. However, we would be ignoring all the power that PyTorch brings to us. In fact, in cases like this is where PyTorch's broadcasting shines.

To see the difference, let's implement a traditional version and some PyTorch versions of two statistics: **sparsity** and **recall**.

Quick setup:

In [235]:
batch_size = 2
sequence_length = 9
num_heads = 3
head_size = 2
hidden_size = num_heads * head_size
device = 'cuda' if torch.cuda.is_available() else 'cpu'

att_dist = torch.randn(batch_size, num_heads, sequence_length, sequence_length, device=device).softmax(dim=-1)
gold_att_dist = torch.randint(0, 2, size=att_dist.shape, device=device)
mask = torch.tensor([5, 8]).unsqueeze(-1) >= torch.arange(sequence_length).unsqueeze(0).expand(batch_size, -1)
mask = mask.to(device)

Since softmax always produce nonzero probabilities, let's set some values to zero arbitraryly:

In [236]:
att_dist = att_dist.masked_fill(att_dist < 0.1, 0)
att_dist[0, 0]

tensor([[0.0000, 0.0000, 0.0000, 0.4710, 0.0000, 0.0000, 0.1022, 0.0000, 0.0000],
        [0.2364, 0.0000, 0.2462, 0.0000, 0.0000, 0.1265, 0.1126, 0.1362, 0.0000],
        [0.0000, 0.0000, 0.1799, 0.0000, 0.0000, 0.2459, 0.0000, 0.2417, 0.0000],
        [0.1417, 0.0000, 0.0000, 0.0000, 0.0000, 0.1482, 0.1141, 0.1394, 0.3346],
        [0.0000, 0.2016, 0.2307, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2868],
        [0.1814, 0.0000, 0.0000, 0.0000, 0.0000, 0.3844, 0.0000, 0.0000, 0.1693],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3962, 0.0000, 0.2924, 0.0000, 0.0000],
        [0.0000, 0.2786, 0.0000, 0.1485, 0.0000, 0.1361, 0.0000, 0.0000, 0.2319],
        [0.0000, 0.2886, 0.1361, 0.0000, 0.1505, 0.0000, 0.2319, 0.0000, 0.0000]],
       device='cuda:0')

One thing to keep in mind is that Transformers' self-attention are masked for keys (last dimension) but not for queries (second last dimension). In practice this behavior does not impact the padding positions are simply ignored for the final task. However, when computing attention statistics we have to take these padded positions into account. 

We will start with raw python to simplify things. 

In [237]:
def calc_sparsity_vanilla(att_dist, mask):
    batch_size, num_heads, _, _ = att_dist.shape
    p = 0  # positive count
    n = 0  # total count
    for i in range(batch_size):
        valid_seq_len = mask[i].sum().item()  
        n += num_heads * valid_seq_len ** 2
        for h in range(num_heads):
            p += sum([int(att_dist[i, h, k1, k2].item() > 0) 
                      for k1 in range(valid_seq_len) 
                      for k2 in range(valid_seq_len)])
    return 1 - p / n

In [238]:
mask.sum(-1)

tensor([6, 9], device='cuda:0')

In [239]:
calc_sparsity_vanilla(att_dist, mask)

0.5982905982905983

Now with PyTorch:

In [240]:
def calc_sparsity_vectorized(att_dist, mask):
    pairwise_mask = mask[:, None, :, None] & mask[:, None, None, :]
    p = (att_dist > 0).masked_fill(~pairwise_mask, False).sum().item()
    n = num_heads * pairwise_mask.sum().item()
    return 1 - p / n

In [241]:
calc_sparsity_vectorized(att_dist, mask)

0.5982905982905983

Simple and efficient!

If we were using a encoder-decoder transformer we would also need to account for causal masks (future position masking). This could be implementead easily as follows:

```python
pairwise_mask = pairwise_mask & causal_mask[:, None, :, :]
```

Now let's turn to the other statistic: recall. 

In [303]:
def calc_recall_vanilla(gold_att_dist, pred_att_dist, mask):
    from sklearn.metrics import recall_score
    batch_size, num_heads, _, _ = pred_att_dist.shape
    recalls = torch.zeros(batch_size, num_heads)
    for i in range(batch_size):
        valid_seq_len = mask[i].sum().item()  
        for h in range(num_heads):
            g = (gold_att_dist[i, h, :valid_seq_len, :valid_seq_len] > 0).long().flatten().tolist()
            p = (pred_att_dist[i, h, :valid_seq_len, :valid_seq_len] > 0).long().flatten().tolist()
            recalls[i, h] = recall_score(g, p)
    return recalls.mean()  # macro-averaged

In [304]:
calc_recall_vanilla(gold_att_dist, att_dist, mask)

tensor(0.4218)

In [305]:
def calc_recall_vectorized(gold_att_dist, pred_att_dist, mask):
    pairwise_mask = mask[:, None, :, None] & mask[:, None, None, :]
    g = (gold_att_dist > 0).masked_fill(~pairwise_mask, False)
    p = (pred_att_dist > 0).masked_fill(~pairwise_mask, False)
    matches_per_head_and_batch = (p & g).sum(-1).sum(-1).float() / g.sum(-1).sum(-1).float()
    return matches_per_head_and_batch.mean()  # macro-averaged

In [306]:
calc_recall_vectorized(gold_att_dist, att_dist, mask)

tensor(0.4218, device='cuda:0')

That is it! To finalize, let's write one more statistic function: a method that returns the portion of fully recovered predictions, i.e., the portion of predicted attention distributions with 100% recall.



In [313]:
def compute_exact_fraction(gold_att_dist, pred_att_dist, mask):
    pairwise_mask = mask[:, None, :, None] & mask[:, None, None, :]
    g = (gold_att_dist > 0).masked_fill(~pairwise_mask, False)
    p = (pred_att_dist > 0).masked_fill(~pairwise_mask, False)
    matches = p & g
    matches_per_query = matches.sum(-1).float()
    total_per_query = g.sum(-1).float()
    # might get nans due to zero division
    recall_per_query = matches_per_query / total_per_query
    exact_per_query = recall_per_query == 1.0
    # filter nans out
    valid_exact_per_query = exact_per_query.masked_fill(~mask[:, None, :], False)
    lengths = mask.sum(-1).unsqueeze(-1).float()
    exact_per_head_and_batch = valid_exact_per_query.sum(-1) / lengths
    return exact_per_head_and_batch.mean()  # macro-averaged

In [314]:
compute_exact_fraction(gold_att_dist, att_dist, mask)

tensor(0.0741, device='cuda:0')