# Compiler Comparison

This notebook compares solutions using two different JIT compilers:

- [JAX JIT](https://jax.readthedocs.io/en/latest/jax.html?highlight=jit)
- [Numba](https://numba.pydata.org/)

In [1]:
from jax.config import config

from src import jax_proliferate, load_data, numba_proliferate, proliferate

## Setup

In [2]:
# JAX uses 32 bit precision out of the box. In order to compare results, 64 bit must be enabled.
config.update("jax_enable_x64", True)

In [3]:
initial_state = load_data()

## Experiments

### Solution without JIT

In [4]:
%timeit proliferate(initial_state, 256)

2.52 ms ± 121 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Solution using Numba

In [5]:
%timeit numba_proliferate(initial_state, 256)

32.2 µs ± 5.65 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Solution using JAX JIT

In [6]:
%timeit jax_proliferate(initial_state, 256)



The slowest run took 16.79 times longer than the fastest. This could mean that an intermediate result is being cached.
297 µs ± 184 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
