In [None]:
import numpy as np

import jax
import jax.numpy as jnp
import jax.random as jrandom

## VMAP


`gaussian` is the density of a normal distribution. We take its derivative, so implicit vectorization does not apply. Apply `jax.vmap` correctly, so that the code below works.

In [None]:
def gaussian(x, params):
    mu, sigma = params
    a = ((x - mu) / sigma) ** 2
    b = sigma * jnp.sqrt(2 * jnp.pi)
    return jnp.exp(-a / 2) / b


dgaussian = jax.grad(gaussian)

params = jnp.array([0.2, 2.0])

# generate a random vector and matrix
key = jrandom.PRNGKey(42)
key, subkey_1, subkey_2 = jrandom.split(key, 3)
xs_vec = jrandom.uniform(subkey_1, (20,))
xs_matr = jrandom.uniform(subkey_2, (30, 30))

# TODO: define dgaussian_vec and dgaussian_matr to accept vectors and matrices
dgaussian_vec = dgaussian
dgaussian_matr = dgaussian

ys_vec = dgaussian_vec(xs_vec, params)
ys_matr = dgaussian_matr(xs_matr, params)

print(ys_vec.shape, ys_matr.shape)

In [None]:
import jax
import jax.numpy as jnp

m_1 = jnp.array([[1.0, 0.0], [0.0, 100.0]])
m_2 = jnp.array([[1.0, 1.0], [-1.0, 1.0]])


def cond_plus_cond_sq(m):
    c = jnp.linalg.cond(m)
    return c + c**2


m = jnp.array((m_1, m_2))

# Question: how to replace this for-loop with jax.vmap?
cs = []
for i in range(m.shape[2]):
    x = m[:, :, i].reshape(2, 2)
    c = jnp.linalg.cond(x)
    cs.append(c)
cs = jnp.array(cs)

# NB: jax.vmap by default does something different:
print(jax.vmap(cond_plus_cond_sq)(m))
print(cs)

## Cubic regression

In [None]:
from matplotlib import pyplot as plt

true_a = -0.5
true_b = 2.3
true_c = 1.0
true_d = -0.2

true_params = jnp.array((true_a, true_b, true_c, true_d))


# TODO-1
@jax.jit
def predict(params, x):
    # interpret params as coefficents a, b, c, d
    # return ax^4 + bx^3 + cx^2 + d
    pass


# generate noisy data: first, evaluate the true cubic polynomial on 200 points
xs = jnp.linspace(-2.0, 2.0, 200)
predict_v = jax.vmap(predict, in_axes=(None, 0))
ys = predict_v(true_params, xs)

# then add some random noise
eps = 0.2
seed = 2
noise = jrandom.uniform(jrandom.PRNGKey(seed), ys.shape, minval=-eps, maxval=eps)

noisy_ys = ys + noise


# TODO-2
@jax.jit
def loss(params, xs, ys):
    # return MSE loss: \frac{1}{n} \sum (predicted[i] - ys[i])^2
    pass


params = jrandom.uniform(jrandom.PRNGKey(seed), (4,))

n_steps = 200
lr = 0.05

loss_and_grad = jax.jit(jax.value_and_grad(loss))

for step in range(n_steps):
    curr_loss, params_grad = loss_and_grad(params, xs, noisy_ys)

    params = params - lr * params_grad

    if step % 10 == 0:
        print(f"Step {step}, loss = {curr_loss}, {params = }")
        plt.plot(xs, predict_v(params, xs), color="red")
        plt.plot(xs, noisy_ys, color="green")
        plt.show()

## JIT

How many times will `jax.jit` compile `f` in the following example?

In [None]:
jax.config.update("jax_enable_x64", False)


@jax.jit
def f(x):
    print("Compiling: ", x)
    return x - jnp.sum(x)


x1 = jnp.ones((2,))
x2 = jnp.zeros((2,))
x3 = jnp.ones((3,))
x4 = jnp.ones((2, 2))
x5 = jnp.ones((4, 4))
x6 = jnp.ones((3, 3, 3))
x7 = jnp.ones((2,), dtype=jnp.int32)
jax.config.update("jax_enable_x64", True)
x8 = jnp.ones((2,), dtype=jnp.float64)

for x in [x1, x2, x3, x4, x5, x6, x7, x8]:
    print("f(x) = ", f(x))

## Legendre

1. For `legendre_1`, use recursive formula $P_n(x) = \frac{(2n-1)xP_{n-1}(x) - (n-1) P_{n-2}(x)}{n}$.
2. For `legendre_2`, use formula $P_n(x) = \frac{1}{2^n n!} \frac{d^n}{dx^n} (x^2 - 1)^n$.
3. Fix `jax.jit` calls so that they work.
4. Time different combinations of `jax.jit` and `jax.vmap`.

In [None]:
# Compute Legendre polynomial P_n(x) using recursive formula
# P_n(x) = \frac{(2n-1)xP_{n-1}(x) - (n-1) P_{n-2}(x)}{n}
def legendre_1(x, n):
    return None


# Compute Legendre polynomial P_n(x) using formula
# P_n(x) = \frac{1}{2^n n!} \frac{d^n}{dx^n} (x^2 - 1)^n


def legendre_2(x, n):
    helper = lambda y: (y**2 - 1) ** n
    return None


# This function is to check previous two
def legendre_3(x, n):
    if n == 0:
        return 1
    elif n == 1:
        return x
    elif n == 2:
        return (3 * x**2 - 1) / 2
    elif n == 3:
        return (5 * x**3 - 3 * x) / 2
    elif n == 4:
        return (35 * x**4 - 30 * x * x + 3) / 8
    else:
        raise RuntimeError("not implemented")


# what can be jit-ted? How to fix the jax.jit call?
jlegendre_1 = jax.jit(legendre_1)
jlegendre_2 = jax.jit(legendre_2)
jlegendre_3 = jax.jit(legendre_3)


x = 0.3
n = 4

print(jlegendre_1(x, n))
print(jlegendre_2(x, n))
print(jlegendre_3(x, n))

# Timing
xs = np.linspace(0, 4, 100000)
zs = np.zeros_like(xs)

# Profile the combination of jit and vmap. What is faster: vmap after jit
# or jit after vmap?

# Evaluate transformed function on zs once to pre-compile it.
# Evaluate transformed function on xs to measure the runtime.
xs = np.linspace(0, 4, 100000)
zs = np.zeros_like(xs)


for f in [legendre_1, legendre_2, legendre_3]:
    # apply jit first, then vmap
    f_jv = None
    # call f_jv once on zs, once on xs, measuring the time
    # of the second call


for f in [legendre_1, legendre_2, legendre_3]:
    # apply vmap first, then jit
    f_vj = None
    # call f_jv once on zs, once on xs, measuring the time
    # of the second call