-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Open
Labels
actionablehigh prioritymodule: amp (automated mixed precision)autocastautocastmodule: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
Hmm, I actually think your example is unexpected, can you file a separate bug for it.
Originally posted by @ezyang in #105211 (comment)
If I'm in an autocast context and I do a forward pass of a model in no_grad, then do another forward pass outside of that no_grad context but still in the autocast context, arbitrary nodes in the graph will be lost. See example:
import torch
import torch.nn as nn
l1 = nn.Sequential(
nn.Linear(2, 2),
).cuda()
l2 = nn.Sequential(
nn.Linear(2, 2),
nn.LayerNorm(2),
).cuda()
x = torch.randn(2, 2).cuda()
#################################
# Just linear
################################
with torch.cuda.amp.autocast():
with torch.no_grad():
y1 = l1(x)
y1_2 = l1(x)
print(y1_2.grad_fn) # None
# Remove autocast
with torch.no_grad():
y1 = l1(x)
y1_2 = l1(x)
print(y1_2.grad_fn) # AddmmBackward
#################################
# Linear -> LayerNorm makes output
# have grad_fn
################################
with torch.cuda.amp.autocast():
with torch.no_grad():
y2 = l2(x)
y2_2 = l2(x)
print(y2_2.grad_fn) # LayerNormBackward
with torch.no_grad():
y2 = l2(x)
y2_2 = l2(x)
print(y2_2.grad_fn) # Still LayerNormBackward
with torch.cuda.amp.autocast():
with torch.no_grad():
y2 = l2(x)
y2_2 = l2(x)
print(y2_2.grad_fn) # LayerNormBackward
y2_2.sum().backward()
print(l2[0].weight.grad) # None
print(l2[1].weight.grad) # Not none
PyTorch version: 2.1.0
cc @ezyang @gchanan @zou3519 @kadeng @albanD @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @mcarilli @ptrblck @leslie-fang-intel @jgong5
jon-chuang
Metadata
Metadata
Assignees
Labels
actionablehigh prioritymodule: amp (automated mixed precision)autocastautocastmodule: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module