Skip to content

Commit

Permalink
Add forward AD formulas for some {batch,layer,group}_norm
Browse files Browse the repository at this point in the history
ghstack-source-id: cfa1e4c7408291495fd9337886c5acce6c938aab
Pull Request resolved: #70355
  • Loading branch information
soulitzer committed Dec 28, 2021
1 parent 2584100 commit d158f47
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tools/autograd/derivatives.yaml
Expand Up @@ -1083,6 +1083,7 @@

- name: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, eps)

- name: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, train, eps, save_mean, save_invstd, grad_input_mask)
Expand All @@ -1091,6 +1092,7 @@

- name: native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
input, weight, bias: "grad.defined() ? native_layer_norm_backward(grad, input, normalized_shape, result1, result2, weight, bias, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
result0: layer_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, normalized_shape)

- name: native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
input, weight, grad_out: layer_norm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, mean, rstd, normalized_shape, grad_input_mask)
Expand All @@ -1100,6 +1102,8 @@

- name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, int N, int C, int HxW, int group, float eps) -> (Tensor, Tensor, Tensor)
input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward(grads[0].is_contiguous() ? grads[0] : grads[0].contiguous(), input.is_contiguous() ? input : input.contiguous(), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>())"
result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group)
output_differentiability: [true, false, false]

- name: ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
self: zeros_like(self)
Expand Down Expand Up @@ -2387,6 +2391,7 @@
# NB2: The quotes around the gradient are needed to appease YAML parsing rules.
- name: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor)
input, weight, bias: "grad.defined() ? (training ? cudnn_batch_norm_backward(input, grad.contiguous(input.suggest_memory_format()), weight, running_mean, running_var, result1, result2, epsilon, retain_variables ? result3.clone() : result3) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple<Tensor, Tensor, Tensor>()"
result0: batch_norm_jvp_saved_var(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon)

# HACK: save_mean and save_var are going to be passed in as
# requires_grad variables (even though we'll never backprop through
Expand Down Expand Up @@ -2434,6 +2439,7 @@

- name: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor)
input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple<Tensor, Tensor, Tensor>()"
result0: batch_norm_jvp_saved_var(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon)

- name: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor)
save_mean: not_implemented("miopen_batch_norm_backward save_mean")
Expand Down
141 changes: 141 additions & 0 deletions torch/csrc/autograd/FunctionsManual.cpp
Expand Up @@ -4569,6 +4569,147 @@ Tensor cumprod_jvp(Tensor self_t, Tensor self_p, Tensor result, int dim) {
}
}

Tensor batch_norm_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const c10::optional<Tensor>& running_mean,
const c10::optional<Tensor>& running_var,
const Tensor& saved_mean, const Tensor& saved_invstd,
bool train,
double eps) {
auto dims = std::vector<int64_t>{};
auto view_size = input_t.sizes().vec();
int64_t numel = 1;
for (const auto dim : c10::irange(view_size.size())) {
if (dim != 1) {
numel *= input_t.size(dim);
view_size[dim] = 1;
dims.push_back(dim);
}
}
Tensor mean_p;
Tensor invstd_p;
Tensor result_t;
if (train) {
mean_p = saved_mean.view(view_size);
invstd_p = saved_invstd.view(view_size);
auto invstd_t = -invstd_p.pow(3) * ((input_t - input_t.mean(dims, true)) * (input_p - mean_p)).sum(dims, true) / numel;
result_t = (input_t - input_t.mean(dims, true)) * invstd_p + (input_p - mean_p) * invstd_t;
} else {
TORCH_INTERNAL_ASSERT(
running_mean.has_value() && running_var.has_value(),
"Expect running_mean and running_var to have value when train=true");

mean_p = running_mean.value().view(view_size);
invstd_p = (1 / at::sqrt(running_var.value() + at::Scalar(eps))).view(view_size);
result_t = input_t * invstd_p;
}

if (weight_p.defined()) {
auto result_p = (input_p - mean_p) * invstd_p;
result_t = result_t * weight_p.view(view_size) + result_p * weight_t.view(view_size);
}

if (bias_p.defined()) {
result_t = result_t + bias_t.view(view_size);
}

return result_t;
}

Tensor batch_norm_jvp_saved_var(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const c10::optional<Tensor>& running_mean,
const c10::optional<Tensor>& running_var,
const Tensor& saved_mean, const Tensor& saved_var,
bool train,
double eps) {
auto saved_invstd = (1 / at::sqrt(saved_var + at::Scalar(eps)));
return batch_norm_jvp(
input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var,
saved_mean, saved_invstd, train, eps);
}

Tensor layer_norm_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const Tensor& saved_mean, const Tensor& saved_invstd,
IntArrayRef normalized_shape) {
auto dims = std::vector<int64_t>(normalized_shape.size());
auto view_size = input_t.sizes().vec();
auto view_size_affine = input_t.sizes().vec();

int64_t numel = 1;
for (const auto i : c10::irange(view_size.size())) {
if (i < view_size.size() - normalized_shape.size()) {
view_size_affine[i] = 1;
} else {
numel *= input_t.size(i);
view_size[i] = 1;
dims[i - view_size.size() + normalized_shape.size()] = i;
}
}
auto mean_p = saved_mean.view(view_size);
auto invstd_p = saved_invstd.view(view_size);
auto invstd_t = -invstd_p.pow(3) * ((input_t - input_t.mean(dims, true)) * (input_p - mean_p)).sum(dims, true) / numel;
auto result_t = (input_t - input_t.mean(dims, true)) * invstd_p + (input_p - mean_p) * invstd_t;

if (weight_p.defined()) {
auto result_p = (input_p - mean_p) * invstd_p;
result_t = result_t * weight_p.view(view_size_affine) + result_p * weight_t.view(view_size_affine);
}

if (bias_p.defined()) {
result_t = result_t + bias_t.view(view_size_affine);
}
return result_t;
}

Tensor group_norm_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const Tensor& saved_mean, const Tensor& saved_invstd,
int64_t groups) {
auto input_shape = input_p.sizes();

int64_t N = input_p.size(0);
int64_t C = input_p.size(1);

auto input_t_reshaped = input_t.view({1, N * groups, N ? -1 : 1});
auto input_p_reshaped = input_p.view({1, N * groups, N ? -1 : 1});

auto result_t = batch_norm_jvp(
input_p_reshaped, input_t_reshaped,
/*weight_p=*/{}, /*weight_t=*/{},
/*bias_p=*/{}, /*bias_t=*/{},
/*running_mean=*/{}, /*running_var=*/{},
saved_mean, saved_invstd, /*train=*/true, /*eps=*/0);

result_t = result_t.view(input_shape);

std::vector<int64_t> affine_param_shape(input_p.dim(), 1);
affine_param_shape[1] = C;

if (weight_p.defined()) {
std::vector<int64_t> view_size(input_t_reshaped.dim(), 1);
view_size[1] = input_t_reshaped.size(1);
auto result_p = (input_p_reshaped - saved_mean.view(view_size)) * saved_invstd.view(view_size);
result_p = result_p.view(input_shape);
result_t = result_t * weight_p.view(affine_param_shape) + result_p * weight_t.view(affine_param_shape);
}

if (bias_p.defined()) {
result_t = result_t + bias_t.view(affine_param_shape);
}

return result_t;
}

Tensor gather_with_keepdimed_indices(const Tensor& input, int64_t dim, const Tensor& indices, bool keepdim) {
auto full_indices = indices;
if (!keepdim) {
Expand Down
37 changes: 37 additions & 0 deletions torch/csrc/autograd/FunctionsManual.h
Expand Up @@ -380,6 +380,43 @@ Tensor _lu_with_info_jvp(
const Tensor& pivs
);

Tensor batch_norm_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const c10::optional<Tensor>& running_mean,
const c10::optional<Tensor>& running_var,
const Tensor& saved_mean, const Tensor& saved_invstd,
bool train,
double eps
);

Tensor batch_norm_jvp_saved_var(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const c10::optional<Tensor>& running_mean,
const c10::optional<Tensor>& running_var,
const Tensor& saved_mean, const Tensor& saved_var,
bool train,
double eps
);

Tensor layer_norm_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const Tensor& saved_mean, const Tensor& saved_invstd,
IntArrayRef normalized_shape
);

Tensor group_norm_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const Tensor& saved_mean, const Tensor& saved_invstd,
int64_t groups
);

Tensor convolution_jvp(
const Tensor& input_p, const Tensor& input_t,
Expand Down
5 changes: 5 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -10919,6 +10919,7 @@ def ref_pairwise_distance(input1, input2):
dtypes=floating_types(),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
decorators=[
# RuntimeError: Cannot insert a Tensor that requires grad as a constant.
# Consider making it a parameter or input, or detaching the gradient
Expand All @@ -10930,6 +10931,7 @@ def ref_pairwise_distance(input1, input2):
dtypes=floating_types(),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
decorators=[
# RuntimeError: Cannot insert a Tensor that requires grad as a constant.
# Consider making it a parameter or input, or detaching the gradient
Expand All @@ -10943,6 +10945,7 @@ def ref_pairwise_distance(input1, input2):
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
decorators=[
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-03)}),
Expand Down Expand Up @@ -11609,6 +11612,7 @@ def ref_pairwise_distance(input1, input2):
dtypes=floating_types(),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
sample_inputs_func=sample_inputs_batch_norm),
# This variant tests batch_norm with cuDNN disabled only on CUDA devices
OpInfo('nn.functional.batch_norm',
Expand All @@ -11617,6 +11621,7 @@ def ref_pairwise_distance(input1, input2):
dtypes=empty_types(),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
decorators=[onlyCUDA, disablecuDNN],
sample_inputs_func=sample_inputs_batch_norm),
# We have to add 2 OpInfo entry for `igamma` and `igammac`.First is the
Expand Down

0 comments on commit d158f47

Please sign in to comment.