-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[autobucketing] aten autobucketing fix to enable aot_eager pass #165063
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165063
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4454804 with merge base 96d91da ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
540e643
to
cbb4538
Compare
@ruisizhang123 I am a bit suspicious about the change here. Why are we tracing through the compiler collectives here? It seems to me it's more important to figure out how to stop tracing through the optimization code itself. |
hmmm I'm probably using the wrong terminology. During aot_autograd compile, the tensors are in FakeMode and they are in meta device. When we do I'm updating the code to use |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
I think I'm still perplexed by the PR description. Would it be possible to get the full stack trace when the old code errors when doing the object collective? |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (default, 1, 3, macos-m1-stable) Details for Dev Infra teamRaised by workflow job |
Yep, updated here: #165063 (comment) @ezyang |
cbb4538
to
fa70ffa
Compare
fa70ffa
to
4454804
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…rch#165063) When the autobucketing pass is registered as aot_eager backend `fw_compiler` and `bw_compiler`, this pr ensures the tensors are all-gathers on "cpu/cuda" device instead of "meta" device. When we do `dist.all_gather_object`, it will create new bytestorage outside no_dispatch [here](https://github.com/pytorch/pytorch/blob/a2e2e1d8c026951baa345f0dd17668bd1718eda5/torch/distributed/distributed_c10d.py#L3303), which is on meta device. Thus, I updated the code to use `unset_fake_temporarily`, which would gather RealTensor from other ranks. It is needed to unblock the aot_eager+autobucketing pass in this [PR](pytorch/torchtitan#1813). Otherwise, I hit the error as follows: ```bash traceback : Traceback (most recent call last): File "/home/ruisizhang123/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 358, in wrapper return f(*args, **kwargs) File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 607, in train self.train_step(data_iterator) ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^ File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 507, in train_step loss = self.forward_backward_step(input_dict, labels) File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 483, in forward_backward_step pred = model_parts[0](inputs, **extra_inputs, **extra_args) File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 418, in __call__ return super().__call__(*args, **kwargs) ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1784, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1795, in _call_impl return forward_call(*args, **kwargs) File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 901, in compile_wrapper raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ruisizhang123/pytorch/torch/_dynamo/output_graph.py", line 2359, in _call_user_compiler raise BackendCompilerFailed( self.compiler_fn, e, inspect.currentframe() ).with_traceback(e.__traceback__) from None File "/home/ruisizhang123/pytorch/torch/_dynamo/output_graph.py", line 2334, in _call_user_compiler compiled_fn = compiler_fn(gm, example_inputs) File "/home/ruisizhang123/pytorch/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__ compiled_gm = compiler_fn(gm, example_inputs) File "/home/ruisizhang123/pytorch/torch/__init__.py", line 2441, in __call__ return self.compiler_fn(model_, inputs_, **self.kwargs) ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ruisizhang123/pytorch/torch/_dynamo/backends/common.py", line 117, in __call__ cg = aot_module_simplified(gm, example_inputs, **self.kwargs) File "/home/ruisizhang123/pytorch/torch/_functorch/aot_autograd.py", line 1100, in aot_module_simplified compiled_fn, _ = aot_stage2_compile( ~~~~~~~~~~~~~~~~~~^ aot_state, ^^^^^^^^^^ ...<4 lines>... inference_compiler, ^^^^^^^^^^^^^^^^^^^ ) ^ File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 257, in aot_stage2_compile return aot_stage2_autograd(aot_state, aot_graph_capture) File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 1696, in aot_stage2_autograd compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args) File "/home/ruisizhang123/torchtitan/torchtitan/experiments/simple_fsdp/backend.py", line 35, in aten_autobucketing_reordering_pass schedule_overlap_bucketing(gm) ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^ File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 755, in schedule_overlap_bucketing ).run() ~~~^^ File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 358, in run self._align_compute_nodes_runtime_estimations_across_all_distributed_ranks() ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^ File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 337, in _align_compute_nodes_runtime_estimations_across_all_distributed_ranks dist.all_gather_object( ~~~~~~~~~~~~~~~~~~~~~~^ gathered_runtime_estimations, runtime_estimations, pg ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ) ^ File "/home/ruisizhang123/pytorch/torch/distributed/c10d_logger.py", line 82, in wrapper return func(*args, **kwargs) File "/home/ruisizhang123/pytorch/torch/distributed/distributed_c10d.py", line 3170, in all_gather_object input_tensor, local_size = _object_to_tensor(obj, current_device, group) ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ruisizhang123/pytorch/torch/distributed/distributed_c10d.py", line 3079, in _object_to_tensor byte_tensor = torch.ByteTensor(byte_storage).to(device) ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^ torch._dynamo.exc.BackendCompilerFailed: backend='compiler_fn' raised: RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "meta". This is no longer allowed; the devices must match. Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo" ``` Pull Request resolved: pytorch#165063 Approved by: https://github.com/eellison
When the autobucketing pass is registered as aot_eager backend
fw_compiler
andbw_compiler
, this pr ensures the tensors are all-gathers on "cpu/cuda" device instead of "meta" device.When we do
dist.all_gather_object
, it will create new bytestorage outside no_dispatch here, which is on meta device. Thus, I updated the code to useunset_fake_temporarily
, which would gather RealTensor from other ranks.It is needed to unblock the aot_eager+autobucketing pass in this PR.
Otherwise, I hit the error as follows:
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben