In [None]:
!pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
# !pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
!pip install flax
!pip install tqdm

In [None]:
!pip install -qq medmnist

In [None]:
import jax
import jax.numpy as jnp
import jax.dlpack
from jax import grad, jit, vmap, random
from jax import random
from jax.example_libraries import stax, optimizers

from tensorflow import keras
import tensorflow_datasets as tfds
import tensorflow as tf

import time
import numpy.random as npr
import math

from typing import Optional

import optax
from flax.training import train_state

import numpy as np
import medmnist

In [None]:
from flax import linen as nn

## Tubelet embedding

In [None]:
class TubeletEmbedding(nn.Module):
  patch_size: int
  embed_dim: int

  @nn.compact
  def __call__(self, videos):
    patches = nn.Conv(
        features=self.embed_dim,
        kernel_size=[self.patch_size, self.patch_size, self.patch_size],
        strides=[self.patch_size, self.patch_size, self.patch_size],
        padding='VALID'
    )(videos)
    b, t, h, w, c = patches.shape
    patches = jnp.reshape(patches, (b, t*h*w, c))
    return patches

## Test tubelet embedding

In [None]:
def test_tubelet_embedding():
  main_rng = jax.random.PRNGKey(42)
  x = jnp.ones(shape=(8, 16, 32, 32, 3))
  embedder = TubeletEmbedding(patch_size=4, embed_dim=128)
  main_rng, rng = random.split(main_rng)
  variables = embedder.init(main_rng, x)
  out = embedder.apply(variables, x)
  print(out.shape)
  return embedder, main_rng

In [None]:
test_embedder, main_rng = test_tubelet_embedding()

## Patch encoder

In [None]:
class PatchEncoder(nn.Module):
  hidden_dim: int

  @nn.compact
  def __call__(self, x):
    assert x.ndim == 3
    n, seq_len, _ = x.shape
    # Hidden dim
    x = nn.Dense(self.hidden_dim)(x)
    # Add cls token
    cls = self.param('cls_token', nn.initializers.zeros, (1, 1, self.hidden_dim))
    cls = jnp.tile(cls, (n, 1, 1))
    x = jnp.concatenate([cls, x], axis=1)
    # Add position embedding
    pos_embed = self.param(
        'position_embedding',
        nn.initializers.normal(stddev=0.02),
        (1, seq_len + 1, self.hidden_dim)
    )
    return x + pos_embed

## Spatio-temporal attention

In [None]:
class MultiHeadSelfAttention(nn.Module):
  hidden_dim: int
  n_heads: int
  drop_p: float

  def setup(self):
    self.q_net = nn.Dense(self.hidden_dim)
    self.k_net = nn.Dense(self.hidden_dim)
    self.v_net = nn.Dense(self.hidden_dim)

    self.proj_net = nn.Dense(self.hidden_dim)

    self.att_drop = nn.Dropout(self.drop_p)
    self.proj_drop = nn.Dropout(self.drop_p)

  def __call__(self, x, train=True):
    B, T, C = x.shape # batch_size, seq_length, hidden_dim
    N, D = self.n_heads, C // self.n_heads # num_heads, head_dim
    q = self.q_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3) # (B, N, T, D)
    k = self.k_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3)
    v = self.v_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3)

    # weights (B, N, T, T)
    weights = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) / math.sqrt(D)
    normalized_weights = nn.softmax(weights, axis=-1)

    # attention (B, N, T, D)
    attention = jnp.matmul(normalized_weights, v)
    attention = self.att_drop(attention, deterministic=not train)

    # gather heads
    attention = attention.transpose(0, 2, 1, 3).reshape(B, T, N*D)

    # project
    out = self.proj_drop(self.proj_net(attention), deterministic=not train)

    return out

In [None]:
class MLP(nn.Module):
  mlp_dim: int
  drop_p: float
  out_dim: Optional[int] = None

  @nn.compact
  def __call__(self, inputs, train=True):
    actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
    x = nn.Dense(features=self.mlp_dim)(inputs)
    x = nn.gelu(x)
    x = nn.Dropout(rate=self.drop_p, deterministic=not train)(x)
    x = nn.Dense(features=actual_out_dim)(x)
    x = nn.Dropout(rate=self.drop_p, deterministic=not train)(x)
    return x

In [None]:
class TransformerEncoder(nn.Module):
  embed_dim: int
  hidden_dim: int
  n_heads: int
  drop_p: float
  mlp_dim: int

  def setup(self):
    self.mha = MultiHeadSelfAttention(self.hidden_dim, self.n_heads, self.drop_p)
    self.mlp = MLP(self.mlp_dim, self.drop_p)
    self.layer_norm = nn.LayerNorm(epsilon=1e-6)
  
  def __call__(self, inputs, train=True):
    # Attention Block
    x = self.layer_norm(inputs)
    x = self.mha(x, train)
    x = inputs + x
    # MLP block
    y = self.layer_norm(x)
    y = self.mlp(y, train)

    return x + y

In [None]:
class ViViT(nn.Module):
  patch_size: int
  embed_dim: int
  hidden_dim: int
  n_heads: int
  drop_p: float
  num_layers: int
  mlp_dim: int
  num_classes: int

  def setup(self):
    self.patch_extracter = TubeletEmbedding(self.patch_size, self.embed_dim)
    self.patch_encoder = PatchEncoder(self.hidden_dim)
    self.dropout = nn.Dropout(self.drop_p)
    self.transformer_encoder = TransformerEncoder(self.embed_dim, self.hidden_dim, self.n_heads, self.drop_p, self.mlp_dim)
    self.cls_head = nn.Dense(features=self.num_classes)

  def __call__(self, x, train=True):
    x = self.patch_extracter(x)
    x = self.patch_encoder(x)
    x = self.dropout(x, deterministic=not train)
    for i in range(self.num_layers):
      x = self.transformer_encoder(x, train)
    # MLP head
    x = x[:, 0] # [CLS] token
    x = self.cls_head(x)
    return x

## Dataset(medmnist)

In [None]:
def download_and_prepare_dataset(data_info: dict):
    """Utility function to download the dataset.

    Arguments:
        data_info (dict): Dataset metadata.
    """
    data_path = keras.utils.get_file(origin=data_info["url"], md5_hash=data_info["MD5"])

    with np.load(data_path) as data:
        # Get videos
        train_videos = data["train_images"]
        valid_videos = data["val_images"]
        test_videos = data["test_images"]

        # Get labels
        train_labels = data["train_labels"].flatten()
        valid_labels = data["val_labels"].flatten()
        test_labels = data["test_labels"].flatten()

    return (
        (train_videos, train_labels),
        (valid_videos, valid_labels),
        (test_videos, test_labels),
    )

In [None]:
DATASET_NAME = "organmnist3d"
BATCH_SIZE = 32

In [None]:
# Get the metadata of the dataset
info = medmnist.INFO[DATASET_NAME]

# Get the dataset
prepared_dataset = download_and_prepare_dataset(info)
(train_videos, train_labels) = prepared_dataset[0]
(valid_videos, valid_labels) = prepared_dataset[1]
(test_videos, test_labels) = prepared_dataset[2]

In [None]:
print(f'train: {len(train_videos)}, train_labels: {len(train_labels)}')
print(f'valid: {len(valid_videos)}, valid_labels: {len(valid_labels)}')
print(f'test: {len(test_videos)}, test_labels: {len(test_labels)}')

## Dataset loader

In [None]:
@tf.function
def preprocess(frames : tf.Tensor, label : tf.Tensor):
  frames = tf.image.convert_image_dtype(
      frames[..., tf.newaxis], tf.float32
  )
  label = tf.cast(label, tf.float32)
  return frames, label

In [None]:
def prepare_dataloder(videos : np.ndarray, labels : np.ndarray, loader_type : str = "train", batch_size : int = BATCH_SIZE):
  dataset = tf.data.Dataset.from_tensor_slices((videos, labels))
  if loader_type == "train":
    dataset = dataset.shuffle(batch_size * 2)
  
  dataloader = (
      dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
      .batch(batch_size)
      .prefetch(tf.data.AUTOTUNE)
  )
  return dataloader

In [None]:
train_loader = prepare_dataloder(train_videos, train_labels, "train")
valid_loader = prepare_dataloder(valid_videos, valid_labels, "valid")
test_loader = prepare_dataloder(test_videos, test_labels, "test")