In [None]:
!nvcc --version

In [None]:
!ls -l /usr/lib/x86_64-linux-gnu/libcudnn.so*

In [None]:
!cat /usr/include/cudnn_version.h | grep MAJOR -A 2

In [None]:
!echo $LD_LIBRARY_PATH

In [None]:
!pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
!pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
!pip install flax

In [None]:
!pip install tqdm

In [None]:
import jax
import jax.numpy as jnp
import jax.dlpack
from jax import grad, jit, vmap, random
from jax import random
from jax.example_libraries import stax, optimizers

from tensorflow import keras
import tensorflow_datasets as tfds
import tensorflow as tf

import time
import numpy.random as npr
import math

from typing import Optional

import optax
from flax.training import train_state

from tqdm import tqdm

## ViT

In [None]:
from flax import linen as nn

In [None]:
class Patches(nn.Module):
  patch_size: int
  embed_dim: int

  def setup(self):
    self.conv = nn.Conv(
        features=self.embed_dim,
        kernel_size=(self.patch_size, self.patch_size),
        strides=(self.patch_size, self.patch_size),
        padding='VALID'
    )

  def __call__(self, images):
    patches = self.conv(images)
    b, h, w, c = patches.shape
    patches = jnp.reshape(patches, (b, h*w, c))
    return patches

In [None]:
class PatchEncoder(nn.Module):
  hidden_dim: int

  @nn.compact
  def __call__(self, x):
    assert x.ndim == 3
    n, seq_len, _ = x.shape
    # Hidden dim
    x = nn.Dense(self.hidden_dim)(x)
    # Add cls token
    cls = self.param('cls_token', nn.initializers.zeros, (1, 1, self.hidden_dim))
    cls = jnp.tile(cls, (n, 1, 1))
    x = jnp.concatenate([cls, x], axis=1)
    # Add position embedding
    pos_embed = self.param(
        'position_embedding', 
        nn.initializers.normal(stddev=0.02), # From BERT
        (1, seq_len + 1, self.hidden_dim)
    )
    return x + pos_embed

In [None]:
class MultiHeadSelfAttention(nn.Module):
  hidden_dim: int
  n_heads: int
  drop_p: float

  def setup(self):
    self.q_net = nn.Dense(self.hidden_dim)
    self.k_net = nn.Dense(self.hidden_dim)
    self.v_net = nn.Dense(self.hidden_dim)

    self.proj_net = nn.Dense(self.hidden_dim)

    self.att_drop = nn.Dropout(self.drop_p)
    self.proj_drop = nn.Dropout(self.drop_p)

  def __call__(self, x, train=True):
    B, T, C = x.shape # batch_size, seq_length, hidden_dim
    N, D = self.n_heads, C // self.n_heads # num_heads, head_dim
    q = self.q_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3) # (B, N, T, D)
    k = self.k_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3)
    v = self.v_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3)

    # weights (B, N, T, T)
    weights = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) / math.sqrt(D)
    normalized_weights = nn.softmax(weights, axis=-1)

    # attention (B, N, T, D)
    attention = jnp.matmul(normalized_weights, v)
    attention = self.att_drop(attention, deterministic=not train)

    # gather heads
    attention = attention.transpose(0, 2, 1, 3).reshape(B, T, N*D)

    # project
    out = self.proj_drop(self.proj_net(attention), deterministic=not train)

    return out

In [None]:
class MLP(nn.Module):
  mlp_dim: int
  drop_p: float
  out_dim: Optional[int] = None

  @nn.compact
  def __call__(self, inputs, train=True):
    actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
    x = nn.Dense(features=self.mlp_dim)(inputs)
    x = nn.gelu(x)
    x = nn.Dropout(rate=self.drop_p, deterministic=not train)(x)
    x = nn.Dense(features=actual_out_dim)(x)
    x = nn.Dropout(rate=self.drop_p, deterministic=not train)(x)
    return x

In [None]:
class Transformer(nn.Module):
  embed_dim: int
  hidden_dim: int
  n_heads: int
  drop_p: float
  mlp_dim: int

  def setup(self):
    self.mha = MultiHeadSelfAttention(self.hidden_dim, self.n_heads, self.drop_p)
    self.mlp = MLP(self.mlp_dim, self.drop_p)
    self.layer_norm = nn.LayerNorm(epsilon=1e-6)
    self.dropout = nn.Dropout(rate=self.drop_p)
  
  def __call__(self, inputs, train=True):
    # Attention Block
    x = self.layer_norm(inputs)
    x = self.mha(x, train)
    x = inputs + self.dropout(x, deterministic=not train)
    # MLP block
    y = self.layer_norm(x)
    y = self.mlp(y, train)

    return x + y

In [None]:
class ViT(nn.Module):
  patch_size: int
  embed_dim: int
  hidden_dim: int
  n_heads: int
  drop_p: float
  num_layers: int
  mlp_dim: int
  num_classes: int

  def setup(self):
    self.patch_extracter = Patches(self.patch_size, self.embed_dim)
    self.patch_encoder = PatchEncoder(self.hidden_dim)
    self.transformer = Transformer(self.embed_dim, self.hidden_dim, self.n_heads, self.drop_p, self.mlp_dim)
    self.mlp_head = MLP(self.mlp_dim, self.drop_p)
    self.cls_head = nn.Dense(features=self.num_classes)

  def __call__(self, x, train=True):
    x = self.patch_extracter(x)
    x = self.patch_encoder(x)
    for i in range(self.num_layers):
      x = self.transformer(x, train)
    # MLP head
    x = x[:, 0] # [CLS] token
    x = self.mlp_head(x, train)
    x = self.cls_head(x)
    return x

## Initialize ViT

In [None]:
main_rng = jax.random.PRNGKey(42)
x = jnp.ones(shape=(5, 32, 32, 3))
# ViT
model = ViT(
    patch_size=4,
    embed_dim=256,
    hidden_dim=512,
    n_heads=8,
    drop_p=0.2,
    num_layers=6,
    mlp_dim=1024,
    num_classes=10
)
main_rng, init_rng, drop_rng = random.split(main_rng, 3)
params = model.init({'params': init_rng, 'dropout': drop_rng}, x, train=True)['params']

In [None]:
jax.tree_map(lambda x: x.shape, params)

## Create TrainState

In [None]:
def init_train_state(
    model, params, learning_rate
):
  optimizer = optax.adam(learning_rate)
  return train_state.TrainState.create(
      apply_fn=model.apply,
      tx=optimizer,
      params=params
  )

In [None]:
state = init_train_state(model, params, 3e-4)

## Dataset preparation

In [None]:
import tensorflow
import tensorflow_datasets as tfds

In [None]:
(full_train_set, test_dataset), ds_info = tfds.load(
    'cifar10',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True
)

In [None]:
def normalize_img(image, label):
  image = tf.cast(image, tf.float32) / 255.0
  return image, label

In [None]:
full_train_set = full_train_set.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE
)

In [None]:
# Split train_set into train and val
validation_split = 0.2
num_data = tf.data.experimental.cardinality(full_train_set).numpy()
train_dataset = full_train_set.take(
    num_data * (1 - validation_split)
)
val_dataset = full_train_set.take(
    num_data * validation_split
)

In [None]:
train_dataset = train_dataset.cache()
train_dataset = train_dataset.shuffle(tf.data.experimental.cardinality(train_dataset).numpy())
train_dataset = train_dataset.batch(64)

In [None]:
val_dataset = val_dataset.cache()
val_dataset = val_dataset.shuffle(tf.data.experimental.cardinality(val_dataset).numpy())
val_dataset = val_dataset.batch(64)

In [None]:
test_dataset = test_dataset.cache()
test_dataset = test_dataset.shuffle(tf.data.experimental.cardinality(test_dataset).numpy())
test_dataset = test_dataset.batch(64)

## Train step

In [None]:
def cross_entropy_loss(*, logits, labels):
  labels_onehot = jax.nn.one_hot(labels, num_classes=10)
  return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

In [None]:
def compute_metrics(*, logits, labels):
  loss = cross_entropy_loss(logits=logits, labels=labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  }
  return metrics

In [None]:
@jax.jit
def train_step(state, batch, rng):
  images, labels = batch
  rng, drop_rng = random.split(rng)

  def loss_fn(params):
    logits = state.apply_fn({'params': params}, images, rngs={'dropout': drop_rng})
    loss = cross_entropy_loss(logits=logits, labels=labels)
    return loss, logits

  gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grads = gradient_fn(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(logits=logits, labels=labels)
  return state, metrics

In [None]:
def train(train_dataset, state, epochs):
  num_train_batches = tf.data.experimental.cardinality(train_dataset)
  for epoch in tqdm(range(1, epochs + 1)):
    train_datagen = iter(tfds.as_numpy(train_dataset))
    for batch_idx in range(num_train_batches):
      batch = next(train_datagen)
      state, metrics = train_step(state, batch, main_rng)
      print(f"epoch {epoch}: acc {metrics['accuracy']}, loss {metrics['loss']}")

In [None]:
train(train_dataset, state, 10)