-
Notifications
You must be signed in to change notification settings - Fork 25.4k
torch.nn.utils.clip_grad_norm_: remove device syncs #61042
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.nn.utils.clip_grad_norm_: remove device syncs #61042
Conversation
💊 CI failures summary and remediationsAs of commit b156402 (more details on the Dr. CI page and at hud.pytorch.org/pr/61042):
🕵️ 3 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
I briefly looked into the Secondly, I'm not sure if the testing I did is sufficient: is there a way to assert in a test than 0 or 1 device sync is occuring during the call to |
There are pre-existing failures for that test- we can ignore it for the purposes of this PR.
No way to assert on device syncs that I'm aware of - cc @ngimel for confirmation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes on the python side lgtm - thanks for the update! Two things:
- Unfortunately, there's currently a C++ reimplementation of this function in
torch/csrc/api/include/torch/nn/utils/clip_grad.h
that should be updated as well - I'm a bit curious about the performance implications of the unconditional multiplication on CPU
@JackCaoG Does it make sense to apply similar updates on the XLA side?
Codecov Report
@@ Coverage Diff @@
## master #61042 +/- ##
==========================================
- Coverage 76.22% 75.74% -0.48%
==========================================
Files 2062 2062
Lines 205577 209332 +3755
==========================================
+ Hits 156693 158552 +1859
- Misses 48884 50780 +1896 |
Thanks @jbschlosser , will take a look tmr. |
@jbschlosser I think pt/xla can also benefit from this change, so I will update our patch as well. We are thinking about removing the patch for |
Ah actually pt/xla already do the manual |
@mautier Would you be willing to update the C++ version in |
@jbschlosser Sorry for the radio silence; life things and a fried power supply have gotten in the way 🙃 I have a patch for the C++ version that I'll push soon; unfortunately it appears that the C++ version has a slightly different API, in that it returns a |
No worries! RIP your power supply :/ Thanks for checking out the C++ version as well! It's unfortunate that the API differences make that last sync unremovable, but it does makes sense. We want to keep the current API, but a comment about the discrepancy resulting in a final unremovable sync wouldn't hurt. |
cc @ngimel |
} else if (norm_type == 0) { | ||
total_norm = static_cast<double>(params_with_grad.size()); | ||
total_norm_tensor = torch::full({}, static_cast<double>(params_with_grad.size())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not all that familiar with the C++ tensor creation APIs, is there a better way of creating a scalar tensor?
And on an unrelated note: order 0 norm is defined in torch.linalg.norm
as the number of non-zero entries, but in this implementation it's just the number of parameters?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch::full()
is fine here imo.
Yeah, it does look like there's a discrepancy between here and the order 0 norm definition in torch.linalg.norm
. I guess there's an implicit assumption that all grads are nonzero here? Regardless, we should leave as-is for backwards compatibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the awesome fix & commenting! LGTM
Mind rebasing to address merge conflicts?
} else if (norm_type == 0) { | ||
total_norm = static_cast<double>(params_with_grad.size()); | ||
total_norm_tensor = torch::full({}, static_cast<double>(params_with_grad.size())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch::full()
is fine here imo.
Yeah, it does look like there's a discrepancy between here and the order 0 norm definition in torch.linalg.norm
. I guess there's an implicit assumption that all grads are nonzero here? Regardless, we should leave as-is for backwards compatibility.
@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
a2e6e26
to
a82f25c
Compare
@jbschlosser Rebased on master and (force-)pushed. Let me know if you want me to squash the commits into one too (not sure if you merge or squash-merge PRs)! |
@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Previously, when the caller opted out of nan/inf checks on the gradient norm, the checks would still be run in order to produce a warning. Unfortunately doing so incurs a synchronization cost when the gradients live on, say, a CUDA device, as the CPU-side control flow depends on the result of CUDA-side computations. This commit removes the warning codepath; when the user opts out of finite-ness checks (`error_if_nonfinite=False`), the checks are skipped, and there is no performance penalty. Additionally, the 2 separate checks (`isnan()` and `isinf()`) are now combined using `torch.logical_or`. This means that when the checks do run on a non-CPU device, a single synchronization will be required instead of 2. (the previous behavior allowed for short-circuiting, but only in the NaN case, not in the happy path)
The `if clip_coef < 1:` conditional incurs a device synchronization when the gradients are on a non-CPU device. This commit removes the conditional in favor of a clamp-ing step and unconditional scaling.
Just like in the python version of this function, this commit removes all nan/inf checks when `error_if_nonfinite = false`. It also changes the computation of the norm to be based on standard pytorch tensor operations (instead of more manual computations with `std::max` and `std::pow`), in order to make it possible for those computations to run directly on the device with no CPU synchronization. Unfortunately, even then, since the C++ API returns the norm of the gradients as a `double` (and not a scalar tensor), the implementation must inevitably synchronize the CPU and device at the end. Nonetheless, this function now synchronizes only once, as late as possible, instead of many times (once per param).
a82f25c
to
b156402
Compare
Rebased again to pick up 2 fixes that went into master since the last rebase; this should address the |
@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
@jbschlosser merged this pull request in e858f6e. |
Fixes #60691
Changes
Per the discussion in the above issue, this PR makes 2 changes:
error_if_nonfinite=False
, the NaN/Inf checks are truly skipped, and no device synchronization occurs.torch.logical_or
to incur only a single sync (instead of 2 in the happy/finite path).clip_coef
conditional is removed, in favor of a call toclamp(..., max=1.0)
and an unconditional multiplication.Testing
clip_grad_norm_
pass.error_if_nonfinite=False
.error_if_nonfinite=True
.