In [1]:
import math

In [38]:
class Value:
    
    def __init__(self, data, op=None, prev=None):
        self.data = data
        self.grad = 0.0
        self.op = op
        self.prev = prev
        
    def __repr__(self):
        # return f"Value({self.data, self.grad, self.op, self.prev})"
        return f"Value({self.data, self.grad})"
        
    # ---------------------------------------------------
    # operators
    # ---------------------------------------------------
    def __add__(self, y):
        return Value(self.data + y.data, op='+', prev=(self, y))

    def __mul__(self, y):
        return Value(self.data * y.data, op='*', prev=(self, y))
    
    def tanh(self):
        return Value(math.tanh(self.data), op='tanh', prev=(self,))

    # ---------------------------------------------------
    # backward
    # ---------------------------------------------------
    def backward(self):
        def populate_gradients(out, op, prev):
            if prev is None:
                return
            match op:
                case '+':
                    x, y = prev
                    x.grad += out.grad
                    y.grad += out.grad
                case '*':
                    x, y = prev
                    x.grad += y.data * out.grad
                    y.grad += x.data * out.grad
                case 'tanh':
                    (x,) = prev
                    x.grad += (1 - out.data**2) * out.grad
                case _:
                    assert False, f'unsupported op: {op}'        

        def build_topo(v):
            if v is None:
                return
            if v not in visited:
                visited.add(v)
                if v.prev:
                    for child in set(v.prev):
                        build_topo(child)
                topo.append(v)

        topo = []
        visited = set()
        build_topo(self)

        self.grad = 1.0
        for node in reversed(topo):
            populate_gradients(node, node.op, node.prev)

In [39]:
u = Value(2)
u

Value((2, 0.0))

In [40]:
v = Value(3.0)
v

Value((3.0, 0.0))

In [41]:
w = Value(5.1)
w

Value((5.1, 0.0))

In [42]:
L = u.tanh()
L

Value((0.9640275800758169, 0.0))

In [43]:
L.grad = 1.0
L

Value((0.9640275800758169, 1.0))

In [44]:
L.prev

(Value((2, 0.0)),)

In [45]:
u.grad = 0

In [46]:
L.backward()

In [47]:
u

Value((2, 0.07065082485316443))

In [12]:
L = u + v

In [13]:
u.grad = 0
v.grad = 0

In [14]:
L.backward()

In [15]:
L

Value((5.0, 1.0))

In [16]:
u, v

(Value((2, 1.0)), Value((3.0, 1.0)))

In [17]:
L = u + u
L

Value((4, 0.0))

In [18]:
u.grad = 0
v.grad = 0

In [19]:
L.backward()

In [20]:
L

Value((4, 1.0))

In [21]:
v

Value((3.0, 0))

In [22]:
u

Value((2, 2.0))

In [23]:
w = Value(5.1)
w

Value((5.1, 0.0))

In [24]:
L = u * v + w
L

Value((11.1, 0.0))

In [25]:
u.grad = 0
v.grad = 0
w.grad = 0

In [26]:
L.backward()

In [27]:
u, v, w

(Value((2, 3.0)), Value((3.0, 2.0)), Value((5.1, 1.0)))