### Conventional Softmax

#### Pseudocode

\section*{Pseudocode}

1. Initialize \( M_0 = -infty \)
2. For \( i = 1 \) to \( N \):
    \[
        M_i = max(M_{i-1}, X_i)
    \]
3. Initialize \( L_0 = 0 \)
4. For \( J = 1 \) to \( N \):
    \[
        L_J = L_{J-1} + e^{X_J - M_N}
    \]
5. For \( k = 1 \) to \( N \):
    \[
        X_k \gets \frac{e^{X_k - M_N}}{L_N}
    \]

In [8]:
# Conventional Softmax
import torch

tensor = torch.randint(0, 10, (1, 10)).float()
tensor

tensor([[8., 5., 1., 5., 6., 8., 3., 3., 5., 8.]])

In [10]:
# Finding the maximum value
m = float(-torch.inf)
for x in tensor[0]:
    m = max(m, x.item())
    print(m)

8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0


In [11]:
# Computhing the normalization factor
l = 0
for x in tensor[0]:
    l += torch.exp(x - m).item()
    print(l)

1.0
1.0497870668768883
1.0506989488494582
1.1004860157263465
1.2358212972176261
2.235821297217626
2.242559244215954
2.249297191214282
2.2990842580911703
3.2990842580911703


In [12]:
# Applying the softmax to each element
softmax_row = [(torch.exp(x - m)/l).item() for x in tensor[0]]
result = []
result.append(softmax_row)

In [13]:
result

[[0.3031144142150879,
  0.015091178007423878,
  0.0002764045784715563,
  0.015091178007423878,
  0.041022077202796936,
  0.3031144142150879,
  0.0020423689857125282,
  0.0020423689857125282,
  0.015091178007423878,
  0.3031144142150879]]

In [43]:
# Consolidated Function 
from typing import List, Optional, Union, Tuple
import torch
from typing import List

def softmax_row(tensor: torch.Tensor) -> List[List[float]]:
    """
    Computes the softmax for a single row tensor.
    Args:
        tensor (torch.Tensor): Input tensor of shape (1, N).

    Returns:
        List[List[float]]: Softmax values for the row as a nested list.
    """
    m = float('-inf')  # Initialize max value
    results = []

    # Step 1: Compute the maximum value in the row
    for x in tensor[0]:
        m = max(m, x.item())

    # Step 2: Compute the normalization factor (denominator)
    l = 0
    for x in tensor[0]:
        l += torch.exp(x - m).item()

    # Step 3: Compute softmax for each element in the row
    softmax_row = [(torch.exp(x - m) / l).item() for x in tensor[0]]
    results.append(softmax_row)

    return results

# Example usage
tensor = torch.randint(0, 10, (1, 10)).float()
softmax_result = softmax_row(tensor)
print("Input Tensor:", tensor)
print("Softmax Result:", softmax_result)


Input Tensor: tensor([[3., 1., 1., 6., 0., 5., 2., 7., 3., 9.]])
Softmax Result: [[0.002048383466899395, 0.0002772185252979398, 0.0002772185252979398, 0.04114287719130516, 0.0001019830015138723, 0.015135619789361954, 0.0007535581244155765, 0.11183793842792511, 0.002048383466899395, 0.8263767957687378]]


In [44]:
t1 = torch.tensor([[1, 2, 3, 4, 1, 2, 3]])
softmax_row(t1)

[[0.023640543222427368,
  0.06426165997982025,
  0.17468130588531494,
  0.47483301162719727,
  0.023640543222427368,
  0.06426165997982025,
  0.17468130588531494]]

: 

## Safe Softmax
#### Pseudocode

1. Initialize \( m_0 = -\infty \), \( l_0 = 0 \)
2. For \( i = 1 \) to \( N \):
    - Compute \( m_i = \max(m_{i-1}, X_i) \)
    - Compute \( l_i = l_{i-1} \cdot e^{m_{i-1} - m_i} + e^{X_i - m_i} \)
3. For \( k = 1 \) to \( N \):
    - Compute \( X_k \gets \frac{e^{X_k - m_N}}{l_N} \)


In [27]:
# Rigged Softmax
import torch

tensor = torch.randint(0, 10, (1, 10)).float()
tensor

tensor([[3., 9., 7., 1., 2., 5., 8., 2., 5., 6.]])

In [None]:
# Find the local maximum
m_prev = float(-torch.inf) 
l_prev = 0
results = [] 
for i in tensor[0]: 
    m_curr = max(m_prev, i)
    l_curr = l_prev * torch.exp(m_prev - m_curr).item() + torch.exp(i - m_curr).item()
    m_prev = m_curr
    l_prev = l_curr

softmax_row = [torch.exp(x - m_prev).item() / l_prev for x in tensor[0]]
results.append(softmax_row)
results 

[[0.0015547872336663783,
  0.6272459103409497,
  0.08488850184024635,
  0.00021041755910617083,
  0.0005719742380081532,
  0.011488409875583918,
  0.2307508807124477,
  0.0005719742380081532,
  0.011488409875583918,
  0.03122873408639953]]

In [34]:
def softmax_new(tensor: torch.Tensor) -> List[List[torch.Tensor]]:
    m_prev = float(-torch.inf)
    l_prev = 0
    results = []
    for i in tensor[0]:
        m_curr = max(m_prev, i)
        l_curr = l_prev * torch.exp(m_prev - m_curr).item() + torch.exp(i - m_curr).item()
        m_prev = m_curr
        l_prev = l_curr

    softmax_row = [torch.exp(x - m_prev).item() / l_prev for x in tensor[0]]
    results.append(softmax_row)
    return results

In [41]:
softmax_row(t1)

[[0.3031144142150879,
  0.015091178007423878,
  0.0002764045784715563,
  0.015091178007423878,
  0.041022077202796936,
  0.3031144142150879,
  0.0020423689857125282,
  0.0020423689857125282,
  0.015091178007423878,
  0.3031144142150879],
 [0.3031144142150879,
  0.015091178007423878,
  0.0002764045784715563,
  0.015091178007423878,
  0.041022077202796936,
  0.3031144142150879,
  0.0020423689857125282,
  0.0020423689857125282,
  0.015091178007423878,
  0.3031144142150879],
 [0.3031144142150879,
  0.015091178007423878,
  0.0002764045784715563,
  0.015091178007423878,
  0.041022077202796936,
  0.3031144142150879,
  0.0020423689857125282,
  0.0020423689857125282,
  0.015091178007423878,
  0.3031144142150879],
 [2.789106723355417e-10,
  2.6782898589386787e-33,
  6.14341615801095e-06,
  2.789106723355417e-10,
  4.780273739395541e-25,
  2.609940782897381e-23,
  0.00012339380919001997,
  0.9998704195022583,
  3.5321712459204884e-24,
  5.379488853812136e-32],
 [0.09367210417985916,
  0.254627168

In [42]:
softmax_new(t1)

[[0.02364054202726851,
  0.06426165690335149,
  0.17468130082440936,
  0.47483299399271744,
  0.02364054202726851,
  0.06426165690335149,
  0.17468130082440936]]

In [17]:
import torch
BATCH_SIZE = 8 
SEQ_LEN = 10 
NUM_HEADS = 12
HEAD_DIM = 128
a1  = torch.tensor([[SEQ_LEN, BATCH_SIZE * NUM_HEADS]])
grid = torch.zeros_like(a1)
grid.shape

torch.Size([1, 2])

In [13]:
Q = torch.randn(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM, dtype=torch.float16)
K = torch.randn(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM, dtype=torch.float16)
V = torch.randn(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM, dtype=torch.float16)
Q.shape

torch.Size([8, 12, 10, 128])

In [19]:
grid = lambda args: (
    (SEQ_LEN + args["BLOCK_SIZE_Q"]-1) // args["BLOCK_SIZE_Q"],
    BATCH_SIZE * NUM_HEADS,
    1,
)

args = {"BLOCK_SIZE_Q": 4}
grid_shape = grid(args)
grid_tensor = torch.zeros(grid_shape)
print(grid_tensor)

tensor([[[0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [

In [22]:
block_index_q = grid(args)[0]
block_index_q

index_batch_head = torch.arange(BATCH_SIZE * NUM_HEADS)
index_batch_head

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
        72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
        90, 91, 92, 93, 94, 95])

In [None]:
qkv_offset = torch.arange(0, SEQ_LEN, args["BLOCK_SIZE_Q"])

In [None]:
# Forward pass of the kernel

import torch
import triton
import triton.language as tl

@triton.jit
def _attn_fwd_inner(
    O_block,
    l_i,
    m_i,
    Q_block,
    K_block_ptr,
    V_block_ptr,
    block_index_q,
    softmax_scale,
    BLOCK_SIZE_Q: tl.constexpr,
    BLOCK_SIZE_KV: tl.constexpr,
    STAGE: tl.constexpr,
    offs_q: tl.constexpr,
    offs_kv: tl.constexpr,
    SEQ_LEN: tl.constexpr,
):
    # range of values handled by this stage
    if STAGE == 1: # LEft part of diagonal
        # From 0 to the left of the diagonal
        lo, hi = 0, block_index_q * BLOCK_SIZE_Q
    elif STAGE == 2: # Exatly Along the diagonal  
        # Used only for the block in which there is transition between non-masked and masked keys
        lo, hi = block_index_q * BLOCK_SIZE_Q, (block_index_q + 1) * BLOCK_SIZE_Q
        lo = tl.multiple_of(lo, BLOCK_SIZE_Q)
    else:
        # Only used for non-causal attention
        lo, hi = 0, SEQ_LEN

    K_block_ptr = tl.advance(K_block_ptr, (0, lo)) 
    V_block_ptr = tl.advance(V_block_ptr, (lo, 0))

    # loop over k, v and update accumulator 
    for start_kv in range(lo, hi, BLOCK_SIZE_KV):
        # Just let the compiler know that start_n is a multiple of BLOCK_N, so the compiler can do optimizations
        start_kv = tl.multiple_of(start_kv, BLOCK_SIZE_KV)

        # -- compute qk ----
        K_block = tl.load(K_block_ptr)
        QK_block = tl.dot(Q_block, K_block)

        if STAGE == 2: 
            # Mask is applied when idx_q > indx_k,v
            mask = offs_q[:, None] >= (start_kv + offs_kv[None, :])
            QK_block = QK_block * softmax_scale + tl.where(mask, 0, -1.0e6)
            m_ij = tl.maximum(m_i, tl.max(QK_block, 1))
            QK_block -= m_ij[:, None]
        else:
            # Compute the maximum value of qk or keep the old max value
            m_ij = tl.maximum(m_i, tl.max(QK_block, 1) * softmax_scale)
            QK_block = QK_block * softmax_scale - m_ij[:, None]

        # Compute the exponential of each dot product, so now we are computing exp(qk_ij - m_ij)
        P_block = tl.math.exp(QK_block)
        # Compute the sum by rows of the attention scores
        l_ij = tl.sum(P_block, 1)

        # This is the correction factor for the previous l_i
        alpha = tl.math.exp(m_i - m_ij) # previous estimate - current estimate
        # Apply the correction factor to the previous l_i and add the new l_ij
        l_i = l_i * alpha + l_ij

        V_block = tl.load(V_block_ptr)
        P_block = P_block.to(tl.float16)
        # This computes the following: O_new = P x V + O_old * alpha
        O_block = O_block * alpha[:, None]
        O_block = tl.dot(P_block, V_block, O_block) # O_block += P_block @ V_block

        m_i = m_ij

        # Move to the next block of K and V
        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_SIZE_KV, 0)) # V[Seq_Len, HEAD_DIM]
        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_SIZE_KV)) # K[HEAD_DIM, Seq_Len]
    return O_block, l_i, m_i


@triton.jit
def _attn_fwd(
    Q,  # [BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM] # Q[index_batch, index_head, :, :]
    K,  # [BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM] # K[index_batch, index_head, :, :]
    V,  # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM
    softmax_scale,
    M,  # BATCH_SIZE, NUM_HEADS, SEQ_LEN
    O,  # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM
    stride_Q_batch,
    stride_Q_head,
    stride_Q_seq,
    stride_Q_dim,
    stride_K_batch,
    stride_K_head,
    stride_K_seq,
    stride_K_dim,
    stride_V_batch,
    stride_V_head,
    stride_V_seq,
    stride_V_dim,
    stride_O_batch,
    stride_O_head,
    stride_O_seq,
    stride_O_dim,
    BATCH_SIZE,
    NUM_HEADS: tl.constexpr,
    SEQ_LEN: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    BLOCK_SIZE_Q: tl.constexpr,
    BLOCK_SIZE_KV: tl.constexpr,
    STAGE: tl.constexpr,
):
    tl.static_assert(BLOCK_SIZE_KV <= HEAD_DIM)

    # This indicate which block in the sequence length to process
    block_index_q = tl.program_id(0)

    # This indicates which head and batch to process. Each program is associated with a single head of single batch
    index_batch_head = tl.program_id(1)
    # This indicates which batch this program is associated with (each batch has NUM_HEADS heads)
    index_batch = index_batch_head // NUM_HEADS # Select the right Batch  
    # This indicate the poisition of the head in the batch # Select the right Head
    index_head = index_batch_head % NUM_HEADS

    # This allows to get the (SEQ_LEN, HEAD_DIM) block of Q, K, V by selecting indexing it by batch and head
    qkv_offset = (
       index_batch.to(tl.int64)* stride_Q_batch # Q[index_batch * stride_Q_batch, :, :, :]
       + index_head.to(tl.int64) * stride_Q_head # Q[index_batch * stride_Q_batch + index_head * stride_Q_head, :, :]
    )

    # We are in Q[index_batch, index_head, block_index_q * BLOCK_SIZE_Q :, : ]
    Q_block_ptr = tl.make_block_ptr(# Currently pointing the perticular program to be working with
      base= Q + qkv_offset, # Q[index_batch, index_head, :, :]
      shape=(SEQ_LEN, HEAD_DIM), 
      strides=(stride_Q_seq, stride_Q_dim),
      offsets=(block_index_q * BLOCK_SIZE_Q, 0),
      block_shape=(BLOCK_SIZE_Q, HEAD_DIM),
      order=(1, 0),
    )
  
    # We are in V[index_batch, index_head, :, :]
    V_block_ptr = tl.make_block_ptr( # V[index_batch, index_head, :, :]
      base= V + qkv_offset, 
      shape=(stride_V_seq, stride_V_dim),
      offsets=(0,0),
      block_shape = (BLOCK_SIZE_KV, HEAD_DIM),
      order=(1, 0),
    )

    # We are in  K[index_batch, index_head, :, :]
    """
    Actually it won't be selecting `everything that is inside` but only the number of elements indicated
    by the `block_shape` parameter of each pointer block. You can consider each pointers block to be
    a tensor of pointers with the shape indicated by the param `block_shape`
    """
    K_block_ptr = tl.make_block_ptr(
      base = K + qkv_offset,
      shape=(HEAD_DIM, SEQ_LEN),
      strides=(
        stride_K_dim,
        stride_K_seq,
      ), # We invert the strides w.r.t Q, so we can transpose the matrix
      offsets=(0,0),
      block_shape=(HEAD_DIM, BLOCK_SIZE_KV),
      order=(0,1),
    )

    # In this the selection of the pointer should exactly indicate the right pointer for writing
    # Q[index_batch, index_head, block_index_q * BLOCK_SIZE_Q :, :]
    O_block_ptr = tl.make_block_ptr(
      base= O + qkv_offset,
      shape=(SEQ_LEN, HEAD_DIM),
      strides=(stride_O_seq, stride_O_dim),
      offsets=(block_index_q * BLOCK_SIZE_Q, 0),
      block_shape=(BLOCK_SIZE_Q, HEAD_DIM),
      order=(1,0),
    )

    # offs_q: the offsets for the tokens in the Q to process
    offs_q = block_index_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) 
    # Suppose program=0, block_size=4, Q[0, 1, 2, 3], Suppose program=3, block_size=4, Q[13, 14, 15, 16]
    """Each block of query is made up of block_size_q no of Queries. Each Q is a token and its
    dimention is not all the token but only the part of the head_dim """
    # offs_kv: the offsets for the tokens in the K and V sequence to process
    """We don't skip any values like Q here bcs we are going to multiply the whole K and V with the Q"""
    offs_kv = tl.arange(0, BLOCK_SIZE_KV)
    # For KV Suppose block_size = 4  -> [0, 1, 2, 3]
    # m_i : the running maximum of each row. We have one for each query
    m_i = tl.zeros([BLOCK_SIZE_Q], dtype=tl.float32) - float("inf")
    # l_i: the running sum. We have one for each query (as we sum the attention scores by rows)
    l_i = tl.zeros([BLOCK_SIZE_Q], dtype=tl.float32) + 1.0 # here +1 is to make the log stable
    # acc: the accumilator for the output, which is a group of rows of the O matrix
    O_block = tl.zeros([BLOCK_SIZE_Q, HEAD_DIM], dtype=tl.float32)

    # load the blocks of Q: it will stay in SRAM throughout
    Q_block = tl.load(Q_block_ptr)
 
    # Stage: 3 if casual else 1
    if STAGE == 1 or STAGE == 3:
      # This step runs for non-casual attention or for the blocks to the left of the diagonal in the casual attention
      O_block, l_i, m_i = _attn_fwd_inner(
        O_block,
        l_i,
        m_i,
        Q_block,
        K_block_ptr,
        V_block_ptr,
        block_index_q,
        softmax_scale,
        BLOCK_SIZE_Q,
        BLOCK_SIZE_KV,
        4 - STAGE,
        offs_q,
        offs_kv,
        SEQ_LEN,
      )

    if STAGE == 3:
      # This step runs for 
      O_block, l_i, m_i = _attn_fwd_inner(
         O_block,
            l_i,
            m_i,
            Q_block,
            K_block_ptr,
            V_block_ptr,
            block_index_q,
            softmax_scale,
            BLOCK_SIZE_Q,
            BLOCK_SIZE_KV,
            2,
            offs_q,
            offs_kv,
            SEQ_LEN,
      )
      # epilogue
      m_i += tl.math.log(
         l_i
      ) # This is needed to compujte the logsumexp for the backward pass
      O_block = O_block / l_i[:, None]
      m_ptrs = M + index_batch_head * SEQ_LEN + offs_q
      tl.store(m_ptrs, m_i)
      tl.store(O_block_ptr, O_block.to(O.type.element_ty))


class TritonAttention(torch.autograd.Function):
  @staticmethod
  def forward(ctx, Q, K, V, casual, softmax_scale):
    HEAD_DIM_Q, HEAD_DIM_K = Q.shape[-1], K.shape[-1]
    HEAD_DIM_V = V.shape[-1]
    
    BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM = Q.shape

    assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_V

    O = torch.empty_like(Q)
    stage = 3 if casual else 1

    grid = lambda args: (
        # ceil(SEQ_LEN / BLOCK_SIZE_Q) = How many blocks of Q we have
        triton.cdiv(SEQ_LEN, args["BLOCK_SIZE_Q"]), # Which group of Queries are we going to work with ?
        BATCH_SIZE * NUM_HEADS, # Which head of which batch element are we going to work with ? --- ^
        1, # Z is the CUDA launch grid
    )

    # Number of parallel programs or kernels : (BATCH_SIZE * NUM_HEADS * NUM_BLOCKS_Q)
    
    # M is logsumexp for the backward pass, one for each query
    M = torch.empty(
        (BATCH_SIZE, NUM_HEADS, SEQ_LEN), device=Q.device, dtype=torch.float32
    )

    _attn_fwd[grid](
        Q=Q, # Query 
        K=K, # Key
        V=V, # Value
        softmax_scale=softmax_scale, # 1/sqrt(HEAD_DIM)
        M=M, # Memory Block (L in psudo code of the paper)
        O=O, # Output
        stride_Q_batch=Q.stride(0), # 
        stride_Q_head=Q.stride(1),
        stride_Q_seq=Q.stride(2),
        stride_Q_dim=Q.stride(3),
        stride_K_batch=K.stride(0),
        stride_K_head=K.stride(1),
        stride_K_seq=K.stride(2),
        stride_K_dim=K.stride(3),
        stride_V_batch=V.stride(0),
        stride_V_head=V.stride(1),
        stride_V_seq=V.stride(2),
        stride_V_dim=V.stride(3),
        stride_O_batch=O.stride(0),
        stride_O_head=O.stride(1),
        stride_O_seq=O.stride(2),
        stride_O_dim=O.stride(3),
        BATCH_SIZE=Q.shape[0],
        NUM_HEADS=Q.shape[1],
        SEQ_LEN=Q.shape[2],
        HEAD_DIM=HEAD_DIM_K,
        STAGE=stage,
    )

    ctx.save_for_backward(Q, K, V, O, M)
    ctx.grid = grid
    ctx.softmax_scale = softmax_scale
    ctx.HEAD_DIM = HEAD_DIM_K
    ctx.casual = casual
    return 0


def test_op(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM, casual, dtype=torch.float16):
  Q = (
      torch.empty(
          (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), dtype=dtype, device="cuda"
      )
      .normal_(mean=0.0, std=0.5)
      .requires_grad_()
  )
  K = (
      torch.empty(
          (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), dtype=dtype, device="cuda"
      )
      .normal_(mean=0.0, std=0.5)
      .requires_grad_()
  )
  V = (
      torch.empty(
          (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), dtype=dtype, device="cuda"
      )
      .normal_(mean=0.0, std=0.5)
      .requires_grad_()
  )
  
  softmax_scale = 1 / (HEAD_DIM**0.5) # Q K^T / sqrt(HEAD_DIM)
  dO = torch.randn_like(0) # Needed for backward pass 

  # reference implementation (naive implimentation wrt CUDA and PyTorch)
  MASK = torch.tril(torch.ones((SEQ_LEN, SEQ_LEN), device="cuda"))
  P = torch.mathmul(Q, K.transpose(2, 3)) * softmax_scale
  if casual:
    P[:, :, MASK == 0] = float("-inf")
  P = torch.softmax(p.float("inf")).half()
  ref_O = torch.matmul(P, V)
  ref_O .backward()
  ref_dV, V.grad = V.grad.clone(), None
  ref_dK, K.grad = K.grad.clone(), None
  ref_dQ, Q.grad = Q.grad.clone(), None

  # triton implimentation 
  tri_out = TritonAttention.apply(Q,K,V, casual, softmax_scale).half()
  tri_out.backward(dO)
  tri_dV, V.grad = V.grad.clone(), None
  tri_dK, K.grad = K.grad.clone(), None
  tri_dQ, Q.grad = Q.grad.clone(), None

  # Compare
  rtrol = 0.0
  atol = 1e-2 
  # Absolute Diffrence
  assert torch.allclose(ref_O, tri_out, atol=atol, rtol=rtol)
  assert torch.allclose(ref_dK, tri_dK, atol=atol, rtol=rtol)
  assert torch.allclose(ref_dV, tri_dV, atol=atol, rtol=rtol)
  assert torch.allclose(ref_dQ, tri_dQ, atol=atol, rtol=rtol)