# Compiler Comparison

This notebook compares solutions using two different JIT compilers:

- [JAX JIT](https://jax.readthedocs.io/en/latest/jax.html?highlight=jit)
- [Numba](https://numba.pydata.org/)
- [PyPy](https://www.pypy.org/)

Where needed, JIT compilers get a warmup run that's not included in the timing measurements.

In [1]:
from jax.config import config

from angler import jax_main, main, numba_main
from src import jax_proliferate, load_data, numba_proliferate, proliferate

## Setup

In [2]:
# JAX uses 32 bit precision out of the box. In order to compare results, 64 bit must be enabled.
config.update("jax_enable_x64", True)

In [3]:
initial_state = load_data()

## Experiments

### Solution without JIT

In [4]:
# warmup
proliferate(initial_state, 256)

array([142796431093, 168998522736, 176972779430, 192434943096,
       219875808369, 220869399474, 267293861994, 118583747405,
       145733806214])

In [5]:
%timeit proliferate(initial_state, 256)

2.2 ms ± 21.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Solution using Numba

In [6]:
# warmup
numba_proliferate(initial_state, 256)

array([142796431093, 168998522736, 176972779430, 192434943096,
       219875808369, 220869399474, 267293861994, 118583747405,
       145733806214])

In [7]:
%timeit numba_proliferate(initial_state, 256)

24.1 µs ± 1.96 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


### Solution using JAX JIT

In [8]:
# warmup
jax_proliferate(initial_state, 256)



DeviceArray([142796431093, 168998522736, 176972779430, 192434943096,
             219875808369, 220869399474, 267293861994, 118583747405,
             145733806214], dtype=int64)

In [9]:
%timeit jax_proliferate(initial_state, 256)

4.69 µs ± 31.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


### cfbolz' Solution without JIT

In [10]:
# warmup
main()

In [11]:
%timeit main()

88.4 µs ± 2.6 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


### cfbolz' Solution using Numba

In [12]:
# warmup
numba_main()

In [13]:
%timeit numba_main()

29.9 µs ± 2.4 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


### cfbolz' Solution using JAX JIT

In [14]:
# warmup
jax_main()

In [15]:
%timeit jax_main()

1.35 µs ± 43.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


### cfbolz' Solution using PyPy

The pure Python version in [`angler.py`](./angler.py) (courtesy of [cfbolz](https://twitter.com/cfbolz)) benchmarked using PyPy yields the following results:

```shell
6.16 µs ± 123 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
```