Skip to content

Add support for non-affine batch norm with float stats and half inputs #22750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from

Conversation

ptrblck
Copy link
Collaborator

@ptrblck ptrblck commented Jul 11, 2019

This PR creates support for non-affine batch norm with float running estimates and half inputs.
Changed were made similar to #16735.

I couldn't find a specific test for SyncBatchNorm, so I used this code to test it.

cc @ngimel

@pytorchbot pytorchbot added module: cuda Related to torch.cuda, and CUDA support in general module: nn Related to torch.nn module: operators labels Jul 11, 2019
@ptrblck
Copy link
Collaborator Author

ptrblck commented Jul 11, 2019

CC @jjsjann123

@ptrblck
Copy link
Collaborator Author

ptrblck commented Jul 11, 2019

Lint issue seems to come from here:
#21323 (comment)

@jjsjann123
Copy link
Collaborator

Great work! Looks like we are doing the right thing everywhere.

Thanks very much for taking care of this!

@ptrblck
Copy link
Collaborator Author

ptrblck commented Jul 12, 2019

Thanks for the review @jjsjann123!

@ptrblck
Copy link
Collaborator Author

ptrblck commented Jul 12, 2019

Please don't merge it yet, as I would like to run some additional tests first.

@jerryzh168 jerryzh168 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 13, 2019
@jerryzh168 jerryzh168 requested a review from li-roy July 13, 2019 00:55
@ptrblck
Copy link
Collaborator Author

ptrblck commented Jul 16, 2019

Sorry for blocking this PR.
I've added support for backward calls, so that it can be reviewed now.

Thanks @jjsjann123 for the support! 🙂

scalar_t inp = input[batch][plane][x];
accscalar_t proj = (inp - mean) * proj_scale;
grad_input[batch][plane][x] = static_cast<scalar_t>((go - proj - grad_mean) * grad_scale);
input_scalar_t inp = input[batch][plane][x];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't inp be stat_accscalar_t? Since it was casted right after.

stat_accscalar_t m_c = mean[plane];
stat_accscalar_t m_dy_c = mean_dy[plane];
stat_accscalar_t factor_1_c = invstd[plane];
stat_accscalar_t factor_2_c = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : static_cast<stat_accscalar_t>(1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last one could just be stat_accscalar_t(1)

@@ -765,22 +765,22 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce_cuda_templ
mean_dy_ = at::empty_like(mean_);
mean_dy_xmu_ = at::empty_like(mean_);
}
auto grad_options = grad_out_.options();
auto grad_options = mean_.options();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we discussed it here before.
IIRC, mean_ is passed here as calculated from the batch_norm_gather_stats_with_counts

mean_dy, mean_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
grad_output,
saved_input,
mean,
invstd,
self.needs_input_grad[0],
self.needs_input_grad[1],
self.needs_input_grad[2]
)

If we have input and weight both in fp16. This will break?
We might have to pass in an optional variable (either weight or bias) here to make the decision. It cannot be deduced from given information.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I could imagine that we must have fp16 layer with fp16 input. If unit test is fine, I can take another look at this with a clearer head tomorrow morning :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense and I'm passing now the weight tensor to this method. Since it can be None, I'm calling weight_.options() now for grad_weight_ and grad_bias_ separately.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Thanks for the hard work.

@jjsjann123
Copy link
Collaborator

Looks good to me except for some code cleaning and that grad_option thing I mentioned there.

@soumith
Copy link
Member

soumith commented Jul 22, 2019

@pytorchbot rebase this please

@ptrblck
Copy link
Collaborator Author

ptrblck commented Aug 2, 2019

Any pointers on this failing test:

00:53:35 FAIL: test_wrong_cuda_fork (__main__.TestMultiprocessing)
00:53:35 ----------------------------------------------------------------------
00:53:35 Traceback (most recent call last):
00:53:35   File "test_multiprocessing.py", line 496, in test_wrong_cuda_fork
00:53:35     you must use the 'spawn' start method")
00:53:35 AssertionError: Regex didn't match: "Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method" not found in 'Traceback (most recent call last):\r\n  File "<string>", line 1, in <module>\r\n  File "C:\\Jenkins\\Miniconda3\\lib\\multiprocessing\\spawn.py", line 105, in spawn_main\r\n    exitcode = _main(fd)\r\n  File "C:\\Jenkins\\Miniconda3\\lib\\multiprocessing\\spawn.py", line 115, in _main\r\n    self = reduction.pickle.load(from_parent)\r\nAttributeError: Can\'t get attribute \'run\' on <module \'__main__\' (built-in)>\r\nTraceback (most recent call last):\r\n  File "<string>", line 1, in <module>\r\n  File "C:\\Jenkins\\Miniconda3\\lib\\multiprocessing\\spawn.py", line 105, in spawn_main\r\n    exitcode = _main(fd)\r\n  File "C:\\Jenkins\\Miniconda3\\lib\\multiprocessing\\spawn.py", line 115, in _main\r\n    self = reduction.pickle.load(from_parent)\r\nAttributeError: Can\'t get attribute \'run\' on <module \'__main__\' (built-in)>\r\n'

Is this a valid failure?

@ptrblck
Copy link
Collaborator Author

ptrblck commented Aug 6, 2019

@pytorchbot retest this please

@ptrblck
Copy link
Collaborator Author

ptrblck commented Aug 7, 2019

@pytorchbot retest this please

@ptrblck
Copy link
Collaborator Author

ptrblck commented Aug 7, 2019

@pytorchbot rebase this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Aug 29, 2019
…s (#22750)

Summary:
This PR creates support for non-affine batch norm with float running estimates and half inputs.
Changed were made similar to pytorch/pytorch#16735.

I couldn't find a specific test for `SyncBatchNorm`, so I used [this code](https://gist.github.com/ptrblck/ab45bfcde6df55ac28a7be18531f4718) to test it.

cc ngimel
Pull Request resolved: pytorch/pytorch#22750

Differential Revision: D17119965

Pulled By: ezyang

fbshipit-source-id: 2e8c5d63fc3c636b8a1338c43c9c101a0f5e9b22
@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in 8640aef.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: cuda Related to torch.cuda, and CUDA support in general module: nn Related to torch.nn open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants