Task 1: Estimating $\pi$



In [None]:
# Estimating PI
from random import random

import jax
import jax.numpy as jnp


def pi_python():
  npoints = 100_000
  xy = [(random(), random()) for _ in range(npoints)]

  inside = 0
  for x, y in xy:
    if x**2 + y**2 < 1:
      inside += 1.0

  return inside * 4.0 / npoints


print(f"pi ~= {pi_python()}")

print("\nPure Python:")
%timeit -n 10 pi_python()

pi ~= 3.13908

Pure Python:
213 ms ± 8.94 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
# Solution:
def pi_jax():
  rng = jax.random.key(42)
  npoints = 100_000

  xy = jax.random.uniform(rng, shape=(2, npoints))
  isinside = (xy[0, ...] ** 2 + xy[1, ...] ** 2) < 1

  return jnp.sum(isinside) * 4 / npoints


# JIT compile
fast_pi_jax = jax.jit(pi_jax)

print(f"pi ~= {pi_jax()}")

print("\nPure JAX:")
%timeit -n 10 pi_jax()

print("\nJAX JIT (with compile-time):")
%timeit -n 1 fast_pi_jax().block_until_ready()

print("\nJAX JIT (without compile-time):")
%timeit -n 10 fast_pi_jax().block_until_ready()

pi ~= 3.1400399208068848

Pure JAX:
5.14 ms ± 831 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

JAX JIT (with compile-time):
The slowest run took 184.76 times longer than the fastest. This could mean that an intermediate result is being cached.
48.4 ms ± 113 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

JAX JIT (without compile-time):
1.89 ms ± 147 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
