Skip to content

NameError('math is not defined') - malformed Triton codegen with math.trunc #133172

@ezyang

Description

@ezyang

🐛 Describe the bug

Internal xref: https://fb.workplace.com/groups/260102303573409/posts/469995645917406/

Failure looks like this:

backend='inductor' raised:
CompilationError: at 12:44:
def triton_poi_fused_gather_12(in_ptr0, in_ptr1, out_ptr0, ks0, ks1, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x2 = xindex
    x1 = (xindex // ks1)
    tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
    tmp1 = libdevice.trunc(0.120000000000000*(ks0.to(tl.float64))).to(tl.int32)
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 < 0
    tmp4 = tl.where(tmp3, tmp2, tmp0)
    tl.device_assert(((0 <= tmp4) & (tmp4 < math.trunc(0.120000000000000*(float(ks0))))) | ~(xmask), "index out of bounds: 0 <= tmp4 < math.trunc(0.120000000000000*(float(ks0)))")
                                            ^
NameError('math is not defined')

tlparse https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/aps-Zen_3L_mod_every_12-999b9b9f5a/attempt_0/version_0/20240810/rank_0/5_0_0/compilation_metrics_171.html

The original code is a bit long, but I think this is the relevant context:

         # File: /packages/aps.ads.icvr/icvr_launcher-inplace#link-tree/aps_models/ads/joint_arch_exploration/modules.py:416 in forward, code: top_k_w, top_k_indices = torch.topk(
        sym_float: "Sym(ToFloat(s0))" = torch.sym_float(primals_32)
        mul_14: "Sym(0.12*ToFloat(s0))" = 0.12 * sym_float;  sym_float = None
        trunc: "Sym(TruncToInt(0.12*ToFloat(s0)))" = math_trunc(mul_14);  mul_14 = None
        topk_1 = torch.ops.aten.topk.default(squeeze_1, trunc, 1);  squeeze_1 = None
        getitem_38: "bf16[1024, TruncToInt(0.12*ToFloat(s0))][TruncToInt(0.12*ToFloat(s0)), 1]cuda:0" = topk_1[0]
        getitem_39: "i64[1024, TruncToInt(0.12*ToFloat(s0))][TruncToInt(0.12*ToFloat(s0)), 1]cuda:0" = topk_1[1];  topk_1 = None
        
         # File: /packages/aps.ads.icvr/icvr_launcher-inplace#link-tree/aps_models/ads/joint_arch_exploration/modules.py:420 in forward, code: top_k_indices_in_order, ordering = torch.sort(top_k_indices, dim=1)
        sort_1 = torch.ops.aten.sort.default(getitem_39, 1)
        getitem_40: "i64[1024, TruncToInt(0.12*ToFloat(s0))][TruncToInt(0.12*ToFloat(s0)), 1]cuda:0" = sort_1[0]
        getitem_41: "i64[1024, TruncToInt(0.12*ToFloat(s0))][TruncToInt(0.12*ToFloat(s0)), 1]cuda:0" = sort_1[1];  sort_1 = None
        
         # File: /packages/aps.ads.icvr/icvr_launcher-inplace#link-tree/aps_models/ads/joint_arch_exploration/modules.py:421 in forward, code: top_k_w_in_order = top_k_w.gather(dim=1, index=ordering)
        gather_3: "bf16[1024, TruncToInt(0.12*ToFloat(s0))][TruncToInt(0.12*ToFloat(s0)), 1]cuda:0" = torch.ops.aten.gather.default(getitem_38, 1, getitem_41);  getitem_38 = getitem_41 = None

Probably not too difficult to create a reproducer / fix.

One thing that's interesting is 0.120000000000000 is Sympy style float printing. I thought I banned this completely in #130027 but in that PR I only fixed guard printing, not codegen!

Versions

main

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions