Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/grouped_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@ def grouped_gemm_jagged_persistent(

if m_size > 0:
# Compute tile grid dimensions for current group
num_m_tiles = (m_size + BLOCK_M - 1) // BLOCK_M
num_m_tiles = helion.cdiv(m_size, BLOCK_M) # pyright: ignore[reportArgumentType]
# Calculate number of N tiles (shared across all groups)
num_n_tiles = (N + BLOCK_N - 1) // BLOCK_N
num_n_tiles = helion.cdiv(N, BLOCK_N) # pyright: ignore[reportArgumentType]
num_group_tiles = num_m_tiles * num_n_tiles

# Distribute tiles among workers using strided access pattern
Expand Down
2 changes: 1 addition & 1 deletion examples/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def rms_norm_bwd(
m_block = hl.register_block_size(x.size(0))
grad_x = torch.empty_like(x)
grad_weight = x.new_empty(
[(x.size(0) + m_block - 1) // m_block, *weight.shape], dtype=torch.float32
[helion.cdiv(x.size(0), m_block), *weight.shape], dtype=torch.float32
)
weight_shape = hl.specialize(weight.size(0))
for mb_cta in hl.tile(x.size(0), block_size=m_block):
Expand Down
3 changes: 2 additions & 1 deletion test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -3153,6 +3153,7 @@ def moe_matmul_ogs(A: torch.Tensor, W: torch.Tensor, expert_token_counts: torch.
from __future__ import annotations

import torch
import helion
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher
Expand Down Expand Up @@ -3218,7 +3219,7 @@ def rms_norm_bwd(grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tensor,
"""
m_block = 32
grad_x = torch.empty_like(x)
grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, *weight.shape], dtype=torch.float32)
grad_weight = x.new_empty([helion.cdiv(x.size(0), m_block), *weight.shape], dtype=torch.float32)
_BLOCK_SIZE_0 = 32
_RDIM_SIZE_2 = 64
_launcher(_helion_rms_norm_bwd, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_out, rsqrt, weight, grad_x, grad_weight, x.size(0), grad_out.stride(0), grad_out.stride(1), grad_weight.stride(0), grad_weight.stride(1), grad_x.stride(0), grad_x.stride(1), rsqrt.stride(0), weight.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, num_warps=4, num_stages=3)
Expand Down
Loading