Learnings from chainRule_2.ipynb are implemented here

In [1]:
class Value:

    def __init__(self, data, _children=(), _op='', label=''):
        self.data = data
        self.grad = 0.0 # during forward pass grad must be set to zero
        self._backward = lambda: None
        self._prev = set(_children)
        self._op = _op
        self.label = label

    def __repr__(self):
        return self.label + "[" + str(self.data) + "] " + str(self.grad)

    def __add__(self, other):
        other = other if isinstance(other, Value) else Value(other)
        out = Value(
            _op="+",
            _children=(self, other),
            data=self.data+other.data,
            )
        def backward():
            self.grad += 1.0 * out.grad  # gradients are accumulated here, hence we go for zero_grad()
            other.grad += 1.0 * out.grad # to flush them out, else we will get incorrect results
        self._backward = backward
        return out

    def __mul__(self, other):
        other = other if isinstance(other, Value) else Value(other)
        out = Value(
            _op="*",
            _children=(self, other),
            data=self.data*other.data,
            )
        def backward():
            self.grad = other.data * out.grad # gradients are accumulated here, hence we go for zero_grad()
            other.grad = self.data * out.grad # to flush them out, else we will get incorrect results
        self._backward = backward
        return out

    def backward(self):
        topo = []
        visited = set()
        def build_topo(v):
            if v not in visited: 
                visited.add(v)
                for child in v._prev: build_topo(child)
                topo.append(v)
        build_topo(self)
        self.grad = 1.0  # grad can be implicitly created only for scalar outputs,
                         # in this case LOSS, else chain rule can never begin,
                         # rather gradient descent can never begin
        for node in reversed(topo): node._backward()

In [2]:
a = Value(data=3.00, label="a")  # leaf node
b = Value(-5.00, label="b") # leaf node
print(
"""
    a         b -----
    |         |     |
    ---> + <---     |
         c          |
         |          |
         ---> + <---
              |
              z
"""
)
c = a + b; c.label = "c"
z = c + b; z.label = "z"
z.backward()
for x in [a, b, c, z]: print(x)


    a         b -----
    |         |     |
    ---> + <---     |
         c          |
         |          |
         ---> + <---
              |
              z

a[3.0] 1.0
b[-5.0] 2.0
c[-2.0] 1.0
z[-7.0] 1.0


In [3]:
a = Value(data=3.00, label="a")  # leaf node
b = Value(-5.00, label="b") # leaf node
print(
"""
    a         b -----
    |         |     |
    ---> + <---     |
         c          |
         |          |
         ---> * <---
              |
              z
"""
)
c = a + b; c.label = "c"
z = c * b; z.label = "z"
z.backward()
for x in [a, b, c, z]: print(x)


    a         b -----
    |         |     |
    ---> + <---     |
         c          |
         |          |
         ---> * <---
              |
              z

a[3.0] -5.0
b[-5.0] -7.0
c[-2.0] -5.0
z[10.0] 1.0
