Skip to content

Commit

Permalink
Update backward formulas (Re #44444) (#46275)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #46275

Re #44444

Test Plan: Imported from OSS

Reviewed By: zou3519

Differential Revision: D24285785

Pulled By: anjali411

fbshipit-source-id: c60ecd4fe4f144132085f2c91d3b950e92b2a491
  • Loading branch information
anjali411 authored and facebook-github-bot committed Oct 26, 2020
1 parent edbc84a commit d94bd99
Show file tree
Hide file tree
Showing 12 changed files with 144 additions and 86 deletions.
10 changes: 4 additions & 6 deletions aten/src/ATen/native/Pow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,11 @@ Tensor& pow_out(Tensor& result, const Tensor& base, Scalar exp) {
"result type ", common_dtype, "can't be cast to the desired output type ",
result.scalar_type());

if (exp.isComplex() && (exp.toComplexDouble() == 0.0) ) {
result.resize_as_(base).fill_(1);
} else if (exp.isComplex() && (exp.toComplexDouble() == 1.0) ) {
result.resize_as_(base).fill_(base);
} else if (!exp.isComplex() && (exp.toDouble() == 0.0)) {
auto exponent = (exp.isComplex()) ? exp.toComplexDouble() : exp.toDouble();

if (exponent == 0.0) {
result.resize_as_(base).fill_(1);
} else if (!exp.isComplex() && (exp.toDouble() == 1.0)) {
} else if (exponent == 1.0) {
result.resize_as_(base).copy_(base);
} else {
auto iter = TensorIterator::unary_op(result, base.to(common_dtype));
Expand Down
22 changes: 18 additions & 4 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,17 +589,31 @@ void logit_backward_kernel(TensorIterator& iter, Scalar eps_scalar) {
}

void tanh_backward_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "tanh_backward_cpu", [&]() {
auto one_vec = Vec256<scalar_t>(scalar_t{1});
if (isComplexType(iter.dtype())) {
AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "tanh_backward_cpu", [&]() {
auto one_vec = Vec256<scalar_t>(scalar_t{1});
cpu_kernel_vec(
iter,
[=](scalar_t a, scalar_t b) -> scalar_t {
return a * (scalar_t{1} - b * b);
return a * std::conj(scalar_t{1} - b * b);
},
[=](Vec256<scalar_t> a, Vec256<scalar_t> b) {
return a * (one_vec - b * b);
return a * (one_vec - b * b).conj();
});
});
} else {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "tanh_backward_cpu", [&]() {
auto one_vec = Vec256<scalar_t>(scalar_t{1});
cpu_kernel_vec(
iter,
[=](scalar_t a, scalar_t b) -> scalar_t {
return a * (scalar_t{1} - b * b);
},
[=](Vec256<scalar_t> a, Vec256<scalar_t> b) {
return a * (one_vec - b * b);
});
});
}
}

void mse_kernel(TensorIterator& iter) {
Expand Down
16 changes: 12 additions & 4 deletions aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,21 @@ void logit_backward_kernel_cuda(TensorIterator& iter, Scalar eps_scalar) {
}

void tanh_backward_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "tanh_backward_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "tanh_backward_cuda", [&] {
if(isComplexType(iter.dtype())) {
AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "tanh_backward_complex_cuda", [&]() {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a * (scalar_t{1.} - b * b);
return a * std::conj(scalar_t{1.} - b * b);
});
});
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "tanh_backward_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "tanh_backward_cuda", [&] {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a * (scalar_t{1.} - b * b);
});
});
});
}
}

REGISTER_DISPATCH(sigmoid_backward_stub, &sigmoid_backward_kernel_cuda);
Expand Down
6 changes: 3 additions & 3 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4917,8 +4917,8 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
# the tests for these ops which do not have 'complex' in variant should not run for complex
# and only run for floating point

# TODO(@anjali411): add the commented tests back after updating the formula based on tensorflow definition
separate_complex_tests = ['view_as_real', 'real', 'imag', 'asin', 'acos'] # ['log', 'log10', 'log1p', 'log2', 'reciprocal', 'tan']
separate_complex_tests = ['view_as_real', 'real', 'imag', 'asin', 'acos', 'div', 'log',
'log10', 'log1p', 'log2', 'pow', 'tan', 'reciprocal', 'rsqrt']

# NOTE: Some non-holomorphic are separately tested in TestAutogradComplex until gradcheck works properly
# for non-holomorphic functions
Expand All @@ -4930,7 +4930,7 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
'chunk', 'split', 'split_with_sizes', 'repeat', 'expand', 'zero_',
'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'sin', 'cos', 'mul', 'sinh',
'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split',
'matmul', 'bmm', 'mv', 'ger', 'diagonal', ] + separate_complex_tests
'matmul', 'bmm', 'mv', 'ger', 'diagonal', 'atan', 'angle'] + separate_complex_tests

# this list corresponds to cases that are not currently implemented
skip_cuda_list = ['bmm_complex', 'matmul_4d_4d_complex']
Expand Down
3 changes: 2 additions & 1 deletion test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.testing._internal.common_methods_invocations import tri_tests_args, tri_large_tests_args, \
_compare_trilu_indices, _compare_large_trilu_indices
from torch.testing._internal.common_utils import TestCase, get_gpu_type, freeze_rng_state, run_tests, \
NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests, IS_SANDCASTLE, \
NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests, IS_SANDCASTLE, IS_WINDOWS, \
slowTest, skipCUDANonDefaultStreamIf, TEST_WITH_ROCM, TEST_NUMPY
from torch.testing._internal.autocast_test_lists import AutocastTestLists

Expand Down Expand Up @@ -2151,6 +2151,7 @@ def run(data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api):

self._run_scaling_case(run, unskipped=3, skipped=1)

@unittest.skipIf(IS_WINDOWS, 'FIXME: fix this test for Windows')
def test_grad_scaling_penalty(self):
def run(data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api):
for i, (input, target) in enumerate(data):
Expand Down
2 changes: 1 addition & 1 deletion test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15631,7 +15631,7 @@ def add_autograd_test(

# Disable complex tests
# TODO: Add complex support for jit
if 'complex' in variant_name or name in ['view_as_complex', 'complex']:
if 'complex' in variant_name or name in ['view_as_complex', 'complex', 'angle']:
return

# Skips aliases, which are tested in test_op_aliases.py
Expand Down
18 changes: 6 additions & 12 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,41 +89,35 @@ def _gradgrad_test_helper(self, device, dtype, op, variant):
return self._check_helper(device, dtype, op, variant, 'gradgradcheck')

# Tests that gradients are computed correctly
# TODO(@anjali411) enable this for torch.cdouble.
@dtypes(torch.double)
@dtypes(torch.double, torch.cdouble)
@ops(op_db)
def test_fn_grad(self, device, dtype, op):
self._grad_test_helper(device, dtype, op, op.get_op())

# TODO(@anjali411) enable this for torch.cdouble.
@dtypes(torch.double)
@dtypes(torch.double, torch.cdouble)
@ops(op_db)
def test_method_grad(self, device, dtype, op):
self._grad_test_helper(device, dtype, op, op.get_method())

# TODO(@anjali411) enable this for torch.cdouble.
@dtypes(torch.double)
@dtypes(torch.double, torch.cdouble)
@ops(op_db)
def test_inplace_grad(self, device, dtype, op):
if not op.test_inplace_grad:
self.skipTest("Skipped! Inplace gradcheck marked to skip.")
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))

# TODO(@anjali411) enable this for torch.cdouble.
# Test that gradients of gradients are computed correctly
@dtypes(torch.double)
@dtypes(torch.double, torch.cdouble)
@ops(op_db)
def test_fn_gradgrad(self, device, dtype, op):
self._gradgrad_test_helper(device, dtype, op, op.get_op())

# TODO(@anjali411) enable this for torch.cdouble.
@dtypes(torch.double)
@dtypes(torch.double, torch.cdouble)
@ops(op_db)
def test_method_gradgrad(self, device, dtype, op):
self._gradgrad_test_helper(device, dtype, op, op.get_method())

# TODO(@anjali411) enable this for torch.cdouble.
@dtypes(torch.double)
@dtypes(torch.double, torch.cdouble)
@ops(op_db)
def test_inplace_gradgrad(self, device, dtype, op):
if not op.test_inplace_grad:
Expand Down
30 changes: 15 additions & 15 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@
self: grad * self.sgn()

- name: acos(Tensor self) -> Tensor
self: grad * -((-self * self + 1).rsqrt())
self: grad * -((-self * self + 1).rsqrt()).conj()

- name: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
self: grad
Expand Down Expand Up @@ -213,7 +213,7 @@
self: grad

- name: angle(Tensor self) -> Tensor
self: grad.to(self.scalar_type()) * (self*Scalar(c10::complex<double>{0.0, 1.0})).conj() / self.abs().pow(2)
self: angle_backward(grad, self)

# The four items below are necessary because TensorIterator doesn't work on
# Variables (codegen does not unwrap the input Tensor for all() and any() ).
Expand All @@ -230,19 +230,19 @@
self: not_implemented("all")

- name: acosh(Tensor self) -> Tensor
self: grad * (self.pow(2) - 1).rsqrt()
self: grad * (self.pow(2) - 1).rsqrt().conj()

- name: acosh_(Tensor(a!) self) -> Tensor(a!)
self: not_implemented("inplace version of acosh")

- name: asinh(Tensor self) -> Tensor
self: grad * (self.pow(2) + 1).rsqrt()
self: grad * (self.pow(2) + 1).rsqrt().conj()

- name: asinh_(Tensor(a!) self) -> Tensor(a!)
self: not_implemented("inplace version of asinh")

- name: atanh(Tensor self) -> Tensor
self: grad * 1 / (1 - self.pow(2))
self: grad * 1 / (1 - self.pow(2)).conj()

- name: atanh_(Tensor(a!) self) -> Tensor(a!)
self: not_implemented("inplace version of atanh")
Expand All @@ -251,10 +251,10 @@
self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset)

- name: asin(Tensor self) -> Tensor
self: grad * (-self * self + 1).rsqrt()
self: grad * (-self * self + 1).rsqrt().conj()

- name: atan(Tensor self) -> Tensor
self: grad / (self * self + 1)
self: grad / (self * self + 1).conj()

- name: atan2(Tensor self, Tensor other) -> Tensor
self, other: atan2_backward(grad, self, other, grad_input_mask)
Expand Down Expand Up @@ -610,16 +610,16 @@
self: grad * polygamma(n + 1, self)

- name: log(Tensor self) -> Tensor
self: grad.div(self)
self: grad.div(self.conj())

- name: log10(Tensor self) -> Tensor
self: grad / (self * 2.3025850929940456)
self: grad / (self.conj() * 2.3025850929940456)

- name: log1p(Tensor self) -> Tensor
self: log1p_backward(grad, self)

- name: log2(Tensor self) -> Tensor
self: grad / (self * 0.6931471805599453)
self: grad / (self.conj() * 0.6931471805599453)

- name: logaddexp(Tensor self, Tensor other) -> Tensor
self: grad / (1 + exp(other - self))
Expand Down Expand Up @@ -884,7 +884,7 @@
self: zeros_like(grad)

- name: reciprocal(Tensor self) -> Tensor
self: -grad * result * result
self: -grad * (result * result).conj()

- name: remainder.Scalar(Tensor self, Scalar other) -> Tensor
self: grad
Expand All @@ -909,7 +909,7 @@
self: zeros_like(grad)

- name: rsqrt(Tensor self) -> Tensor
self: -0.5 * grad * result.pow(3)
self: -0.5 * grad * result.pow(3).conj()

- name: scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)
self: grad.clone().scatter_(dim, index, 0)
Expand Down Expand Up @@ -1046,7 +1046,7 @@
index: non_differentiable

- name: tan(Tensor self) -> Tensor
self: grad * (1 + result.pow(2))
self: grad * (1 + result.pow(2)).conj()

- name: tanh(Tensor self) -> Tensor
self: tanh_backward(grad, result)
Expand Down Expand Up @@ -1670,8 +1670,8 @@
output: grad * grad_output * (-2 * output + 1)

- name: tanh_backward(Tensor grad_output, Tensor output) -> Tensor
grad_output: tanh_backward(grad, output)
output: -2 * output * grad * grad_output
grad_output: tanh_backward(grad, output.conj())
output: grad.conj() * (-2 * output.conj() * grad_output)

# cudnn
- name: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor)
Expand Down
5 changes: 4 additions & 1 deletion tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@
'neg', 'complex', 'select', '_s_where', 'as_strided', 'slice', 'constant_pad_nd',
'unbind', 'split', 'split_with_sizes', 'unsafe_split', 'split_with_sizes_backward',
'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'ger',
'bmm', 'diagonal'
'bmm', 'diagonal', 'cholesky', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal',
'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh'
'dot', 'vdot', 'cholesky', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal',
'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh'
}

# Some operators invalidate the grad_accumulator. Let's reset it.
Expand Down

0 comments on commit d94bd99

Please sign in to comment.