In [144]:
class Scalar:
    def __init__(self, data, label=None, op="", children=None):
        self.data = data
        self.grad = 0.0
        self.op = op
        self.children = children if children is not None else []
        self.label = label
    
    def __repr__(self):
        return f"Scalar({self.data}, gradient {self.grad})"

    def __add__(self, other):
        if isinstance(other, Scalar):
            return Scalar(self.data + other.data, op="+", children=[self, other])
        else:
            return Scalar(self.data + other, op="+", children=[self])
        
    def __radd__(self, other):
        return self + other
    
    def __mul__(self, other):
        if isinstance(other, Scalar):
            return Scalar(self.data * other.data, op="*", children=[self, other])
        else:
            return Scalar(self.data * other, op="*", children=[self])
        
    def __rmul__(self, other):
        return self * other
    
    def __sub__(self, other):
        if isinstance(other, Scalar):
            return Scalar(self.data - other.data, op="-", children=[self, other])
        else:
            return Scalar(self.data - other, op="-", children=[self])
        
    def __rsub__(self, other):
        return Scalar(other) - self
    
    def __truediv__(self, other):
        if isinstance(other, Scalar):
            return Scalar(self.data / other.data, op="/", children=[self, other])
        else:
            return Scalar(self.data / other, op="/", children=[self])
        
    def __rtruediv__(self, other):
        return Scalar(other) / self
    
    def __pow__(self, other):
        if isinstance(other, Scalar):
            return Scalar(self.data ** other.data, op="**", children=[self, other])
        else:
            return Scalar(self.data ** other, op="**", children=[self])
        
    def __rpow__(self, other):
        return Scalar(other) ** self
    
    def backward(self):
        # Topological ordering of all children in the graph
        topo_order = []
        visited = set()
        
        def build_topo_order(node):
            if node not in visited:
                visited.add(node)
                for child in node.children:
                    build_topo_order(child)
                topo_order.append(node)
        
        build_topo_order(self)

        # Initialize the gradient of the final output scalar
        self.grad = 1.0
        
        # Backpropagate through the graph
        for node in reversed(topo_order):
            # Depending on the operation, calculate gradients
            # sum
            if node.op == "+":
                for child in node.children:
                    child.grad += 1.0 * node.grad  # gradient flows directly through addition
            # substraction
            elif node.op == "-":
                node.children[0].grad += 1.0 * node.grad  # For left operand
                node.children[1].grad += -1.0 * node.grad  # For right operand
            # multiplication
            elif node.op == "*":
                node.children[0].grad += node.children[1].data * node.grad  # grad of first operand
                node.children[1].grad += node.children[0].data * node.grad  # grad of second operand
            # division
            elif node.op == "/":
                node.children[0].grad += (1 / node.children[1].data) * node.grad  # grad of numerator
                node.children[1].grad += -(node.children[0].data / node.children[1].data ** 2) * node.grad  # grad of denominator
            # power
            elif node.op == "**":
                base, exponent = node.children
                base.grad += (exponent.data * (base.data ** (exponent.data - 1))) * node.grad  # grad w.r.t base
                exponent.grad += (node.data * math.log(base.data)) * node.grad  # grad w.r.t exponent

In [145]:
# create a graph with graphviz to visualize the computation graph
from graphviz import Digraph

In [146]:
def draw_graph(node):
    def add_nodes_edges(node, dot=None):
        if dot is None:
            dot = Digraph()
            dot.attr(rankdir='LR')  # Set the direction of the graph to left-to-right
            dot.node(str(id(node)), label=f"{{ {node.label} | data {node.data:.4f} | grad {node.grad:.4f} }}", shape='record')
        
        if node.op:
            op_node_id = str(id(node)) + node.op
            dot.node(op_node_id, label=node.op, shape='diamond')
            dot.edge(op_node_id, str(id(node)))
        
        for child in node.children:
            dot.node(str(id(child)), label=f"{{ {child.label} | data {child.data:.4f} | grad {child.grad:.4f} }}", shape='record')
            if node.op:
                dot.edge(str(id(child)), op_node_id)
            else:
                dot.edge(str(id(child)), str(id(node)))
            dot = add_nodes_edges(child, dot=dot)
        
        return dot

    dot = add_nodes_edges(node)
    return dot

def restart_graph(node):
    node.grad = 0.0
    for child in node.children:
        restart_graph(child)

In [166]:
a = Scalar(2, label='a')
b = Scalar(3, label='b')
c = a + b; c.label='c'
d = Scalar(2, label='d')
e = c * d; e.label='e'
f = Scalar(4, label='f')
L = f + e; L.label='L'

In [167]:
L

Scalar(14, gradient 0.0)

In [164]:
# restart the gradients
def restart_graph(node):
    node.grad = 0.0
    for child in node.children:
        restart_graph(child)

restart_graph(L)

In [168]:
def print_graph(node, indent=0):
    print("  " * indent + f"{node.label} | data {node.data:.4f} | grad {node.grad:.4f} | op {node.op}")
    for child in node.children:
        print_graph(child, indent=indent+1)

print_graph(L)

L | data 14.0000 | grad 0.0000 | op +
  f | data 4.0000 | grad 0.0000 | op 
  e | data 10.0000 | grad 0.0000 | op *
    c | data 5.0000 | grad 0.0000 | op +
      a | data 2.0000 | grad 0.0000 | op 
      b | data 3.0000 | grad 0.0000 | op 
    d | data 2.0000 | grad 0.0000 | op 


In [150]:
L.backward()