Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove requires_grad check for replace and assign #3880

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 20 additions & 0 deletions test/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,26 @@ def test_repr_with_grad(self):
c = (a + b).mean().backward()
print(c)

def test_replace_and_assign_grads(self):
for func in [Tensor.replace, Tensor.assign]:
x = Tensor(x_init, requires_grad=True)
y = Tensor.zeros_like(x, requires_grad=False)
# equivalent to y.replace(x) or y.assign(x)
y = func(y, x)
out = y.sum()
out.backward()

x_torch = torch.tensor(x_init, requires_grad=True)
y_torch = torch.zeros_like(x_torch, requires_grad=False)
y_torch.copy_(x_torch)
out_torch = y.sum()
out_torch.backward()
# grad of sum() is [1, 1, 1]
np.testing.assert_allclose(out.grad.numpy(), out_torch.grad.numpy())
self.assertEqual(y.requires_grad, y_torch.requires_grad)
self.assertEqual(out.requires_grad, out_torch.requires_grad)


@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
class TestMoveTensor(unittest.TestCase):
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
Expand Down
3 changes: 1 addition & 2 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,9 @@ def realize(self) -> Tensor:

def replace(self, x:Tensor) -> Tensor:
# used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
assert not x.requires_grad and getattr(self, '_ctx', None) is None
assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
self.lazydata = x.lazydata
self.requires_grad = x.requires_grad
return self

def assign(self, x) -> Tensor:
Expand All @@ -159,7 +159,6 @@ def assign(self, x) -> Tensor:
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
assert not x.requires_grad # self requires_grad is okay?
if not self.lazydata.is_realized(): return self.replace(x)
self.lazydata = self.lazydata.assign(x.lazydata)
return self
Expand Down