# Reverse-mode automatic differentiation

March 13, 2024

In [28]:
import jax
import jax.numpy as np
import math

## Functions and their vJP/pullback rules (rrules)

From Nocedal and Wright (slightly modified):

> Whenever an elementary operation is performed, we can form and store a new node containing the intermediate result, pointers to this new node from parents, and the partial derivatives associated with these arcs.

This implementation idea works when all nodes store scalars, but fails when nodes may store vectors and matrices. This is because chain rule for vectors and matrices are more complicated - they are sometimes referred to as *pullback rules* / *reverse-mode rules* (rrules).

My implementation makes two changes to this:
- Instead of storing a reference to children, we use closure to save children. (Idea from Julia)
- Store pullback rules

Many pullback rules can be found here: https://www.youtube.com/playlist?list=PLISXH-iEM4Jn3SEi07q8MJmDD6BaMWlJE

In [48]:
class Tensor:

    def __init__(self, value):
        self.value = value
        self.pullback_fns = []
        self.cotangent = np.zeros(value.shape)
        self.finalized = False

    @property
    def dim(self):
        return len(self.value.shape)

    def __repr__(self):
        return repr(self.value)

    def pullback(self):
        assert not self.finalized
        for pullback_fn in self.pullback_fns:
            self.cotangent += pullback_fn()
        self.finalized = True

    def __add__(x, y):
        if x.dim == 0 and y.dim == 0:  # scalar addition
            z = Tensor(x.value + y.value)
            x.pullback_fns.append(lambda: z.cotangent)
            y.pullback_fns.append(lambda: z.cotangent)
        elif x.dim == 2 and y.dim == 1:  # matrix-vector addition (involving broadcasting)
            z = Tensor(x.value + y.value.reshape(1, -1))
            x.pullback_fns.append(lambda: z.cotangent)
            y.pullback_fns.append(lambda: z.cotangent.sum(0))
        else:  
            raise TypeError()
        return z

    def __mul__(x, y):
        if x.dim == 0 and y.dim == 0:  # scalar multiplication
            z = Tensor(x.value * y.value)
            x.pullback_fns.append(lambda: z.cotangent * y.value)
            y.pullback_fns.append(lambda: z.cotangent * x.value)
        else:
            raise TypeError()
        return z

    def __matmul__(x, y):
        if x.dim == 2 and y.dim == 2:  # matrix multiplication
            z = Tensor(x.value @ y.value)
            x.pullback_fns.append(lambda: z.cotangent @ y.value.T)
            y.pullback_fns.append(lambda: x.value.T @ z.cotangent)
        else:
            raise TypeError
        return z
    
    def __truediv__(x, y):
        if x.dim == 0 and y.dim == 0:  # scalar division
            z = Tensor(x.value / y.value)
            x.pullback_fns.append(lambda: z.cotangent * (1 / y.value))
            y.pullback_fns.append(lambda: z.cotangent * - x.value / (y.value ** 2))
        else:
            raise TypeError()
        return z


def exp(x):  # elementwise exp
    z = Tensor(np.exp(x.value))
    x.pullback_fns.append(lambda: z.cotangent * z.value)
    return z


def sin(x):  # elementwise sine
    z = Tensor(np.sin(x.value))
    x.pullback_fns.append(lambda: z.cotangent * np.cos(x.value))
    return z


def relu(x):  # elementwise relu
    z = Tensor(x.value * (x.value > 0))
    x.pullback_fns.append(lambda: z.cotangent * (x.value > 0))
    return z
    

def mse(x, target):
    if x.dim == 2 and target.dim == 2:
        diff = x.value - target.value
        z = Tensor(np.mean(diff ** 2))
        x.pullback_fns.append(lambda: z.cotangent * 2 / x.value.size * diff)
        target.pullback_fns.append(lambda: z.cotangent * - 2 / x.value.size * diff)
    return z

## Computation graph with scalars only

In [49]:
x1 = Tensor(np.array(1.))
x2 = Tensor(np.array(2.))
x3 = Tensor(np.array(np.pi/2))

x4 = x1 * x2
x5 = sin(x3)
x6 = exp(x4)
x7 = x4 * x5
x8 = x6 + x7
x9 = x8 / x3
x10 = x8 / x2

# use 0 and 1 to get the cotangent from x10
x9.cotangent = np.array(1.)
x10.cotangent = np.array(0.)
tape = [x1, x2, x3, x4, x5, x6, x7, x8, x9, x10]

for tensor in tape[::-1]: 
    tensor.pullback()

print(x1.cotangent)
print(x2.cotangent)
print(x3.cotangent)

10.681277
5.3406386
-3.8052409


In [50]:
def func(x):
    x1, x2, x3 = x
    x4 = x1 * x2
    x5 = np.sin(x3)
    x6 = np.exp(x4)
    x7 = x4 * x5
    x8 = x6 + x7
    x9 = x8 / x3
    x10 = x8 / x2
    return np.array([x9, x10])

In [51]:
func(np.array([1., 2., np.pi/2]))

Array([5.9772587, 4.694528 ], dtype=float32)

### Compare against JAX's autograd

In [52]:
jax.jacrev(func)(np.array([1., 2., np.pi/2]))

Array([[ 1.0681277e+01,  5.3406386e+00, -3.8052409e+00],
       [ 8.3890562e+00,  1.8472641e+00, -4.3711388e-08]], dtype=float32)

## Computation graph with matrices and vectors (a MLP regression model)

In [53]:
def kaiming_init(num_input_features, num_output_features, seed):
    # kaiming normal init, used to ensure similar variance in inputted and outputted data of each layer
    W = jax.random.normal(seed, shape=(num_input_features, num_output_features)) / math.sqrt(num_input_features)
    b = np.zeros(num_output_features)
    return Tensor(W), Tensor(b)

In [54]:
# dataset
x = np.linspace(-1, 1, num=100)
y = np.linspace(-1, 1, num=100)
xs, ys = np.meshgrid(x, y)
zs = (xs ** 3 - ys ** 2)
xs_f, ys_f, zs_f = xs.flatten(), ys.flatten(), zs.flatten()

In [55]:
# computation graph

tape = []  
# - left variables that don't need gradient can be left out
# - all intermediate variables must be included (unfortunately)
# - order also matters (but not unique)
# - final node also doesn't need to be in here

### left variables

X = Tensor(np.vstack([xs_f, ys_f]).T)
Y = Tensor(zs_f.reshape(-1, 1))

main_seed = jax.random.PRNGKey(42)
seeds = jax.random.split(main_seed)

W1, b1 = kaiming_init(2, 100, seeds[0])
W2, b2 = kaiming_init(100, 100, seeds[1])
W3, b3 = kaiming_init(100, 1, seeds[2])

### intermediate variables

transformed_X = X

for W, b in zip([W1, W2, W3], [b1, b2, b3]):
    
    XW = transformed_X @ W
    XW_plus_b = XW + b
    relu_XW_plus_b = relu(XW_plus_b)
    
    tape.extend([W, b, XW, XW_plus_b, relu_XW_plus_b])
    
    transformed_X = relu_XW_plus_b

mse_loss = mse(transformed_X, Y)

mse_loss

Array(0.4287767, dtype=float32)

In [56]:
mse_loss.cotangent = np.array(1.)

for tensor in tape[::-1]: 
    tensor.pullback()

In [57]:
W2_grad = W2.cotangent

### Compare against JAX's autograd

In [58]:
def mlp_func(W1, b1, W2, b2, W3, b3, X, Y):
    layer1_out = jax.nn.relu(X @ W1 + b1)
    layer2_out = jax.nn.relu(layer1_out @ W2 + b2)
    layer3_out = jax.nn.relu(layer2_out @ W3 + b3)
    mse_loss = np.mean((layer3_out - Y) ** 2)
    return mse_loss

In [59]:
mlp_func(W1.value, b1.value, W2.value, b2.value, W3.value, b3.value, X.value, Y.value)

Array(0.4287767, dtype=float32)

In [60]:
mlp_func_grad = jax.grad(mlp_func, argnums=[0, 1, 2, 3, 4, 5])
W2_grad_jax = mlp_func_grad(W1.value, b1.value, W2.value, b2.value, W3.value, b3.value, X.value, Y.value)[2]

In [61]:
np.allclose(W2_grad, W2_grad_jax)

Array(True, dtype=bool)