### Conventional Softmax

#### Pseudocode

\section*{Pseudocode}

1. Initialize \( M_0 = -infty \)
2. For \( i = 1 \) to \( N \):
    \[
        M_i = max(M_{i-1}, X_i)
    \]
3. Initialize \( L_0 = 0 \)
4. For \( J = 1 \) to \( N \):
    \[
        L_J = L_{J-1} + e^{X_J - M_N}
    \]
5. For \( k = 1 \) to \( N \):
    \[
        X_k \gets \frac{e^{X_k - M_N}}{L_N}
    \]

In [12]:
# Conventional Softmax
import torch

tensor = torch.randint(0, 10, (1, 10)).float()
tensor

tensor([[2., 6., 5., 0., 2., 8., 0., 0., 4., 1.]])

In [2]:
# Finding the maximum value
m = float(-torch.inf)
for x in tensor[0]:
    m = max(m, x.item())
    print(m)

6.0
9.0
9.0
9.0
9.0
9.0
9.0
9.0
9.0
9.0


In [3]:
# Computhing the normalization factor
l = 0
for x in tensor[0]:
    l += torch.exp(x - m).item()
    print(l)

0.049787066876888275
1.0497870668768883
1.0501225295010954
1.0510344114736654
2.0510344114736654
2.0577723584719934
2.1075594253488816
2.15734649222577
2.159825244511012
2.162303996796254


In [4]:
# Applying the softmax to each element
softmax_row = [(torch.exp(x - m)/l).item() for x in tensor[0]]
result = []
result.append(softmax_row)

In [5]:
result

[[0.02302500791847706,
  0.4624696671962738,
  0.00015514128608629107,
  0.0004217177629470825,
  0.4624696671962738,
  0.0031160961370915174,
  0.02302500791847706,
  0.02302500791847706,
  0.0011463477276265621,
  0.0011463477276265621]]

In [6]:
# Consolidated Function 
from typing import List, Optional, Union, Tuple
import torch
from typing import List

def softmax_row(tensor: torch.Tensor) -> List[List[float]]:
    """
    Computes the softmax for a single row tensor.
    Args:
        tensor (torch.Tensor): Input tensor of shape (1, N).

    Returns:
        List[List[float]]: Softmax values for the row as a nested list.
    """
    m = float('-inf')  # Initialize max value
    results = []

    # Step 1: Compute the maximum value in the row
    for x in tensor[0]:
        m = max(m, x.item())

    # Step 2: Compute the normalization factor (denominator)
    l = 0
    for x in tensor[0]:
        l += torch.exp(x - m).item()

    # Step 3: Compute softmax for each element in the row
    softmax_row = [(torch.exp(x - m) / l).item() for x in tensor[0]]
    results.append(softmax_row)

    return results

# Example usage
tensor = torch.randint(0, 10, (1, 10)).float()
softmax_result = softmax_row(tensor)
print("Input Tensor:", tensor)
print("Softmax Result:", softmax_result)


Input Tensor: tensor([[9., 2., 9., 5., 0., 4., 6., 7., 3., 0.]])
Softmax Result: [[0.4517092704772949, 0.00041190555202774704, 0.4517092704772949, 0.008273344486951828, 5.5745353165548295e-05, 0.0030435931403189898, 0.022489279508590698, 0.06113220006227493, 0.0011196753475815058, 5.5745353165548295e-05]]


In [7]:
t1 = torch.tensor([[1, 2, 3, 4, 1, 2, 3]])
softmax_row(t1)

[[0.023640543222427368,
  0.06426165997982025,
  0.17468130588531494,
  0.47483301162719727,
  0.023640543222427368,
  0.06426165997982025,
  0.17468130588531494]]

## Safe Softmax
#### Pseudocode

1. Initialize \( m_0 = -\infty \), \( l_0 = 0 \)
2. For \( i = 1 \) to \( N \):
    - Compute \( m_i = \max(m_{i-1}, X_i) \)
    - Compute \( l_i = l_{i-1} \cdot e^{m_{i-1} - m_i} + e^{X_i - m_i} \)
3. For \( k = 1 \) to \( N \):
    - Compute \( X_k \gets \frac{e^{X_k - m_N}}{l_N} \)


In [8]:
# Rigged Softmax
import torch

tensor = torch.randint(0, 10, (1, 10)).float()
tensor

tensor([[5., 2., 0., 6., 0., 8., 6., 4., 7., 6.]])

In [9]:
# Find the local maximum
m_prev = float(-torch.inf) 
l_prev = 0
results = [] 
for i in tensor[0]: 
    m_curr = max(m_prev, i)
    l_curr = l_prev * torch.exp(m_prev - m_curr).item() + torch.exp(i - m_curr).item()
    m_prev = m_curr
    l_prev = l_curr

softmax_row = [torch.exp(x - m_prev).item() / l_prev for x in tensor[0]]
results.append(softmax_row)
results 

[[0.026982846539221398,
  0.0013433969244828043,
  0.00018180899330761034,
  0.0733469826782807,
  0.00018180899330761034,
  0.5419649766864885,
  0.0733469826782807,
  0.009926435051607186,
  0.199377777716766,
  0.0733469826782807]]

In [10]:
def softmax_new(tensor: torch.Tensor) -> List[List[torch.Tensor]]:
    m_prev = float(-torch.inf)
    l_prev = 0
    results = []
    for i in tensor[0]:
        m_curr = max(m_prev, i)
        l_curr = l_prev * torch.exp(m_prev - m_curr).item() + torch.exp(i - m_curr).item()
        m_prev = m_curr
        l_prev = l_curr

    softmax_row = [torch.exp(x - m_prev).item() / l_prev for x in tensor[0]]
    results.append(softmax_row)
    return results

In [None]:
softmax_new(t1)

[[0.02364054202726851,
  0.06426165690335149,
  0.17468130082440936,
  0.47483299399271744,
  0.02364054202726851,
  0.06426165690335149,
  0.17468130082440936]]

In [None]:
import torch
BATCH_SIZE = 8 
SEQ_LEN = 10 
NUM_HEADS = 12
HEAD_DIM = 128
a1  = torch.tensor([[SEQ_LEN, BATCH_SIZE * NUM_HEADS]])
grid = torch.zeros_like(a1)
grid.shape

torch.Size([1, 2])

In [None]:
Q = torch.randn(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM, dtype=torch.float16)
K = torch.randn(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM, dtype=torch.float16)
V = torch.randn(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM, dtype=torch.float16)
Q.shape

torch.Size([8, 12, 10, 128])

In [None]:
grid = lambda args: (
    (SEQ_LEN + args["BLOCK_SIZE_Q"]-1) // args["BLOCK_SIZE_Q"],
    BATCH_SIZE * NUM_HEADS,
    1,
)

args = {"BLOCK_SIZE_Q": 4}
grid_shape = grid(args)
grid_tensor = torch.zeros(grid_shape)
print(grid_tensor)

tensor([[[0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [

In [None]:
block_index_q = grid(args)[0]
block_index_q

index_batch_head = torch.arange(BATCH_SIZE * NUM_HEADS)
index_batch_head

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
        72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
        90, 91, 92, 93, 94, 95])

In [None]:
qkv_offset = torch.arange(0, SEQ_LEN, args["BLOCK_SIZE_Q"])