In [1]:
import math
# import random

In [42]:
class Value:
    
    def __init__(self, data, op=None, prev=None):
        self.data = data
        self.grad = 0.0
        self.op = op
        self.prev = prev
        
    def __repr__(self):
        # return f"Value({self.data, self.grad, self.op, self.prev})"
        return f"Value({self.data, self.grad})"
        
    # ---------------------------------------------------
    # operators
    # ---------------------------------------------------
    def __add__(self, y):
        return Value(self.data + y.data, op='+', prev=(self, y))

    def __mul__(self, y):
        return Value(self.data * y.data, op='*', prev=(self, y))
    
    def tanh(self):
        return Value(math.tanh(self.data), op='tanh', prev=(self,))

    # ---------------------------------------------------
    # backward
    # ---------------------------------------------------
    def backward(self):
        def populate_gradient(out, op, prev):
            match op:
                case '+':
                    x, y = prev
                    x.grad += out.grad
                    y.grad += out.grad
                case '*':
                    x, y = prev
                    x.grad += y.data * out.grad
                    y.grad += x.data * out.grad
                case 'tanh':
                    (x,) = prev
                    x.grad += (1 - out.data**2) * out.grad
                case _:
                    assert False, f'unsupported op: {op}'        

        def backward_recursive(node, parents):
            if node in parents:
                assert False, 'not a DAG' ## ???
            assert node is not None
            if node in visited:
                return
            else:
                visited.add(node)
            if node.prev:
                populate_gradient(node, node.op, node.prev)
                for nd in node.prev:
                    p2 = parents.copy()
                    p2.add(node)
                    backward_recursive(nd, p2)
            
        self.grad = 1.0
        visited = set()
        backward_recursive(self, set())

In [31]:
u = Value(2)
u

Value((2, 0.0))

In [32]:
L = u.tanh()
L

Value((0.9640275800758169, 0.0))

In [33]:
L.grad = 1.0
L

Value((0.9640275800758169, 1.0))

In [34]:
L.prev

(Value((2, 0.0)),)

In [35]:
L.backward()

In [36]:
u

Value((2, 0.07065082485316443))

In [37]:
v = Value(3.0)
v

Value((3.0, 0.0))

In [53]:
L = u + v
L

Value((5.0, 0.0))

In [55]:
u.grad = 0
v.grad = 0

In [56]:
L.backward()

In [57]:
L

Value((5.0, 1.0))

In [58]:
v

Value((3.0, 1.0))

In [59]:
u

Value((2, 1.0))

In [60]:
w = Value(5.1)
w

Value((5.1, 0.0))

In [61]:
L = u * v + w
L

Value((11.1, 0.0))

In [62]:
u.grad = 0
v.grad = 0
w.grad = 0

In [63]:
L.backward()

In [64]:
u, v, w

(Value((2, 3.0)), Value((3.0, 2.0)), Value((5.1, 1.0)))