Vision Transformer Classifier
---

In this notebook, we are going to use the [ViT backbone](https://github.com/romaingrx/relax/blob/master/relax/models/ViT.py) implemented in the [relax](https://github.com/romaingrx/relax/) package to classify the MNIST images.

In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')

import tensorflow_datasets as tfds

import jax
import einops
import optax
import haiku as hk
from relax import Trainer, TrainingConfig
from relax.models import ViT as ViTBackbone

from typing import Optional, Union, Sequence, Tuple
from dataclasses import dataclass

def ds_to_array(ds):
    itr = (
            ds
            .map(lambda d: (d['image'] / 255, d['label']))
            .as_numpy_iterator()
            )
    return jax.device_put(list(itr))

train_ds, test_ds = tfds.load("mnist", split=["train[:80%]", "test"], batch_size=128)
x = ds_to_array(train_ds)
x_test = ds_to_array(test_ds)

2022-11-22 20:19:07.844841: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-22 20:19:07.990191: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-11-22 20:19:08.585791: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/lib/x86_64-linux-gnu:/usr/lib/x86_64-linux-gnu
2022-11-22 20:19:08.585873: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: l

Let's define the ViT as the ViT backbone (basically just a the patch encoding + transformer encoder) on which we add the classification layers (one hidden and the final one).

In [2]:
class ViT(ViTBackbone):
    def __init__(self, n_classes: int, *args, n_hidden: int = 512, **kwargs):
        super().__init__(*args, **kwargs)
        self.n_classes = n_classes
        self.n_hidden = n_hidden
    
    def __call__(self, x):
        latent = super().__call__(x)
        flattened_latent = einops.rearrange(latent, 'b ... -> b (...)')
        
        z = hk.Linear(self.n_hidden, name="hidden_clf")(flattened_latent)
        z = jax.nn.relu(z)
        
        logits = hk.Linear(self.n_classes, name="clf")(z)
        return logits

In [3]:
@hk.transform
def model(x):
    logits = ViT(10, (8, 8), 64)(x)
    return logits

def loss_fn(params, rng, data):
    x, y = data
    logits = model.apply(params, rng, x)
    return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()

In [4]:
config = TrainingConfig(
            epochs=25,
            )

optimizer = optax.adam(0.001)

trainer = Trainer(model, optimizer, config)

rng = jax.random.PRNGKey(42)
fake_img = jax.numpy.ones((1, 28, 28, 1))
init_state = trainer.init(rng, fake_img)

Let's print out the shapes of our params

In [5]:
from math import prod
nb_params = lambda params: jax.tree_util.tree_reduce(lambda a,b: a+b, jax.tree_map(lambda l:prod(l.shape), params))
print(f"The model has {nb_params(init_state.params)} parameters")
jax.tree_map(lambda l:l.shape, init_state.params)

The model has 379978 parameters


{'vi_t/TransformerBlock/MLP/linear': {'b': (256,), 'w': (64, 256)},
 'vi_t/TransformerBlock/MLP/linear_1': {'b': (128,), 'w': (256, 128)},
 'vi_t/TransformerBlock/MLP/linear_2': {'b': (64,), 'w': (128, 64)},
 'vi_t/TransformerBlock/layer_norm': {'offset': (64,), 'scale': (64,)},
 'vi_t/TransformerBlock/layer_norm_1': {'offset': (64,), 'scale': (64,)},
 'vi_t/TransformerBlock/multi_head_attention/K': {'b': (64,), 'w': (64, 64)},
 'vi_t/TransformerBlock/multi_head_attention/Q': {'b': (64,), 'w': (64, 64)},
 'vi_t/TransformerBlock/multi_head_attention/V': {'b': (64,), 'w': (64, 64)},
 'vi_t/TransformerBlock/multi_head_attention/projection': {'b': (64,),
  'w': (64, 64)},
 'vi_t/clf': {'b': (10,), 'w': (512, 10)},
 'vi_t/hidden_clf': {'b': (512,), 'w': (576, 512)},
 'vi_t/patches_encoder/~embed_positions/embed': {'embeddings': (9, 64)},
 'vi_t/patches_encoder/~project_patches/linear': {'b': (64,), 'w': (64, 64)}}

In [6]:
trained_state = trainer.train(init_state, loss_fn, x, jit_update_step=True)

Training:   0%|          | 0/25 [00:00<?, ?epoch/s]

In [7]:
misclassified = 0
n_obs = 0
for x, y in x_test:
    logits = trainer.apply(trained_state.params, jax.random.PRNGKey(0), x)
    predictions = jax.numpy.argmax(jax.nn.softmax(logits, axis=-1), axis=-1)
    misclassified += (predictions != y).sum()
    n_obs += len(x)
print(f"Accuracy of {1-misclassified/n_obs:.2f} on the test set of {n_obs} observations.")

Accuracy of 0.98 on the test set of 10000 observations.
