In [1]:
import numpy as np

First, let me start this with something simple.

Let me consider y = x*2. If I want to compute dy/dx, I know the answer should be 2. But how would an autograd system compute this?

In [2]:
# Value class

class Value:
    def __init__(self, data):
        self.data = data
        self.grad = 0
        self._op = None
        self._inputs = []
    
    def __mul__(self, other):
        if not isinstance(other, Value):
            other = Value(other)
        result = Value(self.data * other.data)
        result._op = 'mul'
        result._inputs = [self, other]
        return result
    
    def backward(self):
        self.grad = 1

        topo = []
        visited = set()
        def build_topo(v):
            if v not in visited:
                visited.add(v)
                for child in v._inputs:
                    build_topo(child)
                topo.append(v)
                
        build_topo(self)

        for v in reversed(topo):
            if v._op == 'mul':
                a, b = v._inputs
                a.grad += v.grad * b.data
                b.grad += v.grad * a.data

In [3]:
x = Value(3)
y = x * 2
z = y * 5

In [4]:
z.backward()

In [5]:
print(y.grad)

5


In [6]:
print(x.grad)

10


Time to add addition.

In [18]:
# Value class

class Value:
    def __init__(self, data, label = ''):
        self.data = data
        self.grad = 0
        self.label = label
        self._op = None
        self._prev = []

    def __add__(self, other):
        if not isinstance(other, Value):
            other = Value(other)
        result = Value(self.data + other.data, label='+')
        result._op = 'add'
        result._prev = [self, other]
        return result
    
    def __mul__(self, other):
        if not isinstance(other, Value):
            other = Value(other)
        result = Value(self.data * other.data, label='*')
        result._op = 'mul'
        result._prev = [self, other]
        return result
    
    def backward(self):
        self.grad = 1

        topo = []
        visited = set()
        def build_topo(v):
            if v not in visited:
                visited.add(v)
                for child in v._prev:
                    build_topo(child)
                topo.append(v)
                
        build_topo(self)

        for v in reversed(topo):
            if v._op == 'mul':
                a, b = v._prev
                a.grad += v.grad * b.data # here v.grad is an upstream gradient, b.data is the local gradient
                b.grad += v.grad * a.data
            
            if v._op == 'add':
                a, b = v._prev
                a.grad += v.grad * 1 # here v.grad is an upstream gradient, 1 is the local gradient
                b.grad += v.grad * 1

In [19]:
from graphviz import Digraph

def trace(root):
    nodes, edges = set(), set()
    def build(v):
        if v not in nodes:
            nodes.add(v)
            for child in v._prev:
                edges.add((child, v))
                build(child)
    build(root)
    return nodes, edges

def draw_dot(root):
    dot = Digraph(format='svg', graph_attr={'rankdir': 'LR'})
    nodes, edges = trace(root)
    for n in nodes:
        uid = str(id(n))
        dot.node(name=uid, label=f"{{ data: {n.data:.4f} | grad: {n.grad:.4f} }}", shape='record')
        if n._op:
            dot.node(name=uid + n._op, label=n._op)
            dot.edge(uid + n._op, uid)
    for n1, n2 in edges:
        dot.edge(str(id(n1)), str(id(n2)) + n2._op)
    return dot

In [20]:
a = Value(2)
b = a + 3  # b = 2 + 3 = 5
c = a * 2  # c = 2 * 2 = 4
d = b + c  # d = 5 + 4 = 9

In [21]:
def draw_ascii(root):
    nodes, edges = trace(root)
    print("Computation Graph:")
    for node in nodes:
        print(f"Node {id(node)} [data={node.data:.4f}, grad={node.grad:.4f}]")
    for (n1, n2) in edges:
        print(f"Node {id(n1)} -> Node {id(n2)} (label: {n2.label})")

In [22]:
draw_ascii(d)

Computation Graph:
Node 4364205600 [data=9.0000, grad=0.0000]
Node 4365091456 [data=4.0000, grad=0.0000]
Node 4364205936 [data=2.0000, grad=0.0000]
Node 4365093280 [data=2.0000, grad=0.0000]
Node 4365437872 [data=3.0000, grad=0.0000]
Node 4365608896 [data=5.0000, grad=0.0000]
Node 4365091456 -> Node 4364205600 (label: +)
Node 4365093280 -> Node 4365608896 (label: +)
Node 4364205936 -> Node 4365091456 (label: *)
Node 4365093280 -> Node 4365091456 (label: *)
Node 4365437872 -> Node 4365608896 (label: +)
Node 4365608896 -> Node 4364205600 (label: +)


In [23]:
d.grad = 1
d.backward()  # backward pass
draw_ascii(d)

Computation Graph:
Node 4364205600 [data=9.0000, grad=1.0000]
Node 4365091456 [data=4.0000, grad=1.0000]
Node 4364205936 [data=2.0000, grad=2.0000]
Node 4365093280 [data=2.0000, grad=3.0000]
Node 4365437872 [data=3.0000, grad=1.0000]
Node 4365608896 [data=5.0000, grad=1.0000]
Node 4365091456 -> Node 4364205600 (label: +)
Node 4365093280 -> Node 4365608896 (label: +)
Node 4364205936 -> Node 4365091456 (label: *)
Node 4365093280 -> Node 4365091456 (label: *)
Node 4365437872 -> Node 4365608896 (label: +)
Node 4365608896 -> Node 4364205600 (label: +)
