-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[aoti] Add handling of ir.Constants in promote_constants #122419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/122419
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 57acd90 with merge base 16935de ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ab0553f
to
45a5325
Compare
@angelayi has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@angelayi has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
45a5325
to
57acd90
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This issue popped up when enabling predispatch IR on the benchmarks (#122225) On the following model: ``` class M(torch.nn.Module): def __init__(self, device): super().__init__() self.device = device def forward(self, x): t = torch.tensor(x.size(-1), device=self.device, dtype=torch.float) t = torch.sqrt(t * 3) return x * t ``` We get the following error: ``` ====================================================================== ERROR: test_constant_abi_compatible_cuda (__main__.AOTInductorTestABICompatibleCuda) ---------------------------------------------------------------------- Traceback (most recent call last): File "/data/users/angelayi/pytorch/torch/testing/_internal/common_utils.py", line 2741, in wrapper method(*args, **kwargs) File "/data/users/angelayi/pytorch/test/inductor/test_torchinductor.py", line 9232, in new_test return value(self) File "/home/angelayi/.conda/envs/pytorch10/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/data/users/angelayi/pytorch/test/inductor/test_aot_inductor.py", line 922, in test_constant self.check_model(M(self.device), (torch.randn(5, 5, device=self.device),)) File "/data/users/angelayi/pytorch/test/inductor/test_aot_inductor.py", line 91, in check_model actual = AOTIRunnerUtil.run( File "/data/users/angelayi/pytorch/test/inductor/test_aot_inductor_utils.py", line 102, in run so_path = AOTIRunnerUtil.compile( File "/data/users/angelayi/pytorch/test/inductor/test_aot_inductor_utils.py", line 40, in compile so_path = torch._inductor.aot_compile_ep( File "/data/users/angelayi/pytorch/torch/_inductor/__init__.py", line 150, in aot_compile_ep return compile_fx_aot( File "/data/users/angelayi/pytorch/torch/_inductor/compile_fx.py", line 1005, in compile_fx_aot compiled_lib_path = compile_fx( File "/home/angelayi/.conda/envs/pytorch10/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/data/users/angelayi/pytorch/torch/_inductor/compile_fx.py", line 1111, in compile_fx return compile_fx( File "/home/angelayi/.conda/envs/pytorch10/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/data/users/angelayi/pytorch/torch/_inductor/compile_fx.py", line 1145, in compile_fx return compile_fx( File "/home/angelayi/.conda/envs/pytorch10/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/data/users/angelayi/pytorch/torch/_inductor/compile_fx.py", line 1336, in compile_fx return inference_compiler(unlifted_gm, example_inputs_) File "/data/users/angelayi/pytorch/torch/_dynamo/utils.py", line 265, in time_wrapper r = func(*args, **kwargs) File "/data/users/angelayi/pytorch/torch/_inductor/compile_fx.py", line 1266, in fw_compiler_base return inner_compile( File "/home/angelayi/.conda/envs/pytorch10/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/data/users/angelayi/pytorch/torch/_dynamo/repro/after_aot.py", line 83, in debug_wrapper inner_compiled_fn = compiler_fn(gm, example_inputs) File "/data/users/angelayi/pytorch/torch/_inductor/debug.py", line 304, in inner return fn(*args, **kwargs) File "/home/angelayi/.conda/envs/pytorch10/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/home/angelayi/.conda/envs/pytorch10/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/data/users/angelayi/pytorch/torch/_dynamo/utils.py", line 265, in time_wrapper r = func(*args, **kwargs) File "/data/users/angelayi/pytorch/torch/_inductor/compile_fx.py", line 447, in compile_fx_inner compiled_graph = fx_codegen_and_compile( File "/data/users/angelayi/pytorch/torch/_inductor/compile_fx.py", line 707, in fx_codegen_and_compile graph.run(*example_inputs) File "/data/users/angelayi/pytorch/torch/_dynamo/utils.py", line 265, in time_wrapper r = func(*args, **kwargs) File "/data/users/angelayi/pytorch/torch/_inductor/graph.py", line 612, in run return super().run(*args) File "/data/users/angelayi/pytorch/torch/fx/interpreter.py", line 145, in run self.env[node] = self.run_node(node) File "/data/users/angelayi/pytorch/torch/_inductor/graph.py", line 957, in run_node result = super().run_node(n) File "/data/users/angelayi/pytorch/torch/fx/interpreter.py", line 202, in run_node return getattr(self, n.op)(n.target, args, kwargs) File "/data/users/angelayi/pytorch/torch/_inductor/graph.py", line 819, in call_function raise LoweringException(e, target, args, kwargs).with_traceback( File "/data/users/angelayi/pytorch/torch/_inductor/graph.py", line 816, in call_function out = lowerings[target](*args, **kwargs) File "/data/users/angelayi/pytorch/torch/_inductor/lowering.py", line 298, in wrapped out = decomp_fn(*args, **kwargs) File "/data/users/angelayi/pytorch/torch/_inductor/lowering.py", line 5340, in mul return make_pointwise(fn)(a, b) File "/data/users/angelayi/pytorch/torch/_inductor/lowering.py", line 409, in inner inputs = promote_constants(inputs, override_return_dtype) File "/data/users/angelayi/pytorch/torch/_inductor/lowering.py", line 373, in promote_constants ex = next(x for x in inputs if isinstance(x, (TensorBox, ExpandView))) torch._inductor.exc.LoweringException: StopIteration: target: aten.mul.Tensor args[0]: Constant(value=5.0, dtype=torch.float32, device=device(type='cuda', index=0)) args[1]: 3 ``` So I added an additional casing in `promote_constants` to handle the ir.Constants and now it works! Although please let me know if this is the wrong approach. Here's a paste of the full run with the inductor logs: P1198927007 Pull Request resolved: #122419 Approved by: https://github.com/eellison, https://github.com/desertfire, https://github.com/chenyang78
This issue popped up when enabling predispatch IR on the benchmarks (#122225)
On the following model:
We get the following error:
So I added an additional casing in
promote_constants
to handle the ir.Constants and now it works! Although please let me know if this is the wrong approach. Here's a paste of the full run with the inductor logs: P1198927007cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @chauhang