Skip to content

Commit

Permalink
Add variance of symbolic shape
Browse files Browse the repository at this point in the history
  • Loading branch information
marcellofuschi committed May 6, 2024
1 parent 603d3a3 commit c25e381
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
1 change: 0 additions & 1 deletion test/test_tensor_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def test_symbolic_mean(self):
ret = t.mean().item()
assert ret == 1

@unittest.skip("symbolic var isn't supported")
def test_symbolic_var(self):
vv = Variable("a", 1, 10)
vv.bind(2)
Expand Down
10 changes: 6 additions & 4 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from tinygrad.ops import LoadOps, ScheduleItem
from tinygrad.buffer import Buffer, BufferOptions
from tinygrad.device import Device
from tinygrad.shape.symbolic import sint, Variable, MulNode, Node
from tinygrad.shape.symbolic import sint, Variable, MulNode, SumNode, NumNode, Node
from tinygrad.engine.realize import run_schedule, memory_planner
from tinygrad.engine.schedule import create_schedule_with_vars

Expand Down Expand Up @@ -271,6 +271,8 @@ def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None):
@staticmethod
def from_node(y:Node, **kwargs) -> Tensor:
if isinstance(y, MulNode): return Tensor.from_node(y.a, **kwargs) * y.b
if isinstance(y, SumNode): return Tensor.from_node(y.nodes[0], **kwargs) + sum(y.nodes[1:])
if isinstance(y, NumNode): return Tensor(y.b, **kwargs, requires_grad=False)
if isinstance(y, Variable): return Tensor(y, **kwargs, requires_grad=False)
raise RuntimeError(f"unhandled Node {y}")

Expand Down Expand Up @@ -928,9 +930,9 @@ def mean(self, axis=None, keepdim=False):
out = self.sum(axis=axis, keepdim=keepdim)
return out.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so]))
def var(self, axis=None, keepdim=False, correction=1):
assert all_int(self.shape), "does not support symbolic shape"
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
return square_sum.div(max(0, prod(self.shape)/prod(square_sum.shape)-correction))
squares = (self - self.mean(axis=axis, keepdim=True)).square()
n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if si != so])
return squares.sum(axis=axis, keepdim=keepdim).div(max(0, n-correction))
def std(self, axis=None, keepdim=False, correction=1): return self.var(axis, keepdim, correction).sqrt()

def _softmax(self, axis):
Expand Down

0 comments on commit c25e381

Please sign in to comment.