Skip to content

dtype mismatch error due to automatic upcasting #1038

@yf225

Description

@yf225

Repro:

#!/usr/bin/env python3
"""
Minimal SE Block test to reproduce bf16/fp32 mismatch error
"""

import torch
import helion
import helion.language as hl

@helion.kernel(
    config=helion.Config(
        block_sizes=[32],
        indexing='block_ptr',
        load_eviction_policies=['last', 'last'],
        num_stages=8,
        num_warps=1,
        pid_type='flat',
        range_flattens=[None],
        range_multi_buffers=[None],
        range_num_stages=[0],
        range_unroll_factors=[0],
        range_warp_specializes=[None]
    ),
    static_shapes=True
)
def se_block_fwd(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
    """Minimal kernel that triggers bf16/fp32 mismatch"""
    m, n = x.size()
    out = torch.empty([m, n], dtype=x.dtype, device=x.device)

    for tile_m in hl.tile(m):
        x_tile = x[tile_m, :]  # bf16
        # This creates fp32 intermediate result
        sigmoid_result = torch.sigmoid(x_tile @ w[:, :])
        # Multiply bf16 * fp32 gives fp32, but out expects bf16
        acc = hl.full([1], 2.0, dtype=x.dtype) * x_tile * sigmoid_result
        out[tile_m, :] = acc.to(x.dtype)  # Error happens here

    return out


def main():
    # Minimal test to trigger the error
    M, N = 2018304, 128
    dtype = torch.bfloat16

    print(f"Testing with M={M}, N={N}, dtype={dtype}")

    x = torch.randn(M, N, dtype=dtype, device="cuda")
    w = torch.randn(N, N, dtype=dtype, device="cuda")

    # This will trigger the error
    out = se_block_fwd(x, w)
    print("Test completed (should not reach here)")


if __name__ == "__main__":
    main()

Error:

$ python se_block_standalone_test.py 
Testing with M=2018304, N=128, dtype=torch.bfloat16
Traceback (most recent call last):
  File "/home/willfeng/local/pytorch-nightly/triton/language/core.py", line 43, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/triton/language/core.py", line 2192, in store
    return _semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/triton/language/semantic.py", line 1294, in store
    return self._store_block_pointer(ptr, val, mask, boundary_check, cache, eviction)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/triton/language/semantic.py", line 1224, in _store_block_pointer
    assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch"
AssertionError: Block element type(bf16) and value element type(fp32) mismatch

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/willfeng/local/helion/se_block_standalone_test.py", line 58, in <module>
    main()
  File "/home/willfeng/local/helion/se_block_standalone_test.py", line 53, in main
    out = se_block_fwd(x, w)
          ^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/runtime/kernel.py", line 292, in __call__
    return self.bind(args)(*args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/runtime/kernel.py", line 644, in __call__
    return self._run(*args)
           ^^^^^^^^^^^^^^^^
  File "/tmp/torchinductor_willfeng/5z/c5zo63ugbuwahfw7nihpqsmhfalcn7bn3zszpgo4nlo4pqfzoo7q.py", line 40, in se_block_fwd
    _launcher(_helion_se_block_fwd, (triton.cdiv(2018304, _BLOCK_SIZE_0),), x, w, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=1, num_stages=8)
  File "/home/willfeng/local/helion/helion/runtime/__init__.py", line 66, in default_launcher
    return triton_kernel.run(
           ^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/triton/runtime/jit.py", line 733, in run
    kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/triton/runtime/jit.py", line 861, in _do_compile
    kernel = self.compile(src, target=target, options=options.__dict__)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/triton/compiler/compiler.py", line 300, in compile
    module = src.make_ir(target, options, codegen_fns, module_map, context)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/triton/compiler/compiler.py", line 80, in make_ir
    return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
triton.compiler.errors.CompilationError: at 17:4:
    x_tile = tl.load(tl.make_block_ptr(x, [2018304, 128], [128, 1], [offset_0, 0], [_BLOCK_SIZE_0, _RDIM_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero', eviction_policy='evict_last')
    # src[se_block_standalone_test.py:34]: sigmoid_result = torch.sigmoid(x_tile @ w[:, :])
    load_1 = tl.load(tl.make_block_ptr(w, [128, 128], [128, 1], [0, 0], [_RDIM_SIZE_1, _RDIM_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero', eviction_policy='evict_last')
    mm = tl.dot(tl.cast(x_tile, tl.bfloat16), tl.cast(load_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
    v_0 = tl.sigmoid(tl.cast(mm, tl.float32))
    # src[se_block_standalone_test.py:36]: acc = hl.full([1], 2.0, dtype=x.dtype) * x_tile * sigmoid_result
    full = tl.full([1], 2.0, tl.bfloat16)
    v_1 = full[None, :]
    v_2 = v_1 * x_tile
    v_3 = v_2 * v_0
    # src[se_block_standalone_test.py:37]: out[tile_m, :] = acc.to(x.dtype)  # Error happens here
    tl.store(tl.make_block_ptr(out, [2018304, 128], [128, 1], [offset_0, 0], [_BLOCK_SIZE_0, _RDIM_SIZE_1], [1, 0]), v_3, boundary_check=[0, 1])
    ^
Block element type(bf16) and value element type(fp32) mismatch

i.e. the symptom is that torch.sigmoid(x_tile @ w[:, :]) automatically upcasts the x_tile @ w[:, :] result to fp32, causing the downstream dtype mismatch.

I think we should disable automatic upcasting, so that user can have explicit control on the dtype casting behavior.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions