In [None]:
!pip install -qq medmnist

In [None]:
import os
import io
import imageio
import medmnist
import ipywidgets
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt


# Setting seed for reproducibility
SEED = 42
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
keras.utils.set_random_seed(SEED)

In [None]:
# DATA
DATASET_NAME = "organmnist3d"
BATCH_SIZE = 32
AUTO = tf.data.AUTOTUNE
INPUT_SHAPE = (28, 28, 28, 1)
NUM_CLASSES = 11

# OPTIMIZER
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5

# TRAINING
EPOCHS = 60

# TUBELET EMBEDDING
PATCH_SIZE = (8, 8, 8)
NUM_PATCHES = (INPUT_SHAPE[0] // PATCH_SIZE[0]) ** 2

# ViViT ARCHITECTURE
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 128
NUM_HEADS = 8
NUM_LAYERS = 8

In [None]:
def download_and_prepare_dataset(data_info: dict):
  data_path = keras.utils.get_file(origin=data_info['url'], md5_hash=data_info['MD5'])
  with np.load(data_path) as data:
    train_videos = data['train_images']
    valid_videos = data['val_images']
    test_videos = data['test_images']

    train_labels = data['train_labels'].flatten()
    valid_labels = data['val_labels'].flatten()
    test_labels = data['test_labels'].flatten()

  return (train_videos, train_labels), (valid_videos, valid_labels), (test_videos, test_labels)

In [None]:
# Get the metadata of the dataset
info = medmnist.INFO[DATASET_NAME]

In [None]:
# Get the dataset
prepared_dataset = download_and_prepare_dataset(info)
(train_videos, train_labels) = prepared_dataset[0]
(valid_videos, valid_labels) = prepared_dataset[1]
(test_videos, test_labels) = prepared_dataset[2]

In [None]:
print(f'train_videos.shape = {train_videos.shape}')
print(f'train_labels.shape = {train_labels.shape}')

print(f'valid_videos.shape = {valid_videos.shape}')
print(f'valid_labels.shape = {valid_labels.shape}')

print(f'test_videos.shape = {test_videos.shape}')
print(f'test_labels.shape = {test_labels.shape}')

In [None]:
plt.rcParams['axes.grid'] = False
plt.rcParams['figure.figsize'] = [30, 15]

In [None]:
for i in range(len(train_videos[1])):
  plt.subplot(3, 10, i + 1)
  plt.imshow(train_videos[0, i, :, :])
plt.show()

In [None]:
print(train_labels[1])
print(info['label'])

### tf.data pipeline

In [None]:
@tf.function
def preprocess(frames: tf.Tensor, label: tf.Tensor):
  frames = tf.image.convert_image_dtype(
      frames[
             ..., tf.newaxis
      ], tf.float32
  )
  label = tf.cast(label, tf.float32)
  return frames, label

In [None]:
def prepare_dataloader(
    videos: np.ndarray,
    labels: np.ndarray,
    loader_type: str = "train",
    batch_size: int = BATCH_SIZE
):
  dataset = tf.data.Dataset.from_tensor_slices((videos, labels))
  if loader_type == 'train':
    dataset = dataset.shuffle(BATCH_SIZE * 2)
  dataloader = (
      dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
      .batch(batch_size)
      .prefetch(tf.data.AUTOTUNE)
  )
  return dataloader

In [None]:
trainloader = prepare_dataloader(train_videos, train_labels, 'train')
validloader = prepare_dataloader(valid_videos, valid_labels, 'valid')
testloader = prepare_dataloader(test_videos, test_labels, 'test')

In [None]:
class TubeletEmbedding(layers.Layer):
  def __init__(self, embed_dim, patch_size, **kwargs):
    super().__init__(**kwargs)
    self.projection = layers.Conv3D(
        filters=embed_dim,
        kernel_size=patch_size,
        strides=patch_size,
        padding='valid'
    )
    self.flatten = layers.Reshape(target_shape=(-1, embed_dim))

  def call(self, videos):
    projected_patches = self.projection(videos)
    flattend_patches = self.flatten(projected_patches)
    return flattend_patches

In [None]:
class PositionalEncoder(layers.Layer):
  def __init__(self, embed_dim, **kwargs):
    super().__init__(**kwargs)
    self.embed_dim = embed_dim
  
  def build(self, input_shape):
    _, num_tokens, _ = input_shape
    self.position_embedding = layers.Embedding(
        input_dim=num_tokens, output_dim=self.embed_dim
    )
    self.positions = tf.range(start=0, limit=num_tokens, delta=1)

  def call(self, encoded_tokens):
    encoded_positions = self.position_embedding(self.positions)
    encoded_tokens = encoded_tokens + encoded_positions
    return encoded_tokens

### Video vision transformer with spatio-temporal attention

In [None]:
def create_vivit_classifier(
    tubelet_embedder,
    positional_encoder,
    input_shape=INPUT_SHAPE,
    transformer_layers=NUM_LAYERS,
    num_heads=NUM_HEADS,
    embed_dim=PROJECTION_DIM,
    layer_norm_eps=LAYER_NORM_EPS,
    num_classes=NUM_CLASSES
):
  inputs = layers.Input(shape=input_shape)
  patches = tubelet_embedder(inputs)
  encoded_patches = positional_encoder(patches)

  # Create multiple layers of transformer block
  for _ in range(transformer_layers):
    # Layer norm and multi head self attention
    x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    attention_output = layers.MultiHeadAttention(
        num_heads=num_heads, key_dim=embed_dim//num_heads,
        dropout=0.1
    )(x1, x1)

    # Skip connection
    x2 = layers.Add()([attention_output, encoded_patches])

    # Layer norm and MLP
    x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
    x3 = keras.Sequential([
                           layers.Dense(units=embed_dim*4, activation=tf.nn.gelu),
                           layers.Dense(units=embed_dim, activation=tf.nn.gelu)
    ])(x3)

    # Skip connection
    encoded_patches = layers.Add()([x3, x2])

  # Layer norm and global average pooling
  representation = layers.LayerNormalization(epsilon=layer_norm_eps)(encoded_patches)
  representation = layers.GlobalAvgPool1D()(representation)

  # Classify outputs
  outputs = layers.Dense(units=num_classes, activation='softmax')(representation)

  # Create the keras model
  model = keras.Model(inputs=inputs, outputs=outputs)
  return model

In [None]:
def run_experiment():
  model = create_vivit_classifier(
      tubelet_embedder=TubeletEmbedding(
          embed_dim=PROJECTION_DIM, patch_size=PATCH_SIZE
      ),
      positional_encoder=PositionalEncoder(embed_dim=PROJECTION_DIM)
  )

  optimizer = keras.optimizers.Adam(learning_rate=LEARNING_RATE)
  model.compile(
      optimizer=optimizer,
      loss='sparse_categorical_crossentropy',
      metrics=[
               keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
               keras.metrics.SparseTopKCategoricalAccuracy(5, name='top-5-accuracy')
      ]
  )

  # Train the model
  _ = model.fit(trainloader, epochs=EPOCHS, validation_data=validloader)
  _, accuracy, top5_accuracy = model.evaluate(testloader)
  print(f"Test acc: {round(accuracy * 100, 2)}%")
  print(f"Test top5 acc: {round(top5_accuracy * 100, 2)}%")

  return model

In [None]:
model = run_experiment()

### Inference

In [None]:
NUM_SAMPLES_VIZ = 25
testsamples, labels = next(iter(testloader))
testsamples, labels = testsamples[:NUM_SAMPLES_VIZ], labels[:NUM_SAMPLES_VIZ]

In [None]:
ground_truths = []
preds = []
videos = []
for i, (testsample, label) in enumerate(zip(testsamples, labels)):
  with io.BytesIO() as gif:
    imageio.mimsave(gif, (testsample.numpy() * 255).astype('uint8'), 'GIF', fps=5)
    videos.append(gif.getvalue())

  output = model.predict(tf.expand_dims(testsample, axis=0))[0]
  pred = np.argmax(output, axis=0)

  ground_truths.append(label.numpy().astype('int'))
  preds.append(pred)

In [None]:
def make_box_for_grid(image_widget, fit):
  if fit is not None:
    fit_str = '{}'.format(fit)
  else:
    fit_str = str(fit)

  h = ipywidgets.HTML(value='' + str(fit_str) + '')

  boxb = ipywidgets.widgets.Box()
  boxb.children = [image_widget]

  vb = ipywidgets.widgets.VBox()
  vb.layout.align_items = 'center'
  vb.children = [h, boxb]
  return vb

In [None]:
boxes = []
for i in range(NUM_SAMPLES_VIZ):
  ib = ipywidgets.widgets.Image(value=videos[i], width=100, height=100)
  true_class = info['label'][str(ground_truths[i])]
  pred_class = info['label'][str(preds[i])]
  caption = f'T: {true_class} | P: {pred_class}'

  boxes.append(make_box_for_grid(ib, caption))

In [None]:
ipywidgets.widgets.GridBox(
    boxes, layout=ipywidgets.widgets.Layout(grid_template_columns='repeat(5, 200px)')
)

## UCF101 top5 dataset

In [None]:
UCF_BATCH_SIZE = 64
UCF_IMG_SIZE = 224

In [None]:
import pandas as pd
import cv2

In [None]:
!wget -q https://git.io/JGc31 -O ucf101_top5.tar.gz
!tar xf ucf101_top5.tar.gz

In [None]:
train_df = pd.read_csv('train.csv')
test_df = pd.read_csv('test.csv')

print(f'Total videos for training: {len(train_df)}')
print(f'Total videos for test: {len(test_df)}')

train_df.sample(5)

In [None]:
train_video_names = train_df.get('video_name')
test_video_names = test_df.get('video_name')

In [None]:
def crop_center_square(frame):
  y, x = frame.shape[0:2]
  min_dim = min(y, x)
  start_x = (x // 2) - (min_dim // 2)
  start_y = (y // 2) - (min_dim // 2)
  return frame[start_y : start_y + min_dim, start_x : start_x + min_dim]

In [None]:
def load_video(path, max_frames=0, resize=(UCF_IMG_SIZE, UCF_IMG_SIZE)):
  cap = cv2.VideoCapture(path)
  frames = []
  try:
    while True:
      ret, frame = cap.read()
      if not ret:
        break
      frame = crop_center_square(frame)
      frame = cv2.resize(frame, resize)
      frame = frame[:, :, [2, 1, 0]] # BGR2RGB
      frames.append(frame)
      if len(frames) == max_frames:
        break
  finally:
    cap.release()
  return np.array(frames)

In [None]:
label_processor = keras.layers.StringLookup(
    num_oov_indices=0, vocabulary=np.unique(train_df['tag']) # not use UNK
)
print(label_processor.get_vocabulary())

In [None]:
train_tags = label_processor(train_df['tag'])
test_tags = label_processor(test_df['tag'])

In [None]:
print(train_tags.shape)
print(test_tags.shape)

In [None]:
def prepare_dataloader_from_df(
    videos,
    labels,
    loader_type: str = "train",
    batch_size: int = UCF_BATCH_SIZE
):
  dataset = tf.data.Dataset.from_tensor_slices((videos, labels))
  if loader_type == 'train':
    dataset = dataset.shuffle(BATCH_SIZE * 2)
  dataloader = (
      dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
      .batch(batch_size)
      .prefetch(tf.data.AUTOTUNE)
  )
  return dataloader