In [240]:
import jax
import jax.numpy as np
import math

# Computation graph

(Example borrowed from Section 8.2 of automatic differentiation by Nocedal & Wright)

Suppose we want to take the derivative of $x_9$ wrt the three input variables $x_1, x_2, x_3$.

![image.png](attachment:6b691209-73e3-4983-a8c6-144b1e2d404f.png)

![image.png](attachment:4b5af021-edda-4941-8f1a-461de26de18f.png)

![image.png](attachment:bf357460-c641-4e58-840b-a6b2f8155274.png)

## Core of forward mode autograd: dual numbers

Each dual number stores two things:

- Its value $x_i$ in the computation graph
- The directional derivative (along some $p$) of $x_i$ wrt the three input variables: $(\nabla x_i)^T p $

In [228]:
class DualNumber:

    """Define a new class of numbers to override default python +-*/ operations"""

    def __init__(self, value, dd_wrt_input):
        self.value = value
        self.dd_wrt_input = dd_wrt_input

    def __repr__(self):
        return f"Value: {self.value} | DD wrt input: {self.dd_wrt_input}"

    def __add__(self, other):
        return DualNumber(
            self.value + other.value, 
            self.dd_wrt_input + other.dd_wrt_input  # sum rule
        )

    def __mul__(self, other):
        return DualNumber(
            self.value * other.value, 
            self.dd_wrt_input * other.value + self.value * other.dd_wrt_input  # product rule
        )

    def __truediv__(self, other):
        return DualNumber(
            self.value / other.value, 
            (self.dd_wrt_input * other.value - self.value * other.dd_wrt_input) / (other.value ** 2)  # quotient rule
        )

In [229]:
def exp(dual_number):
    return DualNumber(
        np.exp(dual_number.value),
        np.exp(dual_number.value) * dual_number.dd_wrt_input
    )

def sin(dual_number):
    return DualNumber(
        np.sin(dual_number.value),
        np.cos(dual_number.value) * dual_number.dd_wrt_input
    )

## Comparing results to ground truth

In [230]:
# I'm using three seed vectors (e_1, e_2, e_3) at once
x1 = DualNumber(1, np.array([1, 0, 0]))
x2 = DualNumber(2, np.array([0, 1, 0]))
x3 = DualNumber(np.pi/2, np.array([0, 0, 1]))

x4 = x1 * x2
x5 = sin(x3)
x6 = exp(x4)
x7 = x4 * x5
x8 = x6 + x7
x9 = x8 / x3

# answers computed by my autograd system above
x9

Value: 5.977258682250977 | DD wrt input: [10.681278   5.340639  -3.8052413]