Skip to content

Conversation

ruisizhang123
Copy link
Contributor

@ruisizhang123 ruisizhang123 commented Oct 9, 2025

When the autobucketing pass is registered as aot_eager backend fw_compiler and bw_compiler, this pr ensures the tensors are all-gathers on "cpu/cuda" device instead of "meta" device.

When we do dist.all_gather_object, it will create new bytestorage outside no_dispatch here, which is on meta device. Thus, I updated the code to use unset_fake_temporarily, which would gather RealTensor from other ranks.

It is needed to unblock the aot_eager+autobucketing pass in this PR.

Otherwise, I hit the error as follows:

  traceback : Traceback (most recent call last):
    File "/home/ruisizhang123/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 358, in wrapper
      return f(*args, **kwargs)
    File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 607, in train
      self.train_step(data_iterator)
      ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 507, in train_step
      loss = self.forward_backward_step(input_dict, labels)
    File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 483, in forward_backward_step
      pred = model_parts[0](inputs, **extra_inputs, **extra_args)
    File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 418, in __call__
      return super().__call__(*args, **kwargs)
             ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1784, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1795, in _call_impl
      return forward_call(*args, **kwargs)
    File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 901, in compile_wrapper
      raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/_dynamo/output_graph.py", line 2359, in _call_user_compiler
      raise BackendCompilerFailed(
          self.compiler_fn, e, inspect.currentframe()
      ).with_traceback(e.__traceback__) from None
    File "/home/ruisizhang123/pytorch/torch/_dynamo/output_graph.py", line 2334, in _call_user_compiler
      compiled_fn = compiler_fn(gm, example_inputs)
    File "/home/ruisizhang123/pytorch/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__
      compiled_gm = compiler_fn(gm, example_inputs)
    File "/home/ruisizhang123/pytorch/torch/__init__.py", line 2441, in __call__
      return self.compiler_fn(model_, inputs_, **self.kwargs)
             ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/_dynamo/backends/common.py", line 117, in __call__
      cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
    File "/home/ruisizhang123/pytorch/torch/_functorch/aot_autograd.py", line 1100, in aot_module_simplified
      compiled_fn, _ = aot_stage2_compile(
                       ~~~~~~~~~~~~~~~~~~^
          aot_state,
          ^^^^^^^^^^
      ...<4 lines>...
          inference_compiler,
          ^^^^^^^^^^^^^^^^^^^
      )
      ^
    File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 257, in aot_stage2_compile
      return aot_stage2_autograd(aot_state, aot_graph_capture)
    File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 1696, in aot_stage2_autograd
      compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
    File "/home/ruisizhang123/torchtitan/torchtitan/experiments/simple_fsdp/backend.py", line 35, in aten_autobucketing_reordering_pass
      schedule_overlap_bucketing(gm)
      ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^
    File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 755, in schedule_overlap_bucketing
      ).run()
        ~~~^^
    File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 358, in run
      self._align_compute_nodes_runtime_estimations_across_all_distributed_ranks()
      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^
    File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 337, in _align_compute_nodes_runtime_estimations_across_all_distributed_ranks
      dist.all_gather_object(
      ~~~~~~~~~~~~~~~~~~~~~~^
          gathered_runtime_estimations, runtime_estimations, pg
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      )
      ^
    File "/home/ruisizhang123/pytorch/torch/distributed/c10d_logger.py", line 82, in wrapper
      return func(*args, **kwargs)
    File "/home/ruisizhang123/pytorch/torch/distributed/distributed_c10d.py", line 3170, in all_gather_object
      input_tensor, local_size = _object_to_tensor(obj, current_device, group)
                                 ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/distributed/distributed_c10d.py", line 3079, in _object_to_tensor
      byte_tensor = torch.ByteTensor(byte_storage).to(device)
                    ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
  torch._dynamo.exc.BackendCompilerFailed: backend='compiler_fn' raised:
  RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "meta".  This is no longer allowed; the devices must match.
  
  Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
  

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

Copy link

pytorch-bot bot commented Oct 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165063

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4454804 with merge base 96d91da (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@ezyang
Copy link
Contributor

ezyang commented Oct 10, 2025

@ruisizhang123 I am a bit suspicious about the change here. Why are we tracing through the compiler collectives here? It seems to me it's more important to figure out how to stop tracing through the optimization code itself.

@ruisizhang123
Copy link
Contributor Author

@ruisizhang123 I am a bit suspicious about the change here. Why are we tracing through the compiler collectives here? It seems to me it's more important to figure out how to stop tracing through the optimization code itself.

hmmm I'm probably using the wrong terminology. During aot_autograd compile, the tensors are in FakeMode and they are in meta device. When we do dist.all_gather_object, it will create new bytestorage outside no_dispatch here, which is on meta device. We cannot gather the object with RealTensor using dist.all_gather_object.

I'm updating the code to use _functional_collectives.all_gather_tensor, which would gather RealTensor from other ranks. I will update pr description to make it more clear.

@ruisizhang123
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 13, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@ezyang
Copy link
Contributor

ezyang commented Oct 13, 2025

I think I'm still perplexed by the PR description. Would it be possible to get the full stack trace when the old code errors when doing the object collective?

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (default, 1, 3, macos-m1-stable)

Details for Dev Infra team Raised by workflow job

@ruisizhang123
Copy link
Contributor Author

ruisizhang123 commented Oct 13, 2025

I think I'm still perplexed by the PR description. Would it be possible to get the full stack trace when the old code errors when doing the object collective?

Yep, updated here: #165063 (comment) @ezyang

@ruisizhang123
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 15, 2025
…rch#165063)

When the autobucketing pass  is registered as aot_eager backend `fw_compiler` and `bw_compiler`, this pr ensures the tensors are all-gathers on "cpu/cuda" device instead of "meta" device.

When we do `dist.all_gather_object`, it will create new bytestorage outside no_dispatch [here](https://github.com/pytorch/pytorch/blob/a2e2e1d8c026951baa345f0dd17668bd1718eda5/torch/distributed/distributed_c10d.py#L3303), which is on meta device. Thus, I updated the code to use `unset_fake_temporarily`, which would gather RealTensor from other ranks.

 It is needed to unblock the aot_eager+autobucketing pass in this [PR](pytorch/torchtitan#1813).

Otherwise, I hit the error as follows:

```bash
  traceback : Traceback (most recent call last):
    File "/home/ruisizhang123/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 358, in wrapper
      return f(*args, **kwargs)
    File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 607, in train
      self.train_step(data_iterator)
      ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 507, in train_step
      loss = self.forward_backward_step(input_dict, labels)
    File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 483, in forward_backward_step
      pred = model_parts[0](inputs, **extra_inputs, **extra_args)
    File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 418, in __call__
      return super().__call__(*args, **kwargs)
             ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1784, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1795, in _call_impl
      return forward_call(*args, **kwargs)
    File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 901, in compile_wrapper
      raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/_dynamo/output_graph.py", line 2359, in _call_user_compiler
      raise BackendCompilerFailed(
          self.compiler_fn, e, inspect.currentframe()
      ).with_traceback(e.__traceback__) from None
    File "/home/ruisizhang123/pytorch/torch/_dynamo/output_graph.py", line 2334, in _call_user_compiler
      compiled_fn = compiler_fn(gm, example_inputs)
    File "/home/ruisizhang123/pytorch/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__
      compiled_gm = compiler_fn(gm, example_inputs)
    File "/home/ruisizhang123/pytorch/torch/__init__.py", line 2441, in __call__
      return self.compiler_fn(model_, inputs_, **self.kwargs)
             ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/_dynamo/backends/common.py", line 117, in __call__
      cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
    File "/home/ruisizhang123/pytorch/torch/_functorch/aot_autograd.py", line 1100, in aot_module_simplified
      compiled_fn, _ = aot_stage2_compile(
                       ~~~~~~~~~~~~~~~~~~^
          aot_state,
          ^^^^^^^^^^
      ...<4 lines>...
          inference_compiler,
          ^^^^^^^^^^^^^^^^^^^
      )
      ^
    File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 257, in aot_stage2_compile
      return aot_stage2_autograd(aot_state, aot_graph_capture)
    File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 1696, in aot_stage2_autograd
      compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
    File "/home/ruisizhang123/torchtitan/torchtitan/experiments/simple_fsdp/backend.py", line 35, in aten_autobucketing_reordering_pass
      schedule_overlap_bucketing(gm)
      ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^
    File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 755, in schedule_overlap_bucketing
      ).run()
        ~~~^^
    File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 358, in run
      self._align_compute_nodes_runtime_estimations_across_all_distributed_ranks()
      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^
    File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 337, in _align_compute_nodes_runtime_estimations_across_all_distributed_ranks
      dist.all_gather_object(
      ~~~~~~~~~~~~~~~~~~~~~~^
          gathered_runtime_estimations, runtime_estimations, pg
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      )
      ^
    File "/home/ruisizhang123/pytorch/torch/distributed/c10d_logger.py", line 82, in wrapper
      return func(*args, **kwargs)
    File "/home/ruisizhang123/pytorch/torch/distributed/distributed_c10d.py", line 3170, in all_gather_object
      input_tensor, local_size = _object_to_tensor(obj, current_device, group)
                                 ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/distributed/distributed_c10d.py", line 3079, in _object_to_tensor
      byte_tensor = torch.ByteTensor(byte_storage).to(device)
                    ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
  torch._dynamo.exc.BackendCompilerFailed: backend='compiler_fn' raised:
  RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "meta".  This is no longer allowed; the devices must match.

  Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

```

Pull Request resolved: pytorch#165063
Approved by: https://github.com/eellison
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants