-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Closed
Labels
module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
import torch
torch.set_default_device('cuda')
scale = torch.tensor(0.180336877703666687)
x = torch.tensor(1134139801600.000000)
def f(x, scale):
max_scaled = x * scale
return torch.exp(max_scaled - x * scale)
print(f(x, scale))
print(torch.compile(f)(x, scale))
>>> tensor(1., device='cuda:0')
>>> tensor(inf, device='cuda:0')
The root cause here is the same as here: #121558 (comment)
I think the general structure of this can show up in many cases, but in this case, this is a Triton issue. Here, Inductor actually CSE's the x * scale
call.
The generated Triton code is
@triton_heuristics.pointwise(
size_hints=[1],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {3: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(3,), ids_of_folded_args=(3,), divisible_by_8=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_exp_mul_sub_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': 'cddd2cc4107921f6715d58b181fdb08b23055085471461101752ba6efb772ae1'},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
tmp0 = tl.load(in_ptr0 + (0))
tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
tmp2 = tl.load(in_ptr1 + (0))
tmp3 = tl.broadcast_to(tmp2, [XBLOCK])
tmp4 = tmp1 * tmp3
tmp5 = tmp4 - tmp4
tmp6 = tl_math.exp(tmp5)
tl.store(out_ptr0 + (tl.full([XBLOCK], 0, tl.int32)), tmp6, None)
Here, you would imagine that tmp4 - tmp4
is 0. Unfortunately, Triton generates this PTX.
mul.f32 %f5, %f3, %f4;
.loc 1 30 18
neg.f32 %f6, %f5;
fma.rn.f32 %f7, %f3, %f4, %f6;
Which duplicates the FMA again!
Versions
N/A
cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire
Metadata
Metadata
Assignees
Labels
module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module