Skip to content

[inductor] torch.bucketize in fused epilogue throws NameError('XBLOCK is not defined') #148764

@davidberard98

Description

@davidberard98

🐛 Describe the bug

Repro:

import torch
import torch
import torch._inductor.config
from torch._inductor.utils import fresh_inductor_cache

torch._inductor.config.max_autotune_gemm_backends = "TRITON"

def fn(x: torch.Tensor, y: torch.Tensor, buckets: torch.Tensor) -> torch.Tensor:
    z = torch.mm(x, y)
    return torch.bucketize(z, buckets)

buckets = torch.arange(-100, 100, 10, device="cuda")
x = torch.randn(64, 64, device="cuda")
y = torch.randn(64, 64, device="cuda")

with fresh_inductor_cache():
    torch.compile(fn, mode="max-autotune")(x, y, buckets)

Error:

/home/dberard/local/triton-env2/pytorch/torch/_inductor/compile_fx.py:244: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
AUTOTUNE mm(64x64, 64x64)
  triton_mm_1 0.0061 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  triton_mm_4 0.0071 ms 86.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
  triton_mm_2 0.0074 ms 83.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
  triton_mm_3 0.0075 ms 82.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
  triton_mm_7 0.0083 ms 73.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_9 0.0084 ms 73.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_14 0.0084 ms 73.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
  triton_mm_0 0.0086 ms 71.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=1, num_warps=2
  triton_mm_10 0.0088 ms 70.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_6 0.0090 ms 68.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.2732 seconds and 0.2760 seconds precompiling for 15 choices
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533] Triton compilation failed: Placeholder.DESCRIPTIVE_NAME
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533] def triton_(arg_A, arg_B, in_ptr2, out_ptr1):
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     GROUP_M : tl.constexpr = 8
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     EVEN_K : tl.constexpr = True
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     ALLOW_TF32 : tl.constexpr = False
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     ACC_TYPE : tl.constexpr = tl.float32
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     BLOCK_M : tl.constexpr = 32
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     BLOCK_N : tl.constexpr = 32
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     BLOCK_K : tl.constexpr = 64
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     A = arg_A
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     B = arg_B
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533] 
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     M = 64
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     N = 64
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     K = 64
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     if M * N == 0:
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]         # early exit due to zero-size input(s)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]         return
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     stride_am = 64
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     stride_ak = 1
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     stride_bk = 64
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     stride_bn = 1
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533] 
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     # based on triton.ops.matmul
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     pid = tl.program_id(0)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     grid_m = (M + BLOCK_M - 1) // BLOCK_M
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     grid_n = (N + BLOCK_N - 1) // BLOCK_N
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533] 
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     # re-order program ID for better L2 performance
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     width = GROUP_M * grid_n
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     group_id = pid // width
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     pid_m = group_id * GROUP_M + (pid % group_size)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     pid_n = (pid % width) // (group_size)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533] 
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and M >= BLOCK_M:
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]         offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     else:
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]         offs_a_m = rm % M
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and N >= BLOCK_N:
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]         offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     else:
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]         offs_b_n = rn % N
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     offs_k = tl.arange(0, BLOCK_K)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533] 
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     for k_idx in range(0, tl.cdiv(K, BLOCK_K)):
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533] 
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]         a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]         b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533] 
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]         idx_m = offs_a_m[:, None]
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]         idx_n = a_k_idx_vals
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]         xindex = idx_n + 64*idx_m
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]         a = tl.load(A + (xindex))
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533] 
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]         idx_m = b_k_idx_vals
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]         idx_n = offs_b_n[None, :]
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]         xindex = idx_n + 64*idx_m
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]         b = tl.load(B + (xindex))
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]         acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533] 
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     # rematerialize rm and rn to save registers
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     idx_m = rm[:, None]
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     idx_n = rn[None, :]
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     mask = (idx_m < M) & (idx_n < N)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533] 
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     # inductor generates a suffix
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     xindex = idx_n + 64*idx_m
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     tmp0 = triton_helpers.bucketize_binary_search(acc, in_ptr2, 20, 20, 1, 0, tl.int64, False, None, None, None, [XBLOCK], )
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     tl.store(out_ptr1 + (tl.broadcast_to(xindex, acc.shape)), tmp0, mask)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533] 
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533] metadata: {'signature': {'arg_A': '*fp32', 'arg_B': '*fp32', 'in_ptr2': '*i64', 'out_ptr1': '*i64'}, 'device': 0, 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}], 'device_type': 'cuda', 'num_warps': 4, 'num_stages': 2, 'debug': True, 'cc': 90}
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533] Traceback (most recent call last):
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]   File "/home/dberard/local/triton-env2/pytorch/torch/_inductor/runtime/triton_heuristics.py", line 531, in _precompile_config
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     binary = triton.compile(*compile_args, **compile_kwargs)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]   File "/home/dberard/local/triton-env2/triton/python/triton/compiler/compiler.py", line 278, in compile
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     module = src.make_ir(options, codegen_fns, module_map, context)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]   File "/home/dberard/local/triton-env2/triton/python/triton/compiler/compiler.py", line 81, in make_ir
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533] triton.compiler.errors.CompilationError: at 73:114:
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]         acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533] 
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     # rematerialize rm and rn to save registers
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     idx_m = rm[:, None]
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     idx_n = rn[None, :]
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     mask = (idx_m < M) & (idx_n < N)
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533] 
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     # inductor generates a suffix
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     xindex = idx_n + 64*idx_m
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]     tmp0 = triton_helpers.bucketize_binary_search(acc, in_ptr2, 20, 20, 1, 0, tl.int64, False, None, None, None, [XBLOCK], )
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533]                                                                                                                   ^
E0307 08:45:35.023000 3871554 torch/_inductor/runtime/triton_heuristics.py:533] NameError('XBLOCK is not defined')

Versions

viable/strict, Mar 7. H100.

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @aakhundov

Metadata

Metadata

Assignees

Labels

module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions