Skip to content

[XPU] XPU cannot pass TestDot.test_baddbmm_pipeline_debug_dtype_asserts #772

@EikanWang

Description

@EikanWang

Describe the bug
When I tried to run test_dot.py, TestDot.test_baddbmm_pipeline_debug_dtype_asserts could not work well on XPU. There would be a runtime error: RuntimeError: PassManager::run failed

To Reproduce

pip3 install torch==2.9.0 --index-url https://download.pytorch.org/whl/test/xpu
git clone https://github.com/pytorch/helion.git
cd helion
pip install -e .'[dev]'
pytest test_dot.py -k test_input_float8_e4m3fn_acc_None_dynamic_shape

Expected behavior
PASS

Versions
PyTorch: 2.9
Triton: 3.5.0
Helion: a226049

Additional context

------------------------------------------------------------------------------------------------------------------------------------ Captured stderr call ------------------------------------------------------------------------------------------------------------------------------------
python3.10: /root/.triton/llvm/llvm-57088512-almalinux-x64/include/llvm/ADT/ArrayRef.h:268: const T& llvm::ArrayRef<T>::operator[](size_t) const [with T = unsigned int; size_t = long unsigned int]: Assertion `Index < Length && "Invalid index!"' failed.
module attributes {"ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64"} {
  tt.func public @_helion_repro_baddbmm_kernel(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<bf16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
    %cst = arith.constant dense<1.000000e+00> : tensor<1x64x64xf32>
    %cst_0 = arith.constant dense<0.000000e+00> : tensor<1x64x64xf32>
    %cst_1 = arith.constant dense<64> : tensor<1x1x64xi32>
    %cst_2 = arith.constant dense<64> : tensor<1x64x1xi32>
    %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
    %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
    %2 = tt.expand_dims %1 {axis = 2 : i32} : tensor<1x64xi32> -> tensor<1x64x1xi32>
    %3 = arith.muli %2, %cst_2 : tensor<1x64x1xi32>
    %4 = tt.expand_dims %1 {axis = 1 : i32} : tensor<1x64xi32> -> tensor<1x1x64xi32>
    %5 = tt.broadcast %3 : tensor<1x64x1xi32> -> tensor<1x64x64xi32>
    %6 = tt.broadcast %4 : tensor<1x1x64xi32> -> tensor<1x64x64xi32>
    %7 = arith.addi %5, %6 : tensor<1x64x64xi32>
    %8 = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<1x64x64x!tt.ptr<bf16>>
    %9 = tt.addptr %8, %7 : tensor<1x64x64x!tt.ptr<bf16>>, tensor<1x64x64xi32>
    %10 = tt.load %9 : tensor<1x64x64x!tt.ptr<bf16>>
    %11 = arith.muli %4, %cst_1 : tensor<1x1x64xi32>
    %12 = tt.broadcast %2 : tensor<1x64x1xi32> -> tensor<1x64x64xi32>
    %13 = tt.broadcast %11 : tensor<1x1x64xi32> -> tensor<1x64x64xi32>
    %14 = arith.addi %12, %13 : tensor<1x64x64xi32>
    %15 = tt.splat %arg1 : !tt.ptr<bf16> -> tensor<1x64x64x!tt.ptr<bf16>>
    %16 = tt.addptr %15, %14 : tensor<1x64x64x!tt.ptr<bf16>>, tensor<1x64x64xi32>
    %17 = tt.load %16 : tensor<1x64x64x!tt.ptr<bf16>>
    %18 = tt.splat %arg2 : !tt.ptr<bf16> -> tensor<1x64x64x!tt.ptr<bf16>>
    %19 = tt.addptr %18, %7 : tensor<1x64x64x!tt.ptr<bf16>>, tensor<1x64x64xi32>
    %20 = tt.load %19 : tensor<1x64x64x!tt.ptr<bf16>>
    %21 = tt.dot %10, %17, %cst_0, inputPrecision = tf32 : tensor<1x64x64xbf16> * tensor<1x64x64xbf16> -> tensor<1x64x64xf32>
    %22 = arith.truncf %21 : tensor<1x64x64xf32> to tensor<1x64x64xbf16>
    %23 = arith.extf %22 : tensor<1x64x64xbf16> to tensor<1x64x64xf32>
    %24 = arith.subf %cst_0, %23 : tensor<1x64x64xf32>
    %25 = math.exp %24 : tensor<1x64x64xf32>
    %26 = arith.addf %25, %cst : tensor<1x64x64xf32>
    %27 = arith.divf %cst, %26 : tensor<1x64x64xf32>
    %28 = arith.mulf %23, %27 : tensor<1x64x64xf32>
    %29 = arith.truncf %28 : tensor<1x64x64xf32> to tensor<1x64x64xbf16>
    %30 = tt.dot %29, %20, %cst_0, inputPrecision = tf32 : tensor<1x64x64xbf16> * tensor<1x64x64xbf16> -> tensor<1x64x64xf32>
    %31 = arith.truncf %30 : tensor<1x64x64xf32> to tensor<1x64x64xbf16>
    %32 = tt.splat %arg3 : !tt.ptr<bf16> -> tensor<1x64x64x!tt.ptr<bf16>>
    %33 = tt.addptr %32, %7 : tensor<1x64x64x!tt.ptr<bf16>>, tensor<1x64x64xi32>
    tt.store %33, %31 : tensor<1x64x64x!tt.ptr<bf16>>
    tt.return
  }
}

{-#
  external_resources: {
    mlir_reproducer: {
      pipeline: "builtin.module(convert-triton-to-tritongpu{enable-source-remat=false num-ctas=1 num-warps=4 target=xpu threads-per-warp=16}, tritonintelgpu-coalesce, tritonintelgpu-remove-layout-conversions, tritonintelgpu-accelerate-matmul, tritonintelgpu-materialize-block-pointer, tritonintelgpu-remove-layout-conversions, tritonintelgpu-optimize-dot-operands, tritonintelgpu-pipeline{num-stages=3 split-barriers-scope=none}, tritonintelgpu-reduce-variable-liveness, tritongpu-fuse-nested-loops, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, triton-licm, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, tritongpu-combine-tensor-select-and-if, tritongpu-optimize-thread-locality, tritongpu-optimize-dot-operands{hoist-layout-conversion=true}, cse, tritongpu-prefetch, tritongpu-optimize-dot-operands{hoist-layout-conversion=true}, tritonintelgpu-remove-layout-conversions, tritonintelgpu-reduce-data-duplication, tritongpu-reorder-instructions, cse, symbol-dce, sccp, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, arith-emulate-unsupported-floats{source-types={bf16} target-type=f32})",
      disable_threading: false,
      verify_each: true
    }
  }
#-}
/tmp/tmp1o5klau1/3b/c3bgf5agfxo67dn6dfdwuidbilt67mpwj2eg5oesaflw25xnymjs.py:9:0: error: Failures have been detected while processing an MLIR pass pipeline
/tmp/tmp1o5klau1/3b/c3bgf5agfxo67dn6dfdwuidbilt67mpwj2eg5oesaflw25xnymjs.py:9:0: note: Pipeline failed while executing [`TritonIntelGPUAccelerateMatmul` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions