Skip to content

Commit

Permalink
Add scalar.conj() and update backward formulas for add and sub
Browse files Browse the repository at this point in the history
ghstack-source-id: 422aeb2ba1dcdc9d526e2c9b1b73996bbd64831b
Pull Request resolved: #46596
  • Loading branch information
anjali411 committed Oct 20, 2020
1 parent 5003fd1 commit 79cb7e2
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 9 deletions.
10 changes: 10 additions & 0 deletions aten/src/ATen/test/scalar_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,13 @@ TEST(TestScalar, TestScalar) {
ASSERT_EQ(float_one.item<int32_t>(), 1);
ASSERT_EQ(float_one.item<at::Half>(), 1);
}

TEST(TestScalar, TestConj) {
Scalar int_scalar = 257;
Scalar float_scalar = 3.0;
Scalar complex_scalar = c10::complex<double>(2, 3);

ASSERT_EQ(int_scalar.conj().toInt(), 257);
ASSERT_EQ(float_scalar.conj().toDouble(), 3.0);
ASSERT_EQ(complex_scalar.conj().toComplexDouble(), c10::complex<double>(2, -3));
}
10 changes: 10 additions & 0 deletions c10/core/Scalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,14 @@ Scalar Scalar::operator-() const {
}
}

Scalar Scalar::conj() const {
if (isComplex()) {
return Scalar(std::conj(v.z));
} else if (isFloatingPoint()) {
return Scalar(v.d);
} else {
return Scalar(v.i);
}
}

} // namespace c10
2 changes: 1 addition & 1 deletion c10/core/Scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class C10_API Scalar {
}

Scalar operator-() const;

Scalar conj() const;
ScalarType type() const {
if (isComplex()) {
return ScalarType::ComplexDouble;
Expand Down
3 changes: 2 additions & 1 deletion test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4918,7 +4918,8 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
# and only run for floating point

# TODO(@anjali411): add the commented tests back after updating the formula based on tensorflow definition
separate_complex_tests = ['view_as_real', 'real', 'imag', 'asin', 'acos'] # ['log', 'log10', 'log1p', 'log2', 'reciprocal', 'tan']
separate_complex_tests = ['view_as_real', 'real', 'imag', 'asin', 'acos', 'add',
'sub'] # ['log', 'log10', 'log1p', 'log2', 'reciprocal', 'tan']

# NOTE: Some non-holomorphic are separately tested in TestAutogradComplex until gradcheck works properly
# for non-holomorphic functions
Expand Down
12 changes: 6 additions & 6 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,11 @@
self: grad * -((-self * self + 1).rsqrt())

- name: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
self: grad
other: maybe_multiply(grad, alpha)
self: handle_r_to_c(self.scalar_type(), grad)
other: handle_r_to_c(other.scalar_type(), maybe_multiply(grad, alpha.conj()))

- name: add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
self: grad
self: handle_r_to_c(self.scalar_type(), grad)

- name: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
self: maybe_multiply(grad, beta)
Expand Down Expand Up @@ -995,11 +995,11 @@
self: std_backward(result, grad, self, dim, unbiased, keepdim)

- name: sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
self: grad
other: -grad * alpha
self: handle_r_to_c(self.scalar_type(), grad)
other: handle_r_to_c(other.scalar_type(), -grad * alpha.conj())

- name: sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
self: grad
self: handle_r_to_c(self.scalar_type(), grad)

- name: rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
self: -grad * alpha
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@
'neg', 'complex', 'select', '_s_where', 'as_strided', 'slice', 'constant_pad_nd',
'unbind', 'split', 'split_with_sizes', 'unsafe_split', 'split_with_sizes_backward',
'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'ger',
'bmm', 'diagonal'
'bmm', 'diagonal', 'nonzero'
}

# Some operators invalidate the grad_accumulator. Let's reset it.
Expand Down
2 changes: 2 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ def method_tests():
('add', (), ((S, S, S),), 'scalar_broadcast_lhs', (True,)),
('add', (S, S, S), (3.14,), 'constant', (True,)),
('add', (), (3.14,), 'scalar_constant', (True,)),
('add', (S, S, S), (3.14j,), 'complex_scalar_constant', (True,)),
('asinh', (S, S, S), NO_ARGS, ''),
('asinh', (), NO_ARGS, 'scalar'),
('atanh', torch.rand(S, S, S), NO_ARGS, ''),
Expand All @@ -530,6 +531,7 @@ def method_tests():
('sub', (), ((S, S, S),), 'scalar_broadcast_lhs', (True,)),
('sub', (S, S, S), (3.14,), 'constant', (True,)),
('sub', (), (3.14,), 'scalar_constant', (True,)),
('sub', (S, S, S), (3.14j,), 'complex_scalar_constant', (True,)),
('__rsub__', (S, S, S), (3.14,), 'constant', (True, 'aten::rsub')),
('__rsub__', (), (3.14,), 'scalar_constant', (True, 'aten::rsub')),
('mul', (S, S, S), ((S, S, S),), '', (True,)),
Expand Down

0 comments on commit 79cb7e2

Please sign in to comment.