## Generating Random Numbers

In [119]:
import jax.numpy as jnp
from jax import random, jit, grad, jacfwd, jacrev

In [43]:
key = random.PRNGKey(0)
print(f"Key = {key}")

x = random.normal(key, (10,1))
print(x)

print("Same Random Numbers with the same key")
x = random.normal(key, (10,1))
print(x)


print(f"\nSplitting Key:\n")
subkey1, subkey2 = random.split(key,2)
print(f"Key1 = {subkey1}")
print(f"Key2 = {subkey2}")

x1 = random.normal(subkey1, (5,))
x2 = random.normal(subkey2, (5,))

print(x1)
print(x2)

Key = [0 0]
[[-0.372111  ]
 [ 0.2642311 ]
 [-0.18252774]
 [-0.7368198 ]
 [-0.44030386]
 [-0.15214427]
 [-0.6713536 ]
 [-0.59086424]
 [ 0.73168874]
 [ 0.56730247]]
Same Random Numbers with the same key
[[-0.372111  ]
 [ 0.2642311 ]
 [-0.18252774]
 [-0.7368198 ]
 [-0.44030386]
 [-0.15214427]
 [-0.6713536 ]
 [-0.59086424]
 [ 0.73168874]
 [ 0.56730247]]

Splitting Key:

Key1 = [4146024105  967050713]
Key2 = [2718843009 1272950319]
[0.59902614 0.21721433 2.4202888  0.03266731 1.2164947 ]
[-1.4581941 -2.0470448  2.0473387  1.1684093 -0.9758365]


## ReLU

In [65]:
def relu(x):
    return jnp.where(x>0,x,0)

x = random.normal(key,(10,1))
print(f"Input : {x}")

Input : [[-0.372111  ]
 [ 0.2642311 ]
 [-0.18252774]
 [-0.7368198 ]
 [-0.44030386]
 [-0.15214427]
 [-0.6713536 ]
 [-0.59086424]
 [ 0.73168874]
 [ 0.56730247]]


In [66]:
%timeit relu(x).block_until_ready()

254 µs ± 7.55 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


### With JIT

In [67]:
jit_relu = jit(relu)
%timeit jit_relu(x).block_until_ready()

80.8 µs ± 395 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


## Gradient

In [111]:
def f(x):
    return jnp.sum(jnp.log1p(x)/(2*x))

differentiate = grad(f)

print(f"f(x) = {f(2)}")
print(f"Gradient = {differentiate(2.0)}")

f(x) = 0.2746530771255493
Gradient = -0.05399320274591446


In [112]:
x = jnp.arange(1.0,11.)
print(x)


[ 1.  2.  3.  4.  5.  6.  7.  8.  9. 10.]


In [113]:
%timeit f(x)

350 µs ± 10 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [114]:
%timeit differentiate(x)

2.82 ms ± 44.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### With JIT

In [115]:
jit_f = jit(f)
jit_differentiate = jit(differentiate)

In [116]:
%timeit jit_f(x)

78.4 µs ± 1.3 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [117]:
%timeit jit_differentiate(x)

78.8 µs ± 822 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


## Hessian

In [130]:
def hessian(f):
    return jacrev(jacrev(f))

hessian_f = jit(hessian(f))

%timeit hessian_f(x)

78.2 µs ± 1.65 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
