Skip to content

Nesting no_grad in autocast causes backwards graph to be (partially) lost outside of no_grad #112583

@haydn-jones

Description

@haydn-jones
          Hmm, I actually think your example is unexpected, can you file a separate bug for it.

Originally posted by @ezyang in #105211 (comment)

If I'm in an autocast context and I do a forward pass of a model in no_grad, then do another forward pass outside of that no_grad context but still in the autocast context, arbitrary nodes in the graph will be lost. See example:

import torch
import torch.nn as nn

l1 = nn.Sequential(
    nn.Linear(2, 2),
).cuda()

l2 = nn.Sequential(
    nn.Linear(2, 2),
    nn.LayerNorm(2),
).cuda()

x = torch.randn(2, 2).cuda()

#################################
# Just linear
################################
with torch.cuda.amp.autocast():
    with torch.no_grad():
        y1 = l1(x)

    y1_2 = l1(x)
    print(y1_2.grad_fn) # None

# Remove autocast
with torch.no_grad():
    y1 = l1(x)

y1_2 = l1(x)
print(y1_2.grad_fn) # AddmmBackward

#################################
# Linear -> LayerNorm makes output
# have grad_fn
################################
with torch.cuda.amp.autocast():
    with torch.no_grad():
        y2 = l2(x)

    y2_2 = l2(x)
    print(y2_2.grad_fn) # LayerNormBackward

with torch.no_grad():
    y2 = l2(x)

y2_2 = l2(x)
print(y2_2.grad_fn) # Still LayerNormBackward

with torch.cuda.amp.autocast():
    with torch.no_grad():
        y2 = l2(x)

    y2_2 = l2(x)
    print(y2_2.grad_fn) # LayerNormBackward

y2_2.sum().backward()
print(l2[0].weight.grad) # None
print(l2[1].weight.grad) # Not none

PyTorch version: 2.1.0

cc @ezyang @gchanan @zou3519 @kadeng @albanD @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @mcarilli @ptrblck @leslie-fang-intel @jgong5

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions