# Fun with JAX

*Prepared for the Computational Economics Workshop at Hitotsubashi*

#### [John Stachurski](https://johnstachurski.net/)
September 2025

This is a super quick illustration of the power of [JAX](https://github.com/google/jax), a Python library built by Google Research.

It should be run on a machine with a GPU --- for example, try Google Colab with the runtime environment set to include a GPU.

The aim is just to give a small taste of high performance computing in Python -- details will be covered later in the course.

We start with some imports

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

Let's check our hardware:

In [None]:
!nvidia-smi

In [None]:
!lscpu -e

## Transforming Data

A very common numerical task is to apply a transformation to a set of data points.

Our transformation will be the cosine function.

Here we evaluate the cosine function at 50 points.

In [None]:
x = np.linspace(0, 10, 50)
y = np.cos(x)

Let's plot.

In [None]:
fig, ax = plt.subplots()
ax.scatter(x, y)
plt.show()

Our aim is to evaluate the cosine function at many points.

In [None]:
n = 50_000_000
x = np.linspace(0, 10, n)

### With NumPy

In [None]:
%time np.cos(x)

In [None]:
%time np.cos(x)

In [None]:
x = None  

### With JAX

In [None]:
x_jax = jnp.linspace(0, 10, n)

In [None]:
jnp.cos(x_jax)

Let's time it.

(The `block_until_ready()` method is only needed for timing.)

In [None]:
%time jnp.cos(x_jax).block_until_ready()

In [None]:
%time jnp.cos(x_jax).block_until_ready()

Here we change the input size --- can you explain why the timing changes?

In [None]:
x_jax = jnp.linspace(0, 10, n + 1)

In [None]:
%time jnp.cos(x_jax).block_until_ready()

In [None]:
%time jnp.cos(x_jax).block_until_ready()

In [None]:
x_jax = None  # Free memory

## Evaluating a more complicated function

In [None]:
def f(x):
    y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - 0.1 * x**2
    return y

In [None]:
fig, ax = plt.subplots()
x = np.linspace(0, 10, 100)
ax.plot(x, f(x))
ax.scatter(x, f(x))
plt.show()

Now let's try with a large array.

### With NumPy

In [None]:
n = 50_000_000
x = np.linspace(0, 10, n)

In [None]:
%time f(x)

In [None]:
%time f(x)

### With JAX

In [None]:
def f(x):
    y = jnp.cos(2 * x**2) + jnp.sqrt(jnp.abs(x)) + 2 * jnp.sin(x**4) - x**2
    return y

In [None]:
x_jax = jnp.linspace(0, 10, n)

In [None]:
%time f(x_jax).block_until_ready()

In [None]:
%time f(x_jax).block_until_ready()

### Compiling the Whole Function

In [None]:
f_jax = jax.jit(f)

In [None]:
%time f_jax(x_jax).block_until_ready()

In [None]:
%time f_jax(x_jax).block_until_ready()