In [6]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax import device_put
import numpy as np

#### Checking if GPU is detected

In [7]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [8]:
def selu_jnp(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

def selu_np(x, alpha=1.67, lmbda=1.05):
  return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

def sum_logistic_jnp(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

def sum_logistic_np(x):
  return np.sum(1.0 / (1.0 + np.exp(-x)))

def first_finite_differences_np(f, x):
  eps = 1e-3
  return np.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in np.eye(len(x))])

def first_finite_differences_jnp(f, x):
  eps = 1e-3
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])

In [9]:
key = random.PRNGKey(0)
_x = random.normal(key, (10,))

In [10]:
size = 3000
x_jnp = random.normal(key, (size, size), dtype=jnp.float32)
x = np.random.normal(size=(size, size)).astype(np.float32)
x_device = device_put(x)

In [11]:
# runs on the CPU
time_np = %timeit -o np.dot(x, x.T);

# runs on the GPU
time_jnp = %timeit -o jnp.dot(x, x.T).block_until_ready();
time_jnp2 = %timeit -o jnp.dot(x_jnp, x_jnp.T).block_until_ready();
time_jnp3 = %timeit -o jnp.dot(x_device, x_device.T).block_until_ready();

# printing compute times
print(f'-------------------------------------------')
print(f'[ np array,  np.dot] T = {time_np.average:.3e} s')
print(f'[ np array, jnp.dot] T = {time_jnp.average:.3e} s')
print(f'[jnp array, jnp.dot] T = {time_jnp2.average:.3e} s')
print(f'[dev array, jnp.dot] T = {time_jnp3.average:.3e} s')

78.1 ms ± 153 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
32.4 ms ± 393 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
6.07 ms ± 43.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
6.04 ms ± 13.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
-------------------------------------------
[ np array,  np.dot] T = 7.812e-02 s
[ np array, jnp.dot] T = 3.239e-02 s
[jnp array, jnp.dot] T = 6.072e-03 s
[dev array, jnp.dot] T = 6.036e-03 s


## Testing for fixed array sizes. Functions+jit version

In [12]:
# defining the jit functions
selu_np_jit = jit(selu_np)
selu_jnp_jit = jit(selu_jnp)

# runs on the CPU
print(f"Running on CPU")
time_np = %timeit -o selu_np(x);
time_np_jit = %timeit -o selu_np(x);

print(f"Running on GPU")
# runs on the GPU
time_jnp = %timeit -o selu_jnp(x).block_until_ready();
time_jnp2 = %timeit -o selu_jnp(x_jnp).block_until_ready();
time_jnp_jit = %timeit -o selu_jnp_jit(x_device).block_until_ready();
time_jnp_jit2 = %timeit -o selu_jnp_jit(x_jnp).block_until_ready();

# printing compute times
print(f'-------------------------------------------')
print(f'[ np array, selu_np      ] T = {time_np.average:.3e} s; speedup = {time_np.average/time_np.average:.1f}X')
print(f'[ np array, selu_np_jit  ] T = {time_np_jit.average:.3e} s; speedup = {time_np.average/time_np_jit.average:.1f}X')
print(f'[ np array, selu_jnp     ] T = {time_jnp.average:.3e} s; speedup = {time_np.average/time_jnp.average:.1f}X')
print(f'[jnp array, selu_jnp_jit ] T = {time_jnp2.average:.3e} s; speedup = {time_np.average/time_jnp2.average:.1f}X')
print(f'[dev array, selu_jnp_jit ] T = {time_jnp_jit.average:.3e} s; speedup = {time_np.average/time_jnp_jit.average:.1f}X')
print(f'[jnp array, selu_jnp_jit ] T = {time_jnp_jit2.average:.3e} s; speedup = {time_np.average/time_jnp_jit2.average:.1f}X')

Running on CPU
78.5 ms ± 45.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
78.5 ms ± 62.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Running on GPU
25.7 ms ± 9.15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.66 ms ± 57.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
199 µs ± 196 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
200 µs ± 666 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
-------------------------------------------
[ np array, selu_np      ] T = 7.852e-02 s; speedup = 1.0X
[ np array, selu_np_jit  ] T = 7.854e-02 s; speedup = 1.0X
[ np array, selu_jnp     ] T = 2.570e-02 s; speedup = 3.1X
[jnp array, selu_jnp_jit ] T = 1.656e-03 s; speedup = 47.4X
[dev array, selu_jnp_jit ] T = 1.989e-04 s; speedup = 394.7X
[jnp array, selu_jnp_jit ] T = 1.995e-04 s; speedup = 393.5X


### Testing with variable array sizes. Functions+jit versions
jax+jit version is almost 1/10 the speed of np!!

In [16]:
size = 3000
x_jnp = random.normal(key, (size, size), dtype=jnp.float32)
x = np.random.normal(size=(size, size)).astype(np.float32)
x_device = device_put(x)

# defining the jit functions
selu_np_jit = jit(selu_np)
selu_jnp_jit = jit(selu_jnp)

# runs on the CPU
print(f"Running on CPU")
time_np = %timeit -o selu_np(x);
time_np_jit = %timeit -o selu_np(x);

print(f"Running on GPU")
# runs on the GPU
time_jnp = %timeit -o selu_jnp(np.random.randn(np.random.randint(2989, 3001), np.random.randint(2989, 3001))).block_until_ready();
time_jnp2 = %timeit -o selu_jnp(random.normal(key, (np.random.randint(2989, 3001), np.random.randint(2989, 3001)), dtype=jnp.float32)).block_until_ready();
time_jnp_jit2 = %timeit -o selu_jnp_jit(random.normal(key, (np.random.randint(2989, 3001), np.random.randint(2989, 3001)), dtype=jnp.float32)).block_until_ready();

# printing compute times
print(f'-------------------------------------------')
print(f'[ np array, selu_np      ] T = {time_np.average:.3e} s; speedup = {time_np.average/time_np.average:.1f}X')
print(f'[ np array, selu_np_jit  ] T = {time_np_jit.average:.3e} s; speedup = {time_np.average/time_np_jit.average:.1f}X')
print(f'[ np array, selu_jnp     ] T = {time_jnp.average:.3e} s; speedup = {time_np.average/time_jnp.average:.1f}X')
print(f'[jnp array, selu_jnp_jit ] T = {time_jnp2.average:.3e} s; speedup = {time_np.average/time_jnp2.average:.1f}X')
print(f'[jnp array, selu_jnp_jit ] T = {time_jnp_jit2.average:.3e} s; speedup = {time_np.average/time_jnp_jit2.average:.1f}X')

Running on CPU
78.6 ms ± 27.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
78.7 ms ± 49.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Running on GPU
765 ms ± 173 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
439 ms ± 103 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
The slowest run took 362.29 times longer than the fastest. This could mean that an intermediate result is being cached.
316 ms ± 182 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
-------------------------------------------
[ np array, selu_np      ] T = 7.864e-02 s; speedup = 1.0X
[ np array, selu_np_jit  ] T = 7.867e-02 s; speedup = 1.0X
[ np array, selu_jnp     ] T = 7.655e-01 s; speedup = 0.1X
[jnp array, selu_jnp_jit ] T = 4.390e-01 s; speedup = 0.2X
[jnp array, selu_jnp_jit ] T = 3.158e-01 s; speedup = 0.2X


In [80]:
x_np = np.arange(500.)
x_jnp = jnp.arange(500.)
dfn_jnp = grad(sum_logistic_jnp)
# no autograd for numpy function
# -dfn_np = grad(sum_logistic_np)-

dfn_jnp_jit = jit(grad(jit(sum_logistic_jnp)))

print('Comparing autograd and explicit numerical differentiation')
print(first_finite_differences_jnp(sum_logistic_jnp, x_jnp))
print(first_finite_differences_np(sum_logistic_np, x_np))
print(dfn_jnp(x_jnp))

#### Speedup isn't significant for small arrays and large number of calls

In [82]:
x_np = np.arange(5.)
x_jnp = jnp.arange(5.)

print(f"Running on CPU")
time_np = %timeit -o first_finite_differences_np(sum_logistic_np, x_np);

print(f"Running on GPU")
time_jnp = %timeit -o first_finite_differences_jnp(sum_logistic_jnp, x_jnp);
time_grad = %timeit -o dfn_jnp(x_jnp);
time_grad_jit = %timeit -o dfn_jnp_jit(x_jnp);

# printing compute times
print(f'-------------------------------------------')
print(f'[ np array,  np func     ] T = {time_np.average:.3e} s; speedup = {time_np.average/time_np.average:.1f}X')
print(f'[jnp array, jnp func     ] T = {time_jnp.average:.3e} s; speedup = {time_np.average/time_jnp.average:.1f}X')
print(f'[jnp array, jnp grad     ] T = {time_grad.average:.3e} s; speedup = {time_np.average/time_grad.average:.1f}X')
print(f'[jnp array, jnp grad jit ] T = {time_grad_jit.average:.3e} s; speedup = {time_np.average/time_grad_jit.average:.1f}X')

Running on CPU
108 µs ± 244 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Running on GPU
20 ms ± 25.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
4.73 ms ± 4.41 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
61.4 µs ± 600 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
-------------------------------------------
[ np array,  np func     ] T = 1.081e-04 s; speedup = 1.0X
[jnp array, jnp func     ] T = 1.998e-02 s; speedup = 0.0X
[jnp array, jnp grad     ] T = 4.729e-03 s; speedup = 0.0X
[jnp array, jnp grad jit ] T = 6.136e-05 s; speedup = 1.8X


In [84]:
x_np = np.arange(500.)
x_jnp = jnp.arange(500.)

print(f"Running on CPU")
time_np = %timeit -o first_finite_differences_np(sum_logistic_np, x_np);

print(f"Running on GPU")
time_jnp = %timeit -o first_finite_differences_jnp(sum_logistic_jnp, x_jnp).block_until_ready();
time_grad = %timeit -o dfn_jnp(x_jnp).block_until_ready();
time_grad_jit = %timeit -o dfn_jnp_jit(x_jnp).block_until_ready();

# printing compute times
print(f'-------------------------------------------')
print(f'[ np array,  np func     ] T = {time_np.average:.3e} s; speedup = {time_np.average/time_np.average:.1f}X')
print(f'[jnp array, jnp func     ] T = {time_jnp.average:.3e} s; speedup = {time_np.average/time_jnp.average:.1f}X')
print(f'[jnp array, jnp grad     ] T = {time_grad.average:.3e} s; speedup = {time_np.average/time_grad.average:.1f}X')
print(f'[jnp array, jnp grad jit ] T = {time_grad_jit.average:.3e} s; speedup = {time_np.average/time_grad_jit.average:.1f}X')

Running on CPU
21.9 ms ± 46.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Running on GPU
1.91 s ± 12 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
4.76 ms ± 14.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
66.9 µs ± 875 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
-------------------------------------------
[ np array,  np func     ] T = 2.190e-02 s; speedup = 1.0X
[jnp array, jnp func     ] T = 1.910e+00 s; speedup = 0.0X
[jnp array, jnp grad     ] T = 4.759e-03 s; speedup = 4.6X
[jnp array, jnp grad jit ] T = 6.686e-05 s; speedup = 327.5X
