# Jax as accelerated NumPY

In [None]:
import jax.numpy as jnp
import jax.random as jrandom

key = jrandom.PRNGKey(42)

a = jrandom.uniform(key, shape=(30, 30))
det_a = jnp.linalg.det(a)
det_a

In [None]:
b = a @ a
jnp.linalg.det(b) - det_a * det_a

%timeit jnp.linalg.inv(a).block_until_ready()

### Chebyshev polynomials
$$T_2(x) = 2x^2 - 1$$
$$T_3(x) = 4x^3 - 3x$$
$$T_4(x) = 8x^4 - 8x^2 + 1$$



In [None]:
import jax.random as jrandom

seed = 42
key = jrandom.PRNGKey(seed)

# we need 10 random normal matrices and 1 random vector
# key is saved for later
key, *subkeys = jrandom.split(key, num=12)
xs = jrandom.uniform(subkeys[0], (3,))

rand_matrices = [jrandom.normal(subkeys[i], (20, 20)) for i in range(1, 11)]

In [None]:
import jax.numpy as jnp

xs = jnp.linspace(-1.0, 1.0, 100)


def cheb_2(x):
    return jnp.cos(2 * jnp.arccos(x))


def cheb_3(x):
    return jnp.cos(3 * jnp.arccos(x))


def cheb_4(x):
    return jnp.cos(4 * jnp.arccos(x))


ys_2 = cheb_2(xs)
ys_3 = cheb_3(xs)
ys_4 = cheb_4(xs)

from matplotlib import pyplot as plt
from matplotlib import rc

rc('font', **{'family': 'serif', 'serif': ['Times'], 'size': 12})
rc('text', usetex=True)


ys_2 = cheb_2(xs)
ys_3 = cheb_3(xs)
ys_4 = cheb_4(xs)

print(f"{xs.shape = }, {ys_2.shape = }")
plt.plot(xs, ys_2, color="red", label="${T_2}$")
plt.plot(xs, ys_3, color="blue", label="$T_3$")
plt.plot(xs, ys_4, color="green", label="$T_4$")
plt.legend(loc="upper left")
plt.title = "Chebyshev Polynomials"

plt.show()

In [None]:
import jax

cheb_2_prime = jax.grad(cheb_2)
cheb_3_prime = jax.grad(cheb_3)
cheb_4_prime = jax.grad(cheb_4)

ys_2_prime_slow = []
for x in xs:
    ys_2_prime_slow.append(cheb_2_prime(x))

ys_2_prime = jax.vmap(jax.grad(cheb_2))(xs)
ys_3_prime = jax.vmap(jax.grad(cheb_3))(xs)
ys_4_prime = jax.vmap(jax.grad(cheb_4))(xs)

plt.plot(xs, ys_2_prime, color="red", label="$T'_2$")
plt.plot(xs, ys_3_prime, color="blue", label="$T'_3$")
plt.plot(xs, ys_4_prime, color="green", label="$T'_4$")
plt.legend(loc="upper left")
plt.title = "Derivatives"
plt.show()

## Sharp Bits

In [None]:
a_jnp = jnp.arange(10)
a_np = np.arange(10)

try:
    print(a_jnp[10])
    print(a_jnp[1000])
except IndexError:
    print("Jax: Out of bounds")

try:
    print(a_np[10])
except IndexError:
    print("NumPy: Out of bounds")

In [None]:
import jax

# jax.config.update("jax_debug_nans", True)


def cheb_2(x):
    return jnp.cos(2 * jnp.arccos(x))


xs = jnp.linspace(-1.0, 1.0, 4)

cheb_2_prime = jax.grad(cheb_2)
print(jax.vmap(cheb_2_prime)(xs))