<a href="https://colab.research.google.com/github/zetaqubit/dl/blob/main/dl/colabs/jax_gpt2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Overview

This notebook implements GPT-2 and the LM training objective in pure Jax.

# Imports

In [None]:
!pip install einops
!pip install transformers  # for GPT2Tokenizer



In [None]:
import chex
from chex import dataclass
from einops import rearrange
from dataclasses import field
from functools import partial
import jax
from jax.experimental.host_callback import id_print
import jax.numpy as jnp
import numpy as np
from pprint import pprint
from typing import Any

import tensorflow_datasets as tfds
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()

In [None]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

In [None]:
N_DEVICES = len(jax.local_devices())
N_DEVICES

8

# Dataset

In [None]:
BATCH_SIZE = 512
MAX_SEQ_LEN = 256

## Tokenizer

In [None]:
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='[PAD]')

In [None]:
def tokenizer_encode(text, pad=False, return_masks=False):
  if pad:
    out = tokenizer(text,
                    return_tensors='tf',  # pytorch not supported internally. np also fails.
                    max_length=MAX_SEQ_LEN+1,  # +1 so target is MAX_SEQ_LEN
                    padding='max_length',
                    truncation=True)
  else:
    out = tokenizer(text, return_tensors='tf')
  if return_masks:
    return out
  return out['input_ids']

def tokenizer_decode(token_ids):
  if len(token_ids.shape) == 1:
    token_ids = token_ids[None, ...]
  return tokenizer.batch_decode(tf.convert_to_tensor(token_ids))

In [None]:
VOCAB_SIZE = tokenizer.vocab_size
token_ids = tokenizer(['hello world', 'hi'], return_tensors='tf', max_length=10,
                      padding='max_length', truncation=True)['input_ids']
recovered = tokenizer.batch_decode(token_ids)
VOCAB_SIZE, token_ids, recovered

(50257,
 <tf.Tensor: shape=(2, 10), dtype=int32, numpy=
 array([[31373,   995, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257],
        [ 5303, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257]], dtype=int32)>,
 ['hello world[PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]',
  'hi[PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]'])

In [None]:
tokenizer_encode(['hello there', 'hi, how are you'], pad=True, return_masks=True)

{'input_ids': <tf.Tensor: shape=(2, 257), dtype=int32, numpy=
array([[31373,   612, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
        50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
        50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
        50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
        50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
        50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
        50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
        50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
        50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
        50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
        50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
        50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
        50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
        50257, 

## Grain
https://github.com/google/grain

In [None]:
import grain.python as grain

class FilterShortText(grain.FilterTransform):
  def filter(self, element):
    text = element['text']
    return len(text) > 100

class TokenizeTransform(grain.MapTransform):
  def map(self, element):
    text = element['text']
    text = text.decode('utf-8')
    text = text[:10 * MAX_SEQ_LEN]  # optimization to not tokenize what would be truncated anyway
    output = tokenizer_encode(text, pad=True, return_masks=True)
    ids = output['input_ids'].numpy()[0]
    mask = output['attention_mask'].numpy()[0]
    ids_input, ids_target = ids[:-1], ids[1:]
    return ids_input, ids_target, mask

operations = [
    FilterShortText(),
    TokenizeTransform(),
    grain.Batch(BATCH_SIZE, drop_remainder=True),
]

In [None]:
# These options allow the TPUs to not be input-bound
# https://xprof.corp.google.com/trace_viewer/zhaoxu-3071358305372025835
def create_pygrain_loader(ds_name, split):
  ds = tfds.data_source(ds_name, split=split)
  print(f'Number of records: {len(ds)} in {split}')
  dataloader = grain.load(
      source=ds,
      num_epochs=None,
      shuffle=True,
      seed=0,
      shard_options=grain.ShardOptions(shard_index=0, shard_count=1, drop_remainder=True),
      transformations=operations,
      worker_count=64,
      read_options=grain.ReadOptions(num_threads=1, prefetch_buffer_size=32),
  )
  return dataloader

In [None]:
DS_NAME = (
#    'hugginface:wikitext/wikitext-103-raw-v1'
    'c4/en'
)
dataloader_train = create_pygrain_loader(DS_NAME, 'train')
dataloader_valid = create_pygrain_loader(DS_NAME, 'validation')

In [None]:
batch = next(iter(dataloader_train))

## One Batch

Debug dataset that repeats 1 batch of a real dataset. Models trained on this should reach 0 train loss quickly.

In [None]:
class OneBatchDataset:
  def __init__(self, batch, ex_ids, batch_size):
    n = len(ex_ids)
    assert n <= batch_size
    assert batch_size % n == 0
    self.batch = tuple(map(lambda x: np.tile(x[ex_ids], (batch_size // n, 1)), batch))

  def as_numpy_iterator(self):
    while True:
      yield self.batch

  def __iter__(self):
    while True:
      yield self.batch

In [None]:
BATCH_IDXS = np.arange(256)
ds_one_batch = OneBatchDataset(batch, BATCH_IDXS, BATCH_SIZE)
one_batch = next(iter(ds_one_batch))

# Jax Model

In [None]:
def large_negative_num(dtype):
  if jnp.issubdtype(dtype, jnp.inexact):
    dtype_max = jnp.finfo(dtype).max
  elif jnp.issubdtype(dtype, jnp.integer):
    dtype_max = jnp.iinfo(dtype).max
  else:
    raise ValueError(f'Unknown dtype {dtype}')
  return jnp.asarray(-0.7 * dtype_max, dtype=dtype)

large_negative_num(jnp.float32)

DeviceArray(-2.3819763e+38, dtype=float32)

In [None]:
key = jax.random.PRNGKey(1)

# dtype for parameters
w_dtype = (
    # jnp.bfloat16
    jnp.float32
)


@dataclass
class Embeddings:
  weight: jax.Array

  @staticmethod
  def create(key, vocab_size, dim):
    key, key_w = jax.random.split(key, 2)
    #scale = np.sqrt(2 / dim)
    scale = 0.02
    weight = (scale * jax.random.normal(key_w, (vocab_size, dim))).astype(w_dtype)
    return Embeddings(weight=weight)

  def __call__(self, x):  # [b, t] -> [b, t, d]
    return self.weight[x]


@dataclass
class SinusoidalPositionEmbeddings:
  weight: jax.Array

  @staticmethod
  def create(max_seq_len, dim):
    pos = jnp.arange(max_seq_len)  # [s]
    inv_freq = 1. / (10000. ** (jnp.arange(0, dim, 2) / dim))  # [d]
    pos_emb = jnp.einsum('s, d -> s d', pos, inv_freq)
    embs = jnp.zeros((1, max_seq_len, dim))
    embs = embs.at[0, :, 0::2].set(jnp.sin(pos_emb))
    embs = embs.at[0, :, 1::2].set(jnp.cos(pos_emb))
    return SinusoidalPositionEmbeddings(weight=embs)

  def __call__(self, x):
    return self.weight

  @staticmethod
  def _flatten(self):
    return ((),  # these are trainable
            (self.weight,))  # these are not trainable; tag-along data

  @staticmethod
  def _unflatten(aux_data, flat_contents):
    weight, = aux_data
    return SinusoidalPositionEmbeddings(weight=weight)

jax.tree_util.register_pytree_node(SinusoidalPositionEmbeddings,
                                   SinusoidalPositionEmbeddings._flatten,
                                   SinusoidalPositionEmbeddings._unflatten)


@dataclass
class Linear:
  weight: jax.Array
  bias: jax.Array

  @staticmethod
  def create(key, in_dim, out_dim, sigma=0.02):
    key, key_w = jax.random.split(key, 2)
    #sigma = np.sqrt(2 / in_dim)
    weight = sigma * jax.random.normal(key_w, (in_dim, out_dim), dtype=w_dtype)
    bias = jnp.zeros((out_dim,), dtype=w_dtype)
    return Linear(weight=weight, bias=bias)

  def __call__(self, x):
    return x @ self.weight + self.bias


@dataclass
class MLP:
  layers: list[Linear]

  @staticmethod
  def create(key, layer_widths, sigmas=None):
    key, *keys = jax.random.split(key, len(layer_widths))
    if sigmas is None:
      sigmas = [0.02 for _ in layer_widths[1:]]
      layers = [
          Linear.create(key=k, in_dim=in_dim, out_dim=out_dim, sigma=sigma)
          for k, in_dim, out_dim, sigma
          in zip(keys, layer_widths[:-1], layer_widths[1:], sigmas)
      ]
      return MLP(layers=layers)

  def __call__(self, x):
    for i, layer in enumerate(self.layers):
      x = layer(x)
      if i < len(self.layers) - 1:
        x = jax.nn.gelu(x)
    return x


@dataclass
class MultiHeadSelfAttention:
  to_q: Linear
  to_k: Linear
  to_v: Linear
  to_out: Linear
  heads: int
  scale: float
  causal: bool

  @staticmethod
  def create(key, dim, heads, causal=True, to_out_sigma=0.02):
    assert dim % heads == 0
    key, *keys = jax.random.split(key, 5)
    to_q = Linear.create(key=keys[0], in_dim=dim, out_dim=dim)
    to_k = Linear.create(key=keys[1], in_dim=dim, out_dim=dim)
    to_v = Linear.create(key=keys[2], in_dim=dim, out_dim=dim)
    to_out = Linear.create(key=keys[3], in_dim=dim, out_dim=dim, sigma=to_out_sigma)
    return MultiHeadSelfAttention(
        to_q=to_q, to_k=to_k, to_v=to_v, to_out=to_out,
        heads=heads,
        scale=dim ** -0.5,
        causal=causal)

  def __call__(self, x):
    q = self.to_q(x)
    k = self.to_k(x)
    v = self.to_v(x)
    q = rearrange(q, 'b i (h d) -> b h i d', h=self.heads)
    k = rearrange(k, 'b j (h d) -> b h j d', h=self.heads)
    v = rearrange(v, 'b j (h d) -> b h j d', h=self.heads)

    dots = self.scale * jnp.einsum('b h i d, b h j d -> b h i j', q, k)
    j = dots.shape[-1]
    mask = jnp.ones((j, j))
    if self.causal:
      mask = jnp.tril(mask, k=0)
      dots = jnp.where(mask, dots, large_negative_num(dots.dtype))
    dots = dots.astype(jnp.float32)
    attn = jax.nn.softmax(dots, axis=-1)
    outs = jnp.einsum('b h i j, b h j d -> b h i d', attn, v)
    outs = rearrange(outs, 'b h i d -> b i (h d)')
    out = self.to_out(outs)
    return out

  @staticmethod
  def _flatten(self):
    return (
        [self.to_q, self.to_k, self.to_v, self.to_out],  # trainable
        (self.heads, self.scale, self.causal),  # non-trainable
    )

  @staticmethod
  def _unflatten(aux_data, flat_contents):
    heads, scale, causal = aux_data
    to_q, to_k, to_v, to_out = flat_contents
    return MultiHeadSelfAttention(
        to_q=to_q, to_k=to_k, to_v=to_v, to_out=to_out,
        heads=heads, scale=scale, causal=causal)

jax.tree_util.register_pytree_node(MultiHeadSelfAttention,
                                   MultiHeadSelfAttention._flatten,
                                   MultiHeadSelfAttention._unflatten)


@dataclass
class LayerNorm:
  gamma: jax.Array  # shape []
  beta: jax.Array  # shape []

  @staticmethod
  def create():
    gamma = jnp.array([1.0])
    beta = jnp.array([0.0])
    return LayerNorm(gamma=gamma, beta=beta)

  def __call__(self, x, mask):
    reduce_axes = [1, 2]
    mu = jnp.mean(x, axis=reduce_axes, where=mask[:, :, None])
    sigma = jnp.var(x, axis=reduce_axes, where=mask[:, :, None])
    x -= jnp.expand_dims(mu, reduce_axes)
    x /= jnp.sqrt(jnp.expand_dims(sigma, reduce_axes) + 1e-5)
    x *= self.gamma
    x += self.beta
    return x


@dataclass
class TransformerBlock:
  mhsa: MultiHeadSelfAttention
  mlp: MLP
  ln1: LayerNorm
  ln2: LayerNorm

  @staticmethod
  def create(key, dim, heads, causal=True, to_out_sigma=0.02):
    key, *keys = jax.random.split(key, 3)
    mhsa = MultiHeadSelfAttention.create(key=keys[0], dim=dim, heads=heads,
                                         causal=causal, to_out_sigma=to_out_sigma)
    mlp = MLP.create(key=keys[1], layer_widths=[dim, 4*dim, dim])
    ln1 = LayerNorm.create()
    ln2 = LayerNorm.create()
    return TransformerBlock(mhsa=mhsa, mlp=mlp, ln1=ln1, ln2=ln2)

  def __call__(self, x, mask):
    x = x + self.mhsa(self.ln1(x, mask))
    x = x + self.mlp(self.ln2(x, mask))
    return x


@dataclass
class DecoderLM:
  embeddings: Embeddings
  pos_embs: SinusoidalPositionEmbeddings
  layers: list[TransformerBlock]

  @staticmethod
  def create(key, vocab_size, num_layers, dim, heads, max_seq_len, causal=True):
    key_emb, *keys = jax.random.split(key, 1 + num_layers)
    embeddings = Embeddings.create(key=key_emb, vocab_size=vocab_size, dim=dim)
    pos_embs = SinusoidalPositionEmbeddings.create(max_seq_len=max_seq_len, dim=dim)
    to_out_sigma = 0.02 / jnp.sqrt(2 * num_layers)
    layers = [
        TransformerBlock.create(key=key, dim=dim, heads=heads, causal=causal,
                                to_out_sigma=to_out_sigma)
        for key in keys
    ]
    return DecoderLM(embeddings=embeddings, pos_embs=pos_embs, layers=layers)

  def __call__(self, ids, mask):
    x = self.embeddings(ids)
    x = x + self.pos_embs(x)
    for layer in self.layers:
      x = layer(x, mask)
    logits = jnp.einsum('b t d, v d -> b t v', x, self.embeddings.weight) # weight-tied
    logits = jax.nn.log_softmax(logits, axis=-1)
    return logits

## test Embeddings

In [None]:
e = Embeddings.create(key, 10, 4)
e.weight

DeviceArray([[ 0.03750895,  0.00549101,  0.00389322,  0.00697418],
             [-0.01988814, -0.02088845,  0.00778638,  0.0374402 ],
             [ 0.00123927, -0.02986206,  0.01526473, -0.00770232],
             [ 0.01272262,  0.02110806, -0.01728175,  0.02800979],
             [ 0.02787162,  0.01041684, -0.00965307, -0.01007956],
             [ 0.01769945,  0.02692778, -0.00219425, -0.01662124],
             [ 0.00257271,  0.02842328,  0.01239705, -0.01502093],
             [ 0.0147513 ,  0.0002176 ,  0.01646794, -0.00191767],
             [-0.00169041, -0.02123914, -0.00037333,  0.00105569],
             [-0.00892678, -0.0027822 ,  0.00409518,  0.00156941]],            dtype=float32)

In [None]:
ids = jnp.array([[1, 2, 3], [4, 5, 6]])
e(ids)

DeviceArray([[[-0.01988814, -0.02088845,  0.00778638,  0.0374402 ],
              [ 0.00123927, -0.02986206,  0.01526473, -0.00770232],
              [ 0.01272262,  0.02110806, -0.01728175,  0.02800979]],

             [[ 0.02787162,  0.01041684, -0.00965307, -0.01007956],
              [ 0.01769945,  0.02692778, -0.00219425, -0.01662124],
              [ 0.00257271,  0.02842328,  0.01239705, -0.01502093]]],            dtype=float32)

In [None]:
def loss_embeddings(e, ids):
  return e(ids).mean()

jax.value_and_grad(loss_embeddings)(e, ids)

(DeviceArray(0.00419533, dtype=float32),
 Embeddings(weight=DeviceArray([[0.        , 0.        , 0.        , 0.        ],
              [0.04166667, 0.04166667, 0.04166667, 0.04166667],
              [0.04166667, 0.04166667, 0.04166667, 0.04166667],
              [0.04166667, 0.04166667, 0.04166667, 0.04166667],
              [0.04166667, 0.04166667, 0.04166667, 0.04166667],
              [0.04166667, 0.04166667, 0.04166667, 0.04166667],
              [0.04166667, 0.04166667, 0.04166667, 0.04166667],
              [0.        , 0.        , 0.        , 0.        ],
              [0.        , 0.        , 0.        , 0.        ],
              [0.        , 0.        , 0.        , 0.        ]],            dtype=float32)))

## test PositionEmbeddings

In [None]:
sin_embs = SinusoidalPositionEmbeddings.create(128, 6)
sin_embs.weight

DeviceArray([[[ 0.        ,  1.        ,  0.        ,  1.        ,
                0.        ,  1.        ],
              [ 0.84147096,  0.5403023 ,  0.04639928,  0.998923  ,
                0.00215443,  0.9999977 ],
              [ 0.9092974 , -0.41614684,  0.09269861,  0.9956942 ,
                0.00430886,  0.9999907 ],
              [ 0.14112003, -0.9899925 ,  0.13879827,  0.9903207 ,
                0.00646326,  0.99997914],
              [-0.7568025 , -0.6536436 ,  0.18459894,  0.98281395,
                0.00861763,  0.99996287],
              [-0.9589243 ,  0.2836622 ,  0.23000199,  0.9731902 ,
                0.01077197,  0.999942  ],
              [-0.27941552,  0.9601703 ,  0.2749096 ,  0.9614701 ,
                0.01292625,  0.99991643],
              [ 0.6569866 ,  0.75390226,  0.319225  ,  0.947679  ,
                0.01508048,  0.9998863 ],
              [ 0.98935825, -0.14550005,  0.3628528 ,  0.9318465 ,
                0.01723463,  0.99985147],
              [ 0.4

In [None]:
def loss_sin_embs(sin_embs, ids):
  return -sin_embs(ids).mean()

jax.value_and_grad(loss_sin_embs)(sin_embs, ids)

(DeviceArray(-0.18322927, dtype=float32),
 SinusoidalPositionEmbeddings(weight=DeviceArray([[[ 0.        ,  1.        ,  0.        ,  1.        ,
                 0.        ,  1.        ],
               [ 0.84147096,  0.5403023 ,  0.04639928,  0.998923  ,
                 0.00215443,  0.9999977 ],
               [ 0.9092974 , -0.41614684,  0.09269861,  0.9956942 ,
                 0.00430886,  0.9999907 ],
               [ 0.14112003, -0.9899925 ,  0.13879827,  0.9903207 ,
                 0.00646326,  0.99997914],
               [-0.7568025 , -0.6536436 ,  0.18459894,  0.98281395,
                 0.00861763,  0.99996287],
               [-0.9589243 ,  0.2836622 ,  0.23000199,  0.9731902 ,
                 0.01077197,  0.999942  ],
               [-0.27941552,  0.9601703 ,  0.2749096 ,  0.9614701 ,
                 0.01292625,  0.99991643],
               [ 0.6569866 ,  0.75390226,  0.319225  ,  0.947679  ,
                 0.01508048,  0.9998863 ],
               [ 0.98935825, -0.14

## test MultiHeadSelfAttention

In [None]:
mhsa = MultiHeadSelfAttention.create(key, dim=4, heads=2, causal=True)

In [None]:
mhsa

MultiHeadSelfAttention(to_q=Linear(weight=DeviceArray([[-0.00089805, -0.01187019, -0.02608371, -0.00651822],
             [ 0.02117782,  0.01724769,  0.03117264, -0.01532539],
             [ 0.01588214,  0.00171302, -0.03031087,  0.01194714],
             [-0.00190441, -0.00632263, -0.00731096,  0.01891298]],            dtype=float32), bias=DeviceArray([0., 0., 0., 0.], dtype=float32)), to_k=Linear(weight=DeviceArray([[ 0.00914821,  0.01309405, -0.00554321, -0.01999797],
             [-0.01012122,  0.02966917, -0.02834937, -0.02397629],
             [ 0.00108852,  0.02056246,  0.00646167, -0.01054484],
             [-0.02017942, -0.00302269,  0.01521268,  0.00180503]],            dtype=float32), bias=DeviceArray([0., 0., 0., 0.], dtype=float32)), to_v=Linear(weight=DeviceArray([[ 0.01202872,  0.01929402, -0.02928475,  0.00571518],
             [ 0.02516833, -0.0003231 , -0.01465918,  0.03156444],
             [ 0.00365178, -0.00108856,  0.0004419 , -0.02173013],
             [ 0.033257

In [None]:
x = jax.random.normal(key, (2, 5, 4))
mask = jnp.ones((2, 5), dtype=jnp.int32)
x.shape, x.mean(), x.std()

((2, 5, 4),
 DeviceArray(-0.02352496, dtype=float32),
 DeviceArray(0.88507974, dtype=float32))

In [None]:
out = mhsa(x)
out.shape, out.mean(), out.std()

((2, 5, 4),
 DeviceArray(-0.00019259, dtype=float32),
 DeviceArray(0.00080619, dtype=float32))

In [None]:
out

DeviceArray([[[ 3.5223365e-04, -1.2261616e-03,  2.6063018e-03,
                1.1798181e-04],
              [ 4.3814257e-04, -4.4755638e-04,  8.4447488e-04,
                6.5770186e-04],
              [-6.5640733e-04, -1.1401251e-05,  3.1622685e-04,
               -7.8875409e-04],
              [-5.0159823e-04,  2.1421816e-04, -1.5031733e-04,
               -4.5840588e-04],
              [-5.8657490e-04,  2.7089007e-04, -2.0534918e-04,
               -5.0072849e-04]],

             [[-1.4229901e-03,  1.4643967e-03, -2.2975169e-03,
               -1.6099121e-03],
              [-6.9512241e-04,  4.2846799e-04, -5.1409192e-04,
               -5.7392474e-04],
              [-5.9274631e-04,  9.6970703e-05,  6.2740408e-05,
               -6.1600597e-04],
              [-5.1280484e-04,  3.0975789e-04, -3.7076324e-04,
               -5.4342020e-04],
              [-2.6065041e-04,  2.2208970e-04, -3.2328092e-04,
               -2.3981044e-04]]], dtype=float32)

In [None]:
@jax.jit
def mhsa_jitted(x):
  return mhsa(x)

mhsa_jitted(x)

DeviceArray([[[ 3.5223365e-04, -1.2261616e-03,  2.6063018e-03,
                1.1798181e-04],
              [ 4.3814257e-04, -4.4755638e-04,  8.4447488e-04,
                6.5770186e-04],
              [-6.5640733e-04, -1.1401251e-05,  3.1622685e-04,
               -7.8875409e-04],
              [-5.0159823e-04,  2.1421816e-04, -1.5031733e-04,
               -4.5840588e-04],
              [-5.8657490e-04,  2.7089007e-04, -2.0534918e-04,
               -5.0072849e-04]],

             [[-1.4229901e-03,  1.4643967e-03, -2.2975169e-03,
               -1.6099121e-03],
              [-6.9512241e-04,  4.2846799e-04, -5.1409192e-04,
               -5.7392474e-04],
              [-5.9274631e-04,  9.6970703e-05,  6.2740408e-05,
               -6.1600597e-04],
              [-5.1280484e-04,  3.0975789e-04, -3.7076324e-04,
               -5.4342020e-04],
              [-2.6065041e-04,  2.2208970e-04, -3.2328092e-04,
               -2.3981044e-04]]], dtype=float32)

In [None]:
x2 = jax.random.normal(key, (2, 6, 4))
mhsa_jitted(x2)

DeviceArray([[[ 9.2355162e-04, -1.0871887e-04, -7.6125190e-04,
                8.7211886e-04],
              [-3.3225864e-05,  1.4514849e-04, -3.4664199e-04,
                4.0661893e-04],
              [-1.5203953e-03,  6.6145509e-04, -7.1850792e-04,
               -1.7885081e-03],
              [-1.0284185e-03,  5.4180622e-04, -7.4449927e-04,
               -1.3139276e-03],
              [-6.1343610e-04,  1.6504666e-04, -1.3730675e-04,
               -8.9899823e-04],
              [-1.3208017e-04, -3.8671307e-05,  8.7017193e-05,
               -2.6368489e-04]],

             [[ 3.9270520e-03, -1.8672496e-03,  1.8774718e-03,
                4.5844503e-03],
              [ 1.5179217e-03, -3.8094819e-04, -3.2663345e-05,
                1.7614253e-03],
              [ 2.3819357e-03, -1.1696666e-03,  1.2358502e-03,
                2.6842095e-03],
              [ 1.2893528e-03, -3.9525330e-04,  1.6123429e-04,
                1.4953753e-03],
              [ 6.0330518e-04, -1.0314491e-04, -

In [None]:
def loss_mhsa(mhsa, x):
  out = mhsa(x)
  return out[:, 1:, :].mean()

jax.value_and_grad(loss_mhsa)(mhsa, x)

(DeviceArray(-0.00017775, dtype=float32),
 MultiHeadSelfAttention(to_q=Linear(weight=DeviceArray([[ 6.46040462e-07,  6.88114596e-07, -3.13149940e-07,
               -1.47636456e-07],
              [ 2.74369086e-07,  2.08567872e-06, -1.17850504e-07,
               -8.28302291e-07],
              [-9.40899554e-07,  1.79679773e-08, -5.03086085e-07,
               -6.06317371e-07],
              [ 1.98949238e-07, -2.33972969e-07, -6.26538167e-07,
               -4.72522515e-07]], dtype=float32), bias=DeviceArray([-1.5929963e-06, -2.9211435e-07, -2.1658570e-06,
              -2.4130231e-06], dtype=float32)), to_k=Linear(weight=DeviceArray([[-5.2903907e-08, -1.5833085e-07,  1.1134828e-06,
               -6.1250307e-07],
              [ 1.0833571e-06,  8.3119915e-07, -1.6174517e-06,
                9.9819181e-07],
              [ 6.1145408e-07,  2.9176141e-07,  1.2890612e-06,
               -6.8019631e-07],
              [-3.6285272e-08, -1.8763001e-07, -1.5651433e-06,
                1.00162

## test TransformerBlock

In [None]:
t = TransformerBlock.create(key, dim=4, heads=2)

t_out = t(x, mask)
t_out.shape, t_out.mean(), t_out.std()

((2, 5, 4),
 DeviceArray(-0.02326439, dtype=float32),
 DeviceArray(0.88490736, dtype=float32))

In [None]:
def loss_transformer_block(t, x, mask):
  return t(x, mask).mean()

jax.value_and_grad(loss_transformer_block)(t, x, mask)

(DeviceArray(-0.02326439, dtype=float32),
 TransformerBlock(mhsa=MultiHeadSelfAttention(to_q=Linear(weight=DeviceArray([[-5.4389147e-08,  1.7117921e-07, -8.4261410e-07,
               -1.5639525e-06],
              [-1.0106828e-07,  2.1043543e-07, -1.4797588e-06,
               -3.3825381e-06],
              [-1.4524369e-07,  1.4143134e-07,  4.7742651e-07,
                2.2085278e-06],
              [ 1.6177825e-07, -2.3961655e-07, -1.0833901e-06,
                1.7149414e-06]], dtype=float32), bias=DeviceArray([ 5.2197697e-07, -8.2022586e-07, -8.9804644e-07,
               8.9167024e-06], dtype=float32)), to_k=Linear(weight=DeviceArray([[-1.2429228e-07,  1.5513075e-07,  2.0687435e-07,
                9.2135679e-08],
              [ 2.3636471e-07,  4.4608021e-07,  1.3423653e-06,
               -7.8048879e-07],
              [-4.8289792e-08,  2.3543703e-07,  1.3293975e-06,
               -2.2027643e-07],
              [-5.4716995e-08, -4.0551726e-07, -4.7919423e-07,
                6

## test DecoderLM

In [None]:
d = DecoderLM.create(key, vocab_size=3, num_layers=2, dim=4, heads=2, max_seq_len=3)

In [None]:
pprint(jax.tree_map(lambda p: jnp.shape(p), d))

DecoderLM(embeddings=Embeddings(weight=(3, 4)),
          pos_embs=SinusoidalPositionEmbeddings(weight=DeviceArray([[[ 0.        ,  1.        ,  0.        ,  1.        ],
              [ 0.84147096,  0.5403023 ,  0.00999987,  0.99995   ],
              [ 0.9092974 , -0.41614684,  0.01999873,  0.9998    ]]],            dtype=float32)),
          layers=[TransformerBlock(mhsa=MultiHeadSelfAttention(to_q=Linear(weight=(4,
                                                                                   4),
                                                                           bias=(4,)),
                                                               to_k=Linear(weight=(4,
                                                                                   4),
                                                                           bias=(4,)),
                                                               to_v=Linear(weight=(4,
                                                         

In [None]:
pprint(jax.tree_map(lambda p: jnp.std(p), d))

DecoderLM(embeddings=Embeddings(weight=DeviceArray(0.01798398, dtype=float32)),
          pos_embs=SinusoidalPositionEmbeddings(weight=DeviceArray([[[ 0.        ,  1.        ,  0.        ,  1.        ],
              [ 0.84147096,  0.5403023 ,  0.00999987,  0.99995   ],
              [ 0.9092974 , -0.41614684,  0.01999873,  0.9998    ]]],            dtype=float32)),
          layers=[TransformerBlock(mhsa=MultiHeadSelfAttention(to_q=Linear(weight=DeviceArray(0.01451548, dtype=float32),
                                                                           bias=DeviceArray(0., dtype=float32)),
                                                               to_k=Linear(weight=DeviceArray(0.02478699, dtype=float32),
                                                                           bias=DeviceArray(0., dtype=float32)),
                                                               to_v=Linear(weight=DeviceArray(0.02526536, dtype=float32),
                                       

In [None]:
ids_mask = jnp.ones_like(ids, dtype=jnp.int32)
d_out = d(ids, ids_mask)
d_out.shape, d_out.mean(), d_out.std()

((2, 3, 3),
 DeviceArray(-1.0991278, dtype=float32),
 DeviceArray(0.03028868, dtype=float32))

In [None]:
jnp.exp(d_out)  # check that exponentiating logits yields probs

DeviceArray([[[0.346286  , 0.32860494, 0.32505128],
              [0.34975418, 0.32043424, 0.32974482],
              [0.34419867, 0.3234423 , 0.3323133 ]],

             [[0.3465706 , 0.3280305 , 0.32534248],
              [0.34975418, 0.32043424, 0.32974482],
              [0.34419867, 0.3234423 , 0.3323133 ]]], dtype=float32)

In [None]:
def loss_decoder_lm(d, ids, mask):
  return 100 * d(ids, mask).mean()

jax.value_and_grad(loss_decoder_lm)(d, ids, ids_mask)

(DeviceArray(-109.91278, dtype=float32),
 DecoderLM(embeddings=Embeddings(weight=DeviceArray([[-0.8196529 , -0.5671234 , -0.03178539, -1.3359756 ],
              [ 0.67937213,  0.24673419,  0.02402913,  0.9102438 ],
              [ 0.1238312 ,  0.30758926,  0.00296418,  0.41365167]],            dtype=float32)), pos_embs=SinusoidalPositionEmbeddings(weight=DeviceArray([[[ 0.        ,  1.        ,  0.        ,  1.        ],
               [ 0.84147096,  0.5403023 ,  0.00999987,  0.99995   ],
               [ 0.9092974 , -0.41614684,  0.01999873,  0.9998    ]]],            dtype=float32)), layers=[TransformerBlock(mhsa=MultiHeadSelfAttention(to_q=Linear(weight=DeviceArray([[ 7.5875846e-08,  1.5500518e-07, -4.7552930e-08,
               -9.7359589e-08],
              [-1.4707106e-07, -2.3804250e-07,  6.5336053e-08,
                9.5900759e-08],
              [-7.9839424e-08, -1.6917284e-07,  5.2647238e-08,
                1.1143129e-07],
              [ 8.5168892e-08,  1.7968568e-07, -5.

In [None]:
grads = jax.grad(loss_decoder_lm)(d, ids, ids_mask)
d2 = jax.tree_map(lambda p, g: p - 1 * g, d, grads)  # take a step and compare d2 against d

In [None]:
d.pos_embs

SinusoidalPositionEmbeddings(weight=DeviceArray([[[ 0.        ,  1.        ,  0.        ,  1.        ],
              [ 0.84147096,  0.5403023 ,  0.00999987,  0.99995   ],
              [ 0.9092974 , -0.41614684,  0.01999873,  0.9998    ]]],            dtype=float32))

In [None]:
d2.pos_embs  # pos_embs should be frozen and same as d

SinusoidalPositionEmbeddings(weight=DeviceArray([[[ 0.        ,  1.        ,  0.        ,  1.        ],
              [ 0.84147096,  0.5403023 ,  0.00999987,  0.99995   ],
              [ 0.9092974 , -0.41614684,  0.01999873,  0.9998    ]]],            dtype=float32))

In [None]:
# trainable params should have been moved by gradient
pprint(jax.tree_map(lambda d1, d2: jnp.abs(d2 - d1).sum(), d, d2))

DecoderLM(embeddings=Embeddings(weight=DeviceArray(5.462953, dtype=float32)),
          pos_embs=SinusoidalPositionEmbeddings(weight=DeviceArray([[[ 0.        ,  1.        ,  0.        ,  1.        ],
              [ 0.84147096,  0.5403023 ,  0.00999987,  0.99995   ],
              [ 0.9092974 , -0.41614684,  0.01999873,  0.9998    ]]],            dtype=float32)),
          layers=[TransformerBlock(mhsa=MultiHeadSelfAttention(to_q=Linear(weight=DeviceArray(1.7719576e-06, dtype=float32),
                                                                           bias=DeviceArray(4.520792e-07, dtype=float32)),
                                                               to_k=Linear(weight=DeviceArray(4.7171488e-07, dtype=float32),
                                                                           bias=DeviceArray(6.320988e-11, dtype=float32)),
                                                               to_v=Linear(weight=DeviceArray(0.00270491, dtype=float32),
               

In [None]:
grads.embeddings

Embeddings(weight=DeviceArray([[-0.8196529 , -0.5671234 , -0.03178539, -1.3359756 ],
             [ 0.67937213,  0.24673419,  0.02402913,  0.9102438 ],
             [ 0.1238312 ,  0.30758926,  0.00296418,  0.41365167]],            dtype=float32))

In [None]:
d.embeddings

Embeddings(weight=DeviceArray([[ 0.03777524,  0.01509636,  0.01720601,  0.02478568],
             [-0.0172825 , -0.01549674,  0.00084439,  0.0017051 ],
             [ 0.02888897, -0.01042003,  0.01671507, -0.01324442]],            dtype=float32))

In [None]:
d2.embeddings

Embeddings(weight=DeviceArray([[ 0.85742813,  0.5822198 ,  0.0489914 ,  1.3607613 ],
             [-0.6966546 , -0.26223093, -0.02318473, -0.9085387 ],
             [-0.09494223, -0.3180093 ,  0.01375089, -0.4268961 ]],            dtype=float32))

## test LayerNorm

In [None]:
biased_x = 3 * jax.random.normal(key, (2, 3, 2)) + 7
biased_x.shape, biased_x.mean(), biased_x.std()

((2, 3, 2),
 DeviceArray(6.656968, dtype=float32),
 DeviceArray(2.634112, dtype=float32))

In [None]:
mask = jnp.ones((2, 3), dtype=jnp.int32)
#mask = mask.at[:, -1:].set(0)
ln = LayerNorm.create()
debiased_x = ln(biased_x, mask)
debiased_x.shape, debiased_x.mean(), debiased_x.std()

((2, 3, 2),
 DeviceArray(-9.934108e-08, dtype=float32),
 DeviceArray(0.9999989, dtype=float32))

# Optimizers

In [None]:
@dataclass
class SGDOptimizer:
  lr: float

  @staticmethod
  def create(params, lr=0.001):
    return SGDOptimizer(lr=lr)

  def update(self, params, grads, lr=None):
    if lr is None:
      lr = self.lr

    def apply_grad(p, g):
      return p - lr * g
    return jax.tree_map(apply_grad, params, grads)

  @staticmethod
  def _flatten(self):
    return ((), (self.lr,))

  @staticmethod
  def _unflatten(aux_data, flat_contents):
    lr, = aux_data
    return SGDOptimizer(lr=lr)

jax.tree_util.register_pytree_node(SGDOptimizer,
                                   SGDOptimizer._flatten,
                                   SGDOptimizer._unflatten)


@dataclass
class AdamOptimizer:
  m: Any  # pytree
  v: Any  # pytree
  lr: jax.Array  # shape [1]
  betas: jax.Array # shape [2]
  eps: jax.Array  # shape [1]

  @staticmethod
  def create(params, lr=0.001, betas=(0.9, 0.999), eps=1e-8):
    m = jax.tree_map(lambda x: jnp.zeros_like(x), params)
    v = jax.tree_map(lambda x: jnp.zeros_like(x), params)
    return AdamOptimizer(m=m, v=v,
                         lr=jnp.array([lr], dtype=jnp.float32),
                         betas=jnp.array(betas, dtype=jnp.float32),
                         eps=jnp.array([eps], dtype=jnp.float32))

  def update(self, params, grads, lr=None):
    if lr is None:
      lr = self.lr

    def _update_m(m, g):
      return self.betas[0] * m + (1 - self.betas[0]) * g
    def _update_v(v, g):
      return self.betas[1] * v + (1 - self.betas[1]) * g**2
    self.m = jax.tree_map(_update_m, self.m, grads)
    self.v = jax.tree_map(_update_v, self.v, grads)

    def _compute_m_hat(m):
      return m / (1 - self.betas[0])
    def _compute_v_hat(v):
      return v / (1 - self.betas[1])
    m_hat = jax.tree_map(_compute_m_hat, self.m)
    v_hat = jax.tree_map(_compute_v_hat, self.v)

    def _update_params(p, m_hat, v_hat):
      return p - lr * m_hat / (v_hat ** 0.5 + self.eps)
    params_new = jax.tree_map(_update_params, params, m_hat, v_hat)
    return params_new

  @staticmethod
  def _flatten(self):
    return ((self.m, self.v),  # trainable
            (self.lr, self.betas, self.eps))  # non-trainable

  @staticmethod
  def _unflatten(aux_data, flat_contents):
    m, v = flat_contents
    lr, betas, eps = aux_data
    return AdamOptimizer(m=m, v=v, lr=lr, betas=betas, eps=eps)

jax.tree_util.register_pytree_node(AdamOptimizer, AdamOptimizer._flatten,
                                   AdamOptimizer._unflatten)


@dataclass
class CosineDecayLR:
  optimizer: Any  # *Optimizer
  lr: jax.Array  # shape [1]
  total_steps: int
  warmup_steps: int
  curr_lr: jax.Array = field(default_factory=lambda: jnp.array([0.], dtype=jnp.float32))
  curr_step: jax.Array = field(default_factory=lambda: jnp.array([0], dtype=jnp.int32))

  def update(self, params, grads):
    self.step()
    return self.optimizer.update(params, grads, self.curr_lr)

  def step(self):
    self.curr_step = jnp.minimum(self.curr_step + 1, self.total_steps)
    lr_during_ramp = self.lr * self.curr_step / self.warmup_steps
    decay_frac = (self.curr_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
    lr_during_decay = self.lr * 0.5 * (jnp.cos(decay_frac * jnp.pi) + 1)
    in_ramp = self.curr_step < self.warmup_steps
    self.curr_lr = in_ramp * lr_during_ramp + (1 - in_ramp) * lr_during_decay


def clip_grad_norm(grads: Any, max_norm: float):
  norms = jax.tree_map(jnp.linalg.norm, grads)
  def clip(x, norm):
    return jnp.where(norm <= max_norm, x, x * max_norm / (norm + 1e-6))
  return jax.tree_map(clip, grads, norms)

## test SGD

In [None]:
val, grads = jax.value_and_grad(loss_decoder_lm)(d, ids, ids_mask)

In [None]:
grads.embeddings

Embeddings(weight=DeviceArray([[-0.8196529 , -0.5671234 , -0.03178539, -1.3359756 ],
             [ 0.67937213,  0.24673419,  0.02402913,  0.9102438 ],
             [ 0.1238312 ,  0.30758926,  0.00296418,  0.41365167]],            dtype=float32))

In [None]:
d.embeddings

Embeddings(weight=DeviceArray([[ 0.03777524,  0.01509636,  0.01720601,  0.02478568],
             [-0.0172825 , -0.01549674,  0.00084439,  0.0017051 ],
             [ 0.02888897, -0.01042003,  0.01671507, -0.01324442]],            dtype=float32))

In [None]:
sgd_test = SGDOptimizer.create(params=d, lr=0.1)
d_new = jax.jit(sgd_test.update)(d, grads)
d_new.embeddings

Embeddings(weight=DeviceArray([[ 0.11974053,  0.0718087 ,  0.02038455,  0.15838325],
             [-0.08521973, -0.04017016, -0.00155852, -0.08931928],
             [ 0.01650585, -0.04117896,  0.01641865, -0.05460959]],            dtype=float32))

## test Adam

In [None]:
adam_test = AdamOptimizer.create(params=d, lr=0.1)
d_new = jax.jit(adam_test.update)(d, grads)
d_new.embeddings

Embeddings(weight=DeviceArray([[ 0.13777524,  0.11509636,  0.11720599,  0.12478568],
             [-0.11728252, -0.11549675, -0.09915557, -0.09829491],
             [-0.07111105, -0.11042003, -0.08328459, -0.11324443]],            dtype=float32))

## test CosineDecayLR

In [None]:
#@jax.jit
def test_cosine_decay_lr():
  cd = CosineDecayLR(optimizer=None, lr=1.0, total_steps=110, warmup_steps=10)
  for i in range(15):
    print(i, cd.curr_lr)
    cd.step()

  for i in range(15, 100):
    cd.step()

  for i in range(100, 111):
    print(i, cd.curr_lr)
    cd.step()

In [None]:
test_cosine_decay_lr()

0 [0.]
1 [0.1]
2 [0.2]
3 [0.3]
4 [0.4]
5 [0.5]
6 [0.6]
7 [0.7]
8 [0.8]
9 [0.90000004]
10 [1.]
11 [0.99975324]
12 [0.99901336]
13 [0.997781]
14 [0.9960574]
100 [0.02447173]
101 [0.01985314]
102 [0.01570839]
103 [0.0120416]
104 [0.00885636]
105 [0.00615582]
106 [0.00394264]
107 [0.00221902]
108 [0.00098664]
109 [0.00024673]
110 [0.]


## test ClipGradNorm

In [None]:
jax.tree_map(jnp.linalg.norm, grads)

DecoderLM(embeddings=Embeddings(weight=DeviceArray(2.1004543, dtype=float32)), pos_embs=SinusoidalPositionEmbeddings(weight=DeviceArray([[[ 0.        ,  1.        ,  0.        ,  1.        ],
              [ 0.84147096,  0.5403023 ,  0.00999987,  0.99995   ],
              [ 0.9092974 , -0.41614684,  0.01999873,  0.9998    ]]],            dtype=float32)), layers=[TransformerBlock(mhsa=MultiHeadSelfAttention(to_q=Linear(weight=DeviceArray(4.8988517e-07, dtype=float32), bias=DeviceArray(2.4523015e-07, dtype=float32)), to_k=Linear(weight=DeviceArray(1.842191e-07, dtype=float32), bias=DeviceArray(4.678597e-11, dtype=float32)), to_v=Linear(weight=DeviceArray(0.00094585, dtype=float32), bias=DeviceArray(0.00064803, dtype=float32)), to_out=Linear(weight=DeviceArray(0.00346797, dtype=float32), bias=DeviceArray(0.07854329, dtype=float32)), heads=2, scale=0.5, causal=True), mlp=MLP(layers=[Linear(weight=DeviceArray(0.0056897, dtype=float32), bias=DeviceArray(0.00402003, dtype=float32)), Linear(w

In [None]:
grads_clipped = clip_grad_norm(grads, max_norm=0.01)

In [None]:
jax.tree_map(jnp.linalg.norm, grads_clipped)

DecoderLM(embeddings=Embeddings(weight=DeviceArray(0.01, dtype=float32)), pos_embs=SinusoidalPositionEmbeddings(weight=DeviceArray([[[ 0.        ,  1.        ,  0.        ,  1.        ],
              [ 0.84147096,  0.5403023 ,  0.00999987,  0.99995   ],
              [ 0.9092974 , -0.41614684,  0.01999873,  0.9998    ]]],            dtype=float32)), layers=[TransformerBlock(mhsa=MultiHeadSelfAttention(to_q=Linear(weight=DeviceArray(4.8988517e-07, dtype=float32), bias=DeviceArray(2.4523015e-07, dtype=float32)), to_k=Linear(weight=DeviceArray(1.842191e-07, dtype=float32), bias=DeviceArray(4.678597e-11, dtype=float32)), to_v=Linear(weight=DeviceArray(0.00094585, dtype=float32), bias=DeviceArray(0.00064803, dtype=float32)), to_out=Linear(weight=DeviceArray(0.00346797, dtype=float32), bias=DeviceArray(0.00999987, dtype=float32)), heads=2, scale=0.5, causal=True), mlp=MLP(layers=[Linear(weight=DeviceArray(0.0056897, dtype=float32), bias=DeviceArray(0.00402003, dtype=float32)), Linear(weight

# Decoding

In [None]:
# TODO: make this batched
# TODO: replace use_argmax with temperature
# TODO: make this jittable
def sample(model, prompt=None, prompt_ids=None, max_tokens=30, print_loop=False,
           use_argmax=False):
  assert prompt is not None or prompt_ids is not None
  if prompt:
    ids = jnp.array(tokenizer_encode(prompt).numpy())
  else:
    ids = jnp.array([prompt_ids])
  start_idx = ids.shape[1]
  for i in range(max_tokens):
    padded_ids = jnp.zeros((1, MAX_SEQ_LEN), dtype=jnp.int32)
    padded_ids = padded_ids.at[:, :ids.shape[1]].set(ids)
    mask = jnp.zeros_like(padded_ids)
    mask = mask.at[:, ids.shape[1]].set(1)
    logits = model(padded_ids, mask)  # [b, t, v]
    logits = logits[0, start_idx + i - 1, :]  # [v]
    probs = jnp.exp(logits).astype(jnp.float64)
    if use_argmax:
      token = jnp.argmax(logits, axis=-1)[None, None]
    else:  # sample from distribution
      probs /= jnp.sum(probs)  # w/o this, sum is ~1, but not exactly
      token = np.random.choice(np.arange(len(probs)), size=(1, 1), p=probs)
    ids = jnp.concatenate([ids, token], axis=1)
    if print_loop:
      print(tokenizer_decode(ids))
  output = tokenizer_decode(ids)
  return output

# Loss

In [None]:
@jax.jit
def compute_loss(model, batch):
  ids, ids_target, mask = batch
  attn_mask, targets_mask = mask[:, :-1], mask[:, 1:]
  logits = model(ids, attn_mask)
  logits = rearrange(logits, 'b t v -> (b t) v')
  ids_target = ids_target.flatten()
  targets_mask = targets_mask.flatten()
  example_losses = -logits[jnp.arange(len(ids_target)), ids_target]
  loss = jnp.sum(example_losses * targets_mask) / jnp.sum(targets_mask)
  return loss

## test compute_loss

In [None]:
model = DecoderLM.create(key, vocab_size=3, num_layers=2, dim=4, heads=2, max_seq_len=3)
ids_batch = np.array(
    [[0, 1, 2, 1],
     [2, 0, 1, 2]])
ids_input, ids_target = ids_batch[:, :-1], ids_batch[:, 1:]
mask = jnp.array([[1, 1, 1, 0], [1, 1, 1, 0]])

print(compute_loss(model, (ids_input, ids_target, mask)))

1.1060945


# Train loop (data-parallel)

In [None]:
import jax.sharding
from jax.experimental import mesh_utils

devices = mesh_utils.create_device_mesh((N_DEVICES,))
sharding = jax.sharding.PositionalSharding(devices)

In [None]:
@jax.jit
def train_on_batch(model, batch, optimizer):
  loss, grads = jax.value_and_grad(compute_loss)(model, batch)
  #id_print(grads.layers[0].bias.shape, tap_with_device=True, device_index=0)
  #id_print(grads.layers[0].bias[0], tap_with_device=True, device_index=1)

  grads = clip_grad_norm(grads, max_norm=1)
  model = optimizer.update(model, grads)  # optimizer is mutated, so have to return it
  return model, optimizer, loss


def train_parallel(model, train_dl, valid_dl, sharding, optimizer, steps,
                   n_decode=4, decode_every_steps=0):
  model = jax.device_put(model, sharding.replicate())
  optimizer = jax.device_put(optimizer, sharding.replicate())
  losses = []
  print('ground truth')
  for ids in one_batch_unique_ids[:n_decode]:
    print(tokenizer_decode(ids)[0])
  print('-' * 80)

  try:
    for i, batch in enumerate(train_dl):
      batch = jax.device_put(batch, sharding.reshape(N_DEVICES, 1))
      model, optimizer, loss = train_on_batch(model, batch, optimizer)
      if i % 10 == 0:
        losses.append(loss)
        print(f'step {i} lr: {optimizer.curr_lr[0]:0.6g} loss: {loss}')

      if decode_every_steps > 0 and i % decode_every_steps == 0:
        for ids in one_batch_unique_ids[:n_decode]:
          ids = ids[:5]
          text = sample(model, prompt_ids=ids, max_tokens=15)[0]
          print(text)

      if i == steps:
        break
  except KeyboardInterrupt:
    pass
  return model, optimizer, losses

# Train

In [None]:
jax.config.update('jax_default_matmul_precision', 'bfloat16')
#jax.config.update('jax_default_matmul_precision', 'float32')

In [None]:
model = DecoderLM.create(key, vocab_size=VOCAB_SIZE, num_layers=12, dim=768,
                         heads=12, max_seq_len=MAX_SEQ_LEN)

In [None]:
LR = 3e-4
TOTAL_STEPS = 100_000
WARMUP_STEPS = 2000
DECODE_EVERY_STEPS = 1000

#optimizer = SGDOptimizer.create(params=model, lr=LR)
optimizer = AdamOptimizer.create(params=model, lr=LR)
optimizer = CosineDecayLR(optimizer=optimizer,
                          lr=jnp.array(LR),  # Pathways requires everything to be wrapped as jax.Array
                          total_steps=jnp.array(TOTAL_STEPS),
                          warmup_steps=jnp.array(WARMUP_STEPS))

In [None]:
# Profile to see check TPU utilization and see if we're input-bound
#%xprof model, optimizer, losses = train_parallel(model, dataloader_train, None, sharding, optimizer, steps=200)

In [None]:
model, optimizer, losses = train_parallel(
    model, dataloader_train, None, sharding, optimizer,
    steps=TOTAL_STEPS, decode_every_steps=DECODE_EVERY_STEPS,
)

# Run model on examples

In [None]:
ex_id = 2
one_batch_unique_ids[ex_id, ...], tokenizer_decode(one_batch_unique_ids[ex_id, ...])

In [None]:
sample(model, prompt_ids=[1215, 15355, 1355, 23936, 286, 8099], print_loop=True)

# Debug

## Total device ram usage

In [None]:
import humanize

fmt_size = partial(humanize.naturalsize, binary=True)

def print_memory(device):
  stats = device.memory_stats()
  print(stats)
  used = stats['bytes_in_use']
  limit = stats['bytes_limit']
  print(f'Using {fmt_size(used)} / {fmt_size(limit)} ({used/miti:%}) on {device}')

In [None]:
print_memory(jax.local_devices()[0])

In [None]:
print_memory(jax.local_devices()[1])

## Clear memory

In [None]:
def jax_delete_live_arrays():
  for x in jax.live_arrays():
    x.delete()

if False:
  jax.clear_caches()
  jax_delete_live_arrays()

## Checkpointing

https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html

In [None]:
import jax.ad_checkpoint
jax.ad_checkpoint.print_saved_residuals(compute_loss, model, batch)

## print_fwd_bwd

In [None]:
from jax.tree_util import tree_flatten, tree_unflatten

from rich.console import Console
from rich.table import Table
import rich.text

def print_fwd_bwd(f, *args, **kwargs) -> None:
  args, in_tree = tree_flatten((args, kwargs))

  def f_(*args):
    args, kwargs = tree_unflatten(in_tree, args)
    return f(*args, **kwargs)

  fwd = jax.make_jaxpr(lambda *args: jax.vjp(f_, *args))(*args).jaxpr

  y, f_vjp = jax.vjp(f_, *args)
  res, in_tree = tree_flatten(f_vjp)

  def g_(*args):
    *res, y = args
    f_vjp = tree_unflatten(in_tree, res)
    return f_vjp(y)

  bwd = jax.make_jaxpr(g_)(*res, y).jaxpr

  table = Table(show_header=False, show_lines=True, padding=(1, 2, 0, 2), box=None)
  table.add_row("[bold green]forward computation:",
                "[bold green]backward computation:")
  table.add_row(rich.text.Text.from_ansi(str(fwd)),
                rich.text.Text.from_ansi(str(bwd)))
  console = Console(width=240, force_jupyter=True)
  console.print(table)

def _renderable_repr(self):
  return self.html
rich.jupyter.JupyterRenderable._repr_html = _renderable_repr

In [None]:
print_fwd_bwd(compute_loss, model, batch)