In [1]:
!pip3 install -q -U jax jaxlib flax



In [2]:
import os
import re
import requests
from typing import Any
from functools import partial
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm

import jax
import flax
import optax
from jax import lax
import flax.linen as nn
from flax.training import train_state, common_utils

import numpy as np
import jax.numpy as jnp

import tensorflow as tf
import tensorflow_datasets as tfds

  import sys
2022-06-21 13:14:00.081029: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/conda/lib


In [3]:
if 'TPU_NAME' in os.environ:
    if 'TPU_DRIVER_MODE' not in globals():
        url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
        resp = requests.post(url)
        TPU_DRIVER_MODE = 1
    from jax.config import config as jax_config
    jax_config.FLAGS.jax_xla_backend = "tpu_driver"
    jax_config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
    print("TPU DETECTED!")
    print('Registered TPU:', jax_config.FLAGS.jax_backend_target)
elif "COLAB_TPU_ADDR" in os.environ:
    import jax.tools.colab_tpu
    jax.tools.colab_tpu.setup_tpu()
else:
    print('No TPU detected.')

DEVICE_COUNT = len(jax.local_devices())
TPU = DEVICE_COUNT == 8

DATA_PATH = "../input/tpu-getting-started/tfrecords-jpeg-224x224"

if TPU:
    print("8 cores of TPU ( Local devices in Jax ):")
    print('\n'.join(map(str,jax.local_devices())))

TPU DETECTED!
Registered TPU: grpc://10.0.0.2:8470
8 cores of TPU ( Local devices in Jax ):
TPU_0(host=0,(0,0,0,0))
TPU_1(host=0,(0,0,0,1))
TPU_2(host=0,(1,0,0,0))
TPU_3(host=0,(1,0,0,1))
TPU_4(host=0,(0,1,0,0))
TPU_5(host=0,(0,1,0,1))
TPU_6(host=0,(1,1,0,0))
TPU_7(host=0,(1,1,0,1))


In [4]:
SEED = 0
LATENT_DIM = 128
IMAGE_SIZE = 32
BATCH_SIZE = 64
DEVICE_BATCH_SIZE = DEVICE_COUNT * BATCH_SIZE
EPOCHS = 100

In [5]:
def preprocess_data(sample):
    sample = tf.image.convert_image_dtype(sample['image'], tf.float32)
    sample = sample / 255. * 2. - 1.
    return sample


train_dataset = tfds.load("cifar10")['train']
train_size = len(train_dataset)

train_dataset = train_dataset.map(
    preprocess_data, num_parallel_calls=tf.data.AUTOTUNE
)
train_dataset = train_dataset.cache()
train_dataset = train_dataset.repeat()
train_dataset = train_dataset.batch(DEVICE_BATCH_SIZE)

data_generator = iter(tfds.as_numpy(train_dataset))

2022-06-21 13:14:22.088178: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/conda/lib
2022-06-21 13:14:22.088263: W tensorflow/stream_executor/cuda/cuda_driver.cc:326] failed call to cuInit: UNKNOWN ERROR (303)


In [6]:
class Encoder(nn.Module):
    c_hid : int
    latent_dim : int

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=self.c_hid, kernel_size=(3, 3), strides=2)(x)  # 32x32 => 16x16
        x = nn.gelu(x)
        x = nn.Conv(features=self.c_hid, kernel_size=(3, 3))(x)
        x = nn.gelu(x)
        x = nn.Conv(features=2*self.c_hid, kernel_size=(3, 3), strides=2)(x)  # 16x16 => 8x8
        x = nn.gelu(x)
        x = nn.Conv(features=2*self.c_hid, kernel_size=(3, 3))(x)
        x = nn.gelu(x)
        x = nn.Conv(features=2*self.c_hid, kernel_size=(3, 3), strides=2)(x)  # 8x8 => 4x4
        x = nn.gelu(x)
        x = x.reshape(x.shape[0], -1)  # Image grid to single feature vector
        x = nn.Dense(features=self.latent_dim)(x)
        return x


class Decoder(nn.Module):
    c_out : int
    c_hid : int
    latent_dim : int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=2*16*self.c_hid)(x)
        x = nn.gelu(x)
        x = x.reshape(x.shape[0], 4, 4, -1)
        x = nn.ConvTranspose(features=2*self.c_hid, kernel_size=(3, 3), strides=(2, 2))(x)
        x = nn.gelu(x)
        x = nn.Conv(features=2*self.c_hid, kernel_size=(3, 3))(x)
        x = nn.gelu(x)
        x = nn.ConvTranspose(features=self.c_hid, kernel_size=(3, 3), strides=(2, 2))(x)
        x = nn.gelu(x)
        x = nn.Conv(features=2*self.c_hid, kernel_size=(3, 3))(x)
        x = nn.gelu(x)
        x = nn.ConvTranspose(features=self.c_out, kernel_size=(3, 3), strides=(2, 2))(x)
        x = nn.tanh(x)
        return x

In [7]:
class Autoencoder(nn.Module):
    c_hid: int
    latent_dim : int

    def setup(self):
        # Alternative to @nn.compact -> explicitly define modules
        # Better for later when we want to access the encoder and decoder explicitly
        self.encoder = Encoder(c_hid=self.c_hid, latent_dim=self.latent_dim)
        self.decoder = Decoder(c_hid=self.c_hid, latent_dim=self.latent_dim, c_out=3)

    def __call__(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

In [None]:
@partial(jax.pmap, static_broadcasted_argnums=(1))
def initialize_autoencoder_state(key, input_shape):
    model = Autoencoder(c_hid=32, latent_dim=128)
    variables = model.init(key, jnp.ones(input_shape))
    tx = optax.chain(
        optax.clip(1.0),
        optax.adam(
            optax.warmup_cosine_decay_schedule(
                init_value=0.0,
                peak_value=1e-3,
                warmup_steps=100,
                decay_steps=EPOCHS * train_size,
                end_value=1e-5
            )
        )
    )
    return train_state.TrainState.create(
        apply_fn=model.apply,
        tx=tx,
        params=variables['params']
    )

In [19]:
key = jax.random.PRNGKey(seed=SEED)
autoencoder_key, key = jax.random.split(key, 2)
autoencoder_key = common_utils.shard_prng_key(autoencoder_key)
autoencoder_state = initialize_autoencoder_state(
    autoencoder_key, (BATCH_SIZE, *next(data_generator).shape[1:])
)

In [24]:
@partial(jax.pmap, axis_name='num_devices')
def train_step(state: train_state.TrainState, image_batch: jnp.ndarray, key: jnp.ndarray):
    
    def loss_fn(params):
        reconstructed_images = state.apply_fn({'params': params}, image_batch)
        loss = ((reconstructed_images - image_batch) ** 2).mean(axis=0).sum()
        return loss, reconstructed_images
    
    gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, reconstructed_images), gradients = gradient_fn(state.params)
    gradients = jax.lax.pmean(gradients, axis_name='num_devices')
    loss = jax.lax.pmean(loss, axis_name='num_devices')
    updated_state = state.apply_gradients(grads=gradients)
    return updated_state, loss

In [None]:
def train(data_generator, epochs: int, batches_in_epoch: int, state: train_state.TrainState, key: jnp.ndarray):
    for epoch in tqdm(range(1, epochs + 1), desc="Epoch", position=0, leave=True):
        with tqdm(total=batches_in_epoch, desc="Training", leave=False) as progress_bar:
            key, autoencoder_key = jax.random.split(key, 2)
            autoencoder_key = common_utils.shard_prng_key(autoencoder_key)
            batch_data = common_utils.shard(next(data_generator))
            updated_state, loss = train_step(state, batch_data, key)
            progress_bar.update(1)
        metrics = jax.device_get([loss])
        progress_bar.write(f"Epoch: {epoch: <2} | Loss: {metrics[0]:.3f}")
    return state