In [None]:
# the pre-installed version of Jax is very old
# !pip install -U jax[tpu]==0.4.19 jaxlib==0.4.19 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install haliax torchvision
!pip install "haliax @ git+https://github.com/stanford-crfm/haliax.git"

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting nvidia-cublas-cu11>=11.11 (from jax[cuda11_pip])
  Downloading nvidia_cublas_cu11-11.11.3.6-py3-none-manylinux1_x86_64.whl (417.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m417.9/417.9 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-cupti-cu11>=11.8 (from jax[cuda11_pip])
  Downloading nvidia_cuda_cupti_cu11-11.8.87-py3-none-manylinux1_x86_64.whl (13.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.1/13.1 MB[0m [31m79.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-nvcc-cu11>=11.8 (from jax[cuda11_pip])
  Downloading nvidia_cuda_nvcc_cu11-11.8.89-py3-none-manylinux1_x86_64.whl (19.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m19.5/19.5 MB[0m [31m63.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu11>=11.8 (from jax[cuda11_pip])
  Downloading nvidia_cuda_r

In [None]:
import dataclasses
from dataclasses import dataclass
from functools import partial
from typing import Callable, Dict, Optional, Type

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
from jaxtyping import PRNGKeyArray, Array

import haliax as hax
import haliax.jax_utils
import haliax.nn as hnn
from haliax import Axis, NamedArray
from haliax.jax_utils import named_call, shaped_rng_split
from haliax.nn.scan import Stacked

import torch
import torchvision
import torchvision.transforms as transforms

import optax

In [None]:
# Check if JAX is using GPU.
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [None]:
# Axes

Height        = hax.Axis("height", 32)      # Height of the input image.
Width         = hax.Axis("width", 32)       # Width of the input image.
Channels      = hax.Axis("channels", 3)     # Channels of the input image.
Embed         = hax.Axis("embed_dim", 512)  # Dimension of patch embedding vector.
Heads         = hax.Axis("heads", 8)        # Number of heads in self-attention.
Mlp           = hax.Axis("mlp", 256)        # Hidden dim of MLP in transformer block.
Layers        = hax.Axis("num_layers", 6)   # Number of layers in transformer.
Classes       = hax.Axis("num_classes", 10) # Number of classes.

patch_size    = 4
PatchHeight = Height.resize(patch_size)
PatchWidth  = Width.resize(patch_size)

BatchSize     = hax.Axis("batch_size", 64)

In [None]:
class ViTPatchEmbeddings(eqx.Module):
    proj: hnn.Linear

    @staticmethod
    def init(Channels,
             PatchHeight,
             PatchWidth,
             Embed,
             key):

        # Linear projection.
        proj = hnn.Linear.init(In=(PatchHeight, PatchWidth, Channels),
                               Out=(Embed),
                               key=key,
                               use_bias=False)

        return ViTPatchEmbeddings(proj)

    def embed(self, x):
        # Rearrange input into patches.
        x = hax.rearrange(x, "{ (height: nh ph) (width: nw pw)  } -> ... (position: nh nw) (height: ph) (width: pw)", ph=patch_size, pw=patch_size)

        # Apply linear projection to each patch.
        x = self.proj(x)

        return x

In [None]:
class Attention(eqx.Module):
    Embed: hax.Axis
    Heads: hax.Axis
    HeadSize: hax.Axis

    c_attn: hnn.Linear
    c_proj: hnn.Linear

    @staticmethod
    def init(Embed, Heads, key):

        # Get the dimension of each head.
        HeadSize = hax.Axis("head_size", Embed.size // Heads.size)

        # Axis for splitting into queries, keys, and values.
        Qkv = hax.Axis("qkv", size=3)

        k_attn, k_proj = jrandom.split(key, 2)
        c_attn = hnn.Linear.init(In=Embed, Out=(Qkv, Heads, HeadSize), key=k_attn)
        c_proj = hnn.Linear.init(In=(Heads, HeadSize), Out=Embed, key=k_proj)

        return Attention(Embed, Heads, HeadSize, c_attn, c_proj)

    def __call__(self, x):

        q, k, v = self.c_attn(x).unbind("qkv")

        # Rename Pos for the key and value tensors.
        k = k.rename({"position": "position_key"})
        v = v.rename({"position": "position_key"})

        weights = hax.nn.attention.dot_product_attention_weights(self.HeadSize, "position_key", q, k)
        attn_out = haliax.dot("position_key", weights, v)
        x = self.c_proj(attn_out)

        return x

In [None]:
class MLP(eqx.Module):

    c_proj_up: hax.nn.Linear
    c_proj_down: hax.nn.Linear

    @staticmethod
    def init(Embed, Mlp, key, use_bias=True):
        k_proj_up, k_proj_down = jrandom.split(key, 2)
        c_proj_up   = hnn.Linear.init(Out=Mlp, In=Embed, key=k_proj_up, use_bias=use_bias)
        c_proj_down = hnn.Linear.init(Out=Embed, In=Mlp, key=k_proj_down, use_bias=use_bias)
        return MLP(c_proj_up, c_proj_down)

    @named_call
    def __call__(self, x):
        x = self.c_proj_up(x)
        x = hnn.gelu(x)
        return self.c_proj_down(x)

In [None]:
class Block(eqx.Module):
    mlp: MLP
    attn: Attention
    ln1: hnn.LayerNorm
    ln2: hnn.LayerNorm

    @staticmethod
    def init(Embed, Heads, Mlp, key):
        k_mlp, k_attn = jax.random.split(key, 2)
        mlp  = MLP.init(Embed, Mlp, key=k_mlp)
        attn = Attention.init(Embed, Heads, key=k_attn)
        ln1  = hnn.LayerNorm.init(Embed)
        ln2  = hnn.LayerNorm.init(Embed)
        return Block(mlp, attn, ln1, ln2)

    def __call__(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

In [None]:
class VisionTransformer(eqx.Module):

    blocks: hnn.Stacked[Block]
    ln_f: hnn.LayerNorm

    @staticmethod
    def init(Embed, Heads, Mlp, Layers, key):
        blocks = hnn.Stacked.init(
            Layers,
            Block,
        )(Embed, Heads, Mlp, key=jax.random.split(key, Layers.size))
        ln_f = hnn.LayerNorm.init(Embed)
        return VisionTransformer(blocks, ln_f)

    def __call__(self, x):
        x = self.blocks.fold(x)
        x = self.ln_f(x)
        return x

In [None]:
class ViTClassificationHeadModel(eqx.Module):

    vision_transformer: VisionTransformer
    patch_embeddings: ViTPatchEmbeddings
    position_embeddings: hnn.Embedding
    cls_token: NamedArray

    ln: hnn.LayerNorm
    proj: hnn.Linear

    @staticmethod
    def init(Height, Width, Channels, PatchHeight, PatchWidth, Embed, Heads, Mlp, Layers, Classes, key):
        k_tr, k_pte, k_ppe, k_cls, k_proj = jax.random.split(key, 5)

        # Initialize Patch Token Embeddings.
        patch_embeddings = ViTPatchEmbeddings.init(Channels, PatchHeight, PatchWidth, Embed, k_pte)

        # ViT.
        vision_transformer = VisionTransformer.init(Embed, Heads, Mlp, Layers, k_tr)

        # Since we prepend a cls_token to the input, Pos needs to be one larger than the number of patches.
        num_patches = (Height.size // PatchHeight.size) * (Width.size // PatchWidth.size)
        Pos = hax.Axis("position", num_patches + 1)

        # Patch Position Embeddings.
        position_embeddings = hnn.Embedding.init(Pos, Embed, key=k_ppe)

        # cls_token.
        cls_token = hax.random.normal(k_cls, (Embed, Pos.resize(1)))

        # final linear projection.
        ln   = hnn.LayerNorm.init(Embed)
        proj = hnn.Linear.init(Out=Classes, In=Embed, key=k_proj, use_bias=True)

        return ViTClassificationHeadModel(vision_transformer,
                                          patch_embeddings,
                                          position_embeddings,
                                          cls_token,
                                          ln,
                                          proj)

    def __call__(self, x):

        # Embed x as a sequence of patches.
        x = self.patch_embeddings.embed(x)

        # Prepend cls_token.
        x = hax.concatenate("position", [self.cls_token.broadcast_axis(BatchSize), x])

        # Add position embeddings.
        x_Pos = x.resolve_axis("position")
        pos_embeds = self.position_embeddings.embed(hax.arange(x_Pos))
        x = x + pos_embeds

        # Forward pass.
        x = self.vision_transformer(x)

        # Select output corresponding to cls_token.
        x = x[{"position": 0}]
        # x = x.mean("position")

        # Project output to number of classes.
        x = self.ln(x)
        x = self.proj(x)

        return x


In [None]:
"""
Data Loaders.
From https://docs.kidger.site/equinox/examples/vision_transformer/.
"""
batch_size=64

transform_train = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.Resize((32, 32)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

transform_test = transforms.Compose(
    [
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

train_dataset = torchvision.datasets.CIFAR10(
    "CIFAR",
    train=True,
    download=True,
    transform=transform_train,
)

test_dataset = torchvision.datasets.CIFAR10(
    "CIFAR",
    train=False,
    download=True,
    transform=transform_test,
)

trainloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, drop_last=True
)

testloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=True, drop_last=True
)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to CIFAR/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:18<00:00, 9044446.48it/s]


Extracting CIFAR/cifar-10-python.tar.gz to CIFAR
Files already downloaded and verified


In [None]:
@eqx.filter_value_and_grad
def compute_grads(
    model: ViTClassificationHeadModel,
    images: hax.NamedArray,
    labels: hax.NamedArray,
    key):
    logits = model(images)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits.array, labels.array)
    return jnp.mean(loss)

@eqx.filter_jit
def step_model(
    model: ViTClassificationHeadModel,
    optimizer: optax.GradientTransformation,
    state: optax.OptState,
    images: hax.NamedArray,
    labels: hax.NamedArray,
    key,
):
    loss, grads = compute_grads(model, images, labels, key)
    updates, new_state = optimizer.update(grads, state, model)
    model = eqx.apply_updates(model, updates)
    return model, new_state, loss

def train(
    model: ViTClassificationHeadModel,
    optimizer: optax.GradientTransformation,
    state: optax.OptState,
    data_loader: torch.utils.data.DataLoader,
    num_steps: int,
    print_every: int = 1000,
    key=None,
):
    losses = []

    def infinite_trainloader():
        while True:
            yield from data_loader

    for step, batch in zip(range(num_steps), infinite_trainloader()):
        images, labels = batch

        images = hax.named(images.numpy(), (BatchSize, Channels, Height, Width))
        labels = hax.named(labels.numpy(), BatchSize)

        key, *subkeys = jax.random.split(key, num=batch_size + 1)
        subkeys = jnp.array(subkeys)

        (model, state, loss) = step_model(
            model, optimizer, state, images, labels, subkeys
        )

        losses.append(loss)

        if (step % print_every) == 0 or step == num_steps - 1:
            print(f"Step: {step}/{num_steps}, Loss: {loss}.")

    return model, state, losses

In [None]:
key = jax.random.PRNGKey(2023)

key_mdl, key_train = jax.random.split(key, 2)

model = ViTClassificationHeadModel.init(Height,
                                        Width,
                                        Channels,
                                        PatchHeight,
                                        PatchWidth,
                                        Embed,
                                        Heads,
                                        Mlp,
                                        Layers,
                                        Classes,
                                        key_mdl)

optimizer = optax.adamw(
    learning_rate=1e-4,
    b1=0.9,
    b2=0.999,
)

state = optimizer.init(eqx.filter(model, eqx.is_inexact_array))

model, state, losses = train(model, optimizer, state, trainloader, 1000000, print_every=1000, key=key_train)

Step: 0/1000000, Loss: 2.3670318126678467.
Step: 1000/1000000, Loss: 1.3204742670059204.
Step: 2000/1000000, Loss: 1.392897367477417.
Step: 3000/1000000, Loss: 1.337120771408081.
Step: 4000/1000000, Loss: 1.1908444166183472.
Step: 5000/1000000, Loss: 0.9924536347389221.
Step: 6000/1000000, Loss: 0.9787511825561523.
Step: 7000/1000000, Loss: 1.06725013256073.
Step: 8000/1000000, Loss: 1.0366957187652588.
Step: 9000/1000000, Loss: 1.1439564228057861.
Step: 10000/1000000, Loss: 1.0700474977493286.
Step: 11000/1000000, Loss: 0.9560476541519165.
Step: 12000/1000000, Loss: 0.8232935667037964.
Step: 13000/1000000, Loss: 0.7857428789138794.
Step: 14000/1000000, Loss: 0.733768105506897.
Step: 15000/1000000, Loss: 0.5060813426971436.
Step: 16000/1000000, Loss: 0.6339681148529053.
Step: 17000/1000000, Loss: 0.790744423866272.
Step: 18000/1000000, Loss: 0.545936107635498.
Step: 19000/1000000, Loss: 0.732439398765564.
Step: 20000/1000000, Loss: 0.5827034711837769.
Step: 21000/1000000, Loss: 0.77476

In [None]:
accuracies = []

for batch in range(len(test_dataset) // batch_size):
    images, labels = next(iter(testloader))

    logits = jax.vmap(functools.partial(model, enable_dropout=False))(
        images.numpy(), key=jax.random.split(key, num=batch_size)
    )

    predictions = jnp.argmax(logits, axis=-1)

    accuracy = jnp.mean(predictions == labels.numpy())

    accuracies.append(accuracy)

print(f"Accuracy: {np.sum(accuracies) / len(accuracies) * 100}%")