Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion python_bindings/pytaco/pytensor/taco_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,10 @@ def __pow__(self, power, modulo=None):
return tensor_pow(self, power, default_mode)

def __abs__(self):
return tensor_abs(self, default_mode)
return tensor_abs(self, self.format)

def __neg__(self):
return tensor_neg(self, self.format)

def __array__(self):
if not _cm.is_dense(self.format):
Expand Down Expand Up @@ -1482,6 +1485,39 @@ def tensor_logical_not(t1, out_format, dtype=None):
"""
return _compute_unary_elt_eise_op(_cm.logical_not, t1, out_format, dtype)

def tensor_neg(t1, out_format, dtype=None):
"""
Negates every value in the tensor.

The tensor class implements ``__neg__`` using this method.

Parameters
------------
t1: tensor, array_like
input tensor or array_like object

out_format: format, mode_format, optional
* If a :class:`format` is specified, the result tensor is stored in the format out_format.
* If a :class:`mode_format` is specified, the result the result tensor has a with all of the dimensions
stored in the :class:`mode_format` passed in.

dtype: Datatype
The datatype of the output tensor.


Examples
----------
>>> import pytaco as pt
>>> pt.tensor_neg([1, -2, 0], out_format=pt.dense).toarray()
array([-1, 2, 0], dtype=int64)

Returns
--------
neg: tensor
The element wise negation of the input tensor.

"""
return _compute_unary_elt_eise_op(_cm.neg, t1, out_format, dtype)

def tensor_abs(t1, out_format, dtype=None):
"""
Expand Down
4 changes: 4 additions & 0 deletions python_bindings/unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,10 @@ def test_mod(self):
t1[i, j] = pt.remainder(t[i, j], 2)
self.assertEqual(t1, arr % 2)

def test_neg(self):
arr = np.arange(1, 5).reshape([2, 2])
t = pt.from_array(arr)
self.assertEqual(-t, -arr)

class testParsers(unittest.TestCase):

Expand Down