Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

AMP Training Tracker for Aot Autograd bugs #835

@anijain2305

Description

@anijain2305

This issue tracks the issues related to AMP training for aot_eager / aot_nvfuser backend

  • A large number of TorchBench models fail with following error

Repro - python benchmarks/torchbench.py --amp -dcuda --no-skip --training --nvfuser --accuracy-aot-ts-mincut --use-eval-mode --isolate --only=resnet50

File "/scratch/anijain/work/pytorch/functorch/functorch/_src/aot_autograd.py", line 98, in joint_forward_backward
    backward_out = torch.autograd.grad(
  File "/scratch/anijain/work/pytorch/torch/autograd/__init__.py", line 294, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/scratch/anijain/work/pytorch/torch/utils/_python_dispatch.py", line 74, in wrapped
    return f(self, *args, **kwargs)
  File "/scratch/anijain/work/pytorch/torch/fx/experimental/proxy_tensor.py", line 408, in __torch_dispatch__
    return proxy_call(self, func_overload, args, kwargs)
  File "/scratch/anijain/work/pytorch/torch/fx/experimental/proxy_tensor.py", line 173, in proxy_call
    inner_res = func_overload(*pytree.tree_map(unwrap_elem, args), **pytree.tree_map(unwrap_elem, kwargs))
  File "/scratch/anijain/work/pytorch/torch/_ops.py", line 60, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: Expected save_mean to have type Float but got Half
  • Something wrong with partitioner

Repro - python benchmarks/timm_models.py --amp -dcuda --no-skip --training --nvfuser --accuracy-aot-ts-mincut --use-eval-mode --isolate --only=convnext_base

 File "/scratch/anijain/work/pytorch/torch/fx/graph_module.py", line 641, in recompile
    cls.forward = _forward_from_src(self._code, python_code.globals)
  File "/scratch/anijain/work/pytorch/torch/fx/graph_module.py", line 77, in _forward_from_src
    _exec_with_source(src, globals_copy)
  File "/scratch/anijain/work/pytorch/torch/fx/graph_module.py", line 71, in _exec_with_source
    exec(compile(src, key, 'exec'), globals)
  File "<eval_with_key>.10", line 20
    return [_to_copy_default_337, _to_copy_default_338, getitem_122, getitem_121, Invalid Node, Invalid Node, Invalid Node, Invalid Node, Invalid Node, Invalid Node, Invalid Node, Invalid Node, Invalid Node, Invalid Node, Invalid Node, Invalid Node, Invalid Node, Invalid Node, .... , None]
                                                                                          ^
SyntaxError: invalid syntax
ERROR
python benchmarks/torchbench.py --amp -dcuda --no-skip --training --nvfuser --inductor --use-eval-mode --only=resnet50
  File "/scratch/anijain/work/pytorch/torch/fx/interpreter.py", line 162, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/scratch/anijain/work/torchdynamo/torchinductor/graph.py", line 216, in call_function
    raise MissingOperatorWithDecomp(target, args, kwargs)
torchinductor.exc.MissingOperatorWithDecomp: missing decomposition
  target: aten.t.default
  args[0]: TensorBox(StorageBox(
    ConstantBuffer(name='constant53', layout=FixedLayout('cuda', torch.float16, size=[1000, 2048], stride=[2048, 1]))
  ))

There is a decomposition available for aten.t.default in
torch._decomp.get_decompositions().  Please add this operator to the
`decompositions` list in `./torchinductor/decomposition.py`.

ERROR
missing decomposition
  target: aten.t.default
  args[0]: TensorBox(StorageBox(
    ConstantBuffer(name='constant53', layout=FixedLayout('cuda', torch.float16, size=[1000, 2048], stride=[2048, 1]))
  ))

There is a decomposition available for aten.t.default in
torch._decomp.get_decompositions().  Please add this operator to the
`decompositions` list in `./torchinductor/decomposition.py`.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions