<a href="https://colab.research.google.com/github/soulsharp/Triton_kernels_ViT/blob/main/Triton_layer_norm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install triton

Collecting triton
  Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.3 kB)
Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (209.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.5/209.5 MB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: triton
Successfully installed triton-3.1.0


In [None]:
import triton
import triton.language as tl
import torch

In [None]:
@triton.jit
def _layer_norm_forward(
    input_ptr, # Pointer to input, 2D
    output_ptr, # Pointer to output, 2D
    weight_ptr, # Pointer to weights, shape : (num_cols, 1)
    bias_ptr, # Pointer to biases, shape : (num_cols, 1)
    mean_vector_ptr, # Row-wise mean, shape:(num_rows, 1)
    rstd_ptr, # Row-wise reciprocal of standard deviation, shape:(num_rows, 1)
    xhat_ptr, # Stores xhat values for backward pass
    row_stride,
    num_cols,
    epsilon, # For numerical stability
    BLOCK_SIZE:tl.constexpr,
  ):

  pid = tl.program_id(0)
  input_ptr += pid * row_stride
  output_ptr += pid * row_stride

  # Initializes mean accumulator and squared sum accumulator with zeros
  mean_acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
  squared_sum_acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32)

  # Processes a row in chunks
  for offset in range(0, num_cols, BLOCK_SIZE):
      offsets = offset + tl.arange(0, BLOCK_SIZE)
      mask = offsets < num_cols

      # Loads a segment of the row
      row_segment = tl.load(input_ptr + offsets, mask=mask, other=0.).to(tl.float32)

      # Accumulates sums and squared sums
      mean_acc += row_segment
      squared_sum_acc += row_segment * row_segment

  # Aggregates partial sums and partial squared sums
  mean = tl.sum(mean_acc, axis=0) / num_cols
  mean_squared_sum = tl.sum(squared_sum_acc, axis=0) / num_cols

  # Variance and rstd
  variance = mean_squared_sum - mean * mean
  rstd = 1.0 / tl.sqrt(variance + epsilon)

  # Stores mean and rstd
  tl.store(mean_vector_ptr + pid, mean)
  tl.store(rstd_ptr + pid, rstd)

  # Normalizes inputs and stores y as output
  for offset in range(0, num_cols, BLOCK_SIZE):
      offsets = offset + tl.arange(0, BLOCK_SIZE)
      mask = offsets < num_cols

      # Loads a segment of the row, corresponding weights and biases
      row_segment = tl.load(input_ptr + offsets, mask=mask, other=0.).to(tl.float32)
      weights = tl.load(weight_ptr + offsets, mask=mask)
      biases = tl.load(bias_ptr + offsets, mask=mask)

      # Computes normalized inputs
      delta = row_segment - mean
      xhat = delta * rstd
      y = weights * xhat + biases

      # Writes output
      tl.store(output_ptr + offsets, y, mask=mask)
      tl.store(xhat_ptr + offsets, xhat, mask=mask)


In [None]:
@triton.jit
def _backward_pass_layer_norm(
    Dy, # Gradient of loss wrt y
    Dx, # Gradient of loss wrt x(inputs, to be calculated)
    Dw, # Gradient of loss wrt weights(to be calculated)
    Db, # Gradient of loss wrt biases(to be calculated)
    input_ptr,
    weights_ptr
    mean_vector_ptr,
    wdy_ptr, # intermediate ptr to store the values of w@dy
    rstd_ptr,
    xhat_ptr,
    row_stride,
    num_cols,
    GROUP_SIZE_M: tl.constexpr,
    BLOCK_SIZE: tl.constexpr
 ):

 pid = tl.program_id(0)
 row_idx = pid * row_stride
 num_blocks = tl.cdiv(num_cols, BLOCK_SIZE)

 # Shared mem arrays to compute dot products after accumulation of partial sums
 dot_sum_xwdy = tl.zeros([num_blocks])
 dot_sum_wdy = tl.zeros([num_blocks])

 # Gets the row_wise mean and standard deviation depending upon program id
 mean_value = tl.load(mean_vector_ptr + pid)
 rstd_value = tl.load(rstd_ptr + pid)

 for block_id in range(num_blocks):
    start_idx = block_id * BLOCK_SIZE
    offsets = start_idx + tl.arange(0, BLOCK_SIZE)
    mask = offsets < num_cols

    # Loads a bunch of different blocks of values for computation
    grad_y = tl.load(Dy + row_idx + offsets, mask=mask)
    weight_values = tl.load(weights_ptr + row_idx + offsets, mask=mask)
    xhat = tl.load(xhat_ptr + row_idx + offsets, mask=mask)

    # Computes vectors required for the calculation of Dx
    wdy = grad_y * weight_values

    # Accumulates partial sums of dot products in shared memory
    dot_sum_xwdy[block_id] = tl.sum(xhat * wdy, axis = 0)
    dot_sum_wdy[block_id] = tl.sum(grad_y * weight_values, axis=0)

 # Constants required for calculation of Dx
 c1 = tl.sum(dot_sum_xwdy, axis=0) / num_cols
 c2 = tl.sum(dot_sum_wdy, axis=0) / num_cols

 for block_id in range(num_blocks):
    start_idx = block_id * BLOCK_SIZE
    offsets = start_idx + tl.arange(0, BLOCK_SIZE)
    mask = offsets < num_cols

    # Loads a bunch of different blocks of values for computation
    grad_y = tl.load(Dy + row_idx + offsets, mask=mask)
    weight_values = tl.load(weights_ptr + row_idx + offsets, mask=mask)
    xhat = tl.load(xhat_ptr + row_idx + offsets, mask=mask)

    # xhat = delta / rstd_value
    wdy = grad_y * weight_values

    # Core computation
    grad_x = (wdy - c1 * xhat - c2) * rstd_value

    tl.store(Dx + offsets ,grad_x, mask=mask)


In [None]:
# Constants
NUM_ROWS = 4  # Number of rows in the input matrix
NUM_COLS = 16  # Number of columns in the input matrix
BLOCK_SIZE = 8  # Block size for Triton kernel
EPSILON = 1e-5  # Small constant for numerical stability

# Input data
input_tensor = torch.randn((NUM_ROWS, NUM_COLS), dtype=torch.float32, device='cuda')
output_tensor = torch.zeros_like(input_tensor, device='cuda')
weight_tensor = torch.ones(NUM_COLS, dtype=torch.float32, device='cuda')
bias_tensor = torch.zeros(NUM_COLS, dtype=torch.float32, device='cuda')
mean_vector = torch.zeros(NUM_ROWS, dtype=torch.float32, device='cuda')
rstd_vector = torch.zeros(NUM_ROWS, dtype=torch.float32, device='cuda')

# Strides
row_stride = input_tensor.stride(0)

# Launch kernel
layer_norm_forward[(NUM_ROWS,)](
    input_tensor,
    output_tensor,
    weight_tensor,
    bias_tensor,
    mean_vector,
    rstd_vector,
    row_stride,
    NUM_COLS,
    EPSILON,
    BLOCK_SIZE=BLOCK_SIZE
)

_layer_norm_fwd_test[(NUM_ROWS,)](
    input_tensor,
    output_tensor,
    weight_tensor,
    bias_tensor,
    mean_vector,
    rstd_vector,
    row_stride,
    NUM_COLS,
    EPSILON,
    BLOCK_SIZE=BLOCK_SIZE
)

# Retrieve and verify results
print("Input Tensor:")
print(input_tensor.cpu().numpy())
print("Mean Vector:")
print(mean_vector.cpu().numpy())
print("Rstd Vector:")
print(rstd_vector.cpu().numpy())

Input Tensor:
[[ 0.17184934  0.47160515  0.33827543  0.27657324 -0.7147892   0.51168585
  -0.19860959 -0.11520658  1.1411026  -0.77810264 -0.78735423 -0.55497736
  -0.3488355  -0.4574265  -0.59218186 -1.3202435 ]
 [-0.07285843 -0.50958127 -0.8654773   0.726394    0.37132856 -0.88880646
   1.460418    1.1474667   0.2841265  -0.81712204  0.81673616 -0.1994375
   0.92693114 -0.2802274   0.58567303 -0.03764753]
 [-1.0386246  -0.2645126  -0.00845453  1.0592384   1.7422706  -1.1534754
   0.8803959   1.2779768  -0.7138989  -1.7249111  -1.1983298  -0.731013
  -0.28897485 -0.16875947 -0.55907196  0.5937596 ]
 [-0.5679358  -0.2420803   0.46496576  0.32869694 -0.3844268   0.2112381
  -0.5294461   0.7774498  -0.49508306  1.892584   -0.0091246   1.0345228
   0.47390002  0.56024647 -0.87162644 -0.9639047 ]]
Mean Vector:
[-0.18478972  0.16549478 -0.14352408  0.1049985 ]
Rstd Vector:
[1.6354115 1.3952965 1.0345864 1.3537248]
