Skip to content

Automatic Mixed Precision (AMP) doesn't run BatchNorm in FP32 #88929

@jramapuram

Description

@jramapuram

🐛 Describe the bug

BatchNorm should be kept in FP32 when using mixed precision for numerical stability. This works fine when it is the first layer, eg:

import torch
from torch import nn

net = nn.Sequential(nn.BatchNorm1d(4)).cuda()

o = torch.randn(2, 4).cuda()

# auto-cast should automatically cast stuff
with torch.cuda.amp.autocast():
    for layer in net:
        o = layer(o)
        print(o.dtype)

Result:

torch.float32

But does not when it follows other layers:

import torch
from torch import nn

net = nn.Sequential(nn.Linear(4, 4), nn.BatchNorm1d(4), nn.Linear(4, 5)).cuda()
o = torch.randn(2, 4).cuda()

# auto-cast should automatically cast stuff
with torch.cuda.amp.autocast():
    for layer in net:
        o = layer(o)
        print(o.dtype)

Result:

torch.float16
torch.float16
torch.float16

Expected result:

torch.float16
torch.float32  # BN output
torch.float16

Versions

[pip3] torch==1.12.1

cc @mcarilli @ptrblck @leslie-fang-intel @jgong5

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions