Skip to content

Using 'aot_eager' to compile model makes an error about backward() and detach() #97745

@Z-Fran

Description

@Z-Fran

🐛 Describe the bug

When I use 'aot_eager' to compile model, codes make an error RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward., occurs at backward(). Please get more details from Minified repro. In torch/_dynamo/eval_frame.py(188)(the newest version is https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/eval_frame.py#L215), args from self.discriminator(batch_outputs.detach()) is batch_outputs but not batch_outputs.detach().

Error logs

Traceback (most recent call last):
  File "/repro.py", line 49, in <module>
    res = run_fwd_maybe_bwd(opt_mod, (torch.rand((1,3,32,32)).cuda()))
  File "anaconda3/envs/pt20/lib/python3.9/site-packages/torch/_dynamo/debug_utils.py", line 633, in run_fwd_maybe_bwd
    out = gm(args)
  File "anaconda3/envs/pt20/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1222, in g
    return f(*args)
  File "anaconda3/envs/pt20/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "anaconda3/envs/pt20/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "anaconda3/envs/pt20/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "repro.py", line 38, in forward
    set_requires_grad(self.discriminator, False)
  File "repro.py", line 39, in <graph break in forward>
    g_loss = abs(batch_outputs-batch_gt_data).mean()
  File "repro.py", line 40, in <graph break in forward>
    g_loss.backward()
  File "repro.py", line 41, in <graph break in forward>
    set_requires_grad(self.discriminator, True)
  File "repro.py", line 42, in <graph break in forward>
    self.d_step(batch_outputs)
  File "repro.py", line 33, in d_step
    d_loss.backward()
  File "anaconda3/envs/pt20/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "anaconda3/envs/pt20/lib/python3.9/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "anaconda3/envs/pt20/lib/python3.9/site-packages/torch/autograd/function.py", line 274, in apply
    return user_fn(self, *args)
  File "anaconda3/envs/pt20/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2310, in backward
    list(ctx.symints) + list(ctx.saved_tensors) + list(contiguous_args)
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Minified repro

When I use uncompiled original model, there is no error and codes run successfully:

import torch
import torch._dynamo
from torch._dynamo.debug_utils import run_fwd_maybe_bwd
import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
from torch.nn import *

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)).cuda()
    def forward(self, inputs : torch.Tensor):
        return self.conv(inputs)

def set_requires_grad(nets, requires_grad=False):
    if not isinstance(nets, list):
        nets = [nets]
    for net in nets:
        if net is not None:
            for param in net.parameters():
                param.requires_grad = requires_grad

class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.generator = Model()
        self.discriminator = Model()
    def d_step(self, batch_outputs: torch.Tensor):
        fake_d_pred = self.discriminator(batch_outputs.detach())
        d_loss = fake_d_pred.mean()
        d_loss.backward()
    def forward(self, inputs : torch.Tensor):
        batch_inputs = torch.rand((1,3,32,32)).cuda()
        batch_gt_data =  torch.rand((1,3,32,32)).cuda()
        batch_outputs = self.generator(batch_inputs)
        set_requires_grad(self.discriminator, False)
        g_loss = abs(batch_outputs-batch_gt_data).mean()
        g_loss.backward()
        set_requires_grad(self.discriminator, True)
        self.d_step(batch_outputs)

mod = Repro()
opt_mod = torch._dynamo.optimize("aot_eager")(mod)

with torch.cuda.amp.autocast(enabled=False):
    ref = run_fwd_maybe_bwd(mod, (torch.rand((1,3,32,32)).cuda()))

When I use compiled model, codes make an error which is RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward., occurs at the line d_loss.backward().

import torch
import torch._dynamo
from torch._dynamo.debug_utils import run_fwd_maybe_bwd
import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
from torch.nn import *

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)).cuda()
    def forward(self, inputs : torch.Tensor):
        return self.conv(inputs)

def set_requires_grad(nets, requires_grad=False):
    if not isinstance(nets, list):
        nets = [nets]
    for net in nets:
        if net is not None:
            for param in net.parameters():
                param.requires_grad = requires_grad

class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.generator = Model()
        self.discriminator = Model()
    def d_step(self, batch_outputs: torch.Tensor):
        fake_d_pred = self.discriminator(batch_outputs.detach())
        d_loss = fake_d_pred.mean()
        d_loss.backward()
    def forward(self, inputs : torch.Tensor):
        batch_inputs = torch.rand((1,3,32,32)).cuda()
        batch_gt_data =  torch.rand((1,3,32,32)).cuda()
        batch_outputs = self.generator(batch_inputs)
        set_requires_grad(self.discriminator, False)
        g_loss = abs(batch_outputs-batch_gt_data).mean()
        g_loss.backward()
        set_requires_grad(self.discriminator, True)
        self.d_step(batch_outputs)

mod = Repro()
opt_mod = torch._dynamo.optimize("aot_eager")(mod)

with torch.cuda.amp.autocast(enabled=False):
    # ref = run_fwd_maybe_bwd(mod, (torch.rand((1,3,32,32)).cuda()))
    res = run_fwd_maybe_bwd(opt_mod, (torch.rand((1,3,32,32)).cuda()))

But when I use detach(), before function calls, there is no error and codes run successfully: (It seems like there are some bugs about detach() and backward())

import torch
import torch._dynamo
from torch._dynamo.debug_utils import run_fwd_maybe_bwd
import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
from torch.nn import *

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)).cuda()
    def forward(self, inputs : torch.Tensor):
        return self.conv(inputs)

def set_requires_grad(nets, requires_grad=False):
    if not isinstance(nets, list):
        nets = [nets]
    for net in nets:
        if net is not None:
            for param in net.parameters():
                param.requires_grad = requires_grad

class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.generator = Model()
        self.discriminator = Model()
    def d_step(self, batch_outputs: torch.Tensor):
        fake_d_pred = self.discriminator(batch_outputs.detach())
        d_loss = fake_d_pred.mean()
        d_loss.backward()
    def forward(self, inputs : torch.Tensor):
        batch_inputs = torch.rand((1,3,32,32)).cuda()
        batch_gt_data =  torch.rand((1,3,32,32)).cuda()
        batch_outputs = self.generator(batch_inputs)
        set_requires_grad(self.discriminator, False)
        g_loss = abs(batch_outputs-batch_gt_data).mean()
        g_loss.backward()
        set_requires_grad(self.discriminator, True)
        self.d_step(batch_outputs.detach()) ################# use `detach()`, before function calls

mod = Repro()
opt_mod = torch._dynamo.optimize("aot_eager")(mod)

with torch.cuda.amp.autocast(enabled=False):
    # ref = run_fwd_maybe_bwd(mod, (torch.rand((1,3,32,32)).cuda()))
    res = run_fwd_maybe_bwd(opt_mod, (torch.rand((1,3,32,32)).cuda()))

Versions

PyTorch version: 2.0.0+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: CentOS Linux release 7.6.1810 (Core) (x86_64)
GCC version: (GCC) 9.3.0
Clang version: Could not collect
CMake version: version 3.26.0
Libc version: glibc-2.17

Python version: 3.9.16 (main, Mar 8 2023, 14:00:05) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-3.10.0-957.el7.x86_64-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB

Nvidia driver version: 470.129.06
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 106
Model name: Intel(R) Xeon(R) Platinum 8358P CPU @ 2.60GHz
Stepping: 6
CPU MHz: 3400.000
CPU max MHz: 3400.0000
CPU min MHz: 800.0000
BogoMIPS: 5200.00
Virtualization: VT-x
L1d cache: 48K
L1i cache: 32K
L2 cache: 1280K
L3 cache: 49152K
NUMA node0 CPU(s): 0-31,64-95
NUMA node1 CPU(s): 32-63,96-127
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc aperfmperf eagerfpu pni pclmulqdq dtes64 ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch epb cat_l3 intel_pt ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq spec_ctrl intel_stibp flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] open-clip-torch==2.16.0
[pip3] torch==2.0.0+cu118
[pip3] torchaudio==2.0.0+cu118
[pip3] torchvision==0.15.0+cu118
[pip3] triton==2.0.0
[conda] magma-cuda113 2.5.2 1 pytorch
[conda] mkl 2023.0.0 h6d00ec8_25399
[conda] mkl-include 2023.0.0 h06a4308_25399
[conda] numpy 1.23.5 pypi_0 pypi
[conda] open-clip-torch 2.16.0 pypi_0 pypi
[conda] torch 2.0.0+cu118 pypi_0 pypi
[conda] torchaudio 2.0.0+cu118 pypi_0 pypi
[conda] torchvision 0.15.0+cu118 pypi_0 pypi
[conda] triton 2.0.0 pypi_0 pypi

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @soumith @wconstab @ngimel

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: aotdispatchumbrella label for AOTAutograd issuesmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions