The book also introduced another method to AD (auto differentiation), which is forward mode accumulation. This avoid the need to store the values and gradients for the whole graph. In this module, we will visualize the whole process, which help you to better understand the difference between them.

In [2]:
import math

# Function: f(x, y) = x * y + sin(x)

def forward_mode_autodiff(x_val, y_val):
    print("--- Forward Mode Automatic Differentiation ---")

    # Derivative w.r.t x
    x = {'val': x_val, 'dot': 1.0}  # seed derivative for x
    y = {'val': y_val, 'dot': 0.0}  # derivative w.r.t x is 0

    print(f"x: val={x['val']}, dot={x['dot']}")
    print(f"y: val={y['val']}, dot={y['dot']}")

    u = {
        'val': x['val'] * y['val'],
        'dot': x['dot'] * y['val'] + x['val'] * y['dot']
    }
    print(f"u = x * y: val={u['val']}, dot={u['dot']}")

    v = {
        'val': math.sin(x['val']),
        'dot': math.cos(x['val']) * x['dot']
    }
    print(f"v = sin(x): val={v['val']}, dot={v['dot']}")

    f = {
        'val': u['val'] + v['val'],
        'dot': u['dot'] + v['dot']
    }
    print(f"f = u + v: val={f['val']}, df/dx={f['dot']}")

    # Derivative w.r.t y
    x = {'val': x_val, 'dot': 0.0}
    y = {'val': y_val, 'dot': 1.0}

    u = {
        'val': x['val'] * y['val'],
        'dot': x['dot'] * y['val'] + x['val'] * y['dot']
    }
    v = {
        'val': math.sin(x['val']),
        'dot': math.cos(x['val']) * x['dot']
    }
    f = {
        'val': u['val'] + v['val'],
        'dot': u['dot'] + v['dot']
    }
    print(f"df/dy={f['dot']}")


def reverse_mode_autodiff(x_val, y_val):
    print("\n--- Reverse Mode Automatic Differentiation ---")

    # Forward pass
    x = x_val
    y = y_val
    u = x * y
    v = math.sin(x)
    f = u + v
    print(f"Forward pass: x={x}, y={y}, u={u}, v={v}, f={f}")

    # Backward pass
    f_bar = 1.0  # seed derivative
    u_bar = f_bar * 1.0
    v_bar = f_bar * 1.0

    x_bar_from_u = u_bar * y
    y_bar = u_bar * x
    x_bar_from_v = v_bar * math.cos(x)

    x_bar = x_bar_from_u + x_bar_from_v

    print(f"Backward pass: f_bar={f_bar}, u_bar={u_bar}, v_bar={v_bar}")
    print(f"x_bar from u={x_bar_from_u}, x_bar from v={x_bar_from_v}, total x_bar={x_bar}")
    print(f"y_bar={y_bar}")


def explain_differences():
    print("\n--- Differences Between Forward and Reverse Mode ---")
    print("1. Computation Order:")
    print("   - Forward mode: propagate derivatives from inputs → output.")
    print("   - Reverse mode: propagate derivatives from output → inputs.")
    print("2. Memory Usage:")
    print("   - Forward mode: low memory, no need to store all intermediates.")
    print("   - Reverse mode: higher memory, must store intermediates for backward pass.")
    print("3. Efficiency:")
    print("   - Forward mode: efficient when few inputs, many outputs.")
    print("   - Reverse mode: efficient when many inputs, few outputs (e.g., scalar loss).")


# Example run
x_val = 2.0
y_val = 3.0

forward_mode_autodiff(x_val, y_val)
reverse_mode_autodiff(x_val, y_val)
explain_differences()


--- Forward Mode Automatic Differentiation ---
x: val=2.0, dot=1.0
y: val=3.0, dot=0.0
u = x * y: val=6.0, dot=3.0
v = sin(x): val=0.9092974268256817, dot=-0.4161468365471424
f = u + v: val=6.909297426825682, df/dx=2.5838531634528574
df/dy=2.0

--- Reverse Mode Automatic Differentiation ---
Forward pass: x=2.0, y=3.0, u=6.0, v=0.9092974268256817, f=6.909297426825682
Backward pass: f_bar=1.0, u_bar=1.0, v_bar=1.0
x_bar from u=3.0, x_bar from v=-0.4161468365471424, total x_bar=2.5838531634528574
y_bar=2.0

--- Differences Between Forward and Reverse Mode ---
1. Computation Order:
   - Forward mode: propagate derivatives from inputs → output.
   - Reverse mode: propagate derivatives from output → inputs.
2. Memory Usage:
   - Forward mode: low memory, no need to store all intermediates.
   - Reverse mode: higher memory, must store intermediates for backward pass.
3. Efficiency:
   - Forward mode: efficient when few inputs, many outputs.
   - Reverse mode: efficient when many inputs, few o