<a href="https://colab.research.google.com/github/shahabday/DSR-LLMQuantization/blob/main/BONUS_Flash_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn.functional as F

## Logits and Softmax

$$
\Large
\text{softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^Ce^{z_j}}
$$

In [None]:
x = torch.tensor([[1, 2, 3, 4.5, 1.8, 0]]).float()
F.softmax(x, dim=1)

tensor([[0.0214, 0.0581, 0.1578, 0.7074, 0.0475, 0.0079]])

In [None]:
def naive_softmax(x):
    expx = torch.exp(x)
    return expx/expx.sum(axis=1, keepdim=True)

In [None]:
naive_softmax(x)

tensor([[0.0214, 0.0581, 0.1578, 0.7074, 0.0475, 0.0079]])

### Safe Softmax

$$
\Large
m = \text{max}(x)
\\
\Large
\text{expx}_i = e^{x_i-m}
\\
\Large
\text{sumexp} = \sum_{j=1}^Ce^{x_j-m} = \sum_{j=1}^C\text{expx}_j
\\
\Large
\text{softmax}(x_i) = \frac{\text{expx}_i}{\text{sumexp}}
$$

Subtracting the tensor's max value ensures the exp function won't explode.

In [None]:
def safe_softmax(x, memory=False):
    m = x.max(axis=1, keepdim=True).values
    expx = torch.exp(x-m)
    if memory:
        return expx, m, expx.sum(axis=1, keepdims=True)
    else:
        return expx/expx.sum(axis=1, keepdims=True)

In [None]:
safe_softmax(x)

tensor([[0.0214, 0.0581, 0.1578, 0.7074, 0.0475, 0.0079]])

The `memory` argument returns the "ingredients": max value, exponentiated values, and their sum.

In [None]:
expx, m, expxsum = safe_softmax(x, memory=True)
expx, m, expxsum, expx/expxsum

(tensor([[0.0302, 0.0821, 0.2231, 1.0000, 0.0672, 0.0111]]),
 tensor([[4.5000]]),
 tensor([[1.4137]]),
 tensor([[0.0214, 0.0581, 0.1578, 0.7074, 0.0475, 0.0079]]))

### Online Softmax

We can use these ingredients to implement the online (batch) softmax, where maximums and sums are adjusted for every batch:

In [None]:
def online_softmax(x, mi=None, si=None):
    m = x.max(axis=1, keepdim=True).values
    if mi is not None:
        m = torch.maximum(m, mi)
    expx = torch.exp(x-m)
    s = expx.sum(axis=1, keepdim=True)
    if si is not None:
        # if the new batch has a higher max value
        # m > mi, then it adjusts the previous sum
        # if the new batch has a lower max value
        # m = mi, then it multiplies by 1
        s = s + si*torch.exp(mi-m)
    return m, s

$$
\Large
\text{sumexp} = \sum_{j=1}^Ce^{x_j-m_0} = \sum_{j=1}^C\frac{e^{x_j}}{e^{m}}
\\
\Large
m > m_i \implies m_i - m < 0 \implies e^{m_i-m} < 1
$$

In [None]:
m0, s0 = online_softmax(x[:, :2])
m1, s1 = online_softmax(x[:, 2:4], m0, s0)
m2, s2 = online_softmax(x[:, 4:], m1, s1)
torch.exp(x-m2), s2, torch.exp(x-m2)/s2

(tensor([[0.0302, 0.0821, 0.2231, 1.0000, 0.0672, 0.0111]]),
 tensor([[1.4137]]),
 tensor([[0.0214, 0.0581, 0.1578, 0.7074, 0.0475, 0.0079]]))

### Tiled Softmax

We can split the input tensor into tiles.

In [None]:
tiles = torch.split(x, 2, 1)
tiles

(tensor([[1., 2.]]), tensor([[3.0000, 4.5000]]), tensor([[1.8000, 0.0000]]))

Then, we compute the "ingredients" for every tile:

In [None]:
f0, m0, s0 = safe_softmax(tiles[0], True)
f1, m1, s1 = safe_softmax(tiles[1], True)
f2, m2, s2 = safe_softmax(tiles[2], True)

In [None]:
fs = torch.hstack([f0, f1, f2])
ms = torch.hstack([m0, m1, m2])
ss = torch.hstack([s0, s1, s2])

In [None]:
ms, fs, ss

(tensor([[2.0000, 4.5000, 1.8000]]),
 tensor([[0.3679, 1.0000, 0.2231, 1.0000, 1.0000, 0.1653]]),
 tensor([[1.3679, 1.2231, 1.1653]]))

We aggregate the maximums together to compute the adjusting multipliers:

$$
\Large
\text{factors} = \left(e^{m_0-\text{max}(m_0,m_1,m_2)},e^{m_1-\text{max}(m_0,m_1,m_2)},e^{m_2-\text{max}(m_0,m_1,m_2)}\right)
$$

In [None]:
factors = torch.exp(ms - ms.max())
factors

tensor([[0.0821, 1.0000, 0.0672]])

In [None]:
ss, factors, factors*ss

(tensor([[1.3679, 1.2231, 1.1653]]),
 tensor([[0.0821, 1.0000, 0.0672]]),
 tensor([[0.1123, 1.2231, 0.0783]]))

Then we can use these factors to adjust and sum up the sums (softmax's denominator):

$$
\Large
\text{denom} = \text{sumexp}_0\text{factor}_0 + \text{sumexp}_1\text{factor}_1 + \text{sumexp}_2\text{factor}_2
$$

In [None]:
denom = (factors*ss).sum()
denom

tensor(1.4137)

We also have to adjust the exponentiations (numerators):

In [None]:
n = tiles[0].shape[-1] # number of elements in each tile
tiled_factors = factors.repeat_interleave(n, 1)
tiled_factors, fs

(tensor([[0.0821, 0.0821, 1.0000, 1.0000, 0.0672, 0.0672]]),
 tensor([[0.3679, 1.0000, 0.2231, 1.0000, 1.0000, 0.1653]]))

In [None]:
numer = tiled_factors*fs
numer

tensor([[0.0302, 0.0821, 0.2231, 1.0000, 0.0672, 0.0111]])

Here is the resulting softmax:

In [None]:
numer/denom

tensor([[0.0214, 0.0581, 0.1578, 0.7074, 0.0475, 0.0079]])

## Regular Attention

$$
\Large
\text{attention}=\text{softmax}\left(\frac{Q \cdot K^T}{\sqrt{d_k}}\right)\cdot V
$$

In [None]:
import numpy as np

def naive_softmax(x):
    expx = np.exp(x)
    return expx/expx.sum(axis=1, keepdims=True)

N, d = 4, 8  # Example dimensions - batch size, hidden dimensions

np.random.seed(35)
Q = np.random.rand(N, d)  # Example matrix Q
K = np.random.rand(N, d)  # Example matrix K
V = np.random.rand(N, d)  # Example matrix V

In [None]:
P = np.matmul(Q, K.T)
S = naive_softmax(P/np.sqrt(d))
att_out = np.matmul(S, V)
att_out

array([[0.30685401, 0.48892522, 0.51084285, 0.7089212 , 0.499922  ,
        0.44842117, 0.39609549, 0.33958129],
       [0.30520985, 0.46706403, 0.51787088, 0.7027851 , 0.50071816,
        0.46003932, 0.3905262 , 0.34611217],
       [0.29940654, 0.45656683, 0.51468743, 0.70210537, 0.4998878 ,
        0.4630266 , 0.39324783, 0.35211962],
       [0.30492672, 0.48038083, 0.51366597, 0.71231235, 0.50048432,
        0.45614122, 0.39000782, 0.33817095]])

In [None]:
P.shape, S.shape, att_out.shape # total 64

((4, 4), (4, 4), (4, 8))

## Flash Attention

It starts by splitting Q, K, and V into smaller blocks:

In [None]:
Tr = 4
Tc = 2

Q_blocks = np.array_split(Q, Tr)
K_blocks = np.array_split(K, Tc)
V_blocks = np.array_split(V, Tc)

In [None]:
n_Q = len(Q_blocks)
n_K = len(K_blocks)
n_V = len(V_blocks)
n_Q, n_K, n_V

(4, 2, 2)

In [None]:
Q, Q_blocks

(array([[0.45805495, 0.30834961, 0.23148705, 0.27742455, 0.81723481,
         0.11134664, 0.62643723, 0.27678789],
        [0.68217467, 0.67897078, 0.79671742, 0.04580216, 0.91259827,
         0.21381599, 0.3036373 , 0.98906362],
        [0.1858815 , 0.98872484, 0.75008423, 0.22238605, 0.14790391,
         0.51579028, 0.39425832, 0.06988013],
        [0.33822577, 0.01103722, 0.76752786, 0.87472213, 0.53359432,
         0.08441275, 0.8243312 , 0.5045812 ]]),
 [array([[0.45805495, 0.30834961, 0.23148705, 0.27742455, 0.81723481,
          0.11134664, 0.62643723, 0.27678789]]),
  array([[0.68217467, 0.67897078, 0.79671742, 0.04580216, 0.91259827,
          0.21381599, 0.3036373 , 0.98906362]]),
  array([[0.1858815 , 0.98872484, 0.75008423, 0.22238605, 0.14790391,
          0.51579028, 0.39425832, 0.06988013]]),
  array([[0.33822577, 0.01103722, 0.76752786, 0.87472213, 0.53359432,
          0.08441275, 0.8243312 , 0.5045812 ]])])

In [None]:
K, K_blocks

(array([[0.88161863, 0.17404628, 0.40295789, 0.83212654, 0.97866247,
         0.61916477, 0.86992066, 0.2488769 ],
        [0.64303396, 0.30045066, 0.24536055, 0.54602368, 0.11976084,
         0.34309671, 0.63178697, 0.83155192],
        [0.35538789, 0.23541176, 0.80203533, 0.60371286, 0.49363014,
         0.93305116, 0.65311175, 0.67884942],
        [0.31165887, 0.12014239, 0.15491823, 0.76611197, 0.51250289,
         0.46160397, 0.75266263, 0.95110633]]),
 [array([[0.88161863, 0.17404628, 0.40295789, 0.83212654, 0.97866247,
          0.61916477, 0.86992066, 0.2488769 ],
         [0.64303396, 0.30045066, 0.24536055, 0.54602368, 0.11976084,
          0.34309671, 0.63178697, 0.83155192]]),
  array([[0.35538789, 0.23541176, 0.80203533, 0.60371286, 0.49363014,
          0.93305116, 0.65311175, 0.67884942],
         [0.31165887, 0.12014239, 0.15491823, 0.76611197, 0.51250289,
          0.46160397, 0.75266263, 0.95110633]])])

For every Q block, the memory is reset, and every combination of K and V blocks is used to compute attention values incrementally:

In [None]:
max_acc = np.zeros((len(Q_blocks[0]), 1))
output_acc = np.zeros_like(Q_blocks[0])

print(max_acc.shape, output_acc.shape) # initial 9

def accumulate_sync(Q, K, V, restart=False):
    global output_acc, scaling_acc, max_acc

    if restart:
        output_acc = np.zeros_like(Q_blocks[0])
        scaling_acc = np.zeros((len(Q_blocks[0]), 1))
        max_acc = np.zeros((len(Q_blocks[0]), 1))

    S = np.matmul(Q, K.T)/np.sqrt(d)

    smax = S.max(axis=1, keepdims=True)

    P = np.exp(S-smax)
    # numerator = np.matmul(P, V)
    s_new = P.sum(axis=1, keepdims=True)

    max_new = np.maximum(smax, max_acc)

    scaling_new = (scaling_acc*np.exp(max_acc - max_new) + np.exp(smax - max_new)*s_new)

    output_acc = (output_acc*scaling_acc*np.exp(max_acc - max_new) +
                  np.exp(smax - max_new)*np.matmul(P, V))/scaling_new

    scaling_acc = scaling_new
    max_acc = max_new

    print(S.shape, smax.shape, P.shape, s_new.shape, max_new.shape, scaling_new.shape)

(1, 1) (1, 8)


In [None]:
for q in range(n_Q):
    restart = True
    for k in range(n_K):
        accumulate_sync(Q_blocks[q], K_blocks[k], V_blocks[k], restart=restart) # 6 each time, total 15
        restart = False
    print(output_acc)

(1, 2) (1, 1) (1, 2) (1, 1) (1, 1) (1, 1)
(1, 2) (1, 1) (1, 2) (1, 1) (1, 1) (1, 1)
[[0.30685401 0.48892522 0.51084285 0.7089212  0.499922   0.44842117
  0.39609549 0.33958129]]
(1, 2) (1, 1) (1, 2) (1, 1) (1, 1) (1, 1)
(1, 2) (1, 1) (1, 2) (1, 1) (1, 1) (1, 1)
[[0.30520985 0.46706403 0.51787088 0.7027851  0.50071816 0.46003932
  0.3905262  0.34611217]]
(1, 2) (1, 1) (1, 2) (1, 1) (1, 1) (1, 1)
(1, 2) (1, 1) (1, 2) (1, 1) (1, 1) (1, 1)
[[0.29940654 0.45656683 0.51468743 0.70210537 0.4998878  0.4630266
  0.39324783 0.35211962]]
(1, 2) (1, 1) (1, 2) (1, 1) (1, 1) (1, 1)
(1, 2) (1, 1) (1, 2) (1, 1) (1, 1) (1, 1)
[[0.30492672 0.48038083 0.51366597 0.71231235 0.50048432 0.45614122
  0.39000782 0.33817095]]
