-
Notifications
You must be signed in to change notification settings - Fork 22k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix SyncBatchNorm usage without stats tracking #50126
Fix SyncBatchNorm usage without stats tracking #50126
Conversation
💊 CI failures summary and remediationsAs of commit 4e5308c (more details on the Dr. CI page):
1 job timed out:
🚧 1 fixed upstream failure:These were probably caused by upstream breakages that were already fixed.
Please rebase on the
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@malfet has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
cc @jjsjann123 fyi. Is it true that |
908e570
to
2d29fd1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@malfet has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Codecov Report
@@ Coverage Diff @@
## master #50126 +/- ##
===========================================
+ Coverage 70.25% 80.68% +10.43%
===========================================
Files 1900 1900
Lines 206246 206246
===========================================
+ Hits 144894 166408 +21514
+ Misses 61352 39838 -21514 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@malfet has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: In `batch_norm_gather_stats_with_counts_cuda` use `input.scalar_type()` if `running_mean` is not defined In `SyncBatchNorm` forward function create count tensor with `torch.float32` type if `running_mean` is None Fix a few typos Pull Request resolved: pytorch#50126 Test Plan: ``` python -c "import torch;print(torch.batch_norm_gather_stats_with_counts( torch.randn(1, 3, 3, 3, device='cuda'), mean = torch.ones(2, 3, device='cuda'), invstd = torch.ones(2, 3, device='cuda'), running_mean = None, running_var = None , momentum = .1, eps = 1e-5, counts = torch.ones(2, device='cuda')))" ``` Fixes pytorch#49730 Reviewed By: ngimel Differential Revision: D25797930 Pulled By: malfet fbshipit-source-id: 22a91e3969b5e9bbb7969d9cc70b45013a42fe83
Hi, @malfet @ngimel it seems like this still fails when using track_running_stats=False when doing mixed-precision training (in distributed data-parallel)? Version Details: With track_running_stats=False, I get the following stack trace:
Can you please take a look (or should I create a new issue)? |
In
batch_norm_gather_stats_with_counts_cuda
useinput.scalar_type()
ifrunning_mean
is not definedIn
SyncBatchNorm
forward function create count tensor withtorch.float32
type ifrunning_mean
is NoneFix a few typos
Test Plan:
Fixes #49730