# Generative Classifiers

In [176]:
import jax
from jax import random, jit, vmap
from jax.nn import one_hot
from jax.nn.initializers import glorot_uniform
from jax.tree_util import tree_map
from functools import partial # needed to make arguments static in jit compiled code
import jax.numpy as jnp
from flax import linen as nn
import numpy as np
from torchvision.datasets import FashionMNIST
from torch.utils import data
import tqdm
from tqdm import trange
from optax import adam
import matplotlib.pyplot as plt
import pickle
from datetime import datetime
import os

In [7]:
# https://github.com/google-deepmind/dm-haiku/issues/18#issuecomment-981814403
MODELS_PATH = "../model/"
def save_models(model, path: str):
    with open(os.path.join(MODELS_PATH, path), "wb") as file:
        pickle.dump(model, file)

In [8]:
# because Apple sucks
jax.default_device = jax.devices("cpu")[0]
jax.default_device

CpuDevice(id=0)

Let's store all the parameters here:

In [9]:
args = {}

In [23]:
args['seed'] = 1
key = random.PRNGKey(args['seed'])
key

Array([0, 1], dtype=uint32)

## Dataloader

- ### FashionMNIST: a realistic dataset

In [24]:
args['batch_size'] = 10

In [25]:
#https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html#data-loading-with-pytorch

def numpy_collate(batch):
  return tree_map(np.asarray, data.default_collate(batch))

class NumpyLoader(data.DataLoader):
  def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        pin_memory=pin_memory,
        drop_last=drop_last,
        timeout=timeout,
        worker_init_fn=worker_init_fn)

class FlattenAndCast(object):
  def __call__(self, pic):
    return np.ravel(np.array(pic, dtype=jnp.float32))

In [29]:
fashion_mnist_train_ds = FashionMNIST(
    './data/', 
    download=True,
    train=True,
    transform=FlattenAndCast(),
)
n_train_images = len(fashion_mnist_train_ds.data)
classes = np.unique(fashion_mnist_train_ds.targets)
n_classes = len(classes)

fashion_mnist_test_ds = FashionMNIST(
    './data/', 
    download=True, 
    train=False,
)
n_test_images = len(fashion_mnist_test_ds.data)

# Dataloader used for the training
fashion_mnist_train_dl = NumpyLoader(
    fashion_mnist_train_ds, 
    batch_size=args['batch_size'],
    shuffle=True,
)

# Get the full train dataset (for checking accuracy while training)
train_images = jnp.array(fashion_mnist_train_ds.data).reshape(n_train_images, -1)
train_labels = jnp.array(fashion_mnist_train_ds.targets)
print("train_images", train_images.shape, "train_labels", train_labels.shape)

# Get full test dataset
test_images = jnp.array(fashion_mnist_test_ds.data.numpy().reshape(n_test_images, -1))
test_labels = jnp.array(fashion_mnist_test_ds.targets)
print("test_images", test_images.shape, "test_labels", test_labels.shape)

print("n_classes", n_classes, "classes", classes)

train_images (60000, 784) train_labels (60000,)
test_images (10000, 784) test_labels (10000,)
n_classes 10 classes [0 1 2 3 4 5 6 7 8 9]


## The model

Our model where $\theta$ and $\phi$ are learnable parameters:

- Let's take the GFZ probabilistic graphical model for a first implementation: $p(x,z,y)=p(z)p(y|z)p(x|z,y)$.

`
q(z|x,y) is the same across all VAE-based classifiers. It starts with a 3-layer
convolutional neural network with 5 ×5 filters and 64 channels, with a max-pooling operation after each convolution. Then,
the convolutional network is followed by a MLP with 2 hidden layers, each with 500 units, to produce the mean and variance
parameters of q. The label y is injected into the MLP at the first hidden layer, as a one hot encoding (i.e. for MNIST, the
first hidden layer has 500+10 units). The latent dimension is dim(z) = 64.
`

In [451]:
class Encoder(nn.Module):
    n_classes = 10
    d_epsilon = 64
    n_convolutions = 3
    n_channels = 64
    kernel_size = (5, 5)
    strides = (2, 2)
    d_hidden = 500

    @nn.compact
    def __call__(self, X, y, epsilon): # X: (height, width), y: (n_classes,), epsilon: (d_epsilon,) -> (d_epsilon,), 0
        for _ in range(self.n_convolutions):
            X = nn.Conv(
                features=self.n_channels, 
                kernel_size=self.kernel_size, 
                strides=self.strides, 
                kernel_init=glorot_uniform(),
            )(X)
            X = nn.relu(X)

        X_flatten = X.reshape(-1)
        X_flatten = nn.Dense(
            features=self.d_hidden, 
            use_bias=True,
            kernel_init=glorot_uniform(), 
        )(X_flatten)
        X_flatten = nn.relu(X_flatten)
        
        output = jnp.concatenate((X_flatten, y), axis=0)
        output = nn.Dense(
            features=self.d_hidden + self.n_classes, 
            use_bias=False,
            kernel_init=glorot_uniform(), 
        )(output)
        output = nn.relu(output)

        output = nn.Dense(
            features=2 * self.d_epsilon, 
            use_bias=True,
            kernel_init=glorot_uniform(), 
        )(output)
        # end of model

        mu, log_sigma2 = jnp.split(output, 2)
        # this should be correct: to change by recomputing
        logdet = jnp.sum(log_sigma2)
        sigma = jnp.sqrt(jax.lax.pow(10.0, logdet))
        z = mu + sigma * epsilon
        logits =  - (self.d_epsilon * logdet + jnp.dot(epsilon, epsilon)) / 2
        return z, logits

`
For p(y|z) we use a MLP with 1 hidden layer composed of 500 units. For p(x|y,z) we used an MLP with 2
hidden layers, each with 500 units, and 4 ×4 ×64 dimension output, followed by a 3-layer deconvolutional network
with 5 ×5 kernel size, stride 2 and [64, 64, 1] channels.
`

In [458]:
#TODO: implement forward pass
class Log_p_y_z(nn.Module):
    d_hidden = 500
    n_classes = 10

    @nn.compact
    def __call__(self, y, z): # y: (n_classes,), z: (d_epsilon,) -> (n_classes,)
        logits = nn.Dense(
            features=self.d_hidden, 
            use_bias=True,
            kernel_init=glorot_uniform(), 
        )(z)
        logits = nn.relu(logits)
        logits = nn.Dense(
            features=self.n_classes, 
            use_bias=True,
            kernel_init=glorot_uniform(), 
        )(logits)
        
        return logits # use cross entropy with logits

#TODO: implement
class Log_p_x_yz(nn.Module):
    d_hidden = 500
    d_epsilon = 64
    n_classes = 10
    n_channels = 64
    input_kernel = (4, 4)
    kernel_size = (5, 5)
    strides = (2, 2)
    

    @nn.compact
    def __call__(self, X, y, z): # X: (height, width), y: (n_classes,), z: (d_epsilon,) -> 0
        inputs = jnp.concatenate([y, z], 0)
        inputs = nn.Dense(
            features=self.d_hidden, 
            use_bias=True,
            kernel_init=glorot_uniform(), 
        )(inputs)
        inputs = nn.relu(inputs)
        inputs = nn.Dense(
            features=np.prod(self.input_kernel) * self.n_channels, 
            use_bias=True,
            kernel_init=glorot_uniform(), 
        )(inputs)
        inputs = nn.relu(inputs)
        inputs = inputs.reshape(self.input_kernel + (self.n_channels,))

        inputs = nn.ConvTranspose(
            features=2*self.d_epsilon,
            kernel_size=self.kernel_size,
            strides=self.strides,
            padding=(2, 2),
            kernel_init=glorot_uniform(),
        )(inputs)
        inputs = nn.relu(inputs)

        inputs = nn.ConvTranspose(
            features=2*self.d_epsilon,
            kernel_size=self.kernel_size,
            strides=self.strides,
            padding=((2, 3), (2, 3)),
            kernel_init=glorot_uniform(),
        )(inputs)
        inputs = nn.relu(inputs)

        inputs = nn.ConvTranspose(
            features=2,
            kernel_size=self.kernel_size,
            strides=self.strides,
            padding=((2, 3), (2, 3)),
            kernel_init=glorot_uniform(),
        )(inputs)
        inputs = nn.sigmoid(inputs) #this sounds wrong! logsigma2 lives in R and not (-1, 1)

        mu, log_sigma2 = jnp.split(inputs, 2, axis=2)

        # this should be correct: to change by recomputing
        n = X.ndim
        sigma2 = jax.lax.pow(10.0, log_sigma2)
        delta = (X - mu).reshape(-1)
        logits =  - (n * log_sigma2 + jnp.dot(delta, delta) / sigma2) / 2
        return logits
    
# (GFZ) graph model
class Log_p_xzy(nn.Module):

    @nn.compact
    def __call__(self, X, y, z): # X: (height, width), y: (n_classes,), z: (d_epsilon,) -> 0
        d_epsilon = z.shape[0]
        log_prior = -(d_epsilon + jnp.dot(z, z))/2 # assuming it's standard Gaussian
        return log_prior + Log_p_y_z()(y, z) + Log_p_x_yz()(X, y, z)
    
class ModelGFZ(nn.Module):

    @nn.compact
    def __call__(self, X, y, epsilon): # X: (height, width), y: (n_classes,), epsilon: (d_epsilon,) -> 1, 1
        z, logits_q = Encoder()(X, y, epsilon)
        logits_p = Log_p_xzy()(X, y, z)
        return logits_p, logits_q

In [459]:
X = jnp.ones((28, 28, 1))
y = jnp.ones(10)
epsilon = jnp.ones(64)

model = ModelGFZ()
params = model.init(key, X, y, epsilon)
log_p, log_q = model.apply(params, X, y, epsilon)
print(log_p.shape, log_q.shape)

(28, 28, 1) ()


Here are the loss functions defined in "Auto Encoding Varational Bayes" to approximate the true ELBO $\cal L$:

$$
\widetilde{\cal L}^{A}(\theta,\phi;{\bf x}^{(i)}, {\bf y}^{(i)})=\frac{1}{L}\sum_{l=1}^{L}\log p_{\theta}({\bf x}^{(i)},{\bf z}^{(i,l)},{\bf y}^{(i)})-\log q_{\phi}({\bf z}^{(i,l)}|{\bf x}^{(i)}, {\bf y}^{(i)})
$$

$$

\widetilde{\cal L}^{B}(\theta,\phi;{\bf x}^{(i)}, {\bf y}^{(i)})=-D_{K L}(q_{\phi}({\bf z}|{\bf x}^{(i)}, {\bf y}^{(i)})||p_{\theta}({\bf z}))+\frac{1}{L}\sum_{l=1}^{L} \log p_{\theta}({\bf x}^{(i)}|{\bf z}^{(i,l)}, {\bf y}^{(i)})

$$

$$

\widetilde{\cal L}^{M}(\theta,\phi;{\bf X}^{M}, {\bf y}^{M}, {\bf \epsilon})=\frac{N}{M}\sum_{i=1}^{M}=\widetilde{\cal L}^{A\text{ or }B}(\theta,\phi;{\bf X}^{M}_i, {\bf y}^{M}_i)

$$

In [430]:
# loss function \widetilde{\cal L}^{A} for a pair (X, y) and an array of epsilon
def loss_A_single(X: jnp.ndarray, y: jnp.ndarray, epsilon: jnp.ndarray) -> float:
    z, logits_q = log_q_z_xy(epsilon, X, y)
    logits_p = log_p_xzy(X, z, y)
    return jnp.mean(logits_p - logits_q, axis=-1)

loss_A_batch = vmap(loss_A_single, in_axes=(0, 0, 0), out_axes=0)
def loss_M_A(X_batch: jnp.ndarray, y_batch: jnp.ndarray, epsilon: jnp.ndarray):
    return jnp.sum(loss_A_batch(X_batch, y_batch, epsilon))

def loss_B(theta: jnp.ndarray, phi: jnp.ndarray, x: jnp.ndarray, y: jnp.ndarray, eps: jnp.ndarray) -> float:
    return 0

In [None]:
def update_step(apply_fn, X_batch, y_batch, epsilon, opt_state, params, state):
  def batch_loss(params):
    def loss_fn(X, y, epsilon):
      log_p, log_q, updated_state = apply_fn(
        {'params': params, **state},
        X, y, epsilon, mutable=list(state.keys())
      )
      return log_p - log_q, updated_state

    loss, updated_state = jax.vmap(
      loss_fn, 
      in_axes=(0, 0, 0), 
      out_axes=(0, None), # Do not vmap `updated_state`.
      axis_name='batch' # Name batch dim
    )(X_batch, y_batch, epsilon)  # vmap only `X`, `y`, but not `state`.
    return jnp.mean(loss), updated_state

  (loss, updated_state), grads = jax.value_and_grad(
    batch_loss, has_aux=True
  )(params)

  updates, opt_state = tx.update(grads, opt_state)  # Defined below.
  params = optax.apply_updates(params, updates)
  return opt_state, params, updated_state, loss

Noise distribution $p$ that parameterises latent variable $z$, i.e. $z = g_{\theta}(\epsilon, x, y)$ where $\epsilon\sim p(\epsilon)$.

In [16]:
def sample_p(key: jnp.ndarray, epsilon_shape: tuple) -> np.ndarray:
    key, sample_key = random.split(key)
    epsilon = random.normal(sample_key, epsilon_shape)
    return key, epsilon

In [None]:
@partial(jit, static_argnames=['learning_rate'])
def update_sgd(grad_theta: np.ndarray, grad_phi: np.ndarray, learning_rate: float):
    return -learning_rate * grad_theta, -learning_rate * grad_phi

In [None]:
args['num_epochs'] = 1

In [44]:

# for the validation loss logging
n_log_validation = 100

# definition of remaining parameters
dim_batch = 100
d_epsilon = 64

# gradient descent method and associated parameters: use optimisation algo implemented in Optax if needed
learning_rate = 0.01
update = update_sgd

# initial parameter values
theta_0 = jnp.ones(10)
phi_0 = jnp.ones(10)

# optimised parameters
theta = jnp.copy(theta_0)
phi = jnp.copy(phi_0)

#sink for the loss values
training_steps, validation_steps = [], []
training_loss_values, validation_loss_values = [], []

epsilon_shape = (n_batch, d_epsilon)

# training loop
for epoch in range(args['num_epochs']):
    for train_step, (X_batch, y_batch) in tqdm(enumerate(fashion_mnist_train_dl), total=n_train_images):
        # one-hot encoding
        y_batch_one_hot = one_hot(y_batch, n_classes)
        
        # sample from noise distribution
        key, epsilon = sample_p(key, epsilon_shape)

        #compute the loss value and grad w.r.t. phi/theta
        grad_theta = jax.grad(loss, argnums=0)(theta, phi, X_M, y_M, eps)
        loss_value, grad_phi = jax.value_and_grad(loss, argnums=1)(theta, phi, X_M, y_M, eps)
        
        # log the training loss: here it's just the batch loss, need to be fixed.
        training_loss_values.append(loss_value)
        training_steps.append(train_step)

        # log the validation loss for some steps
        if (train_step+1) % n_log_validation == 0:
            validation_loss_value = compute_validation_loss(theta, phi)
            validation_loss_values.append(validation_loss_value)
            validation_steps.append(train_step)

        # update the parameters according to chosen policy
        theta, phi += update(grad_theta, grad_phi, learning_rate=learning_rate)

save_models({
    "learning_rate": learning_rate,
    "N_trains": N_trains,
    "M": M,
    "theta": theta,
    "phi": phi,
}, f"models/test_{datetime.today().strftime('%Y-%m-%d_%H:%M:%S')}.pkl")

SyntaxError: 'tuple' is an illegal expression for augmented assignment (4129801499.py, line 54)

Here is the resulting training/test loss plot:

In [None]:

fig, ax = plt.subplots()
ax.plot(training_steps, training_loss_values, label=["training"])
ax.plot(validation_steps, validation_loss_values, label=["validation"])
ax.set_xlabel("iteration")
fig.suptitle("Loss values during training")
_ = ax.set_ylabel("$\cal L$")