diff --git a/examples/grouped_gemm.py b/examples/grouped_gemm.py index d6d51b597..1a57d9f55 100644 --- a/examples/grouped_gemm.py +++ b/examples/grouped_gemm.py @@ -165,9 +165,9 @@ def grouped_gemm_jagged_persistent( if m_size > 0: # Compute tile grid dimensions for current group - num_m_tiles = (m_size + BLOCK_M - 1) // BLOCK_M + num_m_tiles = helion.cdiv(m_size, BLOCK_M) # pyright: ignore[reportArgumentType] # Calculate number of N tiles (shared across all groups) - num_n_tiles = (N + BLOCK_N - 1) // BLOCK_N + num_n_tiles = helion.cdiv(N, BLOCK_N) # pyright: ignore[reportArgumentType] num_group_tiles = num_m_tiles * num_n_tiles # Distribute tiles among workers using strided access pattern diff --git a/examples/rms_norm.py b/examples/rms_norm.py index 0e2c28d25..bf44214aa 100644 --- a/examples/rms_norm.py +++ b/examples/rms_norm.py @@ -96,7 +96,7 @@ def rms_norm_bwd( m_block = hl.register_block_size(x.size(0)) grad_x = torch.empty_like(x) grad_weight = x.new_empty( - [(x.size(0) + m_block - 1) // m_block, *weight.shape], dtype=torch.float32 + [helion.cdiv(x.size(0), m_block), *weight.shape], dtype=torch.float32 ) weight_shape = hl.specialize(weight.size(0)) for mb_cta in hl.tile(x.size(0), block_size=m_block): diff --git a/test/test_examples.expected b/test/test_examples.expected index 5e56dbe53..4610a115c 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -3153,6 +3153,7 @@ def moe_matmul_ogs(A: torch.Tensor, W: torch.Tensor, expert_token_counts: torch. from __future__ import annotations import torch +import helion import triton import triton.language as tl from helion.runtime import default_launcher as _default_launcher @@ -3218,7 +3219,7 @@ def rms_norm_bwd(grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tensor, """ m_block = 32 grad_x = torch.empty_like(x) - grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, *weight.shape], dtype=torch.float32) + grad_weight = x.new_empty([helion.cdiv(x.size(0), m_block), *weight.shape], dtype=torch.float32) _BLOCK_SIZE_0 = 32 _RDIM_SIZE_2 = 64 _launcher(_helion_rms_norm_bwd, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_out, rsqrt, weight, grad_x, grad_weight, x.size(0), grad_out.stride(0), grad_out.stride(1), grad_weight.stride(0), grad_weight.stride(1), grad_x.stride(0), grad_x.stride(1), rsqrt.stride(0), weight.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, num_warps=4, num_stages=3)