Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Foreach gradient clipping #91846

Closed
wants to merge 19 commits into from
Closed

Foreach gradient clipping #91846

wants to merge 19 commits into from

Conversation

milesial
Copy link
Contributor

@milesial milesial commented Jan 7, 2023

Faster gradient clipping using the foreach functions

[------------------------ (tensors, scalar) -------------------------]
                                   |  without foreach  |  with foreach |    apex 
1 threads: ----------------------------------------------------------------------
      10 tensors of size 4         |         120.5     |       61.1    |     50.3
      100 tensors of size 4        |         946.2     |      239.5    |    136.3
      1000 tensors of size 4       |        9808.5     |     2151.1    |   1006.9
      10000 tensors of size 4      |       96871.2     |    22637.4    |  10119.1
      10 tensors of size 16        |         121.0     |       64.1    |     52.5
      100 tensors of size 16       |         993.4     |      252.6    |    136.7
      1000 tensors of size 16      |        9427.7     |     2151.2    |   1049.5
      10000 tensors of size 16     |       97437.1     |    22203.1    |  10340.0
      10 tensors of size 256       |         118.9     |       62.3    |     51.5
      100 tensors of size 256      |         955.2     |      243.1    |    134.2
      1000 tensors of size 256     |        9374.9     |     2140.7    |   1009.6
      10000 tensors of size 256    |       95302.5     |    21849.4    |  10215.5
      10 tensors of size 65536     |         118.5     |       62.4    |     51.1
      100 tensors of size 65536    |        1740.7     |      243.3    |    225.3
      1000 tensors of size 65536   |       17364.1     |     2228.7    |   2004.5
      10000 tensors of size 65536  |      177510.1     |    25410.4    |  20678.2

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 7, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91846

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d00dafa:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@milesial milesial marked this pull request as ready for review January 9, 2023 19:32
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll let @janeyx99 take a look at this.
She is building a generic tool for doing this per_device_and_dtype_grads collection that will simplify this code.

@janeyx99 janeyx99 self-requested a review January 11, 2023 16:45
Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Can you add a test near test_clip_grad_norm in test/test_nn.py to ensure the calculations are the same?

@janeyx99
Copy link
Contributor

Regarding work on consolidating a util for creating this dictionary: I'm currently landing #92014, which has a version of this grouping function. It would be best if the functionality used across this PR could be abstracted to a common util in that file too!

@milesial
Copy link
Contributor Author

I added tests and used _group_tensors_by_device_and_dtype, let's see what CI says

@drisspg drisspg added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 12, 2023
@milesial
Copy link
Contributor Author

@janeyx99 CI is green. I tweaked the import in your util file to avoid import race issues.

test/test_nn.py Outdated Show resolved Hide resolved
Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fast turnaround--looks awesome overall!

I had some nits and noob questions.

@janeyx99
Copy link
Contributor

I confirm the tests passed!
image

test/test_nn.py Outdated Show resolved Hide resolved
@@ -11486,7 +11408,8 @@ def run_test_case(norm_type, error_if_nonfinite, scalar, grad_only_one_elem, pre

@onlyCUDA
@deviceCountAtLeast(2)
def test_clip_grad_norm_multi_device(self, devices):
@parametrize_test('foreach', (False, True))
def test_clip_grad_norm_multi_device(self, devices, foreach):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a concern with your PR, but I am realizing we never run this in CI because we only have one CI config where there is more than one GPU and we don't run this test in that config. 🤔 Filed #92173

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And then you would also need to add the ciflow/periodic label to get the multigpu tests to trigger.

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Thank you very much for the perf speedup :D

(sorry about the merge conflict--that's my bad)

@milesial
Copy link
Contributor Author

milesial commented Jan 18, 2023

@janeyx99 I rebased but the signature change is breaking torch XLA since the patching there expects the old signature

 Traceback (most recent call last):
  File "/var/lib/jenkins/workspace/xla/test/../../test/test_view_ops.py", line 15, in <module>
    from torch.testing._internal.common_device_type import \
  File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 603, in <module>
    mod = runpy.run_path(path, init_globals=globals())  # type: ignore[func-returns-value]
  File "/opt/conda/lib/python3.7/runpy.py", line 263, in run_path
    pkg_name=pkg_name, script_name=fname)
  File "/opt/conda/lib/python3.7/runpy.py", line 96, in _run_module_code
    mod_name, mod_spec, pkg_name, script_name)
  File "/opt/conda/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/var/lib/jenkins/workspace/xla/test/pytorch_test_base.py", line 8, in <module>
    import torch_xla
  File "/opt/conda/lib/python3.7/site-packages/torch_xla-2.0.0-py3.7-linux-x86_64.egg/torch_xla/__init__.py", line 140, in <module>
    _apply_patches()
  File "/opt/conda/lib/python3.7/site-packages/torch_xla-2.0.0-py3.7-linux-x86_64.egg/torch_xla/_patched_functions.py", line 54, in _apply_patches
    nn.utils.clip_grad_norm_ = _patch(nn.utils.clip_grad_norm_, clip_grad_norm_)
  File "/opt/conda/lib/python3.7/site-packages/torch_xla-2.0.0-py3.7-linux-x86_64.egg/torch_xla/_patched_functions.py", line 16, in _patch
    fn, xfingerprint, fingerprint))
RuntimeError: Unable to patch <function clip_grad_norm_ at 0x7fc8b282d950>, signature mismatch: (parameters: Union[torch.Tensor, Iterable[torch.Tensor]], max_norm: float, norm_type: float = 2.0, error_if_nonfinite: bool = False, foreach: bool = None) -> torch.Tensor vs (parameters: Union[torch.Tensor, Iterable[torch.Tensor]], max_norm: float, norm_type: float = 2.0, error_if_nonfinite: bool = False) -> torch.Tensor

from https://github.com/pytorch/xla/blob/d636e7774b63cc070d7ebbfeec950e4892efa713/torch_xla/_patched_functions.py#L21-L54

How should we approach that, first merge a fix to torch xla to allow both old and new signatures, then merge this PR?

@janeyx99
Copy link
Contributor

Yes, we'd want to sync the lands. If we're able to fix xla without breaking pytorch, then go for it and land a patch there first. In general, we follow these steps, but the force merging may be unnecessary if xla can be green the whole time:

(1) Make a pytorch/pytorch PR and a pytorch/xla patch
(2) update xla.txt on the pytorch.pytorch PR to point to your patch
(3) once pytorch/pytorch CI is fully green, rebase on tip-of-master again to minimize the chance of merge conflicts / last minute problems
(4) once you get fully green CI on the newly rebased pytorch/pytorch PR, merge the pytorch/xla PR (XLA CI will start failing)
(5) update the pytorch/pytorch PR to the new tip-of-master XLA commit hash (no other changes should be required), and immediately force-merge. You’re counting on the fact that CI was green ~3 hours ago.

The force merge is because we’re betting on the fact that nothing should have changed from the last run to the next, and we don’t want to keep XLA CI red for an unnecessary 3 hours. And the rebasing beforehand is because at least so far, merge conflicts have been a frequent source of “the pytorch/xla PR merged, but the pytorch/pytorch PR is no longer ready”

An example can be found by following https://github.com/pytorch/xla/blob/d636e7774b63cc070d7ebbfeec950e4892efa713/.circleci/README.md?plain=1#L10

@milesial milesial requested a review from a team as a code owner January 18, 2023 22:31
@milesial
Copy link
Contributor Author

milesial commented Jan 19, 2023

I don't think I can do it from 2) since I don't have write access to the xla repo, I only have a fork and xla.txt can't be in a fork I think. I tried setting it to pull/4471/head but that didn't work

pytorch/xla#4471

@milesial
Copy link
Contributor Author

@wonjoolee95 I rebased this MR. Once it's green you can merge the XLA MR. Then I'll update the pin on this MR and force merge.

@milesial
Copy link
Contributor Author

@wonjoolee95 CI passed, can you merge the XLA PR?

@wonjoolee95
Copy link
Collaborator

@wonjoolee95 CI passed, can you merge the XLA PR?

@milesial, just merged to master. The new pin should be eac4e547138ab22a9b41c6f96208613fd7dd19d5.

@milesial
Copy link
Contributor Author

Thanks, let's hope it goes smoothly

@pytorchbot merge -f "coordinating merge with XLA, CI passed"

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 20, 2023

You are not authorized to force merges to this repository. Please use the regular @pytorchmergebot merge command instead

@JackCaoG
Copy link
Collaborator

FYI, next time we can merge this pr first(pin can point to a unmerged branch and that's by design), and then merge the xla one. Otherwise xla head will be red until this pr merged.

@milesial
Copy link
Contributor Author

Haha I can't force merge, @janeyx99 can you help?

@wonjoolee95
Copy link
Collaborator

Okay since we can't force merge this right now, I'm going to revert the XLA's PR lol.

@JackCaoG
Copy link
Collaborator

@milesial don't worry about our revert, as long as pytorch still pin to the correct pytorch/xla pin, pytorch CI will be fine.

@wonjoolee95
Copy link
Collaborator

You can update the XLA pin in this PR to 8dcab83819368f468dadbe6e81b064d268830df2 and merge -g. I'll merge the XLA's companion PR once this merges.

@milesial
Copy link
Contributor Author

I think keeping eac4e is fine no? No need to switch to 8dca right?

@pytorchbot merge -g

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 20, 2023
@JackCaoG
Copy link
Collaborator

yea, you can keep the old pin

@janeyx99
Copy link
Contributor

Oh I can force merge :D

@janeyx99
Copy link
Contributor

@pytorchbot merge -f "coordinating with xla, prev ci was all green!"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@milesial
Copy link
Contributor Author

Nice, I guess the XLA side can re-revert, sorry for the commit mess haha. And thanks for the help, that was a fun force-push-to-prod Friday

@janeyx99
Copy link
Contributor

don't speak too soon 🙃

@ngimel
Copy link
Collaborator

ngimel commented Feb 1, 2023

@milesial @crcrpar can you check if debug build using this op errors out? We have reports of debug builds erroring out with

** RuntimeError: t.storage().use_count() == 1 INTERNAL ASSERT FAILED at "caffe2/torch/csrc/autograd/autograd_not_implemented_fallback.cpp":189, please report a bug to PyTorch.

Edit: the issue is most likely not with this PR, which is just python enablement, but with for_each_norm implementation itself.

@milesial
Copy link
Contributor Author

milesial commented Feb 2, 2023

I'll check.
By debug build you mean building with DEBUG=1 right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants