diff --git a/test/test_tensor_variable.py b/test/test_tensor_variable.py index 45b19928de841..6f90d790a2473 100644 --- a/test/test_tensor_variable.py +++ b/test/test_tensor_variable.py @@ -32,7 +32,6 @@ def test_symbolic_mean(self): ret = t.mean().item() assert ret == 1 - @unittest.skip("symbolic var isn't supported") def test_symbolic_var(self): vv = Variable("a", 1, 10) vv.bind(2) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e25284105dc86..a21e1400fcf47 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -14,7 +14,7 @@ from tinygrad.ops import LoadOps, ScheduleItem from tinygrad.buffer import Buffer, BufferOptions from tinygrad.device import Device -from tinygrad.shape.symbolic import sint, Variable, MulNode, Node +from tinygrad.shape.symbolic import sint, Variable, MulNode, SumNode, NumNode, Node from tinygrad.engine.realize import run_schedule, memory_planner from tinygrad.engine.schedule import create_schedule_with_vars @@ -271,6 +271,8 @@ def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None): @staticmethod def from_node(y:Node, **kwargs) -> Tensor: if isinstance(y, MulNode): return Tensor.from_node(y.a, **kwargs) * y.b + if isinstance(y, SumNode): return Tensor.from_node(y.nodes[0], **kwargs) + sum(y.nodes[1:]) + if isinstance(y, NumNode): return Tensor(y.b, **kwargs, requires_grad=False) if isinstance(y, Variable): return Tensor(y, **kwargs, requires_grad=False) raise RuntimeError(f"unhandled Node {y}") @@ -928,9 +930,9 @@ def mean(self, axis=None, keepdim=False): out = self.sum(axis=axis, keepdim=keepdim) return out.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so])) def var(self, axis=None, keepdim=False, correction=1): - assert all_int(self.shape), "does not support symbolic shape" - square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim) - return square_sum.div(max(0, prod(self.shape)/prod(square_sum.shape)-correction)) + squares = (self - self.mean(axis=axis, keepdim=True)).square() + n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if si != so]) + return squares.sum(axis=axis, keepdim=keepdim).div(max(0, n-correction)) def std(self, axis=None, keepdim=False, correction=1): return self.var(axis, keepdim, correction).sqrt() def _softmax(self, axis):