<a href="https://colab.research.google.com/github/ratulb/llmlite.mojo/blob/main/Tensor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
class Tensor:
    def __init__(self, value, requires_grad=False):
        self.value = value
        self.requires_grad = requires_grad
        self.grad = 0.0
        self._backward = lambda: None
        self._prev = set()
        self.name = ""

    def __add__(self, other):
        out = Tensor(self.value + other.value, self.requires_grad or other.requires_grad)

        def _backward():
            if self.requires_grad:
                self.grad += out.grad
            if other.requires_grad:
                other.grad += out.grad

        out._backward = _backward
        out._prev = {self, other}
        return out

    def __mul__(self, other):
        out = Tensor(self.value * other.value, self.requires_grad or other.requires_grad)

        def _backward():
            if self.requires_grad:
                self.grad += other.value * out.grad
            if other.requires_grad:
                other.grad += self.value * out.grad

        out._backward = _backward
        out._prev = {self, other}
        return out

    def backward(self):
        topo = []
        visited = set()

        def build_topo(tensor):
            if tensor not in visited:
                visited.add(tensor)
                for parent in tensor._prev:
                    build_topo(parent)
                topo.append(tensor)

        build_topo(self)

        self.grad = 1.0  # starting gradient
        for t in reversed(topo):
            t._backward()


In [11]:
# Setup
A = Tensor(2.0, requires_grad=True); A.name = 'A'
B = Tensor(3.0, requires_grad=True); B.name = 'B'
D = Tensor(4.0, requires_grad=True); D.name = 'D'

# Graph
C = A + B; C.name = 'C'
E = C + D; E.name = 'E'
F = E + A; F.name = 'F'
G = F * A; G.name = 'G'

#G = (A + B + D + A) * A

# Backward
G.backward()

# Print grads
for t in [A, B, C, D, E, F, G]:
    print(f"{t.name}: value={t.value}, grad={t.grad}, requires_grad={t.requires_grad}")


A: value=2.0, grad=15.0, requires_grad=True
B: value=3.0, grad=2.0, requires_grad=True
C: value=5.0, grad=2.0, requires_grad=True
D: value=4.0, grad=2.0, requires_grad=True
E: value=9.0, grad=2.0, requires_grad=True
F: value=11.0, grad=2.0, requires_grad=True
G: value=22.0, grad=1.0, requires_grad=True
