# Online Softmax
Flash Attention has been responsible for cutting down the run-time of transformer models. Flash Attention can be broken down into two sets of optimizations:
1. GPU Aware I/O Optimizations. These optimizations are related to how data is moved between GPU memory and on-chip (SRAM) memory. These optimizations are not discussed here.
2. Online Softmax. This is an algorithmic optimization that allows us to compute softmax in chunks. Each chunk fits into the SRAM memory of the GPU and can be computed in parallel across multiple streaming multiprocessors (SMs).

In this blog post we will discuss the online softmax algorithm and the simple mathematical tricks that makes it possible.

In [153]:
import torch
from torch import nn

## Softmax Formula

Given a vector of scores $z = [z_1, z_2, \ldots, z_n]$, the softmax function is defined as:

$$
\text{softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^{n} e^{z_j}}
$$

In [243]:
query = torch.randn(1, 8)
keys = torch.randn(8, 8)

dot_products = torch.matmul(query, keys.T)
sofmax = nn.functional.softmax(dot_products, dim=-1)
sofmax

tensor([[1.3880e-01, 5.7845e-03, 5.3377e-07, 8.2416e-01, 1.1100e-02, 1.4832e-03,
         1.8518e-02, 1.5404e-04]])

## Softmax Subtract By Max
To avoid overflow implementations of softmax often subtract all the dot-products by the maximum value in the set. Mathematically this operation is equivalent to normal softmax because of the following:

In the numerator and denominator we can factor out a constant 
$e^{-c}$, where $c = \max(x)$:

Numerator:
$$
e^{x - c} = e^{x} \cdot e^{-c}
$$

Denominator:
$$
\sum_{j=1}^{n} e^{x_j - c} = \sum_{j=1}^{n} e^{x_j} \cdot e^{-c}
$$

Thus we have:
$$
\text{softmax}(x_i - c) = \frac{e^{x_i - c}}{\sum_{j=1}^{n} e^{x_j - c}} = 
\frac{e^{x_i} \cdot \cancel{e^{-c}}}{\sum_{j=1}^{n} e^{x_j} \cdot \cancel{e^{-c}}} = 
\frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}} = \text{softmax}(x_i)
$$

In [244]:
maxs = torch.max(dot_products, dim=-1, keepdim=True)[0]
dot_products -= maxs
softmax_post_max = torch.nn.functional.softmax(dot_products, -1)

Let's verify that these two softmax computations are equivalent:

In [159]:
assert torch.allclose(sofmax, softmax_post_max)

## Online Softmax Numerator

Now let's assume that we have limited memory and cannot compute all the dot-products at once. Instead we need to process them in the max number of chunks that fit in memory.

But hold on, the subtraction by minimum trick relies on knowing the minimum value across all dot-products. If we only have access to a chunk of dot-products at a time, how can we compute the global minimum?

We can compute the minimum in an online fashion by keeping track of the minimum value seen so far as we process each chunk. Let's assume that we have two chunks of dot-products:

$$
\begin{align*}
\text{chunk}_1 &= [d_1, d_2, \ldots, d_k] \\
\text{chunk\_max}_1 &= \max(\text{chunk}_1) \\
\text{chunk\_1\_softmax} &= \text{softmax}(\text{chunk}_1 - \text{chunk\_max}_1)
\end{align*}
$$

$$
\begin{align*}
\text{chunk}_2 &= [d_{k+1}, d_{k+2}, \ldots, d_n] \\
\text{chunk\_max}_2 &= \max(\text{chunk}_2) \\
\text{chunk\_2\_softmax} &= \text{softmax}(\text{chunk}_2 - \text{chunk\_max}_2)
\end{align*}
$$

To compute the overall softmax, we need to adjust the chunk softmaxes based on the difference between the chunk minimums and the global minimum. We get to this by a little bit of high-school math:

$$
\begin{align*}
s &= s_0 + s_1 \\
e^{s} &= e^{s_0 + s_1} = e^{s_0} \cdot e^{s_1}
\end{align*}
$$

The max can be re-written as:

$$
\text{global\_max} = \max(\text{chunk\_max}_1, \text{chunk\_max}_2)
$$

The corrected chunk softmaxes can be computed as:
$$
\mathrm{corrected\_chunk\_1\_softmax} = e^{\mathrm{dot\_products\_chunk}_1 - \mathrm{global\_max}}
$$

Let's add and subtract the term `chunk_max_1` to re-write the above as:
$$
e^{\mathrm{dot\_products\_chunk}_1 - \mathrm{global\_max}} = e^{\mathrm{dot\_products\_chunk}_1 - \mathrm{global\_max} - \mathrm{chunk\_max}_1 + \mathrm{chunk\_max}_1}
$$

Moving terms around we get:
$$
e^{\mathrm{dot\_products\_chunk}_1 - \mathrm{global\_max}} = e^{\mathrm{dot\_products\_chunk}_1 - \mathrm{chunk\_max}_1} \cdot e^{\mathrm{chunk\_max}_1 - \mathrm{global\_max}}
$$

Tada! If we need to compute the softmax scores in chunks all we need to do is keep track of all the maximums seen so far, compute the global maximum, and then adjust each chunk softmax by multiplying it with the appropriate correction factor.

## Online Softmax Denominator
Okay now that we've worked out how to compute the numerator in an online fashion, let's look at the denominator.

The denominator of the softmax function is the sum of exponentials of all dot-products. Similar to how we kept track of the minimum value seen so far.

$$
\begin{align*}
\text{denominator} &= \sum_{j=1}^{n} e^{x_j}
\end{align*}
$$

If we split the dot-products into chunks, we can compute the denominator for each chunk separately and then sum them up:

$$
\begin{align*}
\text{denominator\_chunk}_i &= \sum_{j \in \text{chunk}_i} e^{x_j} \\
\text{denominator} &= \sum_{i} \text{denominator\_chunk}_i
\end{align*}
$$

But since each chunk doesn't know the global maximum we have:
$$
\text{denominator\_chunk}_i = \sum_{j \in \text{chunk}_i} e^{x_j - \text{chunk\_max}_i}
$$

The correct denominator for each chunk can be computed as:
$$
\text{corrected\_denominator\_chunk}_i = \sum_{j \in \text{chunk}_i} e^{x_j - \text{global\_max}} = \sum_{j \in \text{chunk}_i} e^{x_j - \text{chunk\_max}_i} \cdot e^{\text{chunk\_max}_i - \text{global\_max}}
$$

The correction factor is the same for all elements within a chunk, so we can factor it out of the sum:
$$
\text{corrected\_denominator\_chunk}_i = e^{\text{chunk\_max}_i - \text{global\_max}} \cdot \sum_{j \in \text{chunk}_i} e^{x_j - \text{chunk\_max}_i}
$$

The overall denominator can then be computed by summing up the corrected denominators from each chunk:
$$
\text{denominator} = \sum_{i} \text{corrected\_denominator\_chunk}_i = \sum_{i} e^{\text{chunk\_max}_i - \text{global\_max}} \cdot \sum_{j \in \text{chunk}_i} e^{x_j - \text{chunk\_max}_i}
$$

We can store the sum of exponentials for each chunk as we compute them, and then apply the correction factor based on the global maximum when we compute the final denominator.

In [208]:
chunksize = 2
query_chunks = query.chunk(chunksize, dim=0)
key_chunks = keys.chunk(chunksize, dim=0)

In [229]:
maxs = []
dot_products_chunks = []
dot_products_sum = []
for i, q in enumerate(query_chunks):
    for k in key_chunks:
        dot_products_chunks.append(q @ k.T) # [q_S, k_S]
        maxs.append(torch.max(dot_products_chunks[-1], dim=-1)[0])
        dot_products_chunks[-1] -= maxs[-1]
        dot_products_sum.append(torch.exp(dot_products_chunks[-1]).sum())

In [230]:
maxs

[tensor([1.6802]), tensor([0.8117])]

In [231]:
global_max = torch.max(torch.cat(maxs), dim=-1, keepdim=True)[0]
correction_factor = [(local_max - global_max) for local_max, dp in zip(maxs, dot_products_chunks)]

In [232]:
dot_products_chunks

[tensor([[-4.9023, -2.5188,  0.0000, -1.4852]]),
 tensor([[-0.1187, -0.7713,  0.0000, -0.5620]])]

In [233]:
correction_factor

[tensor([0.]), tensor([-0.8685])]

In [234]:
dot_products_sum

[tensor(1.3144), tensor(2.9205)]

In [235]:
denominator = sum([
    dot_products_sum[i] * torch.exp(correction_factor[i])
    for i in range(len(dot_products_chunks))
])

In [236]:
sum(denominator)

tensor(2.5398)

In [237]:
online_softmax = [torch.exp(dp + cf)/denominator for dp, cf in zip(dot_products_chunks, correction_factor)]

In [238]:
online_softmax

[tensor([[0.0029, 0.0317, 0.3937, 0.0892]]),
 tensor([[0.1467, 0.0764, 0.1652, 0.0942]])]

In [242]:
assert torch.allclose(torch.cat(online_softmax, dim=-1), sofmax)