-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Foreach gradient clipping #91846
Foreach gradient clipping #91846
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91846
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d00dafa: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll let @janeyx99 take a look at this.
She is building a generic tool for doing this per_device_and_dtype_grads
collection that will simplify this code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Can you add a test near test_clip_grad_norm in test/test_nn.py to ensure the calculations are the same?
Regarding work on consolidating a util for creating this dictionary: I'm currently landing #92014, which has a version of this grouping function. It would be best if the functionality used across this PR could be abstracted to a common util in that file too! |
I added tests and used |
@janeyx99 CI is green. I tweaked the import in your util file to avoid import race issues. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fast turnaround--looks awesome overall!
I had some nits and noob questions.
@@ -11486,7 +11408,8 @@ def run_test_case(norm_type, error_if_nonfinite, scalar, grad_only_one_elem, pre | |||
|
|||
@onlyCUDA | |||
@deviceCountAtLeast(2) | |||
def test_clip_grad_norm_multi_device(self, devices): | |||
@parametrize_test('foreach', (False, True)) | |||
def test_clip_grad_norm_multi_device(self, devices, foreach): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a concern with your PR, but I am realizing we never run this in CI because we only have one CI config where there is more than one GPU and we don't run this test in that config. 🤔 Filed #92173
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want to see this run in CI, you would have to add a line similar to https://github.com/pytorch/pytorch/pull/92048/files#diff-5356a2d45f3d28e01b954926d7f1681cccb0c9ad2cafd722478de8c283090bd3R48
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And then you would also need to add the ciflow/periodic label to get the multigpu tests to trigger.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! Thank you very much for the perf speedup :D
(sorry about the merge conflict--that's my bad)
@janeyx99 I rebased but the signature change is breaking torch XLA since the patching there expects the old signature
How should we approach that, first merge a fix to torch xla to allow both old and new signatures, then merge this PR? |
Yes, we'd want to sync the lands. If we're able to fix xla without breaking pytorch, then go for it and land a patch there first. In general, we follow these steps, but the force merging may be unnecessary if xla can be green the whole time: (1) Make a pytorch/pytorch PR and a pytorch/xla patch The force merge is because we’re betting on the fact that nothing should have changed from the last run to the next, and we don’t want to keep XLA CI red for an unnecessary 3 hours. And the rebasing beforehand is because at least so far, merge conflicts have been a frequent source of “the pytorch/xla PR merged, but the pytorch/pytorch PR is no longer ready” An example can be found by following https://github.com/pytorch/xla/blob/d636e7774b63cc070d7ebbfeec950e4892efa713/.circleci/README.md?plain=1#L10 |
I don't think I can do it from 2) since I don't have write access to the xla repo, I only have a fork and xla.txt can't be in a fork I think. I tried setting it to |
@wonjoolee95 I rebased this MR. Once it's green you can merge the XLA MR. Then I'll update the pin on this MR and force merge. |
@wonjoolee95 CI passed, can you merge the XLA PR? |
@milesial, just merged to master. The new pin should be |
Thanks, let's hope it goes smoothly @pytorchbot merge -f "coordinating merge with XLA, CI passed" |
You are not authorized to force merges to this repository. Please use the regular |
FYI, next time we can merge this pr first(pin can point to a unmerged branch and that's by design), and then merge the xla one. Otherwise xla head will be red until this pr merged. |
Haha I can't force merge, @janeyx99 can you help? |
Okay since we can't force merge this right now, I'm going to revert the XLA's PR lol. |
@milesial don't worry about our revert, as long as pytorch still pin to the correct pytorch/xla pin, pytorch CI will be fine. |
You can update the XLA pin in this PR to |
I think keeping @pytorchbot merge -g |
yea, you can keep the old pin |
Oh I can force merge :D |
@pytorchbot merge -f "coordinating with xla, prev ci was all green!" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Nice, I guess the XLA side can re-revert, sorry for the commit mess haha. And thanks for the help, that was a fun force-push-to-prod Friday |
don't speak too soon 🙃 |
@milesial @crcrpar can you check if debug build using this op errors out? We have reports of debug builds erroring out with
Edit: the issue is most likely not with this PR, which is just python enablement, but with for_each_norm implementation itself. |
I'll check. |
Faster gradient clipping using the foreach functions