Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add forward AD formulas for {batch,layer,group}_norm #70355

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
522f4c7
Add forward AD formulas for some {batch,layer,group}_norm
soulitzer Dec 23, 2021
c74cdc5
Update on "Add forward AD formulas for {batch,layer,group}_norm"
soulitzer Dec 23, 2021
661b83c
Update on "Add forward AD formulas for {batch,layer,group}_norm"
soulitzer Dec 23, 2021
13f3b44
Update on "Add forward AD formulas for {batch,layer,group}_norm"
soulitzer Dec 23, 2021
67f9882
Update on "Add forward AD formulas for {batch,layer,group}_norm"
soulitzer Dec 28, 2021
4c7fa88
Update on "Add forward AD formulas for {batch,layer,group}_norm"
soulitzer Dec 28, 2021
8a5bf1b
Update on "Add forward AD formulas for {batch,layer,group}_norm"
soulitzer Dec 28, 2021
bf6bfaa
Update on "Add forward AD formulas for {batch,layer,group}_norm"
soulitzer Dec 28, 2021
a9e9c90
Update on "Add forward AD formulas for {batch,layer,group}_norm"
soulitzer Dec 28, 2021
f5c7745
Update on "Add forward AD formulas for {batch,layer,group}_norm"
soulitzer Dec 29, 2021
4c9fb4f
Update on "Add forward AD formulas for {batch,layer,group}_norm"
soulitzer Jan 4, 2022
60b99ee
Update on "Add forward AD formulas for {batch,layer,group}_norm"
soulitzer Jan 4, 2022
cbc1e91
Update on "Add forward AD formulas for {batch,layer,group}_norm"
soulitzer Jan 5, 2022
f0f40df
Update on "Add forward AD formulas for {batch,layer,group}_norm"
soulitzer Jan 5, 2022
8a9d060
Update on "Add forward AD formulas for {batch,layer,group}_norm"
soulitzer Jan 6, 2022
5dd950a
Update on "Add forward AD formulas for {batch,layer,group}_norm"
soulitzer Jan 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions tools/autograd/derivatives.yaml
Expand Up @@ -1083,6 +1083,7 @@

- name: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, eps)

- name: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, train, eps, save_mean, save_invstd, grad_input_mask)
Expand All @@ -1091,6 +1092,7 @@

- name: native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
input, weight, bias: "grad.defined() ? native_layer_norm_backward(grad, input, normalized_shape, result1, result2, weight, bias, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
result0: layer_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, normalized_shape)

- name: native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
input, weight, grad_out: layer_norm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, mean, rstd, normalized_shape, grad_input_mask)
Expand All @@ -1100,6 +1102,9 @@

- name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, int N, int C, int HxW, int group, float eps) -> (Tensor, Tensor, Tensor)
input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward(grads[0].is_contiguous() ? grads[0] : grads[0].contiguous(), input.is_contiguous() ? input : input.contiguous(), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>())"
result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group)
result1: group_norm_mean_jvp(input_t, result1, group)
result2: group_norm_invstd_jvp(input_p, input_t, result1, result2, group)

- name: ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
self: zeros_like(self)
Expand Down Expand Up @@ -2387,6 +2392,7 @@
# NB2: The quotes around the gradient are needed to appease YAML parsing rules.
- name: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor)
input, weight, bias: "grad.defined() ? (training ? cudnn_batch_norm_backward(input, grad.contiguous(input.suggest_memory_format()), weight, running_mean, running_var, result1, result2, epsilon, retain_variables ? result3.clone() : result3) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple<Tensor, Tensor, Tensor>()"
result0: batch_norm_jvp_saved_var(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon)

# HACK: save_mean and save_var are going to be passed in as
# requires_grad variables (even though we'll never backprop through
Expand Down Expand Up @@ -2434,6 +2440,7 @@

- name: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor)
input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple<Tensor, Tensor, Tensor>()"
result0: batch_norm_jvp_saved_var(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon)

- name: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor)
save_mean: not_implemented("miopen_batch_norm_backward save_mean")
Expand Down
223 changes: 223 additions & 0 deletions torch/csrc/autograd/FunctionsManual.cpp
Expand Up @@ -4569,6 +4569,229 @@ Tensor cumprod_jvp(Tensor self_t, Tensor self_p, Tensor result, int dim) {
}
}

// Helper for {batch,layer,group}_norms below
// Computes the jvp for `1 / input.std(dims, keepdim)`
static Tensor _invstd_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& mean_p, const Tensor& invstd_p,
IntArrayRef dims, int64_t numel, bool keepdim) {
Tensor invstd_t;
if (areAnyTensorSubclassLike({input_t, input_p, mean_p, invstd_p}) || input_t._is_zerotensor()) {
invstd_t = -invstd_p.pow(3) * (input_t - input_t.mean(dims, true)) * (input_p - mean_p);
} else {
invstd_t = input_t - input_t.mean(dims, true);
invstd_t *= input_p - mean_p;
invstd_t *= -invstd_p.pow(3);
}
invstd_t = invstd_t.sum(dims, keepdim);
invstd_t /= numel;
return invstd_t;
}

// Helper for {batch,layer,group}_norms below only
// Computes the jvp for `(input - input.mean(dims)) * input.invstd(dims)`
static Tensor _norm_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& mean_p, const Tensor& invstd_p,
IntArrayRef dims, int64_t numel) {
auto invstd_t = _invstd_jvp(input_p, input_t, mean_p, invstd_p, dims, numel, true);
Tensor result_t;
if (areAnyTensorSubclassLike({input_t, input_p, mean_p, invstd_p}) || input_t._is_zerotensor()) {
result_t = (input_t - input_t.mean(dims, true)) * invstd_p + (input_p - mean_p) * invstd_t;
} else {
result_t = input_t - input_t.mean(dims, true);
result_t *= invstd_p;
auto temp = input_p - mean_p;
temp *= invstd_t;
result_t += temp;
}
return result_t;
}

// Helper for {batch,layer,group}_norms below only
// Computes the jvp for `input * weight + bias` where weight and bias may be undefined
// Possibly modifies the input inplace
static Tensor _affine_jvp(
const c10::optional<Tensor>& input_p, Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_t) {
// We allow input_p to be optional because if weight_p isn't defined,
// it may be possible to avoid computing input_p
TORCH_INTERNAL_ASSERT(input_p.has_value() == weight_p.defined());
if (weight_p.defined()) {
if (areAnyTensorSubclassLike({input_p.value(), input_t, weight_p, weight_t}) || input_t._is_zerotensor() || weight_t._is_zerotensor()) {
input_t = input_t * weight_p + input_p.value() * weight_t;
} else {
input_t *= weight_p;
auto temp = input_p.value();
temp *= weight_t;
input_t += temp;
}
}
if (bias_t.defined()) {
if (areAnyTensorSubclassLike({input_t, bias_t}) || input_t._is_zerotensor()) {
input_t = input_t + bias_t;
} else {
input_t += bias_t;
}
}
return input_t;
}

Tensor batch_norm_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const c10::optional<Tensor>& running_mean,
const c10::optional<Tensor>& running_var,
const Tensor& saved_mean, const Tensor& saved_invstd,
bool train,
double eps) {
auto dims = std::vector<int64_t>{};
auto view_size = input_t.sizes().vec();
int64_t numel = 1;
for (const auto dim : c10::irange(view_size.size())) {
if (dim != 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batchnorm doesn't support no-batch dim?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at #60585

  • yes, batch norm semantically is incompatible with no-batch dim
  • for group norm it is BC-breaking, so a potential plan seems to be to create a new nn.Module entirely that would handle the no batch dim case
  • layer norm should be fine as long as we handle the case where normalized shape has the same dimension as input. I do see a test case in OpInfo handling this so we should be ok

cc @jbschlosser

numel *= input_t.size(dim);
view_size[dim] = 1;
albanD marked this conversation as resolved.
Show resolved Hide resolved
dims.push_back(dim);
}
}
Tensor mean_p;
Tensor invstd_p;
Tensor result_t;
if (train) {
mean_p = saved_mean.view(view_size);
invstd_p = saved_invstd.view(view_size);
result_t = _norm_jvp(input_p, input_t, mean_p, invstd_p, dims, numel);
} else {
TORCH_INTERNAL_ASSERT(
running_mean.has_value() && running_var.has_value(),
"Expect running_mean and running_var to have value when train=false");
mean_p = running_mean.value().view(view_size);
invstd_p = (1 / at::sqrt(running_var.value() + at::Scalar(eps))).view(view_size);
result_t = input_t * invstd_p;
}

c10::optional<Tensor> result_p = weight_p.defined()
? c10::optional<Tensor>((input_p - mean_p) * invstd_p) : c10::nullopt;
return _affine_jvp(
result_p, result_t,
weight_p.defined() ? weight_p.view(view_size) : weight_p,
weight_t.defined() ? weight_t.view(view_size) : weight_t,
bias_t.defined() ? bias_t.view(view_size) : bias_t);
}

Tensor batch_norm_jvp_saved_var(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const c10::optional<Tensor>& running_mean,
const c10::optional<Tensor>& running_var,
const Tensor& saved_mean, const Tensor& saved_var,
bool train,
double eps) {
auto saved_invstd = (1 / at::sqrt(saved_var + at::Scalar(eps)));
return batch_norm_jvp(
input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var,
saved_mean, saved_invstd, train, eps);
}

Tensor layer_norm_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const Tensor& saved_mean, const Tensor& saved_invstd,
IntArrayRef normalized_shape) {
auto dims = std::vector<int64_t>{};
auto view_size = input_t.sizes().vec();
auto view_size_affine = input_t.sizes().vec();

int64_t numel = 1;
for (const auto i : c10::irange(view_size.size())) {
if (i < view_size.size() - normalized_shape.size()) {
view_size_affine[i] = 1;
} else {
numel *= input_t.size(i);
view_size[i] = 1;
dims.push_back(i);
}
}
auto mean_p = saved_mean.view(view_size);
auto invstd_p = saved_invstd.view(view_size);
auto result_t = _norm_jvp(input_p, input_t, mean_p, invstd_p, dims, numel);

c10::optional<Tensor> result_p = weight_p.defined()
? c10::optional<Tensor>((input_p - mean_p) * invstd_p) : c10::nullopt;
return _affine_jvp(
result_p, result_t,
weight_p.defined() ? weight_p.view(view_size_affine) : weight_p,
weight_t.defined() ? weight_t.view(view_size_affine) : weight_t,
bias_t.defined() ? bias_t.view(view_size_affine) : bias_t);
}

Tensor group_norm_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const Tensor& saved_mean, const Tensor& saved_invstd,
int64_t groups) {
auto input_shape = input_p.sizes();
int64_t N = input_p.size(0);
int64_t C = input_p.size(1);

auto input_t_reshaped = input_t.view({1, N * groups, N ? -1 : 1});
soulitzer marked this conversation as resolved.
Show resolved Hide resolved
auto input_p_reshaped = input_p.view({1, N * groups, N ? -1 : 1});

auto result_t = batch_norm_jvp(
input_p_reshaped, input_t_reshaped,
/*weight_p=*/{}, /*weight_t=*/{},
/*bias_p=*/{}, /*bias_t=*/{},
/*running_mean=*/{}, /*running_var=*/{},
saved_mean, saved_invstd, /*train=*/true, /*eps=*/0).view(input_shape);

c10::optional<Tensor> result_p = c10::nullopt;
if (weight_p.defined()) {
std::vector<int64_t> view_size(input_t_reshaped.dim(), 1);
view_size[1] = input_t_reshaped.size(1);
result_p = ((input_p_reshaped - saved_mean.view(view_size)) * saved_invstd.view(view_size)).view(input_shape);
}
std::vector<int64_t> affine_param_shape(input_p.dim(), 1);
affine_param_shape[1] = C;

return _affine_jvp(
result_p, result_t,
weight_p.defined() ? weight_p.view(affine_param_shape) : weight_p,
weight_t.defined() ? weight_t.view(affine_param_shape) : weight_t,
bias_t.defined() ? bias_t.view(affine_param_shape) : bias_t);
}

Tensor group_norm_mean_jvp(
const Tensor& input_t, const Tensor& mean_p, int64_t groups) {
int64_t N = input_t.size(0);
int64_t C = input_t.size(1);
std::array<int64_t, 3> view_shape = {1, N * groups, N ? -1 : 1};
auto input_t_reshaped = input_t.view(view_shape);
return input_t_reshaped.mean({2}, false).view_as(mean_p);
}

Tensor group_norm_invstd_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& mean_p, const Tensor& invstd_p,
int64_t groups) {
int64_t N = input_p.size(0);
int64_t C = input_p.size(1);

std::vector<int64_t> view_shape = {1, N * groups, N ? -1 : 1};

auto input_t_reshaped = input_t.view(view_shape);
auto input_p_reshaped = input_p.view(view_shape);

return _invstd_jvp(
input_t_reshaped, input_p_reshaped, mean_p.view(view_shape), invstd_p.view(view_shape),
/*dims=*/{2}, /*numel=*/input_t_reshaped.size(2), /*keepdim=*/false).view_as(invstd_p);
}

Tensor gather_with_keepdimed_indices(const Tensor& input, int64_t dim, const Tensor& indices, bool keepdim) {
auto full_indices = indices;
if (!keepdim) {
Expand Down
47 changes: 47 additions & 0 deletions torch/csrc/autograd/FunctionsManual.h
Expand Up @@ -380,6 +380,53 @@ Tensor _lu_with_info_jvp(
const Tensor& pivs
);

Tensor batch_norm_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const c10::optional<Tensor>& running_mean,
const c10::optional<Tensor>& running_var,
const Tensor& saved_mean, const Tensor& saved_invstd,
bool train,
double eps
);

Tensor batch_norm_jvp_saved_var(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const c10::optional<Tensor>& running_mean,
const c10::optional<Tensor>& running_var,
const Tensor& saved_mean, const Tensor& saved_var,
bool train,
double eps
);

Tensor layer_norm_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const Tensor& saved_mean, const Tensor& saved_invstd,
IntArrayRef normalized_shape
);

Tensor group_norm_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& weight_p, const Tensor& weight_t,
const Tensor& bias_p, const Tensor& bias_t,
const Tensor& saved_mean, const Tensor& saved_invstd,
int64_t groups
);
Tensor group_norm_mean_jvp(
const Tensor& input_t,
const Tensor& mean_p,
int64_t groups
);
Tensor group_norm_invstd_jvp(
const Tensor& input_p, const Tensor& input_t,
const Tensor& mean_p, const Tensor& invstd_p,
int64_t groups
);

Tensor convolution_jvp(
const Tensor& input_p, const Tensor& input_t,
Expand Down
6 changes: 6 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -10921,6 +10921,8 @@ def ref_pairwise_distance(input1, input2):
dtypes=floating_types(),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=[
# RuntimeError: Cannot insert a Tensor that requires grad as a constant.
# Consider making it a parameter or input, or detaching the gradient
Expand All @@ -10932,6 +10934,7 @@ def ref_pairwise_distance(input1, input2):
dtypes=floating_types(),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
decorators=[
# RuntimeError: Cannot insert a Tensor that requires grad as a constant.
# Consider making it a parameter or input, or detaching the gradient
Expand All @@ -10945,6 +10948,7 @@ def ref_pairwise_distance(input1, input2):
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
decorators=[
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-03)}),
Expand Down Expand Up @@ -11611,6 +11615,7 @@ def ref_pairwise_distance(input1, input2):
dtypes=floating_types(),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
sample_inputs_func=sample_inputs_batch_norm),
# This variant tests batch_norm with cuDNN disabled only on CUDA devices
OpInfo('nn.functional.batch_norm',
Expand All @@ -11619,6 +11624,7 @@ def ref_pairwise_distance(input1, input2):
dtypes=empty_types(),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
decorators=[onlyCUDA, disablecuDNN],
sample_inputs_func=sample_inputs_batch_norm),
# We have to add 2 OpInfo entry for `igamma` and `igammac`.First is the
Expand Down