Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SyncBatchNorm usage without stats tracking #50126

Conversation

malfet
Copy link
Contributor

@malfet malfet commented Jan 6, 2021

In batch_norm_gather_stats_with_counts_cuda use input.scalar_type() if running_mean is not defined
In SyncBatchNorm forward function create count tensor with torch.float32 type if running_mean is None
Fix a few typos

Test Plan:

python -c "import torch;print(torch.batch_norm_gather_stats_with_counts( torch.randn(1, 3, 3, 3, device='cuda'), mean = torch.ones(2, 3, device='cuda'), invstd = torch.ones(2, 3, device='cuda'), running_mean = None, running_var = None  , momentum = .1, eps = 1e-5, counts = torch.ones(2, device='cuda')))"

Fixes #49730

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jan 6, 2021

💊 CI failures summary and remediations

As of commit 4e5308c (more details on the Dr. CI page):


  • 1/2 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)
  • 1/2 broken upstream at merge base c517e15 on Jan 06 from 7:16am to 1:33pm

1 job timed out:

  • pytorch_linux_bionic_py3_8_gcc9_coverage_test1

🚧 1 fixed upstream failure:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

If your commit is older than viable/strict, run these commands:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

Check out the recency history of this "viable master" tracking branch.


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

This comment has been revised 23 times.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@malfet has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

aten/src/ATen/native/cuda/Normalization.cu Show resolved Hide resolved
torch/nn/modules/_functions.py Outdated Show resolved Hide resolved
@ngimel
Copy link
Collaborator

ngimel commented Jan 6, 2021

cc @jjsjann123 fyi. Is it true that count type should always be like mean/running mean type, and mean/running mean types should be the same if running_mean is defined?

@malfet malfet force-pushed the malfet/fix-SyncBatchNorm-without-stats-tracking branch from 908e570 to 2d29fd1 Compare January 6, 2021 15:55
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@malfet has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@codecov
Copy link

codecov bot commented Jan 6, 2021

Codecov Report

Merging #50126 (2d29fd1) into master (2ac180a) will increase coverage by 10.43%.
The diff coverage is 0.00%.

@@             Coverage Diff             @@
##           master   #50126       +/-   ##
===========================================
+ Coverage   70.25%   80.68%   +10.43%     
===========================================
  Files        1900     1900               
  Lines      206246   206246               
===========================================
+ Hits       144894   166408    +21514     
+ Misses      61352    39838    -21514     

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@malfet has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@malfet merged this pull request in bf4fcab.

hwangdeyu pushed a commit to hwangdeyu/pytorch that referenced this pull request Jan 14, 2021
Summary:
In `batch_norm_gather_stats_with_counts_cuda` use `input.scalar_type()` if `running_mean` is not defined
In `SyncBatchNorm` forward function create count tensor with `torch.float32` type if `running_mean` is None
Fix a few typos

Pull Request resolved: pytorch#50126

Test Plan:
```
python -c "import torch;print(torch.batch_norm_gather_stats_with_counts( torch.randn(1, 3, 3, 3, device='cuda'), mean = torch.ones(2, 3, device='cuda'), invstd = torch.ones(2, 3, device='cuda'), running_mean = None, running_var = None  , momentum = .1, eps = 1e-5, counts = torch.ones(2, device='cuda')))"
```

Fixes pytorch#49730

Reviewed By: ngimel

Differential Revision: D25797930

Pulled By: malfet

fbshipit-source-id: 22a91e3969b5e9bbb7969d9cc70b45013a42fe83
@rangwani-harsh
Copy link

Hi, @malfet @ngimel it seems like this still fails when using track_running_stats=False when doing mixed-precision training (in distributed data-parallel)?

Version Details:
torch1.10.1
cuda==11.3

With track_running_stats=False, I get the following stack trace:

  File "/home/auk/miniconda3/envs/torch1.10/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/auk/Naman/PyTorch-StudioGAN/src/utils/model_ops.py", line 130, in forward
    out = self.bn(x)
  File "/home/auk/miniconda3/envs/torch1.10/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/auk/miniconda3/envs/torch1.10/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py", line 749, in forward
    return sync_batch_norm.apply(
  File "/home/auk/miniconda3/envs/torch1.10/lib/python3.9/site-packages/torch/nn/modules/_functions.py", line 59, in forward
    mean, invstd = torch.batch_norm_gather_stats_with_counts(
RuntimeError: Expected counts to have type Half but got Float

Can you please take a look (or should I create a new issue)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Program throws exception when using SyncBatchNorm with track_running_stats = False
5 participants