-
Notifications
You must be signed in to change notification settings - Fork 36
Open
Description
Repro:
# Run with:
# rm -rf /tmp/torchinductor_${USER}/ && CUDA_LAUNCH_BLOCKING=1 HELION_AUTOTUNE_RANDOM_SEED=2189049218 python repro_helion_686.py
import torch
import helion
import helion.language as hl
@helion.kernel(config=helion.Config(block_sizes=[32, 32], flatten_loops=[False], indexing='pointer', l2_groupings=[1], loop_orders=[[0, 1]], num_stages=3, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=False)
def bf16_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x, y = torch.broadcast_tensors(x, y)
out = torch.empty(
x.shape,
dtype=torch.promote_types(x.dtype, y.dtype),
device=x.device,
)
for tile in hl.tile(out.size()):
out[tile] = x[tile] + y[tile]
return out
def check(m: int, n: int):
x = torch.randn(m, n, device="cuda", dtype=torch.bfloat16)
y = torch.randn(m, n, device="cuda", dtype=torch.bfloat16)
bf16_add(x, y)
def main():
check(51200, 51200)
if __name__ == "__main__":
main()
Error:
Traceback (most recent call last):
File "/data/users/willfeng/helion/repro_helion_686.py", line 31, in <module>
main()
File "/data/users/willfeng/helion/repro_helion_686.py", line 28, in main
check(51200, 51200)
File "/data/users/willfeng/helion/repro_helion_686.py", line 25, in check
bf16_add(x, y)
File "/data/users/willfeng/helion/helion/runtime/kernel.py", line 286, in __call__
return self.bind(args)(*args)
^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/willfeng/helion/helion/runtime/kernel.py", line 628, in __call__
return self._run(*args)
^^^^^^^^^^^^^^^^
File "/tmp/torchinductor_willfeng/zb/czbypmdpmclo7k7sdi2vssklgemvqm5pwtdoe3qndneegz3p576y.py", line 29, in bf16_add
_launcher(_helion_bf16_add, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, y, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
File "/data/users/willfeng/helion/helion/runtime/__init__.py", line 66, in default_launcher
return triton_kernel.run(
^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch-nightly/triton/runtime/jit.py", line 757, in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
File "/home/willfeng/local/pytorch-nightly/triton/backends/nvidia/driver.py", line 712, in __call__
self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered