This repository was archived by the owner on Aug 1, 2025. It is now read-only.
  
  
  - 
                Notifications
    You must be signed in to change notification settings 
- Fork 129
    This repository was archived by the owner on Aug 1, 2025. It is now read-only.
  
  
[aot-cudagraph] fails on examples/imagenet #1708
Copy link
Copy link
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
🐛 Describe the bug
In here: https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/optimizations/training.py#L431
meta was an empty dict {}
And that case isn't handled in the function.
Error logs
Traceback (most recent call last):
  File "/home/soumith/code/examples/imagenet/main.py", line 515, in <module>
    main()
  File "/home/soumith/code/examples/imagenet/main.py", line 123, in main
    main_worker(args.gpu, ngpus_per_node, args)
  File "/home/soumith/code/examples/imagenet/main.py", line 282, in main_worker
    train(train_loader, model, criterion, optimizer, epoch, device, args)
  File "/home/soumith/code/examples/imagenet/main.py", line 340, in train
    loss.backward()
  File "/home/soumith/code/pytorch/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/home/soumith/code/pytorch/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/soumith/code/pytorch/torch/autograd/function.py", line 270, in apply
    return user_fn(self, *args)
  File "/home/soumith/code/pytorch/functorch/_src/aot_autograd.py", line 464, in backward
    CompiledFunction.compiled_bw = aot_config.bw_compiler(
  File "/home/soumith/code/pytorch/torch/_dynamo/optimizations/training.py", line 493, in _wrapped_bw_compiler
    return disable(bw_compiler(*args, **kwargs))  # type: ignore[operator]
  File "/home/soumith/code/pytorch/torch/_dynamo/optimizations/training.py", line 480, in cudagraphs
    apply_cuda_graphs(model)
  File "/home/soumith/code/pytorch/torch/_dynamo/optimizations/training.py", line 473, in apply_cuda_graphs
    mutated_inputs = find_input_mutations(submod.graph)
  File "/home/soumith/code/pytorch/torch/_dynamo/optimizations/training.py", line 439, in find_input_mutations
    inputs[StorageWeakRef(meta_fk(n.meta).storage())].add(input_idx)
  File "/home/soumith/code/pytorch/torch/_dynamo/optimizations/training.py", line 432, in meta_fk
    return meta["val"] if "val" in meta else meta["fake_result"]
KeyError: 'fake_result'
Did Dynamo succeed?
- Does dynamo.optimize("eager") succeed?
Did AOT succeed?
- Did dynamo.optimize("aot_eager") succeed?
Did Inductor succeed?
- Does dynamo.optimize("inductor") succeed?
Minified repro
Minified Repro
from math import inf
import torch
from torch import tensor, device
import torch.fx as fx
import torch._dynamo
from torch._dynamo.testing import rand_strided
from torch._dynamo.debug_utils import run_fwd_maybe_bwd
from torch._dynamo.debug_utils import same_two_models
args = [((256, 64, 56, 56), (200704, 3136, 56, 1), torch.float32, 'cuda', True)]
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]
from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.self_self_layer2_0_conv1 = Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        self.self_self_layer2_0_bn1 = BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.self_self_layer2_0_relu = ReLU(inplace=True)
        self.self_self_layer2_0_conv2 = Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        self.self_self_layer2_0_bn2 = BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.self_self_layer2_0_downsample_0 = Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        self.self_self_layer2_0_downsample_1 = BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    def forward(self, self_self_layer1_1_relu_1):
        self_self_layer2_0_conv1 = self.self_self_layer2_0_conv1(self_self_layer1_1_relu_1)
        self_self_layer2_0_bn1 = self.self_self_layer2_0_bn1(self_self_layer2_0_conv1);  self_self_layer2_0_conv1 = None
        self_self_layer2_0_relu = self.self_self_layer2_0_relu(self_self_layer2_0_bn1);  self_self_layer2_0_bn1 = None
        self_self_layer2_0_conv2 = self.self_self_layer2_0_conv2(self_self_layer2_0_relu);  self_self_layer2_0_relu = None
        self_self_layer2_0_bn2 = self.self_self_layer2_0_bn2(self_self_layer2_0_conv2);  self_self_layer2_0_conv2 = None
        self_self_layer2_0_downsample_0 = self.self_self_layer2_0_downsample_0(self_self_layer1_1_relu_1);  self_self_layer1_1_relu_1 = None
        self_self_layer2_0_downsample_1 = self.self_self_layer2_0_downsample_1(self_self_layer2_0_downsample_0);  self_self_layer2_0_downsample_0 = None
        self_self_layer2_0_bn2 += self_self_layer2_0_downsample_1;  iadd_2 = self_self_layer2_0_bn2;  self_self_layer2_0_bn2 = self_self_layer2_0_downsample_1 = None
        return (iadd_2,)
mod = Repro().cuda()
opt_mod = torch._dynamo.optimize("aot_cudagraphs")(mod)
with torch.cuda.amp.autocast(enabled=False):
    ref = run_fwd_maybe_bwd(mod, args)
    res = run_fwd_maybe_bwd(opt_mod, args)anijain2305
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working