<a href="https://colab.research.google.com/github/cu-applied-math/appm-4600-numerics/blob/main/Demos/Ch1_AutoDiff.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Automatic Differentiation demo
Using Jax and PyTorch

APPM 4600

Copyright Dept of Applied Math, University of Colorado Boulder. Released under a BSD 3-clause license

Learning objectives:
1. See how to use AutoDiff using two popular frameworks (jax and PyTorch)
2. See that reverse mode is usually faster (than forward mode) for functions $f:\mathbb{R}^n \to \mathbb{R}$
3. Compare to symbolic differentiation

Further reading
- another [AutoDiff](https://github.com/cu-applied-math/SciML-Class/blob/main/Demos/AutomaticDifferentiation.ipynb) demo from CU
- [JAX](https://docs.jax.dev/en/latest/index.html)
- [PyTorch](https://pytorch.org/)

## using jax

In [1]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import jacfwd, jacrev
from jax import nn
import numpy as np

We'll make a simple function. Note that the "@" sign is matrix multiplication (for either jax or numpy), i.e., [jax.numpy.matmul](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.matmul.html#jax.numpy.matmul)

In [2]:
n = int(1e1)
m = int(n/2)

# We want some arbitrary matrix -- e.g., we could do this randomly
# jax has some utilities for this, but if you don't want to learn them,
# just convert from numpy
A = np.random.randn(m,n)
x = np.random.randn(n,1)
A = jnp.array(A)
x = jnp.array(x)

def f(x):
    return jnp.sum( A @ x )

We can ask jax for:
- the gradient (of a function $f: \mathbb{R}^n \to \mathbb{R}$)
- the Jacobian (of a function $f: \mathbb{R}^n \to \mathbb{R}^m$)
  - if $m=1$ this *is* the gradient! (though sometimes there is a transpose difference...)

Gradients are always computed via **reverse mode**, but for Jacobians, you can choose either **reverse** or **forward** mode. In general, if $n > m$ you want **reverse** mode. See Jax's ["Autodiff Cookbook"](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)

In [3]:
g  = grad(f)
J1 = jacfwd(f)
J2 = jacrev(f)

g(x), J1(x), J2(x)

(Array([[ 1.9309464 ],
        [ 1.7242033 ],
        [-0.42367828],
        [-0.02427091],
        [-2.9808059 ],
        [ 1.0146751 ],
        [-0.27772245],
        [-3.8455067 ],
        [-1.8126769 ],
        [-2.2204695 ]], dtype=float32),
 Array([[ 1.9309464 ],
        [ 1.7242033 ],
        [-0.42367828],
        [-0.02427091],
        [-2.9808059 ],
        [ 1.0146751 ],
        [-0.27772245],
        [-3.8455067 ],
        [-1.8126769 ],
        [-2.2204695 ]], dtype=float32),
 Array([[ 1.9309464 ],
        [ 1.7242033 ],
        [-0.42367828],
        [-0.02427091],
        [-2.9808059 ],
        [ 1.0146751 ],
        [-0.27772245],
        [-3.8455067 ],
        [-1.8126769 ],
        [-2.2204695 ]], dtype=float32))

Let's be slightly more interesting

In [4]:
n = int(5e3)
m = n
k = n

key = jax.random.key(seed=0)
A = jax.random.normal(key, (m,n))
B = jax.random.normal(key, (k,m))
x = jax.random.normal(key, (n,1))
# A = jnp.array(np.random.randn(m,n)) # another (slower) way to do it
# B = jnp.array(np.random.randn(k,m))
# x = jnp.array(np.random.randn(n,1))

def f(x):
    return jnp.sum( nn.sigmoid(B @ nn.sigmoid(A @ x ) ) )

# The first time we call the function, it is doing some overhead
%time y = f(x)

CPU times: user 68.2 ms, sys: 1.23 ms, total: 69.5 ms
Wall time: 61.5 ms


In [5]:
%time y = f(x)

CPU times: user 840 μs, sys: 280 μs, total: 1.12 ms
Wall time: 476 μs


In [6]:
g  = grad(f)
J1 = jacfwd(f)
J2 = jacrev(f)

In [7]:
%%time
y = g(x)  # reverse-mode

CPU times: user 122 ms, sys: 3.22 ms, total: 126 ms
Wall time: 113 ms


In [8]:
%%time
y = J1(x) # forward-mode

CPU times: user 1.18 s, sys: 121 ms, total: 1.3 s
Wall time: 265 ms


In [9]:
%%time
y = J2(x) # reverse-mode

CPU times: user 877 ms, sys: 7 ms, total: 884 ms
Wall time: 96.4 ms


We see that reverse-mode (`J2` and `g`) are faster than forward mode (`J1`). Now, naively you'd expect them to be **way** faster, but I think jax is being somewhat clever about how it does the forward mode

## Let's repeat the same thing in PyTorch
PyTorch is another popular autodiff framework

In [10]:
import torch
import matplotlib.pyplot as plt
import sys
import numpy as np
from torch.nn.functional import sigmoid
print("Torch version is", torch.__version__)
print("Numpy version is", np.__version__)
print("Python version is", sys.version)

torch.manual_seed(100)
# dtype = torch.float32 # the default
dtype = torch.float64

n = int(8e3)
m = n
k = n

A = torch.randn((m,n),dtype=dtype)
B = torch.randn((k,m),dtype=dtype)
x = torch.randn((n,1), dtype=dtype, requires_grad=True)

def f(x):
    return torch.sum( sigmoid(B @ sigmoid(A @ x ) ) )

y = f(x)

Torch version is 2.5.1
Numpy version is 2.3.1
Python version is 3.13.5 | packaged by Anaconda, Inc. | (main, Jun 12 2025, 11:09:21) [Clang 14.0.6 ]


In [11]:
%%time
if x.grad is not None:
    x.grad.data.zero_()
out = f(x)
out.backward()
y = x.grad

CPU times: user 1.1 s, sys: 15.3 ms, total: 1.11 s
Wall time: 118 ms


## Speelpenning Function
Taken from the longer [SciML AutoDiff](https://github.com/cu-applied-math/SciML-Class/blob/main/Demos/AutomaticDifferentiation.ipynb) example

In [23]:
import sympy
from sympy.abc import x
roots = np.linspace(0,1,10)
def g(x):
    y = 1
    for i in range(len(roots)):
        y = y * (x - roots[i])
    return y
g(x)

x*(x - 1.0)*(x - 0.888888888888889)*(x - 0.777777777777778)*(x - 0.666666666666667)*(x - 0.555555555555556)*(x - 0.444444444444444)*(x - 0.333333333333333)*(x - 0.222222222222222)*(x - 0.111111111111111)

In [24]:
gprime = sympy.diff(g(x),x)
gprime

x*(x - 1.0)*(x - 0.888888888888889)*(x - 0.777777777777778)*(x - 0.666666666666667)*(x - 0.555555555555556)*(x - 0.444444444444444)*(x - 0.333333333333333)*(x - 0.222222222222222) + x*(x - 1.0)*(x - 0.888888888888889)*(x - 0.777777777777778)*(x - 0.666666666666667)*(x - 0.555555555555556)*(x - 0.444444444444444)*(x - 0.333333333333333)*(x - 0.111111111111111) + x*(x - 1.0)*(x - 0.888888888888889)*(x - 0.777777777777778)*(x - 0.666666666666667)*(x - 0.555555555555556)*(x - 0.444444444444444)*(x - 0.222222222222222)*(x - 0.111111111111111) + x*(x - 1.0)*(x - 0.888888888888889)*(x - 0.777777777777778)*(x - 0.666666666666667)*(x - 0.555555555555556)*(x - 0.333333333333333)*(x - 0.222222222222222)*(x - 0.111111111111111) + x*(x - 1.0)*(x - 0.888888888888889)*(x - 0.777777777777778)*(x - 0.666666666666667)*(x - 0.444444444444444)*(x - 0.333333333333333)*(x - 0.222222222222222)*(x - 0.111111111111111) + x*(x - 1.0)*(x - 0.888888888888889)*(x - 0.777777777777778)*(x - 0.555555555555556)*(x - 0

In [25]:
gprime.evalf(16,subs={x:.88889})

-0.0001040765432583041

That symbolic derivative is **correct**, but it's not an efficient implementation. We can get an efficient implementation if we play around a bit, but it's not automatic.

For example, we can tell sympy to expand $g(x)$ out, and *then* differentiate:

In [26]:
sympy.expand(g(x))

x**10 - 5.0*x**9 + 10.7407407407407*x**8 - 12.962962962963*x**7 + 9.64380429812528*x**6 - 4.56104252400549*x**5 + 1.36173159391165*x**4 - 0.245182437937607*x**3 + 0.0238479488367999*x**2 - 0.000936656708416885*x

In [27]:
gprime2 = sympy.diff( sympy.expand(g(x)), x )
gprime2
gprime2.evalf(16,subs={x:.88889})

-0.0001040765432661566

## Showing that AutoDiff depends on the implementation

We'll define the function $f(x)=0$ but in a slow way

In [17]:
import torch
d   = int(4e3)

torch.manual_seed(100)
A   = torch.randn( (d,d) )

def f(x, N = 100):
    """ Implements the zero function: f(x) = 0 """
    for k in range(N):
        x = A @ x

    return torch.sum(x - x)

x   = torch.randn( (d,1), requires_grad=True )

In [18]:
%%time
with torch.no_grad():
    y = f(x)

CPU times: user 4.48 s, sys: 9.02 ms, total: 4.49 s
Wall time: 489 ms


The gradient is the all zeros vector, but as you can see from the time it takes to execute the code, it's not being that clever...

In [19]:
%%time
y = f(x)
y.backward()

CPU times: user 5.28 s, sys: 15.3 ms, total: 5.3 s
Wall time: 550 ms
