Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add decomposition for aten.native_batch_norm_backward op #675

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
55 changes: 55 additions & 0 deletions functorch/_src/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,61 @@ def native_layer_norm_backward(grad_out: Tensor, input: Tensor, normalized_shape
return (d_input, d_weight, d_bias)


@register_decomposition(aten.native_batch_norm_backward)
def native_batch_norm_backward(grad_out: Tensor, input: Tensor, weight: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], save_mean: Optional[Tensor], save_invstd: Optional[Tensor], train: bool, eps: float, output_mask: List[bool]) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
input_shape = input.shape
input_rank = input.dim()
assert input_rank >= 2, "rank of the input must be at least 2"

axis = 1
num_features = prod(input_shape) / input_shape[axis]
mean = save_mean
invstd = save_invstd
if train:
assert save_mean is not None and save_invstd is not None, "when train=True, save_mean and save_invstd are required"
else:
assert running_mean is not None and running_var is not None
mean = running_mean
invstd = torch.rsqrt(running_var + eps)

broadcast_mask = [1] * input_rank
broadcast_mask[axis] = input_shape[axis]

reduction_axes = []
for i in range(input_rank):
if i != axis:
reduction_axes.append(i)

mean = torch.reshape(mean, broadcast_mask)
norm = 1.0 / num_features
grad_output_sum = torch.sum(grad_out, reduction_axes)
dot_p = torch.sum(grad_out * (input - mean), reduction_axes)

grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask)
proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask)

grad_scale = None
if weight is None:
grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0
else:
grad_scale = torch.reshape(invstd * weight, broadcast_mask)
grad_input = None
if train:
proj = (input - mean) * proj_scale
grad_input = ((grad_out - proj) - grad_mean) * grad_scale
else:
grad_input = grad_out * grad_scale

grad_weight = None
if output_mask[1]:
grad_weight = dot_p * invstd

grad_bias = None
if output_mask[2]:
grad_bias = grad_output_sum
return (grad_input, grad_weight, grad_bias)


@register_decomposition(aten.clamp_min)
def clamp_min(self: Tensor, min: float):
return torch.clamp(self, min=min)
Expand Down