# Generative Classifiers

In [None]:
import jax
from jax import random, jit
from jax.tree_util import tree_map
from functools import partial # needed to make arguments static in jit compiled code
import jax.numpy as jnp
from flax import linen as nn
import numpy as np
from torchvision.datasets import FashionMNIST
from torch.utils import data
from tqdm import trange
from optax import adam
import matplotlib.pyplot as plt
import pickle
from datetime import datetime
import os

In [None]:
# https://github.com/google-deepmind/dm-haiku/issues/18#issuecomment-981814403
MODELS_PATH = "../model/"
def save_models(model, path: str):
    with open(os.path.join(MODELS_PATH, path), "wb") as file:
        pickle.dump(model, file)

In [None]:
# because Apple sucks
jax.default_device = jax.devices("cpu")[0]
jax.default_device

Let's store all the parameters here:

In [None]:
args = {}

In [None]:
args['seed'] = 1
key = random.PRNGKey(args['seed'])
key

## Dataloader

- ### FashionMNIST: a realistic dataset

In [None]:
args['batch_size'] = 10

In [None]:
#https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html#data-loading-with-pytorch

def numpy_collate(batch):
  return tree_map(np.asarray, data.default_collate(batch))

class NumpyLoader(data.DataLoader):
  def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        pin_memory=pin_memory,
        drop_last=drop_last,
        timeout=timeout,
        worker_init_fn=worker_init_fn)

class FlattenAndCast(object):
  def __call__(self, pic):
    return np.ravel(np.array(pic, dtype=jnp.float32))

In [None]:
fashion_mnist_train_ds = FashionMNIST(
    './data/', 
    download=True,
    train=True,
    transform=FlattenAndCast(),
)
n_train_images, n_train_classes = len(fashion_mnist_train_ds.data), len(fashion_mnist_train_ds.targets)

fashion_mnist_test_ds = FashionMNIST(
    './data/', 
    download=True, 
    train=False,
)
n_test_images, n_test_classes = len(fashion_mnist_test_ds.data), len(fashion_mnist_test_ds.targets)

# Dataloader used for the training
fashion_mnist_train_dl = NumpyLoader(
    fashion_mnist_train_ds, 
    batch_size=args['batch_size'],
    shuffle=True,
)

# Get the full train dataset (for checking accuracy while training)
train_images = jnp.array(fashion_mnist_train_ds.data).reshape(n_train_images, -1)
train_labels = jnp.array(fashion_mnist_train_ds.targets)
print("train_images", train_images.shape, "train_labels", train_labels.shape)

# Get full test dataset
test_images = jnp.array(fashion_mnist_test_ds.data.numpy().reshape(n_test_images, -1))
test_labels = jnp.array(fashion_mnist_test_ds.targets)
print("test_images", test_images.shape, "test_labels", test_labels.shape)

## The model

Our model where $\theta$ and $\phi$ are learnable parameters:

In [None]:
def p_theta_xzy(theta: jnp.ndarray, x: float, z: float, y:float ) -> float:
    pass

def q_phi_z_xy(phi: jnp.ndarray, z: float, x: float, y: float) -> float:
    pass

Here are the loss functions defined in "Auto Encoding Varational Bayes" to approximate the true ELBO $\cal L$:

$$
\widetilde{\cal L}^{A}(\theta,\phi;{\bf x}^{(i)}, {\bf y}^{(i)})=\frac{1}{L}\sum_{l=1}^{L}\log p_{\theta}({\bf x}^{(i)},{\bf z}^{(i,l)},{\bf y}^{(i)})-\log q_{\phi}({\bf z}^{(i,l)}|{\bf x}^{(i)}, {\bf y}^{(i)})
$$

$$

\widetilde{\cal L}^{B}(\theta,\phi;{\bf x}^{(i)}, {\bf y}^{(i)})=-D_{K L}(q_{\phi}({\bf z}|{\bf x}^{(i)}, {\bf y}^{(i)})||p_{\theta}({\bf z}))+\frac{1}{L}\sum_{l=1}^{L} \log p_{\theta}({\bf x}^{(i)}|{\bf z}^{(i,l)}, {\bf y}^{(i)})

$$

$$

\widetilde{\cal L}^{M}(\theta,\phi;{\bf X}^{M}, {\bf y}^{M}, {\bf \epsilon})=\frac{1}{M}\sum_{i=1}^{M}=\widetilde{\cal L}^{A/B}(\theta,\phi;{\bf X}^{M}_i, {\bf y}^{M}_i)

$$

In [None]:
def loss_A(theta: jnp.ndarray, phi: jnp.ndarray, x: jnp.ndarray, y: jnp.ndarray, eps: jnp.ndarray) -> float:
    return 0

def loss_B(theta: jnp.ndarray, phi: jnp.ndarray, x: jnp.ndarray, y: jnp.ndarray, eps: jnp.ndarray) -> float:
    return 0

Noise distribution $p$ that parameterises latent variable $z$, i.e. $z = g_{\theta}(\epsilon, x, y)$ where $\epsilon\sim p(\epsilon)$.

In [None]:
def sample_p(keys: np.ndarray) -> np.ndarray:
    pass

In [None]:
@partial(jit, static_argnames=['learning_rate'])
def update_sgd(grad_theta: np.ndarray, grad_phi: np.ndarray, learning_rate: float):
    return -learning_rate * grad_theta, -learning_rate * grad_phi

In [None]:
args['num_epochs'] = 1

In [None]:
# for the validation loss logging
n_log_validation = 100

# definition of remaining parameters
M = 10
N_trains = 10
d_eps = 10

# gradient descent method and associated parameters: use optimisation algo implemented in Optax if needed
learning_rate = 0.01
update = update_sgd

# used loss function
loss = loss_A

# initial parameter values
theta_0 = jnp.ones(10)
phi_0 = jnp.ones(10)

# optimised parameters
theta = jnp.copy(theta_0)
phi = jnp.copy(phi_0)

#sink for the loss values
training_steps, validation_steps = [], []
training_loss_values, validation_loss_values = [], []

# training loop
for epoch in range(args['num_epochs']):
    for train_step, (X_batch, y_batch) in tqdm(enumerate(fashion_mnist_train_dl):
        # sample from noise distribution
        key, *sample_keys = random.split(key)
        key, eps = sample_p(sample_key, epsilon_shape)

        #compute the loss value and grad w.r.t. phi/theta
        grad_theta = jax.grad(loss, argnums=0)(theta, phi, X_M, y_M, eps)
        loss_value, grad_phi = jax.value_and_grad(loss, argnums=1)(theta, phi, X_M, y_M, eps)
        
        # log the training loss: here it's just the batch loss, need to be fixed.
        training_loss_values.append(loss_value)
        training_steps.append(train_step)

        # log the validation loss for some steps
        if (train_step+1) % n_log_validation == 0:
            validation_loss_value = compute_validation_loss(theta, phi)
            validation_loss_values.append(validation_loss_value)
            validation_steps.append(train_step)

        # update the parameters according to chosen policy
        theta, phi += update(grad_theta, grad_phi, learning_rate=learning_rate)

save_models({
    "learning_rate": learning_rate,
    "N_trains": N_trains,
    "M": M,
    "theta": theta,
    "phi": phi,
}, f"models/test_{datetime.today().strftime('%Y-%m-%d_%H:%M:%S')}.pkl")

Here is the resulting training/test loss plot:

In [None]:

fig, ax = plt.subplots()
ax.plot(training_steps, training_loss_values, label=["training"])
ax.plot(validation_steps, validation_loss_values, label=["validation"])
ax.set_xlabel("iteration")
fig.suptitle("Loss values during training")
_ = ax.set_ylabel("$\cal L$")