Skip to content

Commit

Permalink
Add forward AD formulas for {batch,layer,group}_norm (#70355)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #70355

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D33405362

Pulled By: soulitzer

fbshipit-source-id: 55a92e88a04e7b15a0a223025d66c14f7db2a190
  • Loading branch information
soulitzer authored and facebook-github-bot committed Jan 10, 2022
1 parent 7a08030 commit 78994d1
Show file tree
Hide file tree
Showing 4 changed files with 283 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tools/autograd/derivatives.yaml
Expand Up @@ -1084,6 +1084,7 @@

- name: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, eps)

- name: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, train, eps, save_mean, save_invstd, grad_input_mask)
Expand All @@ -1092,6 +1093,7 @@

- name: native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
input, weight, bias: "grad.defined() ? native_layer_norm_backward(grad, input, normalized_shape, result1, result2, weight, bias, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
result0: layer_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, normalized_shape)

- name: native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
input, weight, grad_out: layer_norm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, mean, rstd, normalized_shape, grad_input_mask)
Expand All @@ -1101,6 +1103,9 @@

- name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, int N, int C, int HxW, int group, float eps) -> (Tensor, Tensor, Tensor)
input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward(grads[0].is_contiguous() ? grads[0] : grads[0].contiguous(), input.is_contiguous() ? input : input.contiguous(), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>())"
result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group)
result1: group_norm_mean_jvp(input_t, result1, group)
result2: group_norm_invstd_jvp(input_p, input_t, result1, result2, group)

- name: ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
self: zeros_like(self)
Expand Down Expand Up @@ -2381,6 +2386,7 @@
# NB2: The quotes around the gradient are needed to appease YAML parsing rules.
- name: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor)
input, weight, bias: "grad.defined() ? (training ? cudnn_batch_norm_backward(input, grad.contiguous(input.suggest_memory_format()), weight, running_mean, running_var, result1, result2, epsilon, retain_variables ? result3.clone() : result3) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple<Tensor, Tensor, Tensor>()"
result0: batch_norm_jvp_saved_var(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon)

# HACK: save_mean and save_var are going to be passed in as
# requires_grad variables (even though we'll never backprop through
Expand Down Expand Up @@ -2428,6 +2434,7 @@

- name: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor)
input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple<Tensor, Tensor, Tensor>()"
result0: batch_norm_jvp_saved_var(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon)

- name: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor)
save_mean: not_implemented("miopen_batch_norm_backward save_mean")
Expand Down
223 changes: 223 additions & 0 deletions torch/csrc/autograd/FunctionsManual.cpp
Expand Up @@ -4569,6 +4569,229 @@ Tensor cumprod_jvp(Tensor self_t, Tensor self_p, Tensor result, int dim) {
}
}

// Helper for {batch,layer,group}_norms below
// Computes the jvp for `1 / input.std(dims, keepdim)`
static Tensor _invstd_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& mean_p, const Tensor& invstd_p,
IntArrayRef dims, int64_t numel, bool keepdim) {
Tensor invstd_t;
if (areAnyTensorSubclassLike({input_t, input_p, mean_p, invstd_p}) || input_t._is_zerotensor()) {
invstd_t = -invstd_p.pow(3) * (input_t - input_t.mean(dims, true)) * (input_p - mean_p);
} else {
invstd_t = input_t - input_t.mean(dims, true);
invstd_t *= input_p - mean_p;
invstd_t *= -invstd_p.pow(3);
}
invstd_t = invstd_t.sum(dims, keepdim);
invstd_t /= numel;
return invstd_t;
}

// Helper for {batch,layer,group}_norms below only
// Computes the jvp for `(input - input.mean(dims)) * input.invstd(dims)`
static Tensor _norm_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& mean_p, const Tensor& invstd_p,
IntArrayRef dims, int64_t numel) {
auto invstd_t = _invstd_jvp(input_p, input_t, mean_p, invstd_p, dims, numel, true);
Tensor result_t;
if (areAnyTensorSubclassLike({input_t, input_p, mean_p, invstd_p}) || input_t._is_zerotensor()) {
result_t = (input_t - input_t.mean(dims, true)) * invstd_p + (input_p - mean_p) * invstd_t;
} else {
result_t = input_t - input_t.mean(dims, true);
result_t *= invstd_p;
auto temp = input_p - mean_p;
temp *= invstd_t;
result_t += temp;
}
return result_t;
}

// Helper for {batch,layer,group}_norms below only
// Computes the jvp for `input * weight + bias` where weight and bias may be undefined
// Possibly modifies the input inplace
static Tensor _affine_jvp(
const c10::optional<Tensor>& input_p, Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_t) {
// We allow input_p to be optional because if weight_p isn't defined,
// it may be possible to avoid computing input_p
TORCH_INTERNAL_ASSERT(input_p.has_value() == weight_p.defined());
if (weight_p.defined()) {
if (areAnyTensorSubclassLike({input_p.value(), input_t, weight_p, weight_t}) || input_t._is_zerotensor() || weight_t._is_zerotensor()) {
input_t = input_t * weight_p + input_p.value() * weight_t;
} else {
input_t *= weight_p;
auto temp = input_p.value();
temp *= weight_t;
input_t += temp;
}
}
if (bias_t.defined()) {
if (areAnyTensorSubclassLike({input_t, bias_t}) || input_t._is_zerotensor()) {
input_t = input_t + bias_t;
} else {
input_t += bias_t;
}
}
return input_t;
}

Tensor batch_norm_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const c10::optional<Tensor>& running_mean,
const c10::optional<Tensor>& running_var,
const Tensor& saved_mean, const Tensor& saved_invstd,
bool train,
double eps) {
auto dims = std::vector<int64_t>{};
auto view_size = input_t.sizes().vec();
int64_t numel = 1;
for (const auto dim : c10::irange(view_size.size())) {
if (dim != 1) {
numel *= input_t.size(dim);
view_size[dim] = 1;
dims.push_back(dim);
}
}
Tensor mean_p;
Tensor invstd_p;
Tensor result_t;
if (train) {
mean_p = saved_mean.view(view_size);
invstd_p = saved_invstd.view(view_size);
result_t = _norm_jvp(input_p, input_t, mean_p, invstd_p, dims, numel);
} else {
TORCH_INTERNAL_ASSERT(
running_mean.has_value() && running_var.has_value(),
"Expect running_mean and running_var to have value when train=false");
mean_p = running_mean.value().view(view_size);
invstd_p = (1 / at::sqrt(running_var.value() + at::Scalar(eps))).view(view_size);
result_t = input_t * invstd_p;
}

c10::optional<Tensor> result_p = weight_p.defined()
? c10::optional<Tensor>((input_p - mean_p) * invstd_p) : c10::nullopt;
return _affine_jvp(
result_p, result_t,
weight_p.defined() ? weight_p.view(view_size) : weight_p,
weight_t.defined() ? weight_t.view(view_size) : weight_t,
bias_t.defined() ? bias_t.view(view_size) : bias_t);
}

Tensor batch_norm_jvp_saved_var(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const c10::optional<Tensor>& running_mean,
const c10::optional<Tensor>& running_var,
const Tensor& saved_mean, const Tensor& saved_var,
bool train,
double eps) {
auto saved_invstd = (1 / at::sqrt(saved_var + at::Scalar(eps)));
return batch_norm_jvp(
input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var,
saved_mean, saved_invstd, train, eps);
}

Tensor layer_norm_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const Tensor& saved_mean, const Tensor& saved_invstd,
IntArrayRef normalized_shape) {
auto dims = std::vector<int64_t>{};
auto view_size = input_t.sizes().vec();
auto view_size_affine = input_t.sizes().vec();

int64_t numel = 1;
for (const auto i : c10::irange(view_size.size())) {
if (i < view_size.size() - normalized_shape.size()) {
view_size_affine[i] = 1;
} else {
numel *= input_t.size(i);
view_size[i] = 1;
dims.push_back(i);
}
}
auto mean_p = saved_mean.view(view_size);
auto invstd_p = saved_invstd.view(view_size);
auto result_t = _norm_jvp(input_p, input_t, mean_p, invstd_p, dims, numel);

c10::optional<Tensor> result_p = weight_p.defined()
? c10::optional<Tensor>((input_p - mean_p) * invstd_p) : c10::nullopt;
return _affine_jvp(
result_p, result_t,
weight_p.defined() ? weight_p.view(view_size_affine) : weight_p,
weight_t.defined() ? weight_t.view(view_size_affine) : weight_t,
bias_t.defined() ? bias_t.view(view_size_affine) : bias_t);
}

Tensor group_norm_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const Tensor& saved_mean, const Tensor& saved_invstd,
int64_t groups) {
auto input_shape = input_p.sizes();
int64_t N = input_p.size(0);
int64_t C = input_p.size(1);

auto input_t_reshaped = input_t.view({1, N * groups, N ? -1 : 1});
auto input_p_reshaped = input_p.view({1, N * groups, N ? -1 : 1});

auto result_t = batch_norm_jvp(
input_p_reshaped, input_t_reshaped,
/*weight_p=*/{}, /*weight_t=*/{},
/*bias_p=*/{}, /*bias_t=*/{},
/*running_mean=*/{}, /*running_var=*/{},
saved_mean, saved_invstd, /*train=*/true, /*eps=*/0).view(input_shape);

c10::optional<Tensor> result_p = c10::nullopt;
if (weight_p.defined()) {
std::vector<int64_t> view_size(input_t_reshaped.dim(), 1);
view_size[1] = input_t_reshaped.size(1);
result_p = ((input_p_reshaped - saved_mean.view(view_size)) * saved_invstd.view(view_size)).view(input_shape);
}
std::vector<int64_t> affine_param_shape(input_p.dim(), 1);
affine_param_shape[1] = C;

return _affine_jvp(
result_p, result_t,
weight_p.defined() ? weight_p.view(affine_param_shape) : weight_p,
weight_t.defined() ? weight_t.view(affine_param_shape) : weight_t,
bias_t.defined() ? bias_t.view(affine_param_shape) : bias_t);
}

Tensor group_norm_mean_jvp(
const Tensor& input_t, const Tensor& mean_p, int64_t groups) {
int64_t N = input_t.size(0);
int64_t C = input_t.size(1);
std::array<int64_t, 3> view_shape = {1, N * groups, N ? -1 : 1};
auto input_t_reshaped = input_t.view(view_shape);
return input_t_reshaped.mean({2}, false).view_as(mean_p);
}

Tensor group_norm_invstd_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& mean_p, const Tensor& invstd_p,
int64_t groups) {
int64_t N = input_p.size(0);
int64_t C = input_p.size(1);

std::vector<int64_t> view_shape = {1, N * groups, N ? -1 : 1};

auto input_t_reshaped = input_t.view(view_shape);
auto input_p_reshaped = input_p.view(view_shape);

return _invstd_jvp(
input_t_reshaped, input_p_reshaped, mean_p.view(view_shape), invstd_p.view(view_shape),
/*dims=*/{2}, /*numel=*/input_t_reshaped.size(2), /*keepdim=*/false).view_as(invstd_p);
}

Tensor gather_with_keepdimed_indices(const Tensor& input, int64_t dim, const Tensor& indices, bool keepdim) {
auto full_indices = indices;
if (!keepdim) {
Expand Down
47 changes: 47 additions & 0 deletions torch/csrc/autograd/FunctionsManual.h
Expand Up @@ -380,6 +380,53 @@ Tensor lu_factor_ex_jvp(
const Tensor& pivs
);

Tensor batch_norm_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const c10::optional<Tensor>& running_mean,
const c10::optional<Tensor>& running_var,
const Tensor& saved_mean, const Tensor& saved_invstd,
bool train,
double eps
);

Tensor batch_norm_jvp_saved_var(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const c10::optional<Tensor>& running_mean,
const c10::optional<Tensor>& running_var,
const Tensor& saved_mean, const Tensor& saved_var,
bool train,
double eps
);

Tensor layer_norm_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const Tensor& saved_mean, const Tensor& saved_invstd,
IntArrayRef normalized_shape
);

Tensor group_norm_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const Tensor& saved_mean, const Tensor& saved_invstd,
int64_t groups
);
Tensor group_norm_mean_jvp(
const Tensor& input_t,
const Tensor& mean_p,
int64_t groups
);
Tensor group_norm_invstd_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& mean_p, const Tensor& invstd_p,
int64_t groups
);

Tensor convolution_jvp(
const Tensor& input_p, const Tensor& input_t,
Expand Down
6 changes: 6 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -10849,6 +10849,8 @@ def ref_pairwise_distance(input1, input2):
dtypes=floating_types(),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=[
# RuntimeError: Cannot insert a Tensor that requires grad as a constant.
# Consider making it a parameter or input, or detaching the gradient
Expand All @@ -10860,6 +10862,7 @@ def ref_pairwise_distance(input1, input2):
dtypes=floating_types(),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
decorators=[
# RuntimeError: Cannot insert a Tensor that requires grad as a constant.
# Consider making it a parameter or input, or detaching the gradient
Expand All @@ -10873,6 +10876,7 @@ def ref_pairwise_distance(input1, input2):
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
decorators=[
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-03)}),
Expand Down Expand Up @@ -11539,6 +11543,7 @@ def ref_pairwise_distance(input1, input2):
dtypes=floating_types(),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
sample_inputs_func=sample_inputs_batch_norm),
# This variant tests batch_norm with cuDNN disabled only on CUDA devices
OpInfo('nn.functional.batch_norm',
Expand All @@ -11547,6 +11552,7 @@ def ref_pairwise_distance(input1, input2):
dtypes=empty_types(),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
decorators=[onlyCUDA, disablecuDNN],
sample_inputs_func=sample_inputs_batch_norm),
# We have to add 2 OpInfo entry for `igamma` and `igammac`.First is the
Expand Down

0 comments on commit 78994d1

Please sign in to comment.