In [1]:
!pip3 install -q -U jax jaxlib flax wandb

In [2]:
import os
import re
import wandb
import requests
from typing import Any
from functools import partial
from tqdm.autonotebook import tqdm
from kaggle_secrets import UserSecretsClient

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

import jax
import flax
import optax
from jax import lax
import flax.linen as nn
from flax.training import train_state, common_utils

import numpy as np
import jax.numpy as jnp

import tensorflow as tf
import tensorflow_datasets as tfds

In [3]:
if 'TPU_NAME' in os.environ:
    if 'TPU_DRIVER_MODE' not in globals():
        url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
        resp = requests.post(url)
        TPU_DRIVER_MODE = 1
    from jax.config import config as jax_config
    jax_config.FLAGS.jax_xla_backend = "tpu_driver"
    jax_config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
    print("TPU DETECTED!")
    print('Registered TPU:', jax_config.FLAGS.jax_backend_target)
elif "COLAB_TPU_ADDR" in os.environ:
    import jax.tools.colab_tpu
    jax.tools.colab_tpu.setup_tpu()
else:
    print('No TPU detected.')

DEVICE_COUNT = len(jax.local_devices())
TPU = DEVICE_COUNT == 8

if TPU:
    print("8 cores of TPU ( Local devices in Jax ):")
    print('\n'.join(map(str,jax.local_devices())))

In [4]:
user_secrets = UserSecretsClient()
os.environ["WANDB_API_KEY"] = user_secrets.get_secret("wandb_api_key")
wandb.init(project="jaxpiracy-autoencoder", entity="wandb", job_type="dev")

SEED = 0
LATENT_DIM = 128
IMAGE_SIZE = 32
BATCH_SIZE = 64
DEVICE_BATCH_SIZE = DEVICE_COUNT * BATCH_SIZE
EPOCHS = 100

In [5]:
def preprocess_data(sample):
    sample = tf.image.convert_image_dtype(sample['image'], tf.float32)
    sample = sample / 255.
    return sample


train_dataset = tfds.load("cifar10")['train']
train_size = len(train_dataset)

train_dataset = train_dataset.map(
    preprocess_data, num_parallel_calls=tf.data.AUTOTUNE
)
train_dataset = train_dataset.cache()
train_dataset = train_dataset.repeat()
train_dataset = train_dataset.batch(DEVICE_BATCH_SIZE)

data_generator = iter(tfds.as_numpy(train_dataset))

In [6]:
images = next(iter(train_dataset)).numpy()
fig = plt.figure(figsize=(16, 16))
grid = ImageGrid(fig, 111, nrows_ncols=(8, 8), axes_pad=0.1)
random_images = images[np.random.choice(np.arange(images.shape[0]), 16)]
for ax, image in zip(grid, images):
    image = image * 255.
    ax.imshow(image)
plt.title("Sample Images from CIFAR10 Dataset")
plt.show()

In [7]:
class Encoder(nn.Module):
    num_channels: int
    latent_dim: int
    dtype: Any = jnp.float32
    
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=self.num_channels, kernel_size=(3, 3), strides=(2, 2), dtype=self.dtype)(x)
        x = nn.relu(x)
        x = nn.Conv(features=self.num_channels, kernel_size=(3, 3), dtype=self.dtype)(x)
        x = nn.relu(x)
        x = nn.Conv(features=2 * self.num_channels, kernel_size=(3, 3), strides=(2, 2), dtype=self.dtype)(x)
        x = nn.relu(x)
        x = nn.Conv(features=2 * self.num_channels, kernel_size=(3, 3), dtype=self.dtype)(x)
        x = nn.relu(x)
        x = nn.Conv(features=2 * self.num_channels, kernel_size=(3, 3), strides=(2, 2), dtype=self.dtype)(x)
        x = nn.relu(x)
        x = x.reshape(x.shape[0], -1)
        return nn.Dense(features=self.latent_dim, dtype=self.dtype)(x)

In [8]:
rng = jax.random.PRNGKey(SEED)
x = jnp.ones(shape=(BATCH_SIZE, 32, 32, 3))
model = Encoder(num_channels=32, latent_dim=128)
nn.tabulate(model, jax.random.PRNGKey(0))(x)

In [9]:
class Decoder(nn.Module):
    num_channels: int
    latent_dim: int
    dtype: Any = jnp.float32
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features= 2 * 16 * self.num_channels, dtype=self.dtype)(x)
        x = nn.relu(x)
        x = x.reshape(x.shape[0], 4, 4, -1)
        x = nn.ConvTranspose(features=2 * self.num_channels, kernel_size=(3, 3), strides=(2, 2), dtype=self.dtype)(x)
        x = nn.relu(x)
        x = nn.Conv(features=2 * self.num_channels, kernel_size=(3, 3), dtype=self.dtype)(x)
        x = nn.relu(x)
        x = nn.ConvTranspose(features=self.num_channels, kernel_size=(3, 3), strides=(2, 2), dtype=self.dtype)(x)
        x = nn.relu(x)
        x = nn.Conv(features=self.num_channels, kernel_size=(3, 3), dtype=self.dtype)(x)
        x = nn.relu(x)
        x = nn.ConvTranspose(features=3, kernel_size=(3, 3), strides=(2, 2), dtype=self.dtype)(x)
        return nn.tanh(x)

In [10]:
rng = jax.random.PRNGKey(SEED)
x = jnp.ones(shape=(BATCH_SIZE, LATENT_DIM))
model = Decoder(num_channels=32, latent_dim=LATENT_DIM)
nn.tabulate(model, jax.random.PRNGKey(0))(x)

In [11]:
class Autoencoder(nn.Module):
    num_channels: int
    latent_dim: int
    dtype: Any = jnp.float32
    
    def setup(self):
        self.encoder = Encoder(num_channels=self.num_channels, latent_dim=self.latent_dim, dtype=self.dtype)
        self.decoder = Decoder(num_channels=self.num_channels, latent_dim=self.latent_dim, dtype=self.dtype)
    
    def __call__(self, x):
        x = self.encoder(x)
        return self.decoder(x)

In [12]:
rng = jax.random.PRNGKey(SEED)
x = jnp.ones(shape=(BATCH_SIZE, 32, 32, 3))
model = Autoencoder(num_channels=32, latent_dim=LATENT_DIM)
nn.tabulate(model, jax.random.PRNGKey(0))(x)

In [14]:
def init_train_state(random_key, shape) -> train_state.TrainState:
    model = Autoencoder(num_channels=32, latent_dim=LATENT_DIM)
    variables = model.init(random_key, jnp.ones(shape))
    lr_schedule = optax.warmup_cosine_decay_schedule(
            init_value=0.0,
            peak_value=1e-3,
            warmup_steps=100,
            decay_steps=500 * train_size,
            end_value=1e-5
        )
    optimizer = optax.chain(
            optax.clip(1.0),  # Clip gradients at 1
            optax.adam(lr_schedule)
    )
    return train_state.TrainState.create(
        apply_fn = model.apply,
        tx=optimizer,
        params=variables['params']
    )

parallelized_init_train_state = jax.pmap(init_train_state, static_broadcasted_argnums=(1,))
    
ae_key = jax.random.PRNGKey(SEED)
ae_key = common_utils.shard_prng_key(ae_key)
state = parallelized_init_train_state(ae_key, (BATCH_SIZE, 32, 32, 3))
print(type(state))

In [15]:
def train_step(
    state: train_state.TrainState, batch: jnp.ndarray, key: jnp.ndarray
) -> [train_state.TrainState, jnp.ndarray]:
    
    def loss_fn(params):
        y_hat = state.apply_fn({'params': params}, batch)
        loss = ((y_hat - batch) ** 2).mean(axis=0).sum()
        return loss, y_hat
    
    gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, y_hat), gradient = gradient_fn(state.params)
    gradient = jax.lax.pmean(gradient, axis_name='num_devices')
    loss = jax.lax.pmean(loss, axis_name='num_devices')
    state = state.apply_gradients(grads=gradient)
    return state, loss


parallelized_train_step = jax.pmap(train_step, axis_name='num_devices')

In [16]:
def infer_fn(state: train_state.TrainState, batch: jnp.ndarray):
    logits = state.apply_fn({'params': state.params}, batch)
    return logits

parallelized_infer_fn = jax.pmap(infer_fn, axis_name='num_devices')

In [17]:
def train(dataloader, epochs, batches_in_epoch, state, seed = SEED):
    for epoch in tqdm(range(epochs), desc="Epoch", position=0, leave=True):
        with tqdm(total=batches_in_epoch, desc="Training", leave=False) as progress_bar:
            for _ in range(batches_in_epoch):
                key = jax.random.PRNGKey(seed)
                key, autoencoder_key = jax.random.split(key, 2)
                autoencoder_key = common_utils.shard_prng_key(autoencoder_key)
                batch = common_utils.shard(next(dataloader))
                state, loss = parallelized_train_step(state, batch, autoencoder_key)
                progress_bar.update(1)
            
            loss_metric = np.asarray(jax.device_get([loss])).mean()
            message = f"Epoch: {epoch: <2} | "
            message += f"Loss: {loss_metric}"
            wandb.log({"Loss": loss_metric}, step=epoch)
            progress_bar.write(message)
    
    return state

In [18]:
state = train(data_generator, epochs=10, batches_in_epoch = train_size // DEVICE_BATCH_SIZE, state=state)

In [19]:
wandb.finish()