Skip to content

Conversation

JackCaoG
Copy link
Collaborator

@JackCaoG JackCaoG commented Aug 25, 2023

import torch_xla.core.xla_model as xm no longer trigger the xla runtime to init, hence explictly create the device here. This is a workaround for pytorch/xla#4174.

is_correct reference has been deleted, I think it is a deadcode.

After this patch, I am able to run

python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --training --backend=openxla --only resnet50

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng @anijain2305

@JackCaoG JackCaoG requested a review from shunting314 August 25, 2023 01:30
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 25, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/107919

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c7a1b50 with merge base c99a70c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@JackCaoG JackCaoG requested a review from wconstab August 25, 2023 18:09
try:
import torch_xla.core.xla_model as xm

device = xm.xla_device()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the purpose of calling xla_device()?

i assume if it failed it would throw a different exception so catching importerror isn't enough, also its odd to bind it to a name. is it enough to just do the import here and do the device call somewhere later at first use?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the actual bug is discussed in pytorch/xla#4174. I think it required pytorch/xla to registered something to the pytorch so backward can run correctly(when there is a cpu backward was run first).

The reason that old workaround works is we used to eagerly init the runtime but now that doesn't happen until actual device is being init. I need to root cause this issue...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Init the device is not ideal as it actually make it impossible to run more than one tests at a same time on TPU, since the main process init the device outside of the spawn function and keep the device.

I tried a few other approach and they didn't work. I decided it is better to at least make training runable and fix this later in nightly

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, let's make sure the script still works when XLA is missing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should, since the device init only happens when torch_xla import succed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try to dig into this issue a bit more this afternoon and see if I can workaround it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a few things may make it better

  1. have a API to explicitly initialize the xla runtime and call that rather than 'xla_device'. This can avoid some confusion.
  2. or figure out the root cause :)

But I'm ok to stamp the PR now to unblock. Unless @wconstab has other thoughts.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Shunting, I added a API to explicitly init runtime. Let me update it.

@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 25, 2023
@JackCaoG JackCaoG requested a review from a team as a code owner August 25, 2023 20:51
@JackCaoG
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 25, 2023
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@JackCaoG
Copy link
Collaborator Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Aug 25, 2023
@JackCaoG
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@JackCaoG
Copy link
Collaborator Author

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
- generated xml file: /var/lib/jenkins/workspace/test/test-reports/python-pytest/inductor.test_foreach/inductor.test_foreach-016b9206757a35ab.xml -
=========================== short test summary info ============================
FAILED [0.2053s] inductor/test_foreach.py::ForeachTests::test_2d_blocking__foreach_add - AssertionError: Scalars are not equal!

Expected 6 but got 5.
Absolute difference: 1
Relative difference: 0.16666666666666666

failure doesn't seem relevant

Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree that test_foreach is not relevant. if other tests all pass, you can merge with -f flag (or i can help) but must wait for tests to finish before merging or -f will ignore still-pending tests

@JackCaoG
Copy link
Collaborator Author

Only test failure seems irrelevant. I will force merge

@JackCaoG
Copy link
Collaborator Author

@pytorchbot merge -f "test failure is irrelevant"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

voznesenskym pushed a commit that referenced this pull request Aug 27, 2023
`import torch_xla.core.xla_model as xm` no longer trigger the xla runtime to init, hence explictly create the device here. This is a workaround for pytorch/xla#4174.

`is_correct` reference has been deleted, I think it is a deadcode.

After this patch, I am able to run

```
python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --training --backend=openxla --only resnet50
```

Pull Request resolved: #107919
Approved by: https://github.com/shunting314, https://github.com/wconstab
@izaitsevfb
Copy link
Contributor

@pytorchbot revert -m 'Conflicts with the revert of #106914' -c ghfirst

Hi @JackCaoG, sorry for the churn, but I have to unland your PR temporarily, as it conflicts with another revert (#106914) (in xla hash).

Please rebase and reland at your convenience.

@izaitsevfb
Copy link
Contributor

@pytorchbot revert -m 'Conflicts with the revert of 106914' -c ghfirst

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@JackCaoG your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Aug 29, 2023
This reverts commit ed8f212.

Reverted #107919 on behalf of https://github.com/izaitsevfb due to Conflicts with the revert of 106914 ([comment](#107919 (comment)))
@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@JackCaoG your PR has been successfully reverted.

@JackCaoG JackCaoG reopened this Sep 5, 2023
@JackCaoG
Copy link
Collaborator Author

JackCaoG commented Sep 5, 2023

PR was reverted because it touches XLA pin and there was another PR that also touches XLA pin got reverted. Reopen the pr and try to merge it again.

@JackCaoG
Copy link
Collaborator Author

JackCaoG commented Sep 6, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the JackCaoG/fix_xla_torchbench branch March 7, 2025 02:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo open source Reverted topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants