Skip to content

Conversation

mieshkiwrk
Copy link
Contributor

@mieshkiwrk mieshkiwrk commented May 19, 2025

Support for call_module in copy_paste_aot_backward_graph added recently with PT2.7

Problem is being observed with HPU backend in example repro due to creating fused modules.

import torch


device = 'cpu' #'hpu'
backend = 'inductor' #'hpu_backend'

def fn(t1):
    t1 = t1 * 1
    t1_grad = torch.ones_like(t1, device=device)
    t1.backward(t1_grad, retain_graph=True)
    return t1

t1 = torch.ones(1, requires_grad=True, device=device) #.squeeze()
compiled_fn = torch.compile(fn, backend=backend)
result = compiled_fn(t1)


with torch._dynamo.compiled_autograd._enable(torch.compile(backend=backend)):
    result_grad = torch.ones_like(result, device=device)
    result.backward(result_grad) 

print(f'{result_grad=}')
print(f'{t1.grad=}')

With this change I'm getting same results like on CPU, however I'm facing below problem when running with scalar (t1 tensor after squeeze):
torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function <built-in function getitem>(*(FakeTensor(..., device='hpu:0', size=()), 0), **{}): got IndexError('invalid index of a 0-dim tensor. Use tensor.item()in Python ortensor.item() in C++ to convert a 0-dim tensor to a number')

While on CPU there's following warning and None returned:
repro.py:23: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at pytorch/build/aten/src/ATen/core/TensorBody.h:489.) print(f'{t1.grad=}') t1.grad=None

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @xmfan @jeromean @bsochack @sujoysaraswati

Copy link

pytorch-bot bot commented May 19, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153827

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 274055f with merge base c45515c (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

linux-foundation-easycla bot commented May 19, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

@mieshkiwrk
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label May 20, 2025
@HDCharles HDCharles added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 20, 2025
@xmfan xmfan added the module: hpu Issues related to the hpu device (Habana/Gaudi) label May 20, 2025
@xmfan
Copy link
Member

xmfan commented May 20, 2025

Could you provide more details on how you're running into the issue? Is compiling via torch.compile(backend="hpu_backend") + hpu device sometimes installing submodules on the AOT backward graph (via custom joint passes?). call_module nodes are unexpected in the AOT backward graph under most backends.

I don't have an HPU handy, but if you are able to, could you provide the TORCH_LOGS="aot_graphs" obtained from running the script with hpu backend?

@mieshkiwrk
Copy link
Contributor Author

mieshkiwrk commented May 21, 2025

Sure, at the end full logs with TORCH_LOGS="aot_graphs" (Run with 1-dim tensor as with scalar I'm getting error - haven't found root cause yet where is that happening but brief description if that matters below logs)

Given one example what's being tried to be copied within copy_paste_aot_backward_graph

TRACED GRAPH
===== Backward graph 0 =====
<eval_with_key>.13 class GraphModule(torch.nn.Module):
    def forward(self, tangents_1: "f32[1][1]hpu:0"):
         # File: /home/reproducers/shouldnt_get_here_assert/repro.py:8 in fn, code: t1 = t1 * 1
        mul_1: "f32[1][1]hpu:0" = torch.ops.aten.mul.Tensor(tangents_1, 1);  tangents_1 = None
        return (mul_1,)
		
		
# Before HPU backend processing
	graph():
		%tangents_1 : [num_users=1] = placeholder[target=tangents_1]
		%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, 1), kwargs = {})
		return (mul_1,)


# After HPU backend processing
	graph():
		%input_list : list [num_users=2] = placeholder[target=input_list]
		%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%input_list, 0), kwargs = {})
		%_lambda_ : [num_users=0] = call_function[target=habana_frameworks.torch.dynamo.compile_backend.passes.list_clear](args = (%input_list,), kwargs = {})
		%fused_0 : [num_users=1] = call_module[target=fused_0](args = (%getitem,), kwargs = {})
		return (fused_0,)
	# fused_0
		graph():
			%tangents_1 : [num_users=1] = placeholder[target=tangents_1]
			%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, 1), kwargs = {})
			return mul_1


# [copy_paste_aot_backward_graph] ctx._bw_module.graph=
	graph():
		%input_list : list [num_users=2] = placeholder[target=input_list]
		%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%input_list, 0), kwargs = {})
		%_lambda_ : [num_users=0] = call_function[target=habana_frameworks.torch.dynamo.compile_backend.passes.list_clear](args = (%input_list,), kwargs = {})
		%fused_0 : [num_users=1] = call_module[target=fused_0](args = (%getitem,), kwargs = {})
		return (fused_0,)
	# fused_0:
		graph():
			%tangents_1 : [num_users=1] = placeholder[target=tangents_1]
			%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, 1), kwargs = {})
			return mul_1
I0521 09:32:51.858000 2240152 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:895] [0/0_1] [__aot_graphs] aot_config id: 0, fw_metadata=ViewAndMutationMeta(input_info=[InputAliasInfo(is_leaf=True, mutates_data=False, mutates_metadata=False, mutations_hidden_from_autograd=True, mutations_under_no_grad_or_inference_mode=False, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=True, keep_input_mutations=True)], output_info=[OutputAliasInfo(output_type=<OutputType.non_alias: 1>, raw_type=<class 'torch._subclasses.functional_tensor.FunctionalTensor'>, base_idx=None, dynamic_dims=set(), requires_grad=True, functional_tensor=None)], num_intermediate_bases=0, keep_input_mutations=True, traced_tangents=[FakeTensor(..., device='hpu:0', size=(1,))], subclass_inp_meta=[PlainTensorMeta(unwrapped_idx=0, memory_format=None)], subclass_fw_graph_out_meta=[PlainTensorMeta(unwrapped_idx=0, memory_format=None)], subclass_tangent_meta=[PlainTensorMeta(unwrapped_idx=0, memory_format=torch.contiguous_format)], is_train=True, traced_tangent_metas=None, num_symints_saved_for_bw=0, grad_enabled_mutation=None, deterministic=False, static_input_indices=[], tokens={}, indices_of_inputs_that_requires_grad_with_mutations_in_bw=[], bw_donated_idxs=[], num_backward_tokens=0, num_graphsafe_rng_states=0, graphsafe_rng_state_index=None), inner_meta=ViewAndMutationMeta(input_info=[InputAliasInfo(is_leaf=True, mutates_data=False, mutates_metadata=False, mutations_hidden_from_autograd=True, mutations_under_no_grad_or_inference_mode=False, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=True, keep_input_mutations=True)], output_info=[OutputAliasInfo(output_type=<OutputType.non_alias: 1>, raw_type=<class 'torch._subclasses.functional_tensor.FunctionalTensor'>, base_idx=None, dynamic_dims=set(), requires_grad=True, functional_tensor=None)], num_intermediate_bases=0, keep_input_mutations=True, traced_tangents=[FakeTensor(..., device='hpu:0', size=(1,))], subclass_inp_meta=[PlainTensorMeta(unwrapped_idx=0, memory_format=None)], subclass_fw_graph_out_meta=[PlainTensorMeta(unwrapped_idx=0, memory_format=None)], subclass_tangent_meta=[PlainTensorMeta(unwrapped_idx=0, memory_format=torch.contiguous_format)], is_train=True, traced_tangent_metas=None, num_symints_saved_for_bw=0, grad_enabled_mutation=None, deterministic=False, static_input_indices=[], tokens={}, indices_of_inputs_that_requires_grad_with_mutations_in_bw=[], bw_donated_idxs=[], num_backward_tokens=0, num_graphsafe_rng_states=0, graphsafe_rng_state_index=None)
I0521 09:32:51.858000 2240152 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:1001] [0/0_1] [__aot_graphs] TRACED GRAPH
I0521 09:32:51.858000 2240152 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:1001] [0/0_1] [__aot_graphs]  ===== Forward graph 0 =====
I0521 09:32:51.858000 2240152 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:1001] [0/0_1] [__aot_graphs]  /home/venv_base/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
I0521 09:32:51.858000 2240152 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:1001] [0/0_1] [__aot_graphs]     def forward(self, primals_1: "f32[1][1]hpu:0"):
I0521 09:32:51.858000 2240152 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:1001] [0/0_1] [__aot_graphs]          # File: /home/reproducers/shouldnt_get_here_assert/repro.py:8 in fn, code: t1 = t1 * 1
I0521 09:32:51.858000 2240152 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:1001] [0/0_1] [__aot_graphs]         mul: "f32[1][1]hpu:0" = torch.ops.aten.mul.Tensor(primals_1, 1);  primals_1 = None
I0521 09:32:51.858000 2240152 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:1001] [0/0_1] [__aot_graphs]         return (mul,)
I0521 09:32:51.858000 2240152 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:1001] [0/0_1] [__aot_graphs]
I0521 09:32:51.858000 2240152 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:1001] [0/0_1] [__aot_graphs]
I0521 09:32:51.859000 2240152 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:1012] [0/0_1] [__aot_graphs] TRACED GRAPH
I0521 09:32:51.859000 2240152 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:1012] [0/0_1] [__aot_graphs]  ===== Backward graph 0 =====
I0521 09:32:51.859000 2240152 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:1012] [0/0_1] [__aot_graphs]  <eval_with_key>.13 class GraphModule(torch.nn.Module):
I0521 09:32:51.859000 2240152 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:1012] [0/0_1] [__aot_graphs]     def forward(self, tangents_1: "f32[1][1]hpu:0"):
I0521 09:32:51.859000 2240152 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:1012] [0/0_1] [__aot_graphs]          # File: /home/reproducers/shouldnt_get_here_assert/repro.py:8 in fn, code: t1 = t1 * 1
I0521 09:32:51.859000 2240152 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:1012] [0/0_1] [__aot_graphs]         mul_1: "f32[1][1]hpu:0" = torch.ops.aten.mul.Tensor(tangents_1, 1);  tangents_1 = None
I0521 09:32:51.859000 2240152 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:1012] [0/0_1] [__aot_graphs]         return (mul_1,)
I0521 09:32:51.859000 2240152 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:1012] [0/0_1] [__aot_graphs]
I0521 09:32:51.859000 2240152 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:1012] [0/0_1] [__aot_graphs]
/home/venv_base/lib/python3.10/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn(
/home/venv_base/lib/python3.10/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn(
/home/venv_base/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py:1263: UserWarning: Dynamo does not know how to trace the builtin `habana_frameworks.torch.lib.fork_pybind._recipe_compiler_C.PyCapsule.batch_empty.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
  torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
V0521 09:32:52.148000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:123] [!0/3/0_1] [__aot_graphs] aot_config id: 1, fw_metadata=ViewAndMutationMeta(input_info=[InputAliasInfo(is_leaf=True, mutates_data=False, mutates_metadata=False, mutations_hidden_from_autograd=True, mutations_under_no_grad_or_inference_mode=False, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=False, keep_input_mutations=True), InputAliasInfo(is_leaf=True, mutates_data=False, mutates_metadata=False, mutations_hidden_from_autograd=True, mutations_under_no_grad_or_inference_mode=False, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=True, keep_input_mutations=True), InputAliasInfo(is_leaf=True, mutates_data=False, mutates_metadata=False, mutations_hidden_from_autograd=True, mutations_under_no_grad_or_inference_mode=False, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=False, keep_input_mutations=True)], output_info=[OutputAliasInfo(output_type=<OutputType.alias_of_input: 2>, raw_type=<class 'torch._subclasses.functional_tensor.FunctionalTensor'>, base_idx=0, dynamic_dims=set(), requires_grad=False, functional_tensor=<torch._functorch._aot_autograd.functional_utils.FunctionalTensorMetadataEq object at 0x7f75fe534280>)], num_intermediate_bases=0, keep_input_mutations=True, traced_tangents=[], subclass_inp_meta=[PlainTensorMeta(unwrapped_idx=0, memory_format=None), PlainTensorMeta(unwrapped_idx=1, memory_format=None), PlainTensorMeta(unwrapped_idx=2, memory_format=None)], subclass_fw_graph_out_meta=[PlainTensorMeta(unwrapped_idx=0, memory_format=None)], subclass_tangent_meta=[], is_train=False, traced_tangent_metas=None, num_symints_saved_for_bw=None, grad_enabled_mutation=None, deterministic=None, static_input_indices=[], tokens={}, indices_of_inputs_that_requires_grad_with_mutations_in_bw=[], bw_donated_idxs=None, num_backward_tokens=0, num_graphsafe_rng_states=0, graphsafe_rng_state_index=None),subclass_metadata=None
I0521 09:32:52.155000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:202] [!0/3/0_1] [__aot_graphs] TRACED GRAPH
I0521 09:32:52.155000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:202] [!0/3/0_1] [__aot_graphs]  ===== Forward graph 1 =====
I0521 09:32:52.155000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:202] [!0/3/0_1] [__aot_graphs]  /home/venv_base/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
I0521 09:32:52.155000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:202] [!0/3/0_1] [__aot_graphs]     def forward(self, arg0_1: "f32[1][1]hpu:0", arg1_1: "f32[1][1]hpu:0", arg2_1: "f32[1][1]hpu:0"):
I0521 09:32:52.155000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:202] [!0/3/0_1] [__aot_graphs]         # No stacktrace found for following nodes
I0521 09:32:52.155000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:202] [!0/3/0_1] [__aot_graphs]         select: "f32[][]hpu:0" = torch.ops.aten.select.int(arg0_1, 0, 0);  arg0_1 = None
I0521 09:32:52.155000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:202] [!0/3/0_1] [__aot_graphs]         return (select,)
I0521 09:32:52.155000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:202] [!0/3/0_1] [__aot_graphs]
I0521 09:32:52.155000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:202] [!0/3/0_1] [__aot_graphs]
/home/venv_base/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py:1263: UserWarning: Dynamo does not know how to trace the builtin `habana_frameworks.torch.lib.fork_pybind._recipe_compiler_C.PyCapsule.graph_launch.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
  torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
V0521 09:32:52.249000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:123] [!0/7/0] [__aot_graphs] aot_config id: 2, fw_metadata=ViewAndMutationMeta(input_info=[InputAliasInfo(is_leaf=True, mutates_data=False, mutates_metadata=False, mutations_hidden_from_autograd=True, mutations_under_no_grad_or_inference_mode=False, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=False, keep_input_mutations=True), InputAliasInfo(is_leaf=True, mutates_data=True, mutates_metadata=False, mutations_hidden_from_autograd=False, mutations_under_no_grad_or_inference_mode=True, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=False, keep_input_mutations=True)], output_info=[], num_intermediate_bases=0, keep_input_mutations=True, traced_tangents=[], subclass_inp_meta=[PlainTensorMeta(unwrapped_idx=0, memory_format=None), PlainTensorMeta(unwrapped_idx=1, memory_format=None)], subclass_fw_graph_out_meta=[], subclass_tangent_meta=[], is_train=False, traced_tangent_metas=None, num_symints_saved_for_bw=None, grad_enabled_mutation=None, deterministic=None, static_input_indices=[], tokens={}, indices_of_inputs_that_requires_grad_with_mutations_in_bw=[], bw_donated_idxs=None, num_backward_tokens=0, num_graphsafe_rng_states=0, graphsafe_rng_state_index=None),subclass_metadata=None
I0521 09:32:52.258000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:202] [!0/7/0] [__aot_graphs] TRACED GRAPH
I0521 09:32:52.258000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:202] [!0/7/0] [__aot_graphs]  ===== Forward graph 2 =====
I0521 09:32:52.258000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:202] [!0/7/0] [__aot_graphs]  /home/venv_base/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
I0521 09:32:52.258000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:202] [!0/7/0] [__aot_graphs]     def forward(self, arg0_1: "f32[1][1]hpu:0", arg1_1: "f32[1][1]hpu:0"):
I0521 09:32:52.258000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:202] [!0/7/0] [__aot_graphs]          # File: /home/venv_base/lib/python3.10/site-packages/torch/_dynamo/polyfills/__init__.py:80 in accumulate_grad, code: new_grad = torch.clone(new_grad)
I0521 09:32:52.258000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:202] [!0/7/0] [__aot_graphs]         clone: "f32[1][1]hpu:0" = torch.ops.aten.clone.default(arg0_1);  arg0_1 = None
I0521 09:32:52.258000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:202] [!0/7/0] [__aot_graphs]
I0521 09:32:52.258000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:202] [!0/7/0] [__aot_graphs]          # File: /home/venv_base/lib/python3.10/site-packages/torch/_dynamo/polyfills/__init__.py:84 in accumulate_grad, code: x.grad.add_(new_grad)
I0521 09:32:52.258000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:202] [!0/7/0] [__aot_graphs]         add: "f32[1][1]hpu:0" = torch.ops.aten.add.Tensor(arg1_1, clone);  clone = None
I0521 09:32:52.258000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:202] [!0/7/0] [__aot_graphs]         copy_: "f32[1][1]hpu:0" = torch.ops.aten.copy_.default(arg1_1, add);  arg1_1 = add = copy_ = None
I0521 09:32:52.258000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:202] [!0/7/0] [__aot_graphs]         return ()
I0521 09:32:52.258000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:202] [!0/7/0] [__aot_graphs]
I0521 09:32:52.258000 2240152 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:202] [!0/7/0] [__aot_graphs]
/home/venv_base/lib/python3.10/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn(
result=tensor([1.], device='hpu:0', grad_fn=<CompiledFunctionBackward>)
t1.grad=tensor([2.], device='hpu:0')

Regarding problem with scalar there's such graph being processed:

graph():
    %l_inputs_ : list [num_users=3] = placeholder[target=L_inputs_]
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%l_inputs_, 0), kwargs = {})
    %getitem_1 : [num_users=0] = call_function[target=operator.getitem](args = (%l_inputs_, 1), kwargs = {})
    %getitem_2 : [num_users=0] = call_function[target=operator.getitem](args = (%l_inputs_, 2), kwargs = {})
    %validate_outputs : [num_users=1] = call_function[target=torch._dynamo.compiled_autograd.ops.validate_outputs](args = ([%getitem], [((None, None, hpu:0, 6, 0, None), [], False)]), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%validate_outputs, 0), kwargs = {})
    %call_aot_bwd_prologue : [num_users=1] = call_function[target=torch._dynamo.compiled_autograd.call_aot_bwd_prologue](args = ((), [], %getitem_3), kwargs = {})
    %getitem_5 : [num_users=0] = call_function[target=operator.getitem](args = (%call_aot_bwd_prologue, 0), kwargs = {})
    %getitem_6 : [num_users=0] = call_function[target=operator.getitem](args = (%getitem_5, 0), kwargs = {})

Where output of call_aot_bwd_prologue is list of one scalar tensor, so getitem_5 is scalar and getitem_6 is trying to get first dim of scalar tensor which is ending up with error:

torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function <built-in function getitem>(*(FakeTensor(..., device='hpu:0', size=()), 0), **{}): got IndexError('invalid index of a 0-dim tensor. Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number')

@xmfan
Copy link
Member

xmfan commented May 21, 2025

I think what's happening is the HPU backend mutates the graph produced by AOTDispatch during lowering. And because there's no guarantee that the graph is still runnable after lowering, it will error out when traced (by Dynamo in this case).

For regular backends, we send a copy of the graph for lowering, and preserve the AOTDispatch traced graph: https://github.com/pytorch/pytorch/blob/72a3c8dfa8580f03c992ca06be89b92a6c163b0b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py#L1218-L122. It is possible that the HPU backend implementation does not call their backend through this code path.

The code change may get us past the initial error, but I don't think Compiled Autograd will work unless the HPU backend either preserves the AOTDispatch graph even after lowering, or guarantees that the graph is still executable

@mieshkiwrk
Copy link
Contributor Author

HPU and inductor backends doesn't go through aot_dispatch_export function with this example from description.

With proposed implementation of copy for call_module it's working for non-scalar tensors, for scalar tensors I see problem occuring while dynamo is building fx graph based of instructions, I tried to find source why it considers output of call_aot_bwd_prologue as list of non-scalar tensors while there's actually scalar tensor but without success for now

@xmfan
Copy link
Member

xmfan commented May 23, 2025

HPU and inductor backends doesn't go through aot_dispatch_export function with this example from description.

That is only used by torch.export, I'd expect aot_dispatch_autograd to be called when you torch.compile compiles at forward pass runtime, and aot_dispatch_base to be called when Compiled Autograd compiles at backward pass runtime.

for scalar tensors I see problem occuring while dynamo is building fx graph based of instructions, I tried to find source why it considers output of call_aot_bwd_prologue as list of non-scalar tensors while there's actually scalar tensor but without success for now

I believe the issue is that the AOT backward modified by HPU backend treats its input as a list. And does not expect it to be called directly by torch._functorch._aot_autograd.runtime_wrappers._backward_prologue_functional which is what torch.compile and compiled autograd do.

# Before HPU backend processing, expects graph input args[0] to be tensor
	graph():
		%tangents_1 : [num_users=1] = placeholder[target=tangents_1]
		...


# After HPU backend processing, expects graph input args[0] to be list
	graph():
		%input_list : list [num_users=2] = placeholder[target=input_list] 
                 %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%input_list, 0), kwargs = {})
		...

But the inputs to the AOT backward are created by running torch._functorch._aot_autograd.runtime_wrappers._backward_prologue_functional on the output gradients. And _backward_prologue_functional is not wrapping the inputs in some list. So then in the compiled autograd generated graph:

graph():
    ...
    # create the AOT backward's inputs
    %call_aot_bwd_prologue : [num_users=1] = call_function[target=torch._dynamo.compiled_autograd.call_aot_bwd_prologue](args = ((), [], %getitem_3), kwargs = {})
    %getitem_5 : [num_users=0] = call_function[target=operator.getitem](args = (%call_aot_bwd_prologue, 0), kwargs = {})
    # call the copy pasted AOT backward code
    # this line corresponds to %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%input_list, 0), kwargs = {})
    # but getitem_5 is NOT a list
    %getitem_6 : [num_users=0] = call_function[target=operator.getitem](args = (%getitem_5, 0), kwargs = {})

To dig deeper, we would need to figure out how the HPU backend passes a list (and not just a single tensor) to their AOT backward. AOTAutograd and Compiled Autograd both need the prologue to be able to format the inputs properly.

@mieshkiwrk
Copy link
Contributor Author

mieshkiwrk commented May 26, 2025

Thanks for your comments, indeed it's problem that HPU backend is creating list for inputs, while disabled this optimization scalars are also working as expected - same when copied graph for compilation.

I think this is the way we'll handle this problem for now as briefly tested it doesn't seem to have much performance impact - but for future, do you think this problem can be somehow handled by PyTorch side?
Or there should be requirement that backend should perform optimization for backwards graphs always on copy?

@mieshkiwrk
Copy link
Contributor Author

@pytorchbot merge

Copy link

pytorch-bot bot commented May 28, 2025

Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra.

@xmfan
Copy link
Member

xmfan commented May 28, 2025

Or there should be requirement that backend should perform optimization for backwards graphs always on copy?

I think this is the better option for compiled autograd. Compiled autograd will invoke the backend again on the backward graph traced at runtime, if backends can mutate the AOT backward that we copy paste, then backends will need to handle already lowered graphs.

@xmfan
Copy link
Member

xmfan commented May 28, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 28, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

iupaikov-amd pushed a commit to ROCm/pytorch that referenced this pull request Jun 4, 2025
…3827)

Support for `call_module` in `copy_paste_aot_backward_graph` added recently with PT2.7

Problem is being observed with HPU backend in example repro due to creating fused modules.

```
import torch

device = 'cpu' #'hpu'
backend = 'inductor' #'hpu_backend'

def fn(t1):
    t1 = t1 * 1
    t1_grad = torch.ones_like(t1, device=device)
    t1.backward(t1_grad, retain_graph=True)
    return t1

t1 = torch.ones(1, requires_grad=True, device=device) #.squeeze()
compiled_fn = torch.compile(fn, backend=backend)
result = compiled_fn(t1)

with torch._dynamo.compiled_autograd._enable(torch.compile(backend=backend)):
    result_grad = torch.ones_like(result, device=device)
    result.backward(result_grad)

print(f'{result_grad=}')
print(f'{t1.grad=}')
```

With this change I'm getting same results like on CPU, however I'm facing below problem when running with scalar (t1 tensor after squeeze):
`torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function <built-in function getitem>(*(FakeTensor(..., device='hpu:0', size=()), 0), **{}): got IndexError('invalid index of a 0-dim tensor. Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number')`

While on CPU there's following warning and None returned:
`repro.py:23: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pull/30531 for more informations. (Triggered internally at pytorch/build/aten/src/ATen/core/TensorBody.h:489.)
  print(f'{t1.grad=}')
t1.grad=None`

Pull Request resolved: pytorch#153827
Approved by: https://github.com/xmfan
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: compiled autograd compiled_autograd module: dynamo module: hpu Issues related to the hpu device (Habana/Gaudi) open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants