<a href="https://colab.research.google.com/github/soulsharp/Triton_kernels_ViT/blob/main/Triton_Softmax_Stable.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install triton

Collecting triton
  Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.3 kB)
Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (209.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.5/209.5 MB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: triton
Successfully installed triton-3.1.0


In [4]:
import torch
import triton
import triton.language as tl

**APPROACH 1:**

Multiple loads from global memory, but no pressure on shared memory



In [5]:
@triton.jit
def _softmax_stable_blocked(
    input_ptr,
    output_ptr,
    input_row_stride,
    output_row_stride,
    num_cols,
    BLOCK_SIZE: tl.constexpr,
):

    pid = tl.program_id(0)
    input_row = input_ptr + pid * input_row_stride
    output_row = output_ptr + pid * output_row_stride

    # Shared memory arrays for partial reductions
    num_blocks = tl.cdiv(num_cols, BLOCK_SIZE)
    partial_max = tl.zeros([num_blocks], dtype=tl.float32)
    partial_sum = tl.zeros([num_blocks], dtype=tl.float32)

    # Computes the row max in blocks
    for b_id in range(num_blocks):
        start_idx = block_id * BLOCK_SIZE
        offsets = start_idx + tl.arange(0, BLOCK_SIZE)
        mask = offsets < num_cols
        values = tl.load(input_row + offsets, mask=mask, other=-float('inf'))

        # Gets the max of the elements in the block(partial max)
        partial_max[block_id] = tl.max(values, axis=0)

    # Reduces partial maxima to find the global maximum
    row_max = tl.max(partial_max, axis=0)

    # Computes the exponentials and partial sums
    for block_id in range(num_blocks):
        start_idx = block_id * BLOCK_SIZE
        offsets = start_idx + tl.arange(0, BLOCK_SIZE)
        mask = offsets < num_cols

        values = tl.load(input_row + offsets, mask=mask, other=0)

        # Subtracts row max and exponentiates
        values = tl.exp(values - row_max)

        # Gets the partial sum of the elements in the  block
        partial_sum[block_id] = tl.sum(values, axis=0)

        # Writes back intermediate values for normalization
        tl.store(output_row + offsets, values, mask=mask)

    # Reduces partial sums to find the global sum
    row_sum = tl.sum(partial_sum, axis=0)

    # Normalizes and writes back the final softmax values
    for block_id in range(num_blocks):
        start_idx = block_id * BLOCK_SIZE
        offsets = start_idx + tl.arange(0, BLOCK_SIZE)
        mask = offsets < num_cols

        # Loads partially normalized values(x-mean)
        values = tl.load(output_row + offsets, mask=mask, other=0)
        values /= row_sum

        # Writes back the final normalized values
        tl.store(output_row + offsets, values, mask=mask)


APPROACH 2 :

Loads each row of the matrix once and performs all computations using
shared memory or registers.

Becomes shared memory-bound when the size of rows is large.

In [7]:
@triton.jit
def _softmax_stable(
    input_ptr,
    output_ptr,
    input_row_stride,
    output_row_stride,
    num_cols,
    BLOCK_SIZE:tl.constexpr,
):

  pid = tl.program_id(0)
  input_row = input_ptr + pid * input_row_stride
  partial_max = tl.zeros([tl.cdiv(num_cols, BLOCK_SIZE)], dtype=tl.float32)
  partial_sum = tl.zeros([tl.cdiv(num_cols, BLOCK_SIZE)], dtype=tl.float32)
  shared_mem_row = tl.zeros([num_cols], dtype=tl.float32)
  partial_max_mem_idx = 0
  partial_sum_mem_idx = 0

  for offset in range(0, num_cols, BLOCK_SIZE):
    offsets = offset + tl.arange(0, BLOCK_SIZE)
    mask = offsets < num_cols

    # row_segment = tl.load(input_row + offsets, mask=mask)
    shared_mem_row[offsets] = tl.load(input_row + offsets, mask=mask)
    partial_max[partial_max_mem_idx] = tl.max(shared_mem_row, axis=0)
    partial_max_mem_idx += 1

  row_max = tl.max(partial_max, axis=0)

  for offset in range(0, num_cols, BLOCK_SIZE):
    offsets = offset + tl.arange(0, BLOCK_SIZE)
    mask = offsets < num_cols

    # row_segment = tl.load(input_row + offsets, mask=mask)
    shared_mem_row[offsets] = shared_mem_row[offsets] - row_max
    shared_mem_row[offsets] = tl.exp(shared_mem_row[offsets])
    partial_sum[partial_sum_mem_idx] = tl.sum(shared_mem_row[offsets], axis=0)
    partial_sum_mem_idx += 1

  row_sum = tl.sum(partial_sum, axis=0)

  for offset in range(0, num_cols, BLOCK_SIZE):
    offsets = offset + tl.arange(0, BLOCK_SIZE)
    mask = offsets < num_cols

    shared_mem_row[offsets] /= row_sum

    tl.store[output_ptr + offsets, shared_mem_row[offsets], mask]