-
Notifications
You must be signed in to change notification settings - Fork 22.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement BatchNorm double backwards #2207
Conversation
…ctly from C++. This will be converted to C++ code once ATen is integrated with autograd.
The math isn't that clean here; I'm sure it can be simplified and made more efficient (as one example, I didn't use in-place functions). |
all_inputs.push_back(input_var); | ||
if (weight.get()) { | ||
all_inputs.push_back(weight_var); | ||
all_inputs.push_back(bias_var); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
||
def sum_exclude_dim1(to_sum, keepdim=True): | ||
to_sum = to_sum.sum(dim=0, keepdim=True) | ||
dim = 2 |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
# define some terms we will reuse | ||
M = reduce(mul, input.size()[0:1] + input.size()[2:]) | ||
mu = sum_exclude_dim1(input).div_(M) | ||
input_sub_mu = input - mu |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
gG = None | ||
if affine and ggI is not None: | ||
# gG is just the first backwards with the gamma term removed (then shaped properly) | ||
gG = ggI * first_back_grad_input(gO, 1) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
updated for @albanD's suggestions. |
woohoo! thanks Greg. |
* commit '8262920b72374b1d9643f35057663ab02ab20330': (272 commits) Add ATen overload to AutoGPU. (pytorch#2234) Add comments for default value (pytorch#2242) Remove dead THPP code that has been replaced with ATen objects. (pytorch#2235) fix a bug where an uninitialized at::Tensor was passed to createPyObject (pytorch#2239) Replace thpp::Tensor with ATen Tensor in autograd csrc (pytorch#2170) Added aarch64 support (pytorch#2226) Increase tol. for float tensor qr big test. Improve Variable.retain_grad add `retain_grad` method, to variable, so gradient gets stored during backpop, on non-user variables Implement BatchNorm double backwards (pytorch#2207) [bugfix] in bce_with_logits logsumexp calculation (pytorch#2221) fix for ATen API Change Opt into Trusty builds. (pytorch#2214) allow retain to be specified for unsafeTensorFromTH Deduplicate THPUtils_checkLong/THPUtils_unpackLong (pytorch#2218) fix osx build errors related to long/int64_t Note [Undefined-dim versus 0-dim] Remove __func__ hack in auto nn. Enable Conv groups gradgradchecks. (pytorch#2216) fix a bug where some scalars were getting truncated to integers incorrectly. ...
* commit '8262920b72374b1d9643f35057663ab02ab20330': (272 commits) Add ATen overload to AutoGPU. (pytorch#2234) Add comments for default value (pytorch#2242) Remove dead THPP code that has been replaced with ATen objects. (pytorch#2235) fix a bug where an uninitialized at::Tensor was passed to createPyObject (pytorch#2239) Replace thpp::Tensor with ATen Tensor in autograd csrc (pytorch#2170) Added aarch64 support (pytorch#2226) Increase tol. for float tensor qr big test. Improve Variable.retain_grad add `retain_grad` method, to variable, so gradient gets stored during backpop, on non-user variables Implement BatchNorm double backwards (pytorch#2207) [bugfix] in bce_with_logits logsumexp calculation (pytorch#2221) fix for ATen API Change Opt into Trusty builds. (pytorch#2214) allow retain to be specified for unsafeTensorFromTH Deduplicate THPUtils_checkLong/THPUtils_unpackLong (pytorch#2218) fix osx build errors related to long/int64_t Note [Undefined-dim versus 0-dim] Remove __func__ hack in auto nn. Enable Conv groups gradgradchecks. (pytorch#2216) fix a bug where some scalars were getting truncated to integers incorrectly. ...
* update comments in segment_reduction_op * Update segment_reduction_op.cc
The double backwards function is currently implemented as a python function called direction from C++, but the plan is to convert it to C++ code once ATen is integrated with autograd.