From ac1560ee60e6d7a01f7e9b6d872a1fdc0be7f403 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Tue, 28 Dec 2021 14:42:24 -0500 Subject: [PATCH] Add forward AD formulas for some {batch,layer,group}_norm ghstack-source-id: 58ba0625a82db546ff3879e388f7d5fe6c296075 Pull Request resolved: https://github.com/pytorch/pytorch/pull/70355 --- tools/autograd/derivatives.yaml | 6 + torch/csrc/autograd/FunctionsManual.cpp | 141 ++++++++++++++++++ torch/csrc/autograd/FunctionsManual.h | 37 +++++ .../_internal/common_methods_invocations.py | 5 + 4 files changed, 189 insertions(+) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index ba64fd5a7c5e9e5..e8def84d086df95 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1083,6 +1083,7 @@ - name: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, eps) - name: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor) input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, train, eps, save_mean, save_invstd, grad_input_mask) @@ -1091,6 +1092,7 @@ - name: native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor) input, weight, bias: "grad.defined() ? native_layer_norm_backward(grad, input, normalized_shape, result1, result2, weight, bias, grad_input_mask) : std::tuple()" + result0: layer_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, normalized_shape) - name: native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor) input, weight, grad_out: layer_norm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, mean, rstd, normalized_shape, grad_input_mask) @@ -1100,6 +1102,8 @@ - name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, int N, int C, int HxW, int group, float eps) -> (Tensor, Tensor, Tensor) input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward(grads[0].is_contiguous() ? grads[0] : grads[0].contiguous(), input.is_contiguous() ? input : input.contiguous(), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple())" + result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group) + output_differentiability: [true, false, false] - name: ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) self: zeros_like(self) @@ -2387,6 +2391,7 @@ # NB2: The quotes around the gradient are needed to appease YAML parsing rules. - name: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor) input, weight, bias: "grad.defined() ? (training ? cudnn_batch_norm_backward(input, grad.contiguous(input.suggest_memory_format()), weight, running_mean, running_var, result1, result2, epsilon, retain_variables ? result3.clone() : result3) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple()" + result0: batch_norm_jvp_saved_var(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon) # HACK: save_mean and save_var are going to be passed in as # requires_grad variables (even though we'll never backprop through @@ -2434,6 +2439,7 @@ - name: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor) input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple()" + result0: batch_norm_jvp_saved_var(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon) - name: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor) save_mean: not_implemented("miopen_batch_norm_backward save_mean") diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 3048d016b5ee37a..2729c727a0e4f29 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -4569,6 +4569,147 @@ Tensor cumprod_jvp(Tensor self_t, Tensor self_p, Tensor result, int dim) { } } +Tensor batch_norm_jvp( + const Tensor& input_p, const Tensor& input_t, + const Tensor& weight_p, const Tensor& weight_t, + const Tensor& bias_p, const Tensor& bias_t, + const c10::optional& running_mean, + const c10::optional& running_var, + const Tensor& saved_mean, const Tensor& saved_invstd, + bool train, + double eps) { + auto dims = std::vector{}; + auto view_size = input_t.sizes().vec(); + int64_t numel = 1; + for (const auto dim : c10::irange(view_size.size())) { + if (dim != 1) { + numel *= input_t.size(dim); + view_size[dim] = 1; + dims.push_back(dim); + } + } + Tensor mean_p; + Tensor invstd_p; + Tensor result_t; + if (train) { + mean_p = saved_mean.view(view_size); + invstd_p = saved_invstd.view(view_size); + auto invstd_t = -invstd_p.pow(3) * ((input_t - input_t.mean(dims, true)) * (input_p - mean_p)).sum(dims, true) / numel; + result_t = (input_t - input_t.mean(dims, true)) * invstd_p + (input_p - mean_p) * invstd_t; + } else { + TORCH_INTERNAL_ASSERT( + running_mean.has_value() && running_var.has_value(), + "Expect running_mean and running_var to have value when train=true"); + + mean_p = running_mean.value().view(view_size); + invstd_p = (1 / at::sqrt(running_var.value() + at::Scalar(eps))).view(view_size); + result_t = input_t * invstd_p; + } + + if (weight_p.defined()) { + auto result_p = (input_p - mean_p) * invstd_p; + result_t = result_t * weight_p.view(view_size) + result_p * weight_t.view(view_size); + } + + if (bias_p.defined()) { + result_t = result_t + bias_t.view(view_size); + } + + return result_t; +} + +Tensor batch_norm_jvp_saved_var( + const Tensor& input_p, const Tensor& input_t, + const Tensor& weight_p, const Tensor& weight_t, + const Tensor& bias_p, const Tensor& bias_t, + const c10::optional& running_mean, + const c10::optional& running_var, + const Tensor& saved_mean, const Tensor& saved_var, + bool train, + double eps) { + auto saved_invstd = (1 / at::sqrt(saved_var + at::Scalar(eps))); + return batch_norm_jvp( + input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, + saved_mean, saved_invstd, train, eps); +} + +Tensor layer_norm_jvp( + const Tensor& input_p, const Tensor& input_t, + const Tensor& weight_p, const Tensor& weight_t, + const Tensor& bias_p, const Tensor& bias_t, + const Tensor& saved_mean, const Tensor& saved_invstd, + IntArrayRef normalized_shape) { + auto dims = std::vector(normalized_shape.size()); + auto view_size = input_t.sizes().vec(); + auto view_size_affine = input_t.sizes().vec(); + + int64_t numel = 1; + for (const auto i : c10::irange(view_size.size())) { + if (i < view_size.size() - normalized_shape.size()) { + view_size_affine[i] = 1; + } else { + numel *= input_t.size(i); + view_size[i] = 1; + dims[i - view_size.size() + normalized_shape.size()] = i; + } + } + auto mean_p = saved_mean.view(view_size); + auto invstd_p = saved_invstd.view(view_size); + auto invstd_t = -invstd_p.pow(3) * ((input_t - input_t.mean(dims, true)) * (input_p - mean_p)).sum(dims, true) / numel; + auto result_t = (input_t - input_t.mean(dims, true)) * invstd_p + (input_p - mean_p) * invstd_t; + + if (weight_p.defined()) { + auto result_p = (input_p - mean_p) * invstd_p; + result_t = result_t * weight_p.view(view_size_affine) + result_p * weight_t.view(view_size_affine); + } + + if (bias_p.defined()) { + result_t = result_t + bias_t.view(view_size_affine); + } + return result_t; +} + +Tensor group_norm_jvp( + const Tensor& input_p, const Tensor& input_t, + const Tensor& weight_p, const Tensor& weight_t, + const Tensor& bias_p, const Tensor& bias_t, + const Tensor& saved_mean, const Tensor& saved_invstd, + int64_t groups) { + auto input_shape = input_p.sizes(); + + int64_t N = input_p.size(0); + int64_t C = input_p.size(1); + + auto input_t_reshaped = input_t.view({1, N * groups, N ? -1 : 1}); + auto input_p_reshaped = input_p.view({1, N * groups, N ? -1 : 1}); + + auto result_t = batch_norm_jvp( + input_p_reshaped, input_t_reshaped, + /*weight_p=*/{}, /*weight_t=*/{}, + /*bias_p=*/{}, /*bias_t=*/{}, + /*running_mean=*/{}, /*running_var=*/{}, + saved_mean, saved_invstd, /*train=*/true, /*eps=*/0); + + result_t = result_t.view(input_shape); + + std::vector affine_param_shape(input_p.dim(), 1); + affine_param_shape[1] = C; + + if (weight_p.defined()) { + std::vector view_size(input_t_reshaped.dim(), 1); + view_size[1] = input_t_reshaped.size(1); + auto result_p = (input_p_reshaped - saved_mean.view(view_size)) * saved_invstd.view(view_size); + result_p = result_p.view(input_shape); + result_t = result_t * weight_p.view(affine_param_shape) + result_p * weight_t.view(affine_param_shape); + } + + if (bias_p.defined()) { + result_t = result_t + bias_t.view(affine_param_shape); + } + + return result_t; +} + Tensor gather_with_keepdimed_indices(const Tensor& input, int64_t dim, const Tensor& indices, bool keepdim) { auto full_indices = indices; if (!keepdim) { diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index e1423e2cf718885..08d24e2b0763e98 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -380,6 +380,43 @@ Tensor _lu_with_info_jvp( const Tensor& pivs ); +Tensor batch_norm_jvp( + const Tensor& input_p, const Tensor& input_t, + const Tensor& weight_p, const Tensor& weight_t, + const Tensor& bias_p, const Tensor& bias_t, + const c10::optional& running_mean, + const c10::optional& running_var, + const Tensor& saved_mean, const Tensor& saved_invstd, + bool train, + double eps +); + +Tensor batch_norm_jvp_saved_var( + const Tensor& input_p, const Tensor& input_t, + const Tensor& weight_p, const Tensor& weight_t, + const Tensor& bias_p, const Tensor& bias_t, + const c10::optional& running_mean, + const c10::optional& running_var, + const Tensor& saved_mean, const Tensor& saved_var, + bool train, + double eps +); + +Tensor layer_norm_jvp( + const Tensor& input_p, const Tensor& input_t, + const Tensor& weight_p, const Tensor& weight_t, + const Tensor& bias_p, const Tensor& bias_t, + const Tensor& saved_mean, const Tensor& saved_invstd, + IntArrayRef normalized_shape +); + +Tensor group_norm_jvp( + const Tensor& input_p, const Tensor& input_t, + const Tensor& weight_p, const Tensor& weight_t, + const Tensor& bias_p, const Tensor& bias_t, + const Tensor& saved_mean, const Tensor& saved_invstd, + int64_t groups +); Tensor convolution_jvp( const Tensor& input_p, const Tensor& input_t, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 53235c723ff9389..53ab537d93b90f1 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -10919,6 +10919,7 @@ def ref_pairwise_distance(input1, input2): dtypes=floating_types(), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_out=False, + supports_forward_ad=True, decorators=[ # RuntimeError: Cannot insert a Tensor that requires grad as a constant. # Consider making it a parameter or input, or detaching the gradient @@ -10930,6 +10931,7 @@ def ref_pairwise_distance(input1, input2): dtypes=floating_types(), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_out=False, + supports_forward_ad=True, decorators=[ # RuntimeError: Cannot insert a Tensor that requires grad as a constant. # Consider making it a parameter or input, or detaching the gradient @@ -10943,6 +10945,7 @@ def ref_pairwise_distance(input1, input2): dtypes=floating_types_and(torch.bfloat16), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_out=False, + supports_forward_ad=True, decorators=[ DecorateInfo( toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-03)}), @@ -11609,6 +11612,7 @@ def ref_pairwise_distance(input1, input2): dtypes=floating_types(), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_out=False, + supports_forward_ad=True, sample_inputs_func=sample_inputs_batch_norm), # This variant tests batch_norm with cuDNN disabled only on CUDA devices OpInfo('nn.functional.batch_norm', @@ -11617,6 +11621,7 @@ def ref_pairwise_distance(input1, input2): dtypes=empty_types(), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_out=False, + supports_forward_ad=True, decorators=[onlyCUDA, disablecuDNN], sample_inputs_func=sample_inputs_batch_norm), # We have to add 2 OpInfo entry for `igamma` and `igammac`.First is the