# JAX for Efficient Array Programs in Python

**Thom Badings**  
February 12, 2026  
Part of the *Fundamental Skills Workshop* series of the Erlangen AI Hub, 

---

## About this file

This file complements the slides of the corresponding workshop series. Below, you will find all the commands used during the workshop, along with some of the same material you also find in the slides.

**Credits:** The program and code is taken from the JAX course on https://github.com/matomatical/hijax. All credits for the program and the exercises go to Matthew Farrugia-Robers.

## Elementary Cellular Automata

As our main example, we will look at the `eca0_jax.py` function, which I took from the first lecture in the JAX course above: <https://www.youtube.com/watch?v=HJeeMnLs_Z0&list=PLjl5MxRQg5xrQagVEKk9J5eWWZf6AmYSr>

In [1]:
! python eca0_numpy.py --num-steps 80

rule: 110 (0b01101110)
bits: [0 1 1 1 0 1 1 0]
rule_table[0,0,0] = 0
rule_table[0,0,1] = 1
rule_table[0,1,0] = 1
rule_table[0,1,1] = 1
rule_table[1,0,0] = 0
rule_table[1,0,1] = 1
rule_table[1,1,0] = 1
rule_table[1,1,1] = 0
simulation complete!
time taken 0.00179 seconds
[38;2;255;255;255;48;2;255;255;255m▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀[48;2;42;42;42m▀[38;2;42;42;42m▀[38;2;255;255;255;48;2;255;255;255m▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀[0m
[38;2;255;255;255;48;2;255;255;255m▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀[48;2;42;42;42m▀[38;2;42;42;42m▀[48;2;255;255;255m▀[48;2;42;42;42m▀[38;2;255;255;255;48;2;255;255;255m▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀[0m
[38;2;255;255;255;48;2;255;255;255m▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀[48;2;42;42;42m▀[38;2;42;42;42m▀[48;2;255;255;255m▀▀▀[48;2;42;42;42m▀[38;2;255;255;255;48;2;255;255;255m▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀[0m
[38;2;255;255;255;48;2;255;255;255m▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀[48;2;42;42;42m▀[38;2

We are going to convert this NumPy-based Python code into a JAX-based version. For this, we first need to load the jax.numpy module by adding `import jax.numpy as jnp` to the top of our code. This submodule of JAX has virtually all of the NumPy functions you can think of, but then ported to JAX. It has a pretty good documentation, which can be found at: <https://docs.jax.dev/en/latest/jax.numpy.html>

The API of the jax.numpy functions is virtually always the same as for standard NumPy. To port code from NumPy to JAX, we can thus often simply replace np with jnp.

This is exactly what we've changed in the `eca0_jax.py` file. Let's try to run it:

In [2]:
! python eca0_jax1.py --num-steps 80

rule: 110 (0b01101110)
bits: [0 1 1 1 0 1 1 0]
rule_table[0,0,0] = 0
rule_table[0,0,1] = 1
rule_table[0,1,0] = 1
rule_table[0,1,1] = 1
rule_table[1,0,0] = 0
rule_table[1,0,1] = 1
rule_table[1,1,0] = 1
rule_table[1,1,1] = 0
Traceback (most recent call last):
  File [35m"/Users/thobad/Documents/Teaching/Peer_learning/JAX/eca0_jax1.py"[0m, line [35m91[0m, in [35m<module>[0m
    [31mtyro.cli[0m[1;31m(main)[0m
    [31m~~~~~~~~[0m[1;31m^^^^^^[0m
  File [35m"/Users/thobad/miniconda3/envs/jax-tutorial/lib/python3.14/site-packages/tyro/_cli.py"[0m, line [35m281[0m, in [35mcli[0m
    out = run_with_args_from_cli()
  File [35m"/Users/thobad/Documents/Teaching/Peer_learning/JAX/eca0_jax1.py"[0m, line [35m32[0m, in [35mmain[0m
    states = simulate(
        rule=rule,
        width=width,
        num_steps=num_steps,
    )
  File [35m"/Users/thobad/Documents/Teaching/Peer_learning/JAX/eca0_jax1.py"[0m, line [35m73[0m, in [35msimulate[0m
    [31mstate[0m[1;31m[widt

This gives an error, but why? Well, JAX arrays are immutable, but our code tries to assign a value anyway.

Instead, every time we want to change the array, we need to do `state = state.at[width//2].set(1)`. This takes a particular position of the array, and sets it to a new value. Instead of changing the existing array, it creates a new array, with the desired value set to 1. In a way, this is a core limitation of JAX, but it is often a benificial restriction if we want to ensure our code is, for example, amenable to autodifferentiation.

In [3]:
! python eca0_jax2.py --num-steps 80

rule: 110 (0b01101110)
bits: [0 1 1 1 0 1 1 0]
rule_table[0,0,0] = 0
rule_table[0,0,1] = 1
rule_table[0,1,0] = 1
rule_table[0,1,1] = 1
rule_table[1,0,0] = 0
rule_table[1,0,1] = 1
rule_table[1,1,0] = 1
rule_table[1,1,1] = 0
simulation complete!
time taken 0.24024 seconds
[38;2;255;255;255;48;2;255;255;255m▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀[48;2;42;42;42m▀[38;2;42;42;42m▀[38;2;255;255;255;48;2;255;255;255m▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀[0m
[38;2;255;255;255;48;2;255;255;255m▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀[48;2;42;42;42m▀[38;2;42;42;42m▀[48;2;255;255;255m▀[48;2;42;42;42m▀[38;2;255;255;255;48;2;255;255;255m▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀[0m
[38;2;255;255;255;48;2;255;255;255m▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀[48;2;42;42;42m▀[38;2;42;42;42m▀[48;2;255;255;255m▀▀▀[48;2;42;42;42m▀[38;2;255;255;255;48;2;255;255;255m▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀[0m
[38;2;255;255;255;48;2;255;255;255m▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀[48;2;42;42;42m▀[38;2

This code now runs, but it is actually slower than the initial NumPy version... But why? Because we haven't started to use any of the JAX acceleration mechanisms.

Thus, the next step is to start using JAX's just-in-time (JIT) compilation, which is gonna speed up things immensely after the initial compilation.

In `eca0_jax3-jit.py`, we have taken the first steps toward this. First, we have removed all print options within the simulate function. This is because JAX can only compile "pure" function, that is, functions that do not have side effects, such as printing. Second, we have defined a JIT-compiled version of our simulate function.

In [4]:
import time
# import numpy as np
import jax
import jax.numpy as jnp
import tyro
from jaxtyping import Array, UInt8, Bool
import matthewplotlib as mp
from PIL import Image

from eca0_jax3_jit import simulate

def main(
    rule: int = 110, 
    width: int = 80,
    num_steps: int = 80,
    print_image: bool = True,
    save_image: bool = False,
):
    
    start_time = time.perf_counter()
    compiled_simulate = jax.jit(simulate)
    end_time = time.perf_counter()
    print("simulation complete!")
    print(f"time taken {end_time - start_time:.5f} seconds")

    ########
    
    start_time = time.perf_counter()
    states = compiled_simulate(
        rule=rule,
        width=width,
        num_steps=num_steps,
    )
    end_time = time.perf_counter()
    print("simulation complete!")
    print(f"time taken {end_time - start_time:.5f} seconds")

    if print_image:
        print(mp.image((1.2 - states) / 1.2))

    if save_image:
        print("rendering to 'output.png'...")
        numpy_states = np.asarray(states)
        Image.fromarray(255 - 255 * numpy_states).save('output.png')

main(num_steps=80)

simulation complete!
time taken 0.00017 seconds


TypeError: Shapes must be 1D sequences of concrete values of integer type, got (JitTracer(~int32[]),).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function simulate at /Users/thobad/Documents/Teaching/Peer_learning/JAX/eca0_jax3_jit.py:60 for jit. This concrete value was not available in Python because it depends on the value of the argument width.

This should still give you an error. This is because the shape of the arrays we're working with depends on the inputs. JAX cannot handle this. Therefore, we need to define `width` and `num_steps` to be static arguments. This means that, for every value of these arguments, JAX creates a different compiled version.

In [None]:
import time
# import numpy as np
import jax
import jax.numpy as jnp
import tyro
from jaxtyping import Array, UInt8, Bool
import matthewplotlib as mp
from PIL import Image

from eca0_jax3_jit import simulate

def main(
    rule: int = 110, 
    width: int = 80,
    num_steps: int = 80,
    print_image: bool = True,
    save_image: bool = False,
):
    
    start_time = time.perf_counter()
    compiled_simulate = jax.jit(
        simulate,
        static_argnames=['width', 'num_steps']
    )
    end_time = time.perf_counter()
    print("simulation complete!")
    print(f"time taken {end_time - start_time:.5f} seconds (define JIT compiled function)")

    ########
    
    start_time = time.perf_counter()
    states = compiled_simulate(
        rule=rule,
        width=width,
        num_steps=num_steps,
    ).block_until_ready()
    end_time = time.perf_counter()
    print("simulation complete!")
    print(f"time taken {end_time - start_time:.5f} seconds (first execution)")

    ########
    
    start_time = time.perf_counter()
    states = compiled_simulate(
        rule=rule,
        width=width,
        num_steps=num_steps,
    ).block_until_ready()
    end_time = time.perf_counter()
    print("simulation complete!")
    print(f"time taken {end_time - start_time:.5f} seconds (second execution)")

    if print_image:
        print(mp.image((1.2 - states) / 1.2))

    if save_image:
        print("rendering to 'output.png'...")
        numpy_states = np.asarray(states)
        Image.fromarray(255 - 255 * numpy_states).save('output.png')

main(num_steps=80)

Let's break this down:
- The first time we call the function, it is not actually compiled. This is because JAX uses just-in-time compilation. Thus, the first timing is very fast.
- The actual compiling is only done the first time the function is *called*. Thus, the second time we call the function is slow.
- But look at the second time we called the function! This is already a pretty good speed-up!

Here, `.block_until_ready()` was added to perform honest time comparisons.

Next, lets try a larger example, with a width and depth of 8000:

In [None]:
! python eca0_numpy.py --num-steps 8000 --width 8000 --no-print-image --save-image

In [None]:
main(num_steps=8000, width=8000, print_image=False)

Why is this JAX version so slow? It's because JAX traces your code with abstract arrays through all of your code. This is bad, because our code consists of a for loop, within the simulate function (which is repeated 8000 times).

Luckily, there is a solution to this problem: The JAX function `scan`. With scan, JAX knows it doesn't have to work through the for-loop many times. This yields a much slower "computational graph", meaning the compilation becomes simpler and faster.

In `eca0_jax5_scan.py`, you find an updated version of our code, where we now replaced the for loop with a scan.

In [None]:
! python eca0_jax5_scan.py --num-steps 8000 --width 8000 --no-print-image --save-image

Look at that! The compilation time got down significantly. So now, we have a version of our code that doesn't take long to compile, and is faster than the pure NumPy version.

Let us demonstrate one more powerful feature of JAX, namely automatic vectorization. In `eca_0_jax6_vmap.py`, you find a version of the code, where we have now vectorized the function over the `rule`. This allows us to compute the figures for *all* possible rules at the same time!

Run the following commands and check the differences in compilation time:

In [None]:
! python eca0_numpy.py --num-steps 80 --width 80 --no-print-image --save-image
! python eca0_jax5_scan.py --num-steps 80 --width 80 --no-print-image --save-image
! python eca0_jax6_vmap.py --num-steps 80 --width 80 --no-print-image --save-image

- The first command runs our original NumPy version of the code.
- The second command runs the previous scan-version of the code and generates a single output. 
- The third command instead runs the same but *for all* 256 possible rules!

Compare the runtimes: JAX should be *much* faster than the NumPy version!