In [1]:
from flax import linen as nn
from flax.training import train_state
import jax
from jax import random 
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
import tensorflow_datasets as tfds
import wandb

  from .autonotebook import tqdm as notebook_tqdm


## Construct the Encoder and Decoder of the Network

In [2]:
# Get datasets as dict of JAX arrays.
def get_datasets():
  """Load MNIST train and test datasets into memory."""
  ds_builder = tfds.builder('mnist')
  ds_builder.download_and_prepare()
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
  train_ds['image'] = jnp.float32(train_ds['image']) / 255.0
  test_ds['image'] = jnp.float32(test_ds['image']) / 255.0
  return train_ds, test_ds


train_ds, test_ds = get_datasets()

2024-05-03 17:16:21.464641: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2251] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
2024-05-03 17:16:22.081032: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [3]:
# copy from vae example from the flax repo
class Encoder(nn.Module):
  """VAE Encoder."""

  latents: int

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(500, name='fc1')(x)
    x = nn.relu(x)
    mean_x = nn.Dense(self.latents, name='fc2_mean')(x)
    logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x)
    return mean_x, logvar_x


class Decoder(nn.Module):
  """VAE Decoder."""

  @nn.compact
  def __call__(self, z):
    z = nn.Dense(15, name='fc1_dec')(z)
    z = nn.relu(z)
    z = nn.Dense(10, name='fc2_dec')(z)
    return z


class IntentionMapper(nn.Module):
  """Intention Mapper for NMIST classifications"""

  latents: int = 20

  def setup(self):
    self.encoder = Encoder(self.latents)
    self.decoder = Decoder()

  def __call__(self, x, z_rng):
    mean, logvar = self.encoder(x)
    z = reparameterize(z_rng, mean, logvar)
    recon_x = self.decoder(z)
    return recon_x #, mean, logvar


def reparameterize(rng, mean, logvar):
  std = jnp.exp(0.5 * logvar)
  eps = random.normal(rng, logvar.shape)
  return mean + eps * std


def model(latents):
  return IntentionMapper(latents=latents)

In [4]:
imgs = train_ds["image"]

In [5]:
m_key, z_key = jax.random.split(jax.random.key(0))
mapper = model(20)
# need to reformat the image to a single vector
params = mapper.init(m_key, imgs[0].reshape((1,784)), z_key)

# tree map iterate through the dictionary and apply the mapping to each element
jax.tree_util.tree_map(lambda x: x.shape, params) # check the output shape

{'params': {'decoder': {'fc1_dec': {'bias': (15,), 'kernel': (20, 15)},
   'fc2_dec': {'bias': (10,), 'kernel': (15, 10)}},
  'encoder': {'fc1': {'bias': (500,), 'kernel': (784, 500)},
   'fc2_logvar': {'bias': (20,), 'kernel': (500, 20)},
   'fc2_mean': {'bias': (20,), 'kernel': (500, 20)}}}}

In [6]:
z_key, new_key = jax.random.split(z_key)
output = mapper.apply(params, imgs[:10].reshape((10, 784)), z_key)
output.shape

(10, 10)

In [7]:
train_imgs = train_ds["image"].reshape(60000, 784)
train_label = train_ds["label"]
test_imgs = test_ds["image"].reshape(10000, 784)
test_label = test_ds["label"]

In [8]:
@jax.jit
def apply_model(state, images, labels, rng):
  """Computes gradients, loss and accuracy for a single batch."""

  def loss_fn(params, rng):
    logits = state.apply_fn({'params': params}, images, rng)
    one_hot = jax.nn.one_hot(labels, 10)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
    return loss, logits

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params, rng)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  return grads, loss, accuracy


@jax.jit
def update_model(state, grads):
  return state.apply_gradients(grads=grads)


def train_epoch(state, train_ds, batch_size, rng):
  """Train for a single epoch."""
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, len(train_ds['image']))
  perms = perms[: steps_per_epoch * batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))

  epoch_loss = []
  epoch_accuracy = []

  for perm in perms:
    batch_images = train_ds['image'][perm, ...].reshape((batch_size, 784))
    batch_labels = train_ds['label'][perm, ...]
    grads, loss, accuracy = apply_model(state, batch_images, batch_labels, rng)
    state = update_model(state, grads)
    epoch_loss.append(loss)
    epoch_accuracy.append(accuracy)
  train_loss = np.mean(epoch_loss)
  train_accuracy = np.mean(epoch_accuracy)
  return state, train_loss, train_accuracy

def create_train_state(rng, config):
  """Creates initial `TrainState`."""
  model_key, z_key = jax.random.split(rng)
  mapper = model(20)
  params = mapper.init(model_key, jnp.ones([1, 784]), z_key)['params']
  tx = optax.sgd(config.learning_rate, config.momentum)
  return train_state.TrainState.create(apply_fn=mapper.apply, params=params, tx=tx)

def create_logger(name, project, config, notes=None):
  return wandb.init(name=name, project=project, config=config, notes=notes)

def train_and_evaluate(
    config: ml_collections.ConfigDict
) -> train_state.TrainState:
  """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.

  Returns:
    The train state (which includes the `.params`).
  """
  train_ds, test_ds = get_datasets()
  rng = jax.random.key(0)

  run = create_logger("flax_test", "flax_train", config.to_dict(), "test")

  rng, init_rng = jax.random.split(rng)
  state = create_train_state(init_rng, config)

  for epoch in range(1, config.num_epochs + 1):
    rng, input_rng = jax.random.split(rng)
    state, train_loss, train_accuracy = train_epoch(
        state, train_ds, config.batch_size, input_rng
    )
    rng, input_rng = jax.random.split(rng)
    _, test_loss, test_accuracy = apply_model(
        state, test_ds['image'].reshape((-1, 784)), test_ds['label'], input_rng,
    )
    # log the performance
    wandb.log({
      'epoch': epoch,
      'train_loss': train_loss,
      'train_accuracy': train_accuracy,
      'test_loss': test_loss,
      'test_accuracy': test_accuracy,
    })
  return state

In [9]:
from configs import default as config_lib
config = config_lib.get_config()

In [10]:
train_key = jax.random.key(42)
state = create_train_state(train_key, config)
_ = train_epoch(state, train_ds, batch_size=2048, rng=train_key)

In [11]:
train_and_evaluate(config)

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33myuy004[0m. Use [1m`wandb login --relogin`[0m to force relogin


TrainState(step=Array(70000, dtype=int32, weak_type=True), apply_fn=<bound method Module.apply of IntentionMapper(
    # attributes
    latents = 20
)>, params={'decoder': {'fc1_dec': {'bias': Array([ 0.16708271,  0.15266418,  0.02883372,  0.08760708,  0.28525156,
       -0.07390564, -0.05944288,  0.2435773 ,  0.09477376, -0.00794235,
       -0.01619562,  0.19898416,  0.10344347,  0.2595735 ,  0.07954495],      dtype=float32), 'kernel': Array([[-7.73114741e-01, -3.32529306e-01, -4.23463106e-01,
        -8.75369489e-01, -1.21341171e-02,  1.48929462e-01,
         7.36937761e-01, -1.11453854e-01,  1.53752476e-01,
         6.16736770e-01, -9.38747004e-02, -9.38096583e-01,
         7.84220621e-02,  7.46103749e-02,  8.04540873e-01],
       [ 2.52039164e-01, -1.06448305e+00, -4.05645370e-01,
         6.61983907e-01, -2.47851029e-01,  1.96908370e-01,
         5.73326461e-02, -1.02987774e-01, -5.60440943e-02,
        -1.10944021e+00,  3.23540807e-01,  1.67113408e-01,
        -1.18042797e-01,  5