Skip to content

Commit

Permalink
reduce computation of batch_norm when weight or bias is non
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
XiaobingSuper committed Jul 5, 2023
1 parent 12ca224 commit 1cf15bb
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,19 +1434,16 @@ def native_batch_norm_helper(
invstd = _unsqueeze_to_dim(invstd, input.dim() - 1)
output = (input - mean) * invstd

if weight is None:
weight = input.new_ones(())
else:
if weight is not None:
weight = weight.flatten()
weight = _unsqueeze_to_dim(weight, input.dim() - 1)
output = output * weight

if bias is None:
bias = input.new_zeros(())
else:
if bias is not None:
bias = bias.flatten()
bias = _unsqueeze_to_dim(bias, input.dim() - 1)
output = output + bias

weight = _unsqueeze_to_dim(weight, input.dim() - 1)
bias = _unsqueeze_to_dim(bias, input.dim() - 1)
output = output * weight + bias
if input.device.type == "cpu":
save_mean = save_mean.to(dtype=input.dtype)
save_rstd = save_rstd.to(dtype=input.dtype)
Expand Down

0 comments on commit 1cf15bb

Please sign in to comment.