Skip to content

Autocast region with torch.scatter_add generates type mismatch errors #51730

@romerojosh

Description

@romerojosh

🐛 Bug

torch.scatter_add does not promote to widest input type and can generate type mismatch errors when used in an autocast-enabled region.

To Reproduce

I have a small script here that reproduces the behavior:

import torch
# Define test model
class foo(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_layer1 = torch.nn.Linear(2, 2)
        self.linear_layer2 = torch.nn.Linear(2, 2)
        self.norm = torch.nn.LayerNorm(2)
    def forward(self, x, y):
        x = self.linear_layer1(x)
        y = self.linear_layer2(y)
        for i in range(2):
            print(x.dtype, y.dtype)
            x = torch.zeros_like(x).scatter_add(0, torch.zeros_like(x, dtype=torch.int64), y)
            x = self.norm(x)
        return x
# Run forward
model = foo().cuda()
x = torch.ones(2,2).cuda()
y = torch.ones(2,2).cuda()
with torch.cuda.amp.autocast(enabled=True):
    res = model(x, y)

Running this script, I get the following result:

$ python main.py
torch.float16 torch.float16
torch.float32 torch.float16
Traceback (most recent call last):
  File "main.py", line 26, in <module>
    res = model(x, y)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 744, in _call_impl
    result = self.forward(*input, **kwargs)
  File "main.py", line 17, in forward
    x = torch.zeros_like(x).scatter_add(0, torch.zeros_like(x, dtype=torch.int64), y)
RuntimeError: scatter_add_cuda_(): Expected self.dtype to be equal to src.dtype

In the first iteration through the loop, x is promoted to FP32 from torch.nn.LayerNorm. On the second iteration, torch.scatter_add fails since y remains in FP16.

Expected behavior

torch.scatter_add should promote y to FP32 in the autocast-enabled region.

Environment

  • PyTorch Version (e.g., 1.0): 1.8.0a0+160689 (NGC 20.12)
  • OS (e.g., Linux): Linux
  • Python version: 3.8

cc: @mcarilli

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @mcarilli @ptrblck

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions