Skip to content

Commit

Permalink
setitem in-place operator tests (#4577)
Browse files Browse the repository at this point in the history
* tests and error

* rename to in-place

* add a note

* more comments

* more comments

* disable folded advanced setitem tests for now
  • Loading branch information
geohotstan committed May 14, 2024
1 parent 0fa57b8 commit 089eeec
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 0 deletions.
6 changes: 6 additions & 0 deletions test/imported/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,6 +1094,8 @@ def test_getitem_scalars(self):
r[zero]
numpy_testing_assert_equal_helper(r, r[...])

# TODO fancy setitem
'''
def test_setitem_scalars(self):
zero = Tensor(0, dtype=dtypes.int64)
Expand Down Expand Up @@ -1121,6 +1123,7 @@ def test_setitem_scalars(self):
# TODO: weird inaccuracy Max relative difference: 3.85322971e-08
# numpy_testing_assert_equal_helper(9.9, r)
np.testing.assert_allclose(9.9, r, rtol=1e-7)
'''

def test_basic_advanced_combined(self):
# From the NumPy indexing example
Expand Down Expand Up @@ -1518,7 +1521,10 @@ def test_everything_returns_views(self):
def test_broaderrors_indexing(self):
a = Tensor.zeros(5, 5)
self.assertRaises(IndexError, a.__getitem__, ([0, 1], [0, 1, 2]))
# TODO: fancy setitem
'''
self.assertRaises(IndexError, a.contiguous().__setitem__, ([0, 1], [0, 1, 2]), 0)
'''

# TODO out of bound getitem does not raise error
'''
Expand Down
35 changes: 35 additions & 0 deletions test/test_setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,41 @@ def test_setitem_into_noncontiguous(self):
assert not t.lazydata.st.contiguous
with self.assertRaises(AssertionError): t[1] = 5

def test_setitem_inplace_operator(self):
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] += 2
np.testing.assert_allclose(t.numpy(), [[0, 1], [4, 5]])

t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] -= 1
np.testing.assert_allclose(t.numpy(), [[0, 1], [1, 2]])

t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] *= 2
np.testing.assert_allclose(t.numpy(), [[0, 1], [4, 6]])

# NOTE: have to manually cast setitem target to least_upper_float for div
t = Tensor.arange(4, dtype=dtypes.float).reshape(2, 2).contiguous()
t[1] /= 2
np.testing.assert_allclose(t.numpy(), [[0, 1], [1, 1.5]])

t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] **= 2
np.testing.assert_allclose(t.numpy(), [[0, 1], [4, 9]])

t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] ^= 5
np.testing.assert_allclose(t.numpy(), [[0, 1], [7, 6]])

@unittest.expectedFailure
def test_setitem_consecutive_inplace_operator(self):
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] += 2
t = t.contiguous()
# TODO: RuntimeError: must be contiguous for assign ShapeTracker(views=(View(shape=(2,), strides=(1,), offset=2, mask=None, contiguous=False),))
t[1] -= 1
np.testing.assert_allclose(t.numpy(), [[0, 1], [3, 4]])

# TODO: implement fancy setitem
@unittest.expectedFailure
def test_fancy_setitem(self):
Expand Down
2 changes: 2 additions & 0 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,8 @@ def __setitem__(self, indices, v:Union[Tensor, ConstType]) -> None:
if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
if isinstance(indices, (Tensor, list)) or (isinstance(indices, tuple) and any(isinstance(i, (Tensor, list)) for i in indices)):
raise NotImplementedError("Advanced indexing setitem is not currently supported")

assign_to = self.realize().__getitem__(indices)
# NOTE: contiguous to prevent const folding.
Expand Down

0 comments on commit 089eeec

Please sign in to comment.