diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 11b879a28353031..532536bfe22f8cf 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1084,6 +1084,7 @@ - name: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, eps) - name: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor) input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, train, eps, save_mean, save_invstd, grad_input_mask) @@ -1092,6 +1093,7 @@ - name: native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor) input, weight, bias: "grad.defined() ? native_layer_norm_backward(grad, input, normalized_shape, result1, result2, weight, bias, grad_input_mask) : std::tuple()" + result0: layer_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, normalized_shape) - name: native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor) input, weight, grad_out: layer_norm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, mean, rstd, normalized_shape, grad_input_mask) @@ -1101,6 +1103,9 @@ - name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, int N, int C, int HxW, int group, float eps) -> (Tensor, Tensor, Tensor) input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward(grads[0].is_contiguous() ? grads[0] : grads[0].contiguous(), input.is_contiguous() ? input : input.contiguous(), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple())" + result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group) + result1: group_norm_mean_jvp(input_t, result1, group) + result2: group_norm_invstd_jvp(input_p, input_t, result1, result2, group) - name: ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) self: zeros_like(self) @@ -2388,6 +2393,7 @@ # NB2: The quotes around the gradient are needed to appease YAML parsing rules. - name: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor) input, weight, bias: "grad.defined() ? (training ? cudnn_batch_norm_backward(input, grad.contiguous(input.suggest_memory_format()), weight, running_mean, running_var, result1, result2, epsilon, retain_variables ? result3.clone() : result3) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple()" + result0: batch_norm_jvp_saved_var(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon) # HACK: save_mean and save_var are going to be passed in as # requires_grad variables (even though we'll never backprop through @@ -2435,6 +2441,7 @@ - name: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor) input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple()" + result0: batch_norm_jvp_saved_var(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon) - name: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor) save_mean: not_implemented("miopen_batch_norm_backward save_mean") diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 42f16cd78ea7568..9d266bc0113e487 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -4569,6 +4569,229 @@ Tensor cumprod_jvp(Tensor self_t, Tensor self_p, Tensor result, int dim) { } } +// Helper for {batch,layer,group}_norms below +// Computes the jvp for `1 / input.std(dims, keepdim)` +static Tensor _invstd_jvp( + const Tensor& input_p, const Tensor& input_t, + const Tensor& mean_p, const Tensor& invstd_p, + IntArrayRef dims, int64_t numel, bool keepdim) { + Tensor invstd_t; + if (areAnyTensorSubclassLike({input_t, input_p, mean_p, invstd_p}) || input_t._is_zerotensor()) { + invstd_t = -invstd_p.pow(3) * (input_t - input_t.mean(dims, true)) * (input_p - mean_p); + } else { + invstd_t = input_t - input_t.mean(dims, true); + invstd_t *= input_p - mean_p; + invstd_t *= -invstd_p.pow(3); + } + invstd_t = invstd_t.sum(dims, keepdim); + invstd_t /= numel; + return invstd_t; +} + +// Helper for {batch,layer,group}_norms below only +// Computes the jvp for `(input - input.mean(dims)) * input.invstd(dims)` +static Tensor _norm_jvp( + const Tensor& input_p, const Tensor& input_t, + const Tensor& mean_p, const Tensor& invstd_p, + IntArrayRef dims, int64_t numel) { + auto invstd_t = _invstd_jvp(input_p, input_t, mean_p, invstd_p, dims, numel, true); + Tensor result_t; + if (areAnyTensorSubclassLike({input_t, input_p, mean_p, invstd_p}) || input_t._is_zerotensor()) { + result_t = (input_t - input_t.mean(dims, true)) * invstd_p + (input_p - mean_p) * invstd_t; + } else { + result_t = input_t - input_t.mean(dims, true); + result_t *= invstd_p; + auto temp = input_p - mean_p; + temp *= invstd_t; + result_t += temp; + } + return result_t; +} + +// Helper for {batch,layer,group}_norms below only +// Computes the jvp for `input * weight + bias` where weight and bias may be undefined +// Possibly modifies the input inplace +static Tensor _affine_jvp( + const c10::optional& input_p, Tensor& input_t, + const Tensor& weight_p, const Tensor& weight_t, + const Tensor& bias_t) { + // We allow input_p to be optional because if weight_p isn't defined, + // it may be possible to avoid computing input_p + TORCH_INTERNAL_ASSERT(input_p.has_value() == weight_p.defined()); + if (weight_p.defined()) { + if (areAnyTensorSubclassLike({input_p.value(), input_t, weight_p, weight_t}) || input_t._is_zerotensor() || weight_t._is_zerotensor()) { + input_t = input_t * weight_p + input_p.value() * weight_t; + } else { + input_t *= weight_p; + auto temp = input_p.value(); + temp *= weight_t; + input_t += temp; + } + } + if (bias_t.defined()) { + if (areAnyTensorSubclassLike({input_t, bias_t}) || input_t._is_zerotensor()) { + input_t = input_t + bias_t; + } else { + input_t += bias_t; + } + } + return input_t; +} + +Tensor batch_norm_jvp( + const Tensor& input_p, const Tensor& input_t, + const Tensor& weight_p, const Tensor& weight_t, + const Tensor& bias_p, const Tensor& bias_t, + const c10::optional& running_mean, + const c10::optional& running_var, + const Tensor& saved_mean, const Tensor& saved_invstd, + bool train, + double eps) { + auto dims = std::vector{}; + auto view_size = input_t.sizes().vec(); + int64_t numel = 1; + for (const auto dim : c10::irange(view_size.size())) { + if (dim != 1) { + numel *= input_t.size(dim); + view_size[dim] = 1; + dims.push_back(dim); + } + } + Tensor mean_p; + Tensor invstd_p; + Tensor result_t; + if (train) { + mean_p = saved_mean.view(view_size); + invstd_p = saved_invstd.view(view_size); + result_t = _norm_jvp(input_p, input_t, mean_p, invstd_p, dims, numel); + } else { + TORCH_INTERNAL_ASSERT( + running_mean.has_value() && running_var.has_value(), + "Expect running_mean and running_var to have value when train=false"); + mean_p = running_mean.value().view(view_size); + invstd_p = (1 / at::sqrt(running_var.value() + at::Scalar(eps))).view(view_size); + result_t = input_t * invstd_p; + } + + c10::optional result_p = weight_p.defined() + ? c10::optional((input_p - mean_p) * invstd_p) : c10::nullopt; + return _affine_jvp( + result_p, result_t, + weight_p.defined() ? weight_p.view(view_size) : weight_p, + weight_t.defined() ? weight_t.view(view_size) : weight_t, + bias_t.defined() ? bias_t.view(view_size) : bias_t); +} + +Tensor batch_norm_jvp_saved_var( + const Tensor& input_p, const Tensor& input_t, + const Tensor& weight_p, const Tensor& weight_t, + const Tensor& bias_p, const Tensor& bias_t, + const c10::optional& running_mean, + const c10::optional& running_var, + const Tensor& saved_mean, const Tensor& saved_var, + bool train, + double eps) { + auto saved_invstd = (1 / at::sqrt(saved_var + at::Scalar(eps))); + return batch_norm_jvp( + input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, + saved_mean, saved_invstd, train, eps); +} + +Tensor layer_norm_jvp( + const Tensor& input_p, const Tensor& input_t, + const Tensor& weight_p, const Tensor& weight_t, + const Tensor& bias_p, const Tensor& bias_t, + const Tensor& saved_mean, const Tensor& saved_invstd, + IntArrayRef normalized_shape) { + auto dims = std::vector{}; + auto view_size = input_t.sizes().vec(); + auto view_size_affine = input_t.sizes().vec(); + + int64_t numel = 1; + for (const auto i : c10::irange(view_size.size())) { + if (i < view_size.size() - normalized_shape.size()) { + view_size_affine[i] = 1; + } else { + numel *= input_t.size(i); + view_size[i] = 1; + dims.push_back(i); + } + } + auto mean_p = saved_mean.view(view_size); + auto invstd_p = saved_invstd.view(view_size); + auto result_t = _norm_jvp(input_p, input_t, mean_p, invstd_p, dims, numel); + + c10::optional result_p = weight_p.defined() + ? c10::optional((input_p - mean_p) * invstd_p) : c10::nullopt; + return _affine_jvp( + result_p, result_t, + weight_p.defined() ? weight_p.view(view_size_affine) : weight_p, + weight_t.defined() ? weight_t.view(view_size_affine) : weight_t, + bias_t.defined() ? bias_t.view(view_size_affine) : bias_t); +} + +Tensor group_norm_jvp( + const Tensor& input_p, const Tensor& input_t, + const Tensor& weight_p, const Tensor& weight_t, + const Tensor& bias_p, const Tensor& bias_t, + const Tensor& saved_mean, const Tensor& saved_invstd, + int64_t groups) { + auto input_shape = input_p.sizes(); + int64_t N = input_p.size(0); + int64_t C = input_p.size(1); + + auto input_t_reshaped = input_t.view({1, N * groups, N ? -1 : 1}); + auto input_p_reshaped = input_p.view({1, N * groups, N ? -1 : 1}); + + auto result_t = batch_norm_jvp( + input_p_reshaped, input_t_reshaped, + /*weight_p=*/{}, /*weight_t=*/{}, + /*bias_p=*/{}, /*bias_t=*/{}, + /*running_mean=*/{}, /*running_var=*/{}, + saved_mean, saved_invstd, /*train=*/true, /*eps=*/0).view(input_shape); + + c10::optional result_p = c10::nullopt; + if (weight_p.defined()) { + std::vector view_size(input_t_reshaped.dim(), 1); + view_size[1] = input_t_reshaped.size(1); + result_p = ((input_p_reshaped - saved_mean.view(view_size)) * saved_invstd.view(view_size)).view(input_shape); + } + std::vector affine_param_shape(input_p.dim(), 1); + affine_param_shape[1] = C; + + return _affine_jvp( + result_p, result_t, + weight_p.defined() ? weight_p.view(affine_param_shape) : weight_p, + weight_t.defined() ? weight_t.view(affine_param_shape) : weight_t, + bias_t.defined() ? bias_t.view(affine_param_shape) : bias_t); +} + +Tensor group_norm_mean_jvp( + const Tensor& input_t, const Tensor& mean_p, int64_t groups) { + int64_t N = input_t.size(0); + int64_t C = input_t.size(1); + std::array view_shape = {1, N * groups, N ? -1 : 1}; + auto input_t_reshaped = input_t.view(view_shape); + return input_t_reshaped.mean({2}, false).view_as(mean_p); +} + +Tensor group_norm_invstd_jvp( + const Tensor& input_p, const Tensor& input_t, + const Tensor& mean_p, const Tensor& invstd_p, + int64_t groups) { + int64_t N = input_p.size(0); + int64_t C = input_p.size(1); + + std::vector view_shape = {1, N * groups, N ? -1 : 1}; + + auto input_t_reshaped = input_t.view(view_shape); + auto input_p_reshaped = input_p.view(view_shape); + + return _invstd_jvp( + input_t_reshaped, input_p_reshaped, mean_p.view(view_shape), invstd_p.view(view_shape), + /*dims=*/{2}, /*numel=*/input_t_reshaped.size(2), /*keepdim=*/false).view_as(invstd_p); +} + Tensor gather_with_keepdimed_indices(const Tensor& input, int64_t dim, const Tensor& indices, bool keepdim) { auto full_indices = indices; if (!keepdim) { diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index b3853104144c22f..818e5c9f6bef09e 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -380,6 +380,53 @@ Tensor lu_factor_ex_jvp( const Tensor& pivs ); +Tensor batch_norm_jvp( + const Tensor& input_p, const Tensor& input_t, + const Tensor& weight_p, const Tensor& weight_t, + const Tensor& bias_p, const Tensor& bias_t, + const c10::optional& running_mean, + const c10::optional& running_var, + const Tensor& saved_mean, const Tensor& saved_invstd, + bool train, + double eps +); + +Tensor batch_norm_jvp_saved_var( + const Tensor& input_p, const Tensor& input_t, + const Tensor& weight_p, const Tensor& weight_t, + const Tensor& bias_p, const Tensor& bias_t, + const c10::optional& running_mean, + const c10::optional& running_var, + const Tensor& saved_mean, const Tensor& saved_var, + bool train, + double eps +); + +Tensor layer_norm_jvp( + const Tensor& input_p, const Tensor& input_t, + const Tensor& weight_p, const Tensor& weight_t, + const Tensor& bias_p, const Tensor& bias_t, + const Tensor& saved_mean, const Tensor& saved_invstd, + IntArrayRef normalized_shape +); + +Tensor group_norm_jvp( + const Tensor& input_p, const Tensor& input_t, + const Tensor& weight_p, const Tensor& weight_t, + const Tensor& bias_p, const Tensor& bias_t, + const Tensor& saved_mean, const Tensor& saved_invstd, + int64_t groups +); +Tensor group_norm_mean_jvp( + const Tensor& input_t, + const Tensor& mean_p, + int64_t groups +); +Tensor group_norm_invstd_jvp( + const Tensor& input_p, const Tensor& input_t, + const Tensor& mean_p, const Tensor& invstd_p, + int64_t groups +); Tensor convolution_jvp( const Tensor& input_p, const Tensor& input_t, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index fd14ca0969f8b31..123c0ce6d9a7536 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -10849,6 +10849,8 @@ def ref_pairwise_distance(input1, input2): dtypes=floating_types(), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, decorators=[ # RuntimeError: Cannot insert a Tensor that requires grad as a constant. # Consider making it a parameter or input, or detaching the gradient @@ -10860,6 +10862,7 @@ def ref_pairwise_distance(input1, input2): dtypes=floating_types(), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_out=False, + supports_forward_ad=True, decorators=[ # RuntimeError: Cannot insert a Tensor that requires grad as a constant. # Consider making it a parameter or input, or detaching the gradient @@ -10873,6 +10876,7 @@ def ref_pairwise_distance(input1, input2): dtypes=floating_types_and(torch.bfloat16), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_out=False, + supports_forward_ad=True, decorators=[ DecorateInfo( toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-03)}), @@ -11539,6 +11543,7 @@ def ref_pairwise_distance(input1, input2): dtypes=floating_types(), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_out=False, + supports_forward_ad=True, sample_inputs_func=sample_inputs_batch_norm), # This variant tests batch_norm with cuDNN disabled only on CUDA devices OpInfo('nn.functional.batch_norm', @@ -11547,6 +11552,7 @@ def ref_pairwise_distance(input1, input2): dtypes=empty_types(), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_out=False, + supports_forward_ad=True, decorators=[onlyCUDA, disablecuDNN], sample_inputs_func=sample_inputs_batch_norm), # We have to add 2 OpInfo entry for `igamma` and `igammac`.First is the