In [1]:
import math
import random

In [222]:
class Value:
    
    def __init__(self, data, prev=None):
        self.data = data
        self.grad = 0.0
        self.prev = prev
        
    def __repr__(self):
        return f"Value({self.data, self.grad, self.prev})"
        
    # ---------------------------------------------------
    # operators
    # ---------------------------------------------------
    def __add__(self, y):
        return Value(self.data + y.data, prev=('+', self, y))

    def __mul__(self, y):
        return Value(self.data * y.data, prev=('*', self, y))
    
    def tanh(self):
        return Value(math.tanh(self.data), prev=('tanh', self))

    # ---------------------------------------------------
    # backward
    # ---------------------------------------------------
    def backward(self):
        def backward_recursive(node):
            assert node is not None
            if node in visited:
                assert False, 'not a DAG'
            else:
                visited.add(node)
            if node.prev is not None:
                op = node.prev[0]
                if op == '+':
                    x, y = node.prev[1:]
                    x.grad += node.grad
                    y.grad += node.grad
                elif op == '*':
                    x, y = node.prev[1:]
                    x.grad += y.data * node.grad
                    y.grad += x.data * node.grad
                elif op == 'tanh':
                    (x,) = node.prev[1:]
                    x.grad += (1 - node.data**2) * node.grad
                else:
                    assert False, f'unsupported op: {op}'
                    
                for nd in node.prev[1:]:
                    backward_recursive(nd)
            
        self.grad = 1.0
        visited = set()
        backward_recursive(self)

In [216]:
u = Value(2)
u

Value((2, 0.0, None))

In [217]:
L = u.tanh()
L

Value((0.9640275800758169, 0.0, ('tanh', Value((2, 0.0, None)))))

In [218]:
L.grad = 1.0
L

Value((0.9640275800758169, 1.0, ('tanh', Value((2, 0.0, None)))))

In [219]:
L.prev

('tanh', Value((2, 0.0, None)))

In [220]:
L.backward()

In [221]:
u

Value((2, 0.07065082485316443, None))

In [205]:
v = Value(3.0)
v

Value((3.0, 0.0, None))

In [206]:
w = Value(5.1)
w

Value((5.1, 0.0, None))

In [207]:
L = u * v + w
L

Value((11.1, 0.0, ('+', Value((6.0, 0.0, ('*', Value((2, 0.07065082485316443, None)), Value((3.0, 0.0, None))))), Value((5.1, 0.0, None)))))