-
Notifications
You must be signed in to change notification settings - Fork 25.1k
Closed
Labels
Description
🐛 Describe the bug
BatchNorm should be kept in FP32 when using mixed precision for numerical stability. This works fine when it is the first layer, eg:
import torch
from torch import nn
net = nn.Sequential(nn.BatchNorm1d(4)).cuda()
o = torch.randn(2, 4).cuda()
# auto-cast should automatically cast stuff
with torch.cuda.amp.autocast():
for layer in net:
o = layer(o)
print(o.dtype)
Result:
torch.float32
But does not when it follows other layers:
import torch
from torch import nn
net = nn.Sequential(nn.Linear(4, 4), nn.BatchNorm1d(4), nn.Linear(4, 5)).cuda()
o = torch.randn(2, 4).cuda()
# auto-cast should automatically cast stuff
with torch.cuda.amp.autocast():
for layer in net:
o = layer(o)
print(o.dtype)
Result:
torch.float16
torch.float16
torch.float16
Expected result:
torch.float16
torch.float32 # BN output
torch.float16
Versions
[pip3] torch==1.12.1