In [16]:
import jax
from jax import random, jit
from functools import partial # needed to make arguments static in jit compiled code
import jax.numpy as jnp
import numpy as np
from tqdm import trange
from optax import adam

In [12]:
# because Apple sucks
jax.default_device = jax.devices("cpu")[0]

In [13]:
seed = 1
key = random.PRNGKey(seed)
key

Array([0, 1], dtype=uint32)

In [14]:
N = 1000
d = 10
n_classes = 10

key, key_X, key_y = random.split(key, 3)

X = random.normal(key_X, (N, d))
y = jax.nn.one_hot(random.randint(key_y, (N,), 0, n_classes), n_classes)

In [15]:
y[:1]

Array([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]], dtype=float32)

Our model where $\theta$ and $\phi$ are learnable parameters:

In [43]:
def p_xzy(theta: jnp.ndarray, x: float, z: float, y:float ) -> float:
    pass

def q_z_xy(phi: jnp.ndarray, z: float, x: float, y: float) -> float:
    pass

Here are the loss functions defined in "Auto Encoding Varational Bayes" to approximate the true ELBO $\cal L$:

$$
\widetilde{\cal L}^{A}(\theta,\phi;{\bf x}^{(i)}, {\bf y}^{(i)})=\frac{1}{L}\sum_{l=1}^{L}\log p_{\theta}({\bf x}^{(i)},{\bf z}^{(i,l)},{\bf y}^{(i)})-\log q_{\phi}({\bf z}^{(i,l)}|{\bf x}^{(i)}, {\bf y}^{(i)})
$$

$$

\widetilde{\cal L}^{B}(\theta,\phi;{\bf x}^{(i)}, {\bf y}^{(i)})=-D_{K L}(q_{\phi}({\bf z}|{\bf x}^{(i)}, {\bf y}^{(i)})||p_{\theta}({\bf z}))+\frac{1}{L}\sum_{l=1}^{L} \log p_{\theta}({\bf x}^{(i)}|{\bf z}^{(i,l)}, {\bf y}^{(i)})

$$


In [46]:
def loss_A(theta: jnp.ndarray, phi: jnp.ndarray, x: jnp.ndarray, y: jnp.ndarray, eps: jnp.ndarray) -> float:
    return 0

def loss_B(theta: jnp.ndarray, phi: jnp.ndarray, x: jnp.ndarray, y: jnp.ndarray, eps: jnp.ndarray) -> float:
    return 0

Noise distribution $p$ that parameterises latent variable $z$, i.e. $z = g_{\theta}(\epsilon, x, y)$ where $\epsilon\sim p(\epsilon)$.

In [51]:
def sample_p(keys: np.ndarray) -> np.ndarray:
    pass

In [17]:
@partial(jit, static_argnames=['learning_rate'])
def update_sgd(grad_theta: np.ndarray, grad_phi: np.ndarray, learning_rate: float):
    return -learning_rate * grad_theta, -learning_rate * grad_phi

In [52]:
# definition of remaining parameters
M = 10
N_trains = 10
learning_rate = 0.01
d_eps = 10

# used loss and update function: we can use optimisation algo implemented in Optax
loss = loss_A
update = update_sgd

# initial parameter values
theta_0 = jnp.ones(10)
phi_0 = jnp.ones(10)

# optimised parameters
theta = jnp.copy(theta_0)
phi = jnp.copy(phi_0)

#sink for the loss values
loss_values = []

# training loop
for iteration in trange(N_trains):
    # sample X^M, y^M mini-batches
    key, batch_key = random.split(key)
    indexes = random.choice(batch_key, M)
    X_M = X[indexes, ...]
    y_M = y[indexes, ...]

    # sample from noise distribution
    key, *sample_keys = random.split(key, d_eps+1)
    eps = sample_p(sample_keys)

    #compute the loss value and grad w.r.t. phi/theta
    grad_theta = jax.grad(loss, argnums=0)(theta, phi, X_M, y_M, eps)
    loss_value, grad_phi = jax.value_and_grad(loss, argnums=1)(theta, phi, X_M, y_M, eps)
    loss_values.append(loss_value)

    # update the parameters according to chosen policy
    theta, phi += update(grad_theta, grad_phi, learning_rate=learning_rate)

  0%|          | 0/10 [00:00<?, ?it/s]


NameError: name 'sample_p' is not defined