# Programming the GPUs to get speed-ups in implementation

## 1. [JAX](https://github.com/google/jax) 

In [None]:
import jax.numpy as jnp
from jax import jit

In [None]:
from timeit import default_timer as timer # for timing the application

In [None]:
def slow_f(x):
  # Element-wise ops see a large benefit from fusion
  return x * x + x * 2.0

In [None]:
# This is wrong, because JAX doesn't block until it is ready
start = timer()
x_no_block = jnp.ones((5000, 5000))
end = timer()
print("Time needed to run the sum with gpu: ", end - start)

In [None]:
# This is the correct solution
start = timer()
x_block = jnp.ones((5000, 5000)).block_until_ready()
end = timer()
print("Time needed to run the sum with gpu: ", end - start)

## 2. [CuPy](https://cupy.dev/)

In [None]:
import cupy as cp
import numpy as np

In [None]:
problem_size = [100, 100, 100]

In [None]:
start = timer()
rand_cpu = np.random.rand(*problem_size) # describe what * operator does to a list!
end = timer()
print("Time needed to run the sum with gpu: ", end - start)

In [None]:
start = timer()
rand_gpu = cp.random.rand(*problem_size)
end = timer()
print("Time needed to run the sum with gpu: ", end - start)

Transferring data from cpu to the gpu

In [None]:
dat_cpu = np.random.rand(*problem_size)
dat_gpu = cp.asarray(dat_cpu)

Transferring data from the gpu to the cpu

In [None]:
dat_gpu = cp.random.rand(*problem_size)
dat_cpu = cp.asnumpy(dat_gpu)

Transferring data is quite similar as it is in pytorch!