Skip to content

Conversation

mautier
Copy link
Contributor

@mautier mautier commented Jun 30, 2021

Fixes #60691

Changes

Per the discussion in the above issue, this PR makes 2 changes:

  1. When error_if_nonfinite=False, the NaN/Inf checks are truly skipped, and no device synchronization occurs.
    • Additionally, when performing the checks, the 2 results are combined with torch.logical_or to incur only a single sync (instead of 2 in the happy/finite path).
  2. The clip_coef conditional is removed, in favor of a call to clamp(..., max=1.0) and an unconditional multiplication.

Testing

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 30, 2021

💊 CI failures summary and remediations

As of commit b156402 (more details on the Dr. CI page and at hud.pytorch.org/pr/61042):



🕵️ 3 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See GitHub Actions build Windows CI (pytorch-win-vs2019-cpu-py3) / test (default, 1, 2, windows.4xlarge) (1/3)

Step: "Store PyTorch Test Reports" (full log | diagnosis details | 🔁 rerun)

2021-07-21T14:10:07.2119187Z ls: cannot access ...d/win_tmp/ci_scripts/*': No such file or directory
2021-07-21T14:10:07.1091726Z + PYTORCH_FINAL_PACKAGE_DIR=/c/1051982143/build-results/
2021-07-21T14:10:07.1155749Z ++ cygpath -w /c/1051982143/build-results/
2021-07-21T14:10:07.1257889Z + PYTORCH_FINAL_PACKAGE_DIR_WIN='C:\1051982143\build-results\'
2021-07-21T14:10:07.1258432Z + export PYTORCH_FINAL_PACKAGE_DIR_WIN
2021-07-21T14:10:07.1259332Z + export PYTORCH_TEST_SKIP_NOARCH=1
2021-07-21T14:10:07.1259750Z + PYTORCH_TEST_SKIP_NOARCH=1
2021-07-21T14:10:07.1260448Z + mkdir -p /c/actions-runner/_work/pytorch/pytorch/pytorch-1051982143/build/win_tmp/build/torch
2021-07-21T14:10:07.1621933Z + CI_SCRIPTS_DIR=/c/actions-runner/_work/pytorch/pytorch/pytorch-1051982143/build/win_tmp/ci_scripts
2021-07-21T14:10:07.1622817Z + mkdir -p /c/actions-runner/_work/pytorch/pytorch/pytorch-1051982143/build/win_tmp/ci_scripts
2021-07-21T14:10:07.1821847Z ++ ls '/c/actions-runner/_work/pytorch/pytorch/pytorch-1051982143/build/win_tmp/ci_scripts/*'
2021-07-21T14:10:07.2119187Z ls: cannot access '/c/actions-runner/_work/pytorch/pytorch/pytorch-1051982143/build/win_tmp/ci_scripts/*': No such file or directory
2021-07-21T14:10:07.2121563Z + '[' -n '' ']'
2021-07-21T14:10:07.2122927Z + export SCRIPT_HELPERS_DIR=/c/actions-runner/_work/pytorch/pytorch/pytorch-1051982143/.jenkins/pytorch/win-test-helpers
2021-07-21T14:10:07.2123939Z + SCRIPT_HELPERS_DIR=/c/actions-runner/_work/pytorch/pytorch/pytorch-1051982143/.jenkins/pytorch/win-test-helpers
2021-07-21T14:10:07.2124575Z + IN_PULL_REQUEST=
2021-07-21T14:10:07.2124871Z + '[' -n '' ']'
2021-07-21T14:10:07.2125275Z + [[ pytorch-win-vs2019-cpu-py3 == *cuda11* ]]
2021-07-21T14:10:07.2125705Z + run_tests
2021-07-21T14:10:07.2126277Z + for path in '/c/Program Files/NVIDIA Corporation/NVSMI/nvidia-smi.exe' /c/Windows/System32/nvidia-smi.exe
2021-07-21T14:10:07.2127031Z + [[ -x /c/Program Files/NVIDIA Corporation/NVSMI/nvidia-smi.exe ]]
2021-07-21T14:10:07.2127640Z + '/c/Program Files/NVIDIA Corporation/NVSMI/nvidia-smi.exe'

See GitHub Actions build Windows CI (pytorch-win-vs2019-cpu-py3) / test (default, 2, 2, windows.4xlarge) (2/3)

Step: "Store PyTorch Test Reports" (full log | diagnosis details | 🔁 rerun)

2021-07-21T14:52:20.5154082Z test_add_done_ca...arg() takes 0 positional arguments but 1 was given
2021-07-21T14:52:20.5108691Z   c:\jenkins\miniconda3\lib\site-packages\coverage\execfile.py(247): run
2021-07-21T14:52:20.5109354Z   c:\jenkins\miniconda3\lib\site-packages\coverage\cmdline.py(746): do_run
2021-07-21T14:52:20.5110020Z   c:\jenkins\miniconda3\lib\site-packages\coverage\cmdline.py(588): command_line
2021-07-21T14:52:20.5110692Z   c:\jenkins\miniconda3\lib\site-packages\coverage\cmdline.py(871): main
2021-07-21T14:52:20.5111326Z   C:\Jenkins\Miniconda3\Scripts\coverage.exe\__main__.py(7): <module>
2021-07-21T14:52:20.5111854Z   c:\jenkins\miniconda3\lib\runpy.py(87): _run_code
2021-07-21T14:52:20.5112355Z   c:\jenkins\miniconda3\lib\runpy.py(194): _run_module_as_main
2021-07-21T14:52:20.5112646Z 
2021-07-21T14:52:20.5112889Z ok (0.000s)
2021-07-21T14:52:20.5139220Z   test_add_done_callback_maintains_callback_order (__main__.TestFuture) ... ok (0.016s)
2021-07-21T14:52:20.5154082Z   test_add_done_callback_no_arg_error_is_ignored (__main__.TestFuture) ... [E pybind_utils.h:200] Got the following error when running the callback: TypeError: no_arg() takes 0 positional arguments but 1 was given
2021-07-21T14:52:20.5154964Z ok (0.000s)
2021-07-21T14:52:20.5183047Z   test_add_done_callback_simple (__main__.TestFuture) ... ok (0.000s)
2021-07-21T14:52:20.5239908Z   test_chained_then (__main__.TestFuture) ... ok (0.000s)
2021-07-21T14:52:20.6329406Z   test_collect_all (__main__.TestFuture) ... ok (0.121s)
2021-07-21T14:52:20.6349374Z   test_done (__main__.TestFuture) ... ok (0.000s)
2021-07-21T14:52:20.6382764Z   test_done_exception (__main__.TestFuture) ... ok (0.000s)
2021-07-21T14:52:20.6424457Z   test_interleaving_then_and_add_done_callback_maintains_callback_order (__main__.TestFuture) ... ok (0.000s)
2021-07-21T14:52:20.6451442Z   test_interleaving_then_and_add_done_callback_propagates_error (__main__.TestFuture) ... [E pybind_utils.h:200] Got the following error when running the callback: ValueError: Expected error
2021-07-21T14:52:20.6452279Z 
2021-07-21T14:52:20.6452501Z At:

See GitHub Actions build Windows CI (pytorch-win-vs2019-cuda10-cudnn7-py3) / test (default, 1, 2, windows.8xlarge.nvidia.gpu) (3/3)

Step: "Store PyTorch Test Reports" (full log | diagnosis details | 🔁 rerun)

2021-07-21T14:40:17.0599885Z ls: cannot access ...d/win_tmp/ci_scripts/*': No such file or directory
2021-07-21T14:40:16.9865038Z + PYTORCH_FINAL_PACKAGE_DIR=/c/1051982129/build-results/
2021-07-21T14:40:16.9943148Z ++ cygpath -w /c/1051982129/build-results/
2021-07-21T14:40:17.0069816Z + PYTORCH_FINAL_PACKAGE_DIR_WIN='C:\1051982129\build-results\'
2021-07-21T14:40:17.0070979Z + export PYTORCH_FINAL_PACKAGE_DIR_WIN
2021-07-21T14:40:17.0072217Z + export PYTORCH_TEST_SKIP_NOARCH=1
2021-07-21T14:40:17.0072814Z + PYTORCH_TEST_SKIP_NOARCH=1
2021-07-21T14:40:17.0073609Z + mkdir -p /c/actions-runner/_work/pytorch/pytorch/pytorch-1051982129/build/win_tmp/build/torch
2021-07-21T14:40:17.0263353Z + CI_SCRIPTS_DIR=/c/actions-runner/_work/pytorch/pytorch/pytorch-1051982129/build/win_tmp/ci_scripts
2021-07-21T14:40:17.0264631Z + mkdir -p /c/actions-runner/_work/pytorch/pytorch/pytorch-1051982129/build/win_tmp/ci_scripts
2021-07-21T14:40:17.0506518Z ++ ls '/c/actions-runner/_work/pytorch/pytorch/pytorch-1051982129/build/win_tmp/ci_scripts/*'
2021-07-21T14:40:17.0599885Z ls: cannot access '/c/actions-runner/_work/pytorch/pytorch/pytorch-1051982129/build/win_tmp/ci_scripts/*': No such file or directory
2021-07-21T14:40:17.0603151Z + '[' -n '' ']'
2021-07-21T14:40:17.0604464Z + export SCRIPT_HELPERS_DIR=/c/actions-runner/_work/pytorch/pytorch/pytorch-1051982129/.jenkins/pytorch/win-test-helpers
2021-07-21T14:40:17.0606383Z + SCRIPT_HELPERS_DIR=/c/actions-runner/_work/pytorch/pytorch/pytorch-1051982129/.jenkins/pytorch/win-test-helpers
2021-07-21T14:40:17.0607186Z + IN_PULL_REQUEST=
2021-07-21T14:40:17.0607546Z + '[' -n '' ']'
2021-07-21T14:40:17.0608185Z + [[ pytorch-win-vs2019-cuda10-cudnn7-py3 == *cuda11* ]]
2021-07-21T14:40:17.0608859Z + run_tests
2021-07-21T14:40:17.0609598Z + for path in '/c/Program Files/NVIDIA Corporation/NVSMI/nvidia-smi.exe' /c/Windows/System32/nvidia-smi.exe
2021-07-21T14:40:17.0610567Z + [[ -x /c/Program Files/NVIDIA Corporation/NVSMI/nvidia-smi.exe ]]
2021-07-21T14:40:17.0611385Z + '/c/Program Files/NVIDIA Corporation/NVSMI/nvidia-smi.exe'

🚧 3 fixed upstream failures:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

If your commit is older than viable/strict, run these commands:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

Preview docs built from this PR

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@mautier mautier marked this pull request as ready for review July 1, 2021 08:22
@mautier
Copy link
Contributor Author

mautier commented Jul 1, 2021

I briefly looked into the pr/pytorch-linux-bionic-rocm4.2-py3.6 CI failure: test_reduce_stress_cuda in test_c10d_gloo.py is the failing test; as far as I can tell it is not related to clip_grad_norm_ in any way, so let me know how to proceed!

Secondly, I'm not sure if the testing I did is sufficient: is there a way to assert in a test than 0 or 1 device sync is occuring during the call to clip_grad_norm_? Or is my manual check sufficient?

@jbschlosser
Copy link
Contributor

I briefly looked into the pr/pytorch-linux-bionic-rocm4.2-py3.6 CI failure: test_reduce_stress_cuda in test_c10d_gloo.py is the failing test; as far as I can tell it is not related to clip_grad_norm_ in any way, so let me know how to proceed!

There are pre-existing failures for that test- we can ignore it for the purposes of this PR.

Secondly, I'm not sure if the testing I did is sufficient: is there a way to assert in a test than 0 or 1 device sync is occuring during the call to clip_grad_norm_? Or is my manual check sufficient?

No way to assert on device syncs that I'm aware of - cc @ngimel for confirmation

Copy link
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes on the python side lgtm - thanks for the update! Two things:

  • Unfortunately, there's currently a C++ reimplementation of this function in torch/csrc/api/include/torch/nn/utils/clip_grad.h that should be updated as well
  • I'm a bit curious about the performance implications of the unconditional multiplication on CPU

@JackCaoG Does it make sense to apply similar updates on the XLA side?

@codecov
Copy link

codecov bot commented Jul 1, 2021

Codecov Report

Merging #61042 (40150e3) into master (fab1b6c) will decrease coverage by 0.47%.
The diff coverage is 100.00%.

❗ Current head 40150e3 differs from pull request most recent head a2e6e26. Consider uploading reports for the commit a2e6e26 to get more accurate results

@@            Coverage Diff             @@
##           master   #61042      +/-   ##
==========================================
- Coverage   76.22%   75.74%   -0.48%     
==========================================
  Files        2062     2062              
  Lines      205577   209332    +3755     
==========================================
+ Hits       156693   158552    +1859     
- Misses      48884    50780    +1896     

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jul 2, 2021

Thanks @jbschlosser , will take a look tmr.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jul 2, 2021

@jbschlosser I think pt/xla can also benefit from this change, so I will update our patch as well. We are thinking about removing the patch for clip_grad_norm_ for a while since item call is removed but we didn't benchmark the result yet..

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jul 2, 2021

Ah actually pt/xla already do the manual clapping and check for error_if_nonfinite already. Pt/xla should be fine.

@jbschlosser
Copy link
Contributor

@mautier Would you be willing to update the C++ version in torch/csrc/api/include/torch/nn/utils/clip_grad.h as well?

@mautier
Copy link
Contributor Author

mautier commented Jul 8, 2021

@jbschlosser Sorry for the radio silence; life things and a fried power supply have gotten in the way 🙃

I have a patch for the C++ version that I'll push soon; unfortunately it appears that the C++ version has a slightly different API, in that it returns a double (instead of a scalar tensor like in python). So while I was able to eliminate the unwanted syncs and align the error_if_nonfinite semantics, there remains a final sync that cannot be removed without changing the function signature.

@jbschlosser
Copy link
Contributor

@jbschlosser Sorry for the radio silence; life things and a fried power supply have gotten in the way 🙃

I have a patch for the C++ version that I'll push soon; unfortunately it appears that the C++ version has a slightly different API, in that it returns a double (instead of a scalar tensor like in python). So while I was able to eliminate the unwanted syncs and align the error_if_nonfinite semantics, there remains a final sync that cannot be removed without changing the function signature.

No worries! RIP your power supply :/

Thanks for checking out the C++ version as well! It's unfortunate that the API differences make that last sync unremovable, but it does makes sense. We want to keep the current API, but a comment about the discrepancy resulting in a final unremovable sync wouldn't hurt.

@iramazanli iramazanli added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 8, 2021
@albanD
Copy link
Collaborator

albanD commented Jul 9, 2021

cc @ngimel

} else if (norm_type == 0) {
total_norm = static_cast<double>(params_with_grad.size());
total_norm_tensor = torch::full({}, static_cast<double>(params_with_grad.size()));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not all that familiar with the C++ tensor creation APIs, is there a better way of creating a scalar tensor?

And on an unrelated note: order 0 norm is defined in torch.linalg.norm as the number of non-zero entries, but in this implementation it's just the number of parameters?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch::full() is fine here imo.

Yeah, it does look like there's a discrepancy between here and the order 0 norm definition in torch.linalg.norm. I guess there's an implicit assumption that all grads are nonzero here? Regardless, we should leave as-is for backwards compatibility.

Copy link
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the awesome fix & commenting! LGTM

Mind rebasing to address merge conflicts?

} else if (norm_type == 0) {
total_norm = static_cast<double>(params_with_grad.size());
total_norm_tensor = torch::full({}, static_cast<double>(params_with_grad.size()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch::full() is fine here imo.

Yeah, it does look like there's a discrepancy between here and the order 0 norm definition in torch.linalg.norm. I guess there's an implicit assumption that all grads are nonzero here? Regardless, we should leave as-is for backwards compatibility.

@facebook-github-bot
Copy link
Contributor

@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@mautier mautier force-pushed the mautier/clip_grad_norm_remove_device_syncs branch from a2e6e26 to a82f25c Compare July 19, 2021 18:04
@mautier
Copy link
Contributor Author

mautier commented Jul 19, 2021

@jbschlosser Rebased on master and (force-)pushed. Let me know if you want me to squash the commits into one too (not sure if you merge or squash-merge PRs)!

@facebook-github-bot
Copy link
Contributor

@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

mautier and others added 3 commits July 21, 2021 12:20
Previously, when the caller opted out of nan/inf checks on the gradient
norm, the checks would still be run in order to produce a warning.

Unfortunately doing so incurs a synchronization cost when the gradients
live on, say, a CUDA device, as the CPU-side control flow depends on the
result of CUDA-side computations.

This commit removes the warning codepath; when the user opts out of
finite-ness checks (`error_if_nonfinite=False`), the checks are skipped,
and there is no performance penalty.

Additionally, the 2 separate checks (`isnan()` and `isinf()`) are now
combined using `torch.logical_or`. This means that when the checks do
run on a non-CPU device, a single synchronization will be required
instead of 2. (the previous behavior allowed for short-circuiting, but
only in the NaN case, not in the happy path)
The `if clip_coef < 1:` conditional incurs a device synchronization when
the gradients are on a non-CPU device. This commit removes the
conditional in favor of a clamp-ing step and unconditional scaling.
Just like in the python version of this function, this commit removes
all nan/inf checks when `error_if_nonfinite = false`.

It also changes the computation of the norm to be based on standard
pytorch tensor operations (instead of more manual computations with
`std::max` and `std::pow`), in order to make it possible for those
computations to run directly on the device with no CPU synchronization.

Unfortunately, even then, since the C++ API returns the norm of the
gradients as a `double` (and not a scalar tensor), the implementation
must inevitably synchronize the CPU and device at the end.

Nonetheless, this function now synchronizes only once, as late as
possible, instead of many times (once per param).
@mautier mautier force-pushed the mautier/clip_grad_norm_remove_device_syncs branch from a82f25c to b156402 Compare July 21, 2021 09:21
@mautier
Copy link
Contributor Author

mautier commented Jul 21, 2021

Rebased again to pick up 2 fixes that went into master since the last rebase; this should address the quick-checks and shellcheck lint failures.

@facebook-github-bot
Copy link
Contributor

@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@jbschlosser merged this pull request in e858f6e.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torch.nn.utils.clip_grad_norm_: bad GPU utilization due to GPU-data-dependent control-flow
7 participants