In [None]:
from micrograd import Value
from graphviz import Digraph


def trace(root):
    nodes, edges = set(), set()
    def build(v):
        if v not in nodes:
            nodes.add(v)
            for child in v._prev:
                edges.add((child, v))
                build(child)
    build(root)
    return nodes, edges

def draw_dot(root, format='svg', rankdir='LR'):
    """
    format: png | svg | ...
    rankdir: TB (top to bottom graph) | LR (left to right)
    """
    assert rankdir in ['LR', 'TB']
    nodes, edges = trace(root)
    dot = Digraph(format=format, graph_attr={'rankdir': rankdir}) #, node_attr={'rankdir': 'TB'})
    
    for n in nodes:
        dot.node(name=str(id(n)), label="{ data %.4f grad %.4f }" % (n.data, n.grad), shape='record')
        if n._op:
            dot.node(name=str(id(n)) + n._op, label=n._op)
            dot.edge(str(id(n)) + n._op, str(id(n)))
    
    for n1, n2 in edges:
        dot.edge(str(id(n1)), str(id(n2)) + n2._op)
    
    return dot
 

In [None]:
a = Value(2.0)
b = Value(-3.0)
c = Value(10.0)
f = a * b   # f(a) = a*b = -6.0
g = f + c   # g(a) = (a*b)*c = -6.0 + 10.0 = 4.0
g.backward()
print(a.grad)
draw_dot(g)


In [None]:
a = Value(-4.0)
b = Value(2.0)
c = a + b                      # -2.0
d = a * b + b**3               # -8.0 + 8.0 = 0.0
c += c + 1                     # -3.0
c += 1 + c + (-a)              # -3.0 + 1 + -3.0 + 4.0 = -1.0 
d += d * 2 + (b + a).relu()    # 0.0
d += 3 * d + (b - a).relu()    # 6.0
e = c - d                      # -1.0 - 6.0 = -7.0
f = e**2                       # 49.0
g = f / 2.0                    # 24.5
g += 10.0 / f                  # 24.5 + 0.2040816327 = 24.7040816327
print(f'{g.data:.4f}')  # prints 24.7041, the outcome of this forward pass
g.backward()
draw_dot(g)
