Skip to content

Conversation

zhxchen17
Copy link
Contributor

@zhxchen17 zhxchen17 commented Feb 2, 2025

Fixing the following issue when compiling the following program:

                window = torch.hann_window(N_FFT).to(x.device)
                stft = torch.stft(
                    x, N_FFT, HOP_LENGTH, window=window, return_complex=True
                )
                magnitudes = stft[..., :-1].abs() ** 2
                return magnitudes
Traceback (most recent call last):
  File "/home/zhxchen17/miniconda3/envs/dev/lib/python3.11/unittest/case.py", line 57, in testPartExecutor
    yield
  File "/home/zhxchen17/miniconda3/envs/dev/lib/python3.11/unittest/case.py", line 623, in run
    self._callTestMethod(testMethod)
  File "/home/zhxchen17/miniconda3/envs/dev/lib/python3.11/unittest/case.py", line 579, in _callTestMethod
    if method() is not None:
       ^^^^^^^^
  File "/home/zhxchen17/pytorch/torch/testing/_internal/common_utils.py", line 3120, in wrapper
    method(*args, **kwargs)
  File "/home/zhxchen17/pytorch/test/inductor/test_torchinductor.py", line 12356, in new_test
    return value(self)
           ^^^^^^^^^^^
  File "/home/zhxchen17/pytorch/test/inductor/test_aot_inductor.py", line 4334, in test_stft
    self.check_model(model, example_inputs)
  File "/home/zhxchen17/pytorch/test/inductor/test_aot_inductor_utils.py", line 185, in check_model
    actual = AOTIRunnerUtil.run(
             ^^^^^^^^^^^^^^^^^^^
  File "/home/zhxchen17/pytorch/test/inductor/test_aot_inductor_utils.py", line 137, in run
    optimized = AOTIRunnerUtil.load(device, so_path)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhxchen17/pytorch/test/inductor/test_aot_inductor_utils.py", line 119, in load
    return torch._export.aot_load(so_path, device)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhxchen17/pytorch/torch/_export/__init__.py", line 165, in aot_load
    runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device)  # type: ignore[assignment, call-arg]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected extern kernel aten::hann_window to have serialized argument type as_scalar_type for argument 1 but got as_device

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @desertfire @chauhang @aakhundov

Copy link

pytorch-bot bot commented Feb 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146263

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure

As of commit 911b791 with merge base 07b9fe0 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@zhxchen17
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 3, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@atalman
Copy link
Contributor

atalman commented Feb 4, 2025

@pytorchmergebot revert -c ghfirst -m "multiple build failures, please see associated diff"

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@zhxchen17 your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Feb 4, 2025
…lues. (#146263)"

This reverts commit 11f6980.

Reverted #146263 on behalf of https://github.com/atalman due to multiple build failures, please see associated diff ([comment](#146263 (comment)))
@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Feb 4, 2025
@facebook-github-bot
Copy link
Contributor

@zhxchen17 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@zhxchen17
Copy link
Contributor Author

internal test failures fixed, see D69124148

…146263)

Summary:
Fixing the following issue when compiling the following program:
```
                window = torch.hann_window(N_FFT).to(x.device)
                stft = torch.stft(
                    x, N_FFT, HOP_LENGTH, window=window, return_complex=True
                )
                magnitudes = stft[..., :-1].abs() ** 2
                return magnitudes
```
```
Traceback (most recent call last):
  File "/home/zhxchen17/miniconda3/envs/dev/lib/python3.11/unittest/case.py", line 57, in testPartExecutor
    yield
  File "/home/zhxchen17/miniconda3/envs/dev/lib/python3.11/unittest/case.py", line 623, in run
    self._callTestMethod(testMethod)
  File "/home/zhxchen17/miniconda3/envs/dev/lib/python3.11/unittest/case.py", line 579, in _callTestMethod
    if method() is not None:
       ^^^^^^^^
  File "/home/zhxchen17/pytorch/torch/testing/_internal/common_utils.py", line 3120, in wrapper
    method(*args, **kwargs)
  File "/home/zhxchen17/pytorch/test/inductor/test_torchinductor.py", line 12356, in new_test
    return value(self)
           ^^^^^^^^^^^
  File "/home/zhxchen17/pytorch/test/inductor/test_aot_inductor.py", line 4334, in test_stft
    self.check_model(model, example_inputs)
  File "/home/zhxchen17/pytorch/test/inductor/test_aot_inductor_utils.py", line 185, in check_model
    actual = AOTIRunnerUtil.run(
             ^^^^^^^^^^^^^^^^^^^
  File "/home/zhxchen17/pytorch/test/inductor/test_aot_inductor_utils.py", line 137, in run
    optimized = AOTIRunnerUtil.load(device, so_path)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhxchen17/pytorch/test/inductor/test_aot_inductor_utils.py", line 119, in load
    return torch._export.aot_load(so_path, device)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhxchen17/pytorch/torch/_export/__init__.py", line 165, in aot_load
    runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device)  # type: ignore[assignment, call-arg]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected extern kernel aten::hann_window to have serialized argument type as_scalar_type for argument 1 but got as_device
```


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang aakhundov


Differential Revision: D69124148

Pulled By: zhxchen17
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69124148

@zhxchen17
Copy link
Contributor Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: trunk / linux-focal-rocm6.3-py3.10 / test (distributed, 1, 1, linux.rocm.gpu.4)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the zhxchen17/aoti/0 branch March 8, 2025 01:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants