In [None]:
!pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
# !pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
!pip install flax
!pip install tqdm

In [None]:
import jax
import jax.numpy as jnp
import jax.dlpack
from jax import grad, jit, vmap, random
from jax import random
from jax.example_libraries import stax, optimizers

from tensorflow import keras
import tensorflow_datasets as tfds
import tensorflow as tf

import time
import numpy.random as npr
import math

from typing import Optional

import optax
from flax.training import train_state

In [None]:
from flax import linen as nn

## Tubelet embedding

In [None]:
class TubeletEmbedding(nn.Module):
  patch_size: int
  embed_dim: int

  @nn.compact
  def __call__(self, videos):
    patches = nn.Conv(
        features=self.embed_dim,
        kernel_size=[self.patch_size, self.patch_size, self.patch_size],
        strides=[self.patch_size, self.patch_size, self.patch_size],
        padding='VALID'
    )(videos)
    b, t, h, w, c = patches.shape
    patches = jnp.reshape(patches, (b, t*h*w, c))
    return patches

## Patch encoder

In [None]:
class PatchEncoder(nn.Module):
  hidden_dim: int

  @nn.compact
  def __call__(self, x):
    assert x.ndim == 3
    n, seq_len, _ = x.shape
    # Hidden dim
    x = nn.Dense(self.hidden_dim)(x)
    # Add cls token
    cls = self.param('cls_token', nn.initializers.zeros, (1, 1, self.hidden_dim))
    cls = jnp.tile(cls, (n, 1, 1))
    x = jnp.concatenate([cls, x], axis=1)
    # Add position embedding
    pos_embed = self.param(
        'position_embedding',
        nn.initializers.normal(stddev=0.02),
        (1, seq_len + 1, self.hidden_dim)
    )
    return x + pos_embed

## Test tubelet embedding

In [None]:
def test_tubelet_embedding():
  main_rng = jax.random.PRNGKey(42)
  x = jnp.ones(shape=(8, 16, 32, 32, 3))
  embedder = TubeletEmbedding(patch_size=4, embed_dim=128)
  variables = embedder.init(main_rng, x)
  out = embedder.apply(variables, x)
  print(out.shape)

In [None]:
test_tubelet_embedding()