<a href="https://colab.research.google.com/github/ved1beta/Triton/blob/main/softmax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install triton torch
import torch
import triton
import triton.language as tl
def softmax(x: torch.Tensor) -> torch.Tensor:

    rows, cols = x.shape
    assert x.dim() == 2, f"Expected 2D input, got {x.dim()}D input"

    block_size = triton.next_power_of_2(cols)

    num_warps = 4  # Each warp has 32 threads
    if block_size > 2047:
        num_warps = 8
    if block_size > 4095:
        num_warps = 16

    # Define grid size, each thread block (Block) processes one row of data
    grid = (rows,) # This creates a tuple containing only rows


    sm_out = torch.empty_like(x)

    _softmax_fwd_kernel[grid](
        sm_out,
        sm_out.stride(0),
        x,
        x.stride(0),
        cols,
        block_size=block_size,
        num_warps=num_warps
    )

    return sm_out







Collecting triton
  Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.3 kB)
Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (209.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.5/209.5 MB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: triton
Successfully installed triton-3.1.0


In [4]:
@triton.jit
def _softmax_fwd_kernel(
    output_ptr,
    stride_output_row,
    input_ptr,
    stride_input_row,
    num_cols,
    block_size: tl.constexpr,
):
    # Get the ID of the current program (row index)
    row_index = tl.program_id(0)

    # Calculate the starting pointer for the current row
    row_start_ptr = input_ptr + (row_index * stride_input_row)
    col_offsets = tl.arange(0, block_size)
    input_pointers = row_start_ptr + col_offsets

    # Create a mask to prevent out-of-bounds access
    row_mask = col_offsets < num_cols

    # Load data from global memory to shared SRAM
    row = tl.load(input_pointers, mask=row_mask, other=float("-inf"))

    # Softmax calculation
    safe_row = row - tl.max(row, axis=0)
    numerator = tl.exp(safe_row)
    denominator = tl.sum(numerator, axis=0)
    sm_out = numerator / denominator

    # Write results back to global memory
    output_row_ptr = output_ptr + (row_index * stride_output_row)
    output_pointers = output_row_ptr + col_offsets
    tl.store(output_pointers, sm_out, mask=row_mask)


In [5]:
_softmax_fwd_kernel[grid](
    sm_out,                # Pointer to the output tensor
    sm_out.stride(0),      # Stride of the output tensor in the row direction
    x,                     # Pointer to the input tensor
    x.stride(0),           # Stride of the input tensor in the row direction
    cols,                  # Number of columns in the input tensor
    # Kernel configuration parameters
    block_size=block_size,
    num_warps=num_warps
)

NameError: name 'grid' is not defined