## Setup

Install JAX and Flax first. Confirm we have TPUs set up.

In [1]:
!pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install -U flax orbax
import jax
jax.devices()

Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html


[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

Take care of the imports.

In [2]:
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax, orbax
from typing import Any
import os
from collections import Counter
from dataclasses import dataclass
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import pandas as pd


In [3]:
mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))

In [4]:
def causal_attention_mask(seq_len):
    return jnp.tril(jnp.ones((seq_len, seq_len)))

class model(nnx.Module):
    def __init__(self, maxlen: int, embed_dim: int, num_heads: int, feed_forward_dim: int, num_transformer_blocks: int, rngs: nnx.Rngs):
        self.mha = nnx.MultiHeadAttention(num_heads=num_heads,
                                          in_features=embed_dim,
                                          kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
                                          bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
                                          rngs=rngs)

    def __call__(self, inputs, training: bool = False):
        input_shape = inputs.shape
        _, seq_len, _ = input_shape

        # Create causal mask
        mask = causal_attention_mask(seq_len)

        # Apply MultiHeadAttention with causal mask
        attention_output = self.mha(
            inputs_q=inputs,
            mask=mask,
            decode=False
        )
        return attention_output

def create_model(rngs):
    return model(maxlen, embed_dim, num_heads, feed_forward_dim, num_transformer_blocks=4, rngs=rngs)


Set some hyperparameters. The model is much bigger with a lot more transformer layers and attention heads.

In [5]:
num_transformer_blocks = 8
maxlen = 256
embed_dim = 256
num_heads = 8
feed_forward_dim = 256

model = create_model(rngs=nnx.Rngs(0))

# Saving
Colab TPU v2 has a problem when saving the model weights. Kaggle TPU v3 works.

In [6]:
# Don't do this on Colab.

import orbax.checkpoint as orbax

state = nnx.state(model)

checkpointer = orbax.PyTreeCheckpointer()
checkpointer.save('/content/save/', state)

TypeError: write(): incompatible function arguments. The following argument types are supported:
    1. (self: tensorstore.TensorStore, source: Union[tensorstore.TensorStore, numpy.typing.ArrayLike]) -> tensorstore.WriteFutures

Invoked with: TensorStore({
  'base': {
    'assume_metadata': True,
    'driver': 'zarr',
    'dtype': 'float32',
    'kvstore': {
      'base': {
        'driver': 'file',
        'path': '/content/save.orbax-checkpoint-tmp-1728360408721411/ocdbt.process_0/',
      },
      'cache_pool': 'cache_pool#ocdbt',
      'config': {
        'max_decoded_node_bytes': 100000000,
        'max_inline_value_bytes': 1024,
      },
      'driver': 'ocdbt',
      'experimental_read_coalescing_interval': '1ms',
      'experimental_read_coalescing_merged_bytes': 500000000000,
      'experimental_read_coalescing_threshold_bytes': 1000000,
      'path': 'mha.key.bias.value/',
    },
    'metadata': {
      'chunks': [8, 32],
      'compressor': {'id': 'zstd', 'level': 1},
      'dimension_separator': '.',
      'dtype': '<f4',
      'fill_value': None,
      'filters': None,
      'order': 'C',
      'shape': [8, 32],
      'zarr_format': 2,
    },
    'recheck_cached_data': False,
    'recheck_cached_metadata': False,
  },
  'context': {
    'cache_pool': {},
    'cache_pool#ocdbt': {'total_bytes_limit': 100000000},
    'data_copy_concurrency': {},
    'file_io_concurrency': {'limit': 128},
    'file_io_sync': True,
    'ocdbt_coordinator': {},
  },
  'driver': 'cast',
  'dtype': 'float32',
  'transform': {
    'input_exclusive_max': [[8], [32]],
    'input_inclusive_min': [0, 0],
  },
}), array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
      dtype=float32); kwargs: can_reference_source_data_indefinitely=True