In [1]:
import math
import random

In [12]:
class Value:
    
    def __init__(self, data, op=None, prev=None):
        self.data = data
        self.grad = 0.0
        self.op = op
        self.prev = prev
        
    def __repr__(self):
        # return f"Value({self.data, self.grad, self.op, self.prev})"
        return f"Value({self.data, self.grad})"
        
    # ---------------------------------------------------
    # operators
    # ---------------------------------------------------
    def __add__(self, y):
        return Value(self.data + y.data, op='+', prev=(self, y))

    def __mul__(self, y):
        return Value(self.data * y.data, op='*', prev=(self, y))
    
    def tanh(self):
        return Value(math.tanh(self.data), op='tanh', prev=(self,))

    # ---------------------------------------------------
    # backward
    # ---------------------------------------------------
    def backward(self):
        def populate_gradient(out, op, prev):
            match op:
                case '+':
                    x, y = prev
                    x.grad += out.grad
                    y.grad += out.grad
                case '*':
                    x, y = prev
                    x.grad += y.data * out.grad
                    y.grad += x.data * out.grad
                case 'tanh':
                    (x,) = prev
                    x.grad += (1 - out.data**2) * out.grad
                case _:
                    assert False, f'unsupported op: {op}'        

        def backward_recursive(node):
            assert node is not None
            if node in visited:
                assert False, 'not a DAG'
            else:
                visited.add(node)
            if node.prev:
                populate_gradient(node, node.op, node.prev)
                for nd in node.prev:
                    backward_recursive(nd)
            
        self.grad = 1.0
        visited = set()
        backward_recursive(self)

In [13]:
u = Value(2)
u

Value((2, 0.0))

In [14]:
L = u.tanh()
L

Value((0.9640275800758169, 0.0))

In [15]:
L.grad = 1.0
L

Value((0.9640275800758169, 1.0))

In [16]:
L.prev

(Value((2, 0.0)),)

In [17]:
L.backward()

In [18]:
u

Value((2, 0.07065082485316443))

In [19]:
v = Value(3.0)
v

Value((3.0, 0.0))

In [20]:
w = Value(5.1)
w

Value((5.1, 0.0))

In [21]:
L = u * v + w
L

Value((11.1, 0.0))