diff --git a/python_bindings/pytaco/pytensor/taco_tensor.py b/python_bindings/pytaco/pytensor/taco_tensor.py index e57ea1ff9..1045e7423 100644 --- a/python_bindings/pytaco/pytensor/taco_tensor.py +++ b/python_bindings/pytaco/pytensor/taco_tensor.py @@ -374,7 +374,10 @@ def __pow__(self, power, modulo=None): return tensor_pow(self, power, default_mode) def __abs__(self): - return tensor_abs(self, default_mode) + return tensor_abs(self, self.format) + + def __neg__(self): + return tensor_neg(self, self.format) def __array__(self): if not _cm.is_dense(self.format): @@ -1482,6 +1485,39 @@ def tensor_logical_not(t1, out_format, dtype=None): """ return _compute_unary_elt_eise_op(_cm.logical_not, t1, out_format, dtype) +def tensor_neg(t1, out_format, dtype=None): + """ + Negates every value in the tensor. + + The tensor class implements ``__neg__`` using this method. + + Parameters + ------------ + t1: tensor, array_like + input tensor or array_like object + + out_format: format, mode_format, optional + * If a :class:`format` is specified, the result tensor is stored in the format out_format. + * If a :class:`mode_format` is specified, the result the result tensor has a with all of the dimensions + stored in the :class:`mode_format` passed in. + + dtype: Datatype + The datatype of the output tensor. + + + Examples + ---------- + >>> import pytaco as pt + >>> pt.tensor_neg([1, -2, 0], out_format=pt.dense).toarray() + array([-1, 2, 0], dtype=int64) + + Returns + -------- + neg: tensor + The element wise negation of the input tensor. + + """ + return _compute_unary_elt_eise_op(_cm.neg, t1, out_format, dtype) def tensor_abs(t1, out_format, dtype=None): """ diff --git a/python_bindings/unit_tests.py b/python_bindings/unit_tests.py index cafecf1b4..9bf657521 100644 --- a/python_bindings/unit_tests.py +++ b/python_bindings/unit_tests.py @@ -251,6 +251,10 @@ def test_mod(self): t1[i, j] = pt.remainder(t[i, j], 2) self.assertEqual(t1, arr % 2) + def test_neg(self): + arr = np.arange(1, 5).reshape([2, 2]) + t = pt.from_array(arr) + self.assertEqual(-t, -arr) class testParsers(unittest.TestCase):