# Programming the GPUs to get speed-ups in implementation

## 1. [JAX](https://github.com/google/jax) 

In [4]:
import jax.numpy as jnp
from jax import jit

In [5]:
from timeit import default_timer as timer # for timing the application

In [6]:
def slow_f(x):
  # Element-wise ops see a large benefit from fusion
  return x * x + x * 2.0

In [10]:
# This is wrong, because JAX doesn't block until it is ready
start = timer()
x_no_block = jnp.ones((5000, 5000))
end = timer()
print("Time needed to run the sum with gpu: ", end - start)

Time needed to run the sum with gpu:  0.0011037000003852881


In [11]:
# This is the correct solution
start = timer()
x_block = jnp.ones((5000, 5000)).block_until_ready()
end = timer()
print("Time needed to run the sum with gpu: ", end - start)

Time needed to run the sum with gpu:  0.0021476679999068438
