-
Notifications
You must be signed in to change notification settings - Fork 77
Closed
Labels
Description
Repro:
#!/usr/bin/env python3
"""
Minimal SE Block test to reproduce bf16/fp32 mismatch error
"""
import torch
import helion
import helion.language as hl
@helion.kernel(
config=helion.Config(
block_sizes=[32],
indexing='block_ptr',
load_eviction_policies=['last', 'last'],
num_stages=8,
num_warps=1,
pid_type='flat',
range_flattens=[None],
range_multi_buffers=[None],
range_num_stages=[0],
range_unroll_factors=[0],
range_warp_specializes=[None]
),
static_shapes=True
)
def se_block_fwd(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
"""Minimal kernel that triggers bf16/fp32 mismatch"""
m, n = x.size()
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
for tile_m in hl.tile(m):
x_tile = x[tile_m, :] # bf16
# This creates fp32 intermediate result
sigmoid_result = torch.sigmoid(x_tile @ w[:, :])
# Multiply bf16 * fp32 gives fp32, but out expects bf16
acc = hl.full([1], 2.0, dtype=x.dtype) * x_tile * sigmoid_result
out[tile_m, :] = acc.to(x.dtype) # Error happens here
return out
def main():
# Minimal test to trigger the error
M, N = 2018304, 128
dtype = torch.bfloat16
print(f"Testing with M={M}, N={N}, dtype={dtype}")
x = torch.randn(M, N, dtype=dtype, device="cuda")
w = torch.randn(N, N, dtype=dtype, device="cuda")
# This will trigger the error
out = se_block_fwd(x, w)
print("Test completed (should not reach here)")
if __name__ == "__main__":
main()Error:
$ python se_block_standalone_test.py
Testing with M=2018304, N=128, dtype=torch.bfloat16
Traceback (most recent call last):
File "/home/willfeng/local/pytorch-nightly/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch-nightly/triton/language/core.py", line 2192, in store
return _semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch-nightly/triton/language/semantic.py", line 1294, in store
return self._store_block_pointer(ptr, val, mask, boundary_check, cache, eviction)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch-nightly/triton/language/semantic.py", line 1224, in _store_block_pointer
assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch"
AssertionError: Block element type(bf16) and value element type(fp32) mismatch
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/willfeng/local/helion/se_block_standalone_test.py", line 58, in <module>
main()
File "/home/willfeng/local/helion/se_block_standalone_test.py", line 53, in main
out = se_block_fwd(x, w)
^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/helion/helion/runtime/kernel.py", line 292, in __call__
return self.bind(args)(*args)
^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/helion/helion/runtime/kernel.py", line 644, in __call__
return self._run(*args)
^^^^^^^^^^^^^^^^
File "/tmp/torchinductor_willfeng/5z/c5zo63ugbuwahfw7nihpqsmhfalcn7bn3zszpgo4nlo4pqfzoo7q.py", line 40, in se_block_fwd
_launcher(_helion_se_block_fwd, (triton.cdiv(2018304, _BLOCK_SIZE_0),), x, w, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=1, num_stages=8)
File "/home/willfeng/local/helion/helion/runtime/__init__.py", line 66, in default_launcher
return triton_kernel.run(
^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch-nightly/triton/runtime/jit.py", line 733, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch-nightly/triton/runtime/jit.py", line 861, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch-nightly/triton/compiler/compiler.py", line 300, in compile
module = src.make_ir(target, options, codegen_fns, module_map, context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch-nightly/triton/compiler/compiler.py", line 80, in make_ir
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
triton.compiler.errors.CompilationError: at 17:4:
x_tile = tl.load(tl.make_block_ptr(x, [2018304, 128], [128, 1], [offset_0, 0], [_BLOCK_SIZE_0, _RDIM_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero', eviction_policy='evict_last')
# src[se_block_standalone_test.py:34]: sigmoid_result = torch.sigmoid(x_tile @ w[:, :])
load_1 = tl.load(tl.make_block_ptr(w, [128, 128], [128, 1], [0, 0], [_RDIM_SIZE_1, _RDIM_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero', eviction_policy='evict_last')
mm = tl.dot(tl.cast(x_tile, tl.bfloat16), tl.cast(load_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
v_0 = tl.sigmoid(tl.cast(mm, tl.float32))
# src[se_block_standalone_test.py:36]: acc = hl.full([1], 2.0, dtype=x.dtype) * x_tile * sigmoid_result
full = tl.full([1], 2.0, tl.bfloat16)
v_1 = full[None, :]
v_2 = v_1 * x_tile
v_3 = v_2 * v_0
# src[se_block_standalone_test.py:37]: out[tile_m, :] = acc.to(x.dtype) # Error happens here
tl.store(tl.make_block_ptr(out, [2018304, 128], [128, 1], [offset_0, 0], [_BLOCK_SIZE_0, _RDIM_SIZE_1], [1, 0]), v_3, boundary_check=[0, 1])
^
Block element type(bf16) and value element type(fp32) mismatch
i.e. the symptom is that torch.sigmoid(x_tile @ w[:, :]) automatically upcasts the x_tile @ w[:, :] result to fp32, causing the downstream dtype mismatch.
I think we should disable automatic upcasting, so that user can have explicit control on the dtype casting behavior.