-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Closed
Labels
high prioritymodule: amp (automated mixed precision)autocastautocasttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Bug
torch.scatter_add
does not promote to widest input type and can generate type mismatch errors when used in an autocast-enabled region.
To Reproduce
I have a small script here that reproduces the behavior:
import torch
# Define test model
class foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear_layer1 = torch.nn.Linear(2, 2)
self.linear_layer2 = torch.nn.Linear(2, 2)
self.norm = torch.nn.LayerNorm(2)
def forward(self, x, y):
x = self.linear_layer1(x)
y = self.linear_layer2(y)
for i in range(2):
print(x.dtype, y.dtype)
x = torch.zeros_like(x).scatter_add(0, torch.zeros_like(x, dtype=torch.int64), y)
x = self.norm(x)
return x
# Run forward
model = foo().cuda()
x = torch.ones(2,2).cuda()
y = torch.ones(2,2).cuda()
with torch.cuda.amp.autocast(enabled=True):
res = model(x, y)
Running this script, I get the following result:
$ python main.py
torch.float16 torch.float16
torch.float32 torch.float16
Traceback (most recent call last):
File "main.py", line 26, in <module>
res = model(x, y)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 744, in _call_impl
result = self.forward(*input, **kwargs)
File "main.py", line 17, in forward
x = torch.zeros_like(x).scatter_add(0, torch.zeros_like(x, dtype=torch.int64), y)
RuntimeError: scatter_add_cuda_(): Expected self.dtype to be equal to src.dtype
In the first iteration through the loop, x
is promoted to FP32 from torch.nn.LayerNorm
. On the second iteration, torch.scatter_add
fails since y
remains in FP16.
Expected behavior
torch.scatter_add
should promote y
to FP32 in the autocast-enabled region.
Environment
- PyTorch Version (e.g., 1.0): 1.8.0a0+160689 (NGC 20.12)
- OS (e.g., Linux): Linux
- Python version: 3.8
cc: @mcarilli
cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @mcarilli @ptrblck
Metadata
Metadata
Assignees
Labels
high prioritymodule: amp (automated mixed precision)autocastautocasttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module