-
Notifications
You must be signed in to change notification settings - Fork 25k
Add support for non-affine batch norm with float stats and half inputs #22750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
CC @jjsjann123 |
Lint issue seems to come from here: |
Great work! Looks like we are doing the right thing everywhere. Thanks very much for taking care of this! |
Thanks for the review @jjsjann123! |
Please don't merge it yet, as I would like to run some additional tests first. |
Sorry for blocking this PR. Thanks @jjsjann123 for the support! 🙂 |
scalar_t inp = input[batch][plane][x]; | ||
accscalar_t proj = (inp - mean) * proj_scale; | ||
grad_input[batch][plane][x] = static_cast<scalar_t>((go - proj - grad_mean) * grad_scale); | ||
input_scalar_t inp = input[batch][plane][x]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't inp be stat_accscalar_t
? Since it was casted right after.
stat_accscalar_t m_c = mean[plane]; | ||
stat_accscalar_t m_dy_c = mean_dy[plane]; | ||
stat_accscalar_t factor_1_c = invstd[plane]; | ||
stat_accscalar_t factor_2_c = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : static_cast<stat_accscalar_t>(1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The last one could just be stat_accscalar_t(1)
@@ -765,22 +765,22 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce_cuda_templ | |||
mean_dy_ = at::empty_like(mean_); | |||
mean_dy_xmu_ = at::empty_like(mean_); | |||
} | |||
auto grad_options = grad_out_.options(); | |||
auto grad_options = mean_.options(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if we discussed it here before.
IIRC, mean_
is passed here as calculated from the batch_norm_gather_stats_with_counts
pytorch/torch/nn/modules/_functions.py
Lines 62 to 70 in b93f29d
mean_dy, mean_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce( | |
grad_output, | |
saved_input, | |
mean, | |
invstd, | |
self.needs_input_grad[0], | |
self.needs_input_grad[1], | |
self.needs_input_grad[2] | |
) |
If we have input and weight both in fp16. This will break?
We might have to pass in an optional variable (either weight or bias) here to make the decision. It cannot be deduced from given information.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I could imagine that we must have fp16 layer with fp16 input. If unit test is fine, I can take another look at this with a clearer head tomorrow morning :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense and I'm passing now the weight
tensor to this method. Since it can be None
, I'm calling weight_.options()
now for grad_weight_
and grad_bias_
separately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Thanks for the hard work.
Looks good to me except for some code cleaning and that |
…duce_cuda_template, add requested changes
@pytorchbot rebase this please |
Any pointers on this failing test:
Is this a valid failure? |
@pytorchbot retest this please |
@pytorchbot retest this please |
@pytorchbot rebase this please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
…s (#22750) Summary: This PR creates support for non-affine batch norm with float running estimates and half inputs. Changed were made similar to pytorch/pytorch#16735. I couldn't find a specific test for `SyncBatchNorm`, so I used [this code](https://gist.github.com/ptrblck/ab45bfcde6df55ac28a7be18531f4718) to test it. cc ngimel Pull Request resolved: pytorch/pytorch#22750 Differential Revision: D17119965 Pulled By: ezyang fbshipit-source-id: 2e8c5d63fc3c636b8a1338c43c9c101a0f5e9b22
This PR creates support for non-affine batch norm with float running estimates and half inputs.
Changed were made similar to #16735.
I couldn't find a specific test for
SyncBatchNorm
, so I used this code to test it.cc @ngimel