-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SyncBatchNorm] Support running with low precision parameters #98332
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/98332
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 285e049: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 3fc48950d1ebc41bb25bac9939a1c1e3e315d39d Pull Request resolved: #98332
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the fix!
…ded" This PR fixes #96203. **Details** When using `nn.SyncBatchNorm` with the model converted to FP16, there is a dtype discrepancy in the `SyncBatchNorm.forward()` causing an error like: ``` File "/.../pytorch/torch/nn/modules/_functions.py", line 91, in forward mean, invstd = torch.batch_norm_gather_stats_with_counts( RuntimeError: Expected counts to have type Half but got Float ``` [`torch.batch_norm_gather_stats_with_counts()`](https://github.com/pytorch/pytorch/blob/fe9da29842a07a1f44d6b8c2a4c75053da9e84d0/torch/nn/modules/_functions.py#L88-L97) requires the `running_mean`, `running_var`, and `counts` to have the same dtype. However, when the model has been converted to FP16, only `running_mean` and `running_var` use FP16, while the `counts` are in FP32 due to [`mean` being in FP32](https://github.com/pytorch/pytorch/blob/fe9da29842a07a1f44d6b8c2a4c75053da9e84d0/torch/nn/modules/_functions.py#L25-L30). This PR resolves this by casting `counts` from FP32 to FP16 instead of the alternative to cast `mean` and `invstd` from FP32 to FP16. Moreover, for the backward, this PR casts `weight` from FP16 to FP32 to match the dtype of `mean` and `invstd` as required by `torch.batch_norm_backward_elemt()` instead of the alternative to cast `mean` and `invstd` from FP32 to FP16. **Test Plan** I dug up this run command from 2021: ``` WORLD_SIZE=2 BACKEND=nccl python -m pytest test/distributed/test_distributed_spawn.py -k test_DistributedDataParallel_SyncBatchNorm_half -vs ``` [ghstack-poisoned]
ghstack-source-id: e8a775d674c7d1c0f17b9b2348c9ca6184a59d54 Pull Request resolved: #98332
running_mean.dtype
if needed…ers" This PR fixes #96203. **Details** When using `nn.SyncBatchNorm` with the model converted to FP16, there is a dtype discrepancy in the `SyncBatchNorm.forward()` causing an error like: ``` File "/.../pytorch/torch/nn/modules/_functions.py", line 91, in forward mean, invstd = torch.batch_norm_gather_stats_with_counts( RuntimeError: Expected counts to have type Half but got Float ``` [`torch.batch_norm_gather_stats_with_counts()`](https://github.com/pytorch/pytorch/blob/fe9da29842a07a1f44d6b8c2a4c75053da9e84d0/torch/nn/modules/_functions.py#L88-L97) requires the `running_mean`, `running_var`, and `counts` to have the same dtype. However, when the model has been converted to FP16, only `running_mean` and `running_var` use FP16, while the `counts` are in FP32 due to [`mean` being in FP32](https://github.com/pytorch/pytorch/blob/fe9da29842a07a1f44d6b8c2a4c75053da9e84d0/torch/nn/modules/_functions.py#L25-L30). This PR resolves this by casting `counts` from FP32 to FP16 instead of the alternative to cast `mean` and `invstd` from FP32 to FP16. Moreover, for the backward, this PR casts `weight` from FP16 to FP32 to match the dtype of `mean` and `invstd` as required by `torch.batch_norm_backward_elemt()` instead of the alternative to cast `mean` and `invstd` from FP32 to FP16. **Test Plan** I dug up this run commands from 2021: For `world_size` in `{1,2}` and `backend` in `{nccl, gloo}`: ``` WORLD_SIZE=world_size BACKEND=backend python -m pytest test/distributed/test_distributed_spawn.py -k test_DistributedDataParallel_SyncBatchNorm_half -vs ``` [ghstack-poisoned]
ghstack-source-id: a804c303e0ebe15f4fb0680737c25992d2c44da8 Pull Request resolved: #98332
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR fixes #96203. **Details** When using `nn.SyncBatchNorm` with the model converted to FP16, there is a dtype discrepancy in the `SyncBatchNorm.forward()` causing an error like: ``` File "/.../pytorch/torch/nn/modules/_functions.py", line 91, in forward mean, invstd = torch.batch_norm_gather_stats_with_counts( RuntimeError: Expected counts to have type Half but got Float ``` [`torch.batch_norm_gather_stats_with_counts()`](https://github.com/pytorch/pytorch/blob/fe9da29842a07a1f44d6b8c2a4c75053da9e84d0/torch/nn/modules/_functions.py#L88-L97) requires the `running_mean`, `running_var`, and `counts` to have the same dtype. However, when the model has been converted to FP16, only `running_mean` and `running_var` use FP16, while the `counts` are in FP32 due to [`mean` being in FP32](https://github.com/pytorch/pytorch/blob/fe9da29842a07a1f44d6b8c2a4c75053da9e84d0/torch/nn/modules/_functions.py#L25-L30). This PR resolves this by casting `counts` from FP32 to FP16 instead of the alternative to cast `mean` and `invstd` from FP32 to FP16. Moreover, for the backward, this PR casts `weight` from FP16 to FP32 to match the dtype of `mean` and `invstd` as required by `torch.batch_norm_backward_elemt()` instead of the alternative to cast `mean` and `invstd` from FP32 to FP16. **Test Plan** I dug up this run command from 2021: For `world_size` in `{1,2}` and `backend` in `{nccl, gloo}`: ``` WORLD_SIZE=world_size BACKEND=backend python -m pytest test/distributed/test_distributed_spawn.py -k test_DistributedDataParallel_SyncBatchNorm_half -vs ``` Pull Request resolved: #98332 Approved by: https://github.com/rohan-varma
Stack from ghstack (oldest at bottom):
requires_grad_mask
#98299_use_sharded_views()
forSHARD_GRAD_OP
#98250requires_grad
foruse_orig_params=True
#98221This PR fixes #96203.
Details
When using
nn.SyncBatchNorm
with the model converted to FP16, there is a dtype discrepancy in theSyncBatchNorm.forward()
causing an error like:torch.batch_norm_gather_stats_with_counts()
requires therunning_mean
,running_var
, andcounts
to have the same dtype. However, when the model has been converted to FP16, onlyrunning_mean
andrunning_var
use FP16, while thecounts
are in FP32 due tomean
being in FP32. This PR resolves this by castingcounts
from FP32 to FP16 instead of the alternative to castmean
andinvstd
from FP32 to FP16.Moreover, for the backward, this PR casts
weight
from FP16 to FP32 to match the dtype ofmean
andinvstd
as required bytorch.batch_norm_backward_elemt()
instead of the alternative to castmean
andinvstd
from FP32 to FP16.Test Plan
I dug up this run command from 2021:
For
world_size
in{1,2}
andbackend
in{nccl, gloo}
: