<a href="https://colab.research.google.com/github/yangliupku/cs336_assignment2_systems/blob/main/notebooks/Triton_weighted_sum.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [62]:
import triton
import triton.language as tl
import torch
from einops import rearrange

In [70]:
@triton.jit
def weighted_sum_fwd(
    x_ptr, weight_ptr,
    output_ptr,
    x_stride_row, x_stride_dim,
    weight_stride_dim,
    output_stride_row,
    ROWS, D,
    ROWS_TILE_SIZE: tl.constexpr, D_TILE_SIZE: tl.constexpr,
):
  row_tile_idx = tl.program_id(0)
  tl.device_print("row_tile_idx", row_tile_idx)

  x_block_ptr = tl.make_block_ptr(
      x_ptr,
      shape=(ROWS, D),
      strides=(x_stride_row, x_stride_dim),
      offsets=(row_tile_idx * ROWS_TILE_SIZE, 0),
      block_shape = (ROWS_TILE_SIZE, D_TILE_SIZE),
      order = (1, 0),
  )

  weight_block_ptr = tl.make_block_ptr(
      weight_ptr,
      shape=(D,),
      strides=(weight_stride_dim,),
      offsets=(0,),
      block_shape=(D_TILE_SIZE,),
      order=(0,),
  )

  output_block_ptr = tl.make_block_ptr(
      output_ptr,
      shape=(ROWS, ),
      strides=(output_stride_row,),
      offsets=(row_tile_idx * ROWS_TILE_SIZE,),
      block_shape=(ROWS_TILE_SIZE,),
      order=(0,),
  )

  output = tl.zeros((ROWS_TILE_SIZE, ), dtype=tl.float32)

  for i in range(tl.cdiv(D, D_TILE_SIZE)):
    row = tl.load(x_block_ptr, boundary_check=(0, 1), padding_option='zero')
    weight = tl.load(weight_block_ptr, boundary_check=(0,), padding_option='zero')
    output += tl.sum(row*weight[None, :], axis=1)
    x_block_ptr =x_block_ptr.advance((0, D_TILE_SIZE))
    weight_block_ptr = weight_block_ptr.advance((D_TILE_SIZE,))

  tl.store(output_block_ptr, output, boundary_check=(0,))

class WeightedSumFunc(torch.autograd.Function):
  @staticmethod
  def forward(ctx, x, weight):
    D, output_dims = x.shape[-1], x.shape[:-1]
    input_shape = x.shape
    x=rearrange(x, "... d-> (...) d")
    ctx.save_for_backward(x, weight)
    assert len(weight.shape) == 1 and weight.shape[0] == D
    assert x.is_cuda and weight.is_cuda
    assert x.is_contiguous()
    ctx.D_TILE_SIZE = triton.next_power_of_2(D)
    ctx.ROWS_TILE_SIZE = 16
    ctx.input_shape = input_shape
    y = torch.empty(output_dims, device = x.device)

    n_rows = y.numel()
    grid = (triton.cdiv(n_rows, ctx.ROWS_TILE_SIZE),)

    weighted_sum_fwd[grid](
        x, weight,
        y,
        x.stride(0), x.stride(1),
        weight.stride(0),
        y.stride(0),
        ROWS=n_rows, D=D,
        ROWS_TILE_SIZE=ctx.ROWS_TILE_SIZE, D_TILE_SIZE=ctx.D_TILE_SIZE,
    )
    return y.view(input_shape[:-1])

  @staticmethod
  def backward(ctx, grad_output):
    """
    grad_output is the gradient flowing backwards from whatever operation comes after your function.
    It represents ∂Loss/∂y where y is your function's output.

    Example: if your function outputs y with shape [5, 8], then:
    grad_output has shape [5, 8] and contains ∂Loss/∂y
    """
    x_flat, weight = ctx.saved_tensors
    input_shape = ctx.input_shape

    # Flatten grad_output to match x_flat
    grad_output_flat = grad_output.contiguous().view(-1)

    # Compute gradients
    grad_x = None
    grad_weight = None

    if ctx.needs_input_grad[0]:
        # grad_x = grad_output.unsqueeze(-1) * weight.unsqueeze(0)
        grad_x = grad_output_flat.unsqueeze(-1) * weight.unsqueeze(0)
        grad_x = grad_x.view(input_shape)

    if ctx.needs_input_grad[1]:
        # grad_weight = sum(grad_output.unsqueeze(-1) * x, dim=0)
        grad_weight = torch.sum(grad_output_flat.unsqueeze(-1) * x_flat, dim=0)

    return grad_x, grad_weight

f_weighted_sum = WeightedSumFunc.apply


In [74]:
x = torch.rand(5, 10, requires_grad=True, device='cuda')
weight = torch.rand(10, requires_grad=True, device='cuda')
y = WeightedSumFunc.apply(x, weight)
print(y)
y.sum().backward()
print(x.grad)
x.grad.data.zero_()
weight.grad.data.zero_()
z = (x*weight).sum(dim=-1)
z.sum().backward()
print(x.grad)

tensor([2.7370, 2.4548, 2.7630, 2.4602, 2.9581], device='cuda:0',
       grad_fn=<WeightedSumFuncBackward>)
tensor([[0.5056, 0.4949, 0.1311, 0.5317, 0.7577, 0.0072, 0.6537, 0.9445, 0.2022,
         0.1763],
        [0.5056, 0.4949, 0.1311, 0.5317, 0.7577, 0.0072, 0.6537, 0.9445, 0.2022,
         0.1763],
        [0.5056, 0.4949, 0.1311, 0.5317, 0.7577, 0.0072, 0.6537, 0.9445, 0.2022,
         0.1763],
        [0.5056, 0.4949, 0.1311, 0.5317, 0.7577, 0.0072, 0.6537, 0.9445, 0.2022,
         0.1763],
        [0.5056, 0.4949, 0.1311, 0.5317, 0.7577, 0.0072, 0.6537, 0.9445, 0.2022,
         0.1763]], device='cuda:0')
tensor([[0.5056, 0.4949, 0.1311, 0.5317, 0.7577, 0.0072, 0.6537, 0.9445, 0.2022,
         0.1763],
        [0.5056, 0.4949, 0.1311, 0.5317, 0.7577, 0.0072, 0.6537, 0.9445, 0.2022,
         0.1763],
        [0.5056, 0.4949, 0.1311, 0.5317, 0.7577, 0.0072, 0.6537, 0.9445, 0.2022,
         0.1763],
        [0.5056, 0.4949, 0.1311, 0.5317, 0.7577, 0.0072, 0.6537, 0.9445, 0.2022,
 

In [65]:
weight

tensor([0.9127, 0.6412, 0.9072, 0.7362, 0.2823, 0.8845, 0.2758, 0.0404, 0.5012,
        0.5784], device='cuda:0', requires_grad=True)

In [69]:
weight.unsqueeze(0).shape

torch.Size([1, 10])