Skip to content

fma can cause drastically worse precision in torch.compile/Triton #122260

@Chillee

Description

@Chillee

🐛 Describe the bug

import torch
torch.set_default_device('cuda')

scale = torch.tensor(0.180336877703666687)
x = torch.tensor(1134139801600.000000)

def f(x, scale):
    max_scaled = x * scale
    return torch.exp(max_scaled - x * scale)

print(f(x, scale))
print(torch.compile(f)(x, scale))
>>> tensor(1., device='cuda:0')
>>> tensor(inf, device='cuda:0')

The root cause here is the same as here: #121558 (comment)

I think the general structure of this can show up in many cases, but in this case, this is a Triton issue. Here, Inductor actually CSE's the x * scale call.

The generated Triton code is

@triton_heuristics.pointwise(
    size_hints=[1],
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {3: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(3,), ids_of_folded_args=(3,), divisible_by_8=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_exp_mul_sub_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': 'cddd2cc4107921f6715d58b181fdb08b23055085471461101752ba6efb772ae1'},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 1
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    tmp0 = tl.load(in_ptr0 + (0))
    tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
    tmp2 = tl.load(in_ptr1 + (0))
    tmp3 = tl.broadcast_to(tmp2, [XBLOCK])
    tmp4 = tmp1 * tmp3
    tmp5 = tmp4 - tmp4
    tmp6 = tl_math.exp(tmp5)
    tl.store(out_ptr0 + (tl.full([XBLOCK], 0, tl.int32)), tmp6, None)

Here, you would imagine that tmp4 - tmp4 is 0. Unfortunately, Triton generates this PTX.

        mul.f32         %f5, %f3, %f4;
        .loc    1 30 18
        neg.f32         %f6, %f5;
        fma.rn.f32      %f7, %f3, %f4, %f6;

Which duplicates the FMA again!

Versions

N/A

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire

Metadata

Metadata

Assignees

Labels

module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions