In [1]:
%load_ext autoreload
%autoreload 2

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

In [2]:
import argparse
import json
import time
import gc
import random
from functools import partial
import numpy as np
# import wandb
from tqdm import tqdm

import jax
import jax.numpy as jnp
from jax.experimental.maps import thread_resources
from jax.experimental.maps import Mesh
from jax.experimental.maps import xmap
from jax.experimental.pjit import pjit
from jax.experimental import PartitionSpec as P

import haiku as hk
import optax

from mesh_transformer import util
from mesh_transformer.layers import EmbeddingShard, TransformerLayerShard, RelativePositionEmbs, ProjectionShard, \
    TransformerLayerShardV2, Projection, EmbeddingShardV2
from mesh_transformer.checkpoint import read_ckpt, write_ckpt, write_ckpt_v2, load_ckpt_v2
from mesh_transformer.transformer_shard import CausalTransformerShard, CausalTransformer, CausalTransformerV2, CausalTransformerV2
from mesh_transformer.util import clip_by_global_norm, additive_weight_decay, to_f32, to_bf16, maybe_shard, head_print, global_norm
# from tfrecord_loader import TFRecordNewInputs
# from smart_open import open
# from google.cloud import storage
# from google.cloud.exceptions import NotFound

In [3]:
import flax
from flax.training import train_state
from flax.training.train_state import TrainState
from flax.core.frozen_dict import freeze

In [4]:
tpu_size = jax.device_count(); print(f"jax devices: {tpu_size}")

jax devices: 4


In [5]:
mp_per_host = cores_per_replica = 2
mesh_shape = (tpu_size // cores_per_replica, cores_per_replica)
devices = np.array(jax.devices()).reshape(mesh_shape)
cpu_device = jax.devices('cpu')[0]
mesh = Mesh(devices, ('dp', 'mp'))

In [None]:
mesh_shape = (4, 2)  # assume we hav 8 devices
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
mesh = Mesh(devices, ('x', 'y'))
mesh = Mesh(devices, ('mp', 'dp'))

In [58]:
w = np.arange(8 * 2).reshape(8, 2); w  # io
x = np.ones((4, 8)); x  # bi
def forward(x, w): return x @ w  # bi,io->bo
forward(x, w)

array([[ 0,  1],
       [ 2,  3],
       [ 4,  5],
       [ 6,  7],
       [ 8,  9],
       [10, 11],
       [12, 13],
       [14, 15]])

array([[1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1.]])

array([[56., 64.],
       [56., 64.],
       [56., 64.],
       [56., 64.]])

In [63]:
init_pjit = pjit(lambda x: x, in_axis_resources=None, out_axis_resources=P('mp', None))
with mesh: w = init_pjit(w)

In [64]:
forward_pjit = pjit(forward, in_axis_resources=(P('dp', 'mp'), P('mp', None)), out_axis_resources=None)
with mesh: y = forward_pjit(x, w)

In [42]:
w = np.arange(2 * 8).reshape(2, 8); w  # io
x = np.ones((4, 2)); x  # bi
forward(x, w)

array([[ 0,  1,  2,  3,  4,  5,  6,  7],
       [ 8,  9, 10, 11, 12, 13, 14, 15]])

array([[1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.]])

In [None]:
parallel = ('mp', 'dp')

In [55]:
init_pjit = pjit(lambda x: x, in_axis_resources=None, out_axis_resources=P(None, parallel))
with mesh: w = init_pjit(w)

In [56]:
forward_pjit = pjit(forward, in_axis_resources=(P('dp', None), P(None, parallel)),
                             out_axis_resources=P('dp', 'mp'))
with mesh: y = forward_pjit(x, w)

In [60]:
init_pjit = pjit(lambda x: x, in_axis_resources=None, out_axis_resources=P(parallel, None))
with mesh: w = init_pjit(w)

In [61]:
forward_pjit = pjit(forward, in_axis_resources=(P('dp', 'mp'), P(parallel, None)), out_axis_resources=None)
with mesh: y = forward_pjit(x, w)

In [6]:
from typing import Any, Callable
from flax import core, struct

class TrainState(struct.PyTreeNode):
    step: int
    apply_fn: Callable = struct.field(pytree_node=False)
    params: core.FrozenDict[str, Any]
    tx: optax.GradientTransformation = struct.field(pytree_node=False)
    opt_state: optax.OptState

    def apply_gradients(self, *, grads, **kwargs):
        updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params)
        new_params = optax.apply_updates(self.params, to_f32(updates))  # XD: to_f32 from mesh-transformer-jax
        return self.replace(step=self.step + 1, params=new_params, opt_state=new_opt_state, **kwargs)

    @classmethod
    def create(cls, *, apply_fn, params, tx, **kwargs):
        opt_state = tx.init(params)
        return cls(step=0, apply_fn=apply_fn, params=params, tx=tx, opt_state=opt_state, **kwargs)

In [4]:
config = params = json.load(open('configs/example_config.json'))
per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]

scheduler = util.gpt3_schedule(config["warmup_steps"], config["anneal_steps"], config["lr"], config["end_lr"])
optimizer = optax.chain(
    optax.scale(1 / config.get("gradient_accumulation_steps", 1)),
    clip_by_global_norm(1, use_psum=False),
    optax.scale_by_adam(),
    additive_weight_decay(config["weight_decay"]),
    optax.scale(-1),
    optax.scale_by_schedule(scheduler)
)

In [6]:
key = hk.PRNGSequence(42)  # PRNGKey(42)
seq, vocab = config["seq"], config["n_vocab"]
example_shape, train_example_shape = (1, seq,), (1, 1, seq)
x = jax.random.uniform(next(key), example_shape, minval=0, maxval=vocab).astype(jnp.uint32)  # batch, len

In [115]:
def init_params(key, x):
    def train_loss(x, y): return CausalTransformerShard(config).loss(x, y)
    param_init_fn = hk.transform(hk.experimental.optimize_rng_use(train_loss)).init
    params = param_init_fn(key, x, x)
    return params

params = init_params(next(key), x)

In [116]:
def train_loss(batch):
    transformer = CausalTransformerShard(config)
    out = transformer.loss(**batch, z_loss=True)
    return out["loss"], out["last_loss"]
loss_apply_fn = hk.without_apply_rng(hk.transform(train_loss)).apply

In [117]:
def init(params): return params, optimizer.init(params)
state_shapes = jax.eval_shape(init, params)
params_spec = jax.tree_map(partial(shard_strategy, parallel=("mp", "dp")), state_shapes[0])
params_spec = freeze(params_spec)
# params = freeze(params)

def get_opt_spec(x): return params_spec if isinstance(x, dict) else None  # from run_clm_mp.py
params_spec, opt_state_spec = jax.tree_map(get_opt_spec, state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)))
state_spec = TrainState(params=params_spec, opt_state=opt_state_spec, step=None, apply_fn=loss_apply_fn, tx=optimizer)

In [118]:
# from dalle-mini
params = jax.tree_map(lambda x: np.asarray(x), params)  # from run_clm_mp.py
def init_state(params): return TrainState.create(apply_fn=loss_apply_fn, tx=optimizer, params=params)
# with mesh: params, opt_state = pjit(init_state, None, state_shard)(params)
with mesh: state = pjit(init_state, (params_spec,), state_spec, donate_argnums=(0,))(freeze(params))  # in_axis_resources should be None?
# del params  # free CPU memory

In [119]:
jax.tree_map(lambda x: [(db.shape, db.device()) for db in x.device_buffers], state.params)

FrozenDict({
    causal_transformer_shard/~/embedding_shard_v2/~/linear: {
        b: [((32,), GpuDevice(id=0, process_index=0)), ((32,), GpuDevice(id=1, process_index=0)), ((32,), GpuDevice(id=2, process_index=0)), ((32,), GpuDevice(id=3, process_index=0)), ((32,), GpuDevice(id=4, process_index=0)), ((32,), GpuDevice(id=5, process_index=0)), ((32,), GpuDevice(id=6, process_index=0)), ((32,), GpuDevice(id=7, process_index=0))],
        w: [((8, 32), GpuDevice(id=0, process_index=0)), ((8, 32), GpuDevice(id=1, process_index=0)), ((8, 32), GpuDevice(id=2, process_index=0)), ((8, 32), GpuDevice(id=3, process_index=0)), ((8, 32), GpuDevice(id=4, process_index=0)), ((8, 32), GpuDevice(id=5, process_index=0)), ((8, 32), GpuDevice(id=6, process_index=0)), ((8, 32), GpuDevice(id=7, process_index=0))],
    },
    causal_transformer_shard/~/layer_0/~/fc_in: {
        b: [((16,), GpuDevice(id=0, process_index=0)), ((16,), GpuDevice(id=1, process_index=0)), ((16,), GpuDevice(id=2, process_index=0)

In [34]:
init_xmap = xmap(fun=init,
                in_axes=(["shard", ...], ["batch", ...]),
                out_axes=["shard", ...],
                axis_resources={'shard': 'mp', 'batch': 'dp'})

  warn("xmap is an experimental feature and probably has bugs!")


In [25]:
with mesh: state = init_xmap(jnp.array(key.take(mp_per_host)), x)

In [15]:
def shard_strategy(shape_dtype, parallel):
    if shape_dtype.ndim == 0:
        return P()
    if shape_dtype.ndim == 1:
        if shape_dtype.shape[0] == config["d_model"]: # layernorm or fc_out bias
            return P(None)
        elif shape_dtype.shape[0] == config["n_vocab"]:  # projection bias
            return P(parallel)
        else:
            assert shape_dtype.shape[0] == config["d_model"] * 4, str(shape_dtype)  # fc_in bias
            return P(parallel)
    assert shape_dtype.ndim == 2, str(shape_dtype)
    # embedding/projection layers
    if shape_dtype.shape == (config["n_vocab"], config["d_model"]):
        return P(parallel, None)
    elif shape_dtype.shape == (config["d_model"], config["n_vocab"]):
        return P(None, parallel)

    # a transformer layer
    elif shape_dtype.shape[0] == config["d_model"] or shape_dtype.shape[1] == config["d_model"]:
        # shard along the axis which is _not_ the model dimension
        if shape_dtype.shape[1] == config["d_model"]:
            return P(parallel, None)
        elif shape_dtype.shape[0] == config["d_model"]:
            return P(None, parallel)
    else:
        raise NotImplementedError("borked")

In [33]:
# def eval_step(state, ctx, tgt, ctx_length):
def eval_step(params, batch):
    ctx, tgt, ctx_length = batch
    def eval_loss(x, y, mask):
        transformer = CausalTransformerShard(config)
        return transformer.loss(x, y, mask=mask)#['loss']
    eval_loss_fn = hk.without_apply_rng(hk.transform(eval_loss)).apply

    # mask = (jnp.arange(0, len(ctx)) > ctx_length) * -1e10  # XD: j
    # XD: copied from V2
    # mask = (jnp.arange(0, ctx.shape[1])[None, :] > ctx_length[:, None]) * -1e10  # XD: bj
    # mask = mask[:, None, None, :]  # XD: bj->bnij
#     return eval_loss_fn(to_bf16(state["params"]), ctx, tgt, mask)
    return eval_loss_fn(to_bf16(params), ctx, tgt, 0.)

In [133]:
def eval_step(state, batch):  # ctx, tgt, ctx_length
    def eval_loss(batch):  # ctx, tgt, mask
        return CausalTransformerShard(config).loss(**batch)#['loss']
    eval_loss_fn = hk.without_apply_rng(hk.transform(eval_loss)).apply
    # XD: copied from V2
    # mask = (jnp.arange(0, ctx.shape[1])[None, :] > ctx_length[:, None]) * -1e10  # XD: bj
    # mask = mask[:, None, None, :]  # XD: bj->bnij
    mask = 0.
    batch = {k: v for k, v in batch.items() if k != 'ctx_length'}
    batch['mask'] = mask
    return eval_loss_fn(to_bf16(state.params), batch)

In [28]:
# def train_step(state, ctx, tgt):
#     params, opt_state = state['params'], state['opt_state']
def train_step(state, step, batch):
    ctx, tgt = batch
    params, opt_state = state
    def train_loss(x, y):
        transformer = CausalTransformerShard(config)
        out = transformer.loss(x, y, z_loss=True)
        return out["loss"], out["last_loss"]
    train_loss_fn = hk.without_apply_rng(hk.transform(train_loss)).apply

    def microbatch(old_grad, batch):
        ctx, tgt = batch
        val_grad_fn = jax.value_and_grad(train_loss_fn, has_aux=True)
        (loss, last_loss), grad = val_grad_fn(to_bf16(params), ctx, tgt)

        new_grad = jax.tree_multimap(lambda a, b: a + b, old_grad, grad)
        # gnorm = global_norm(grad)
        return new_grad, (loss, last_loss)#, gnorm)

    if ctx.shape[0] == 1:
        val_grad_fn = jax.value_and_grad(train_loss_fn, has_aux=True)
        (loss, last_loss), grad = val_grad_fn(to_bf16(state.params), ctx[0], tgt[0])
    else:
        grad, (loss, last_loss, gnorm) = jax.lax.scan(microbatch,
                                               jax.tree_map(lambda x: jnp.zeros_like(x).astype(jnp.bfloat16), params),
                                               (ctx, tgt))
    grad = jax.lax.pmean(grad, "batch")  # for xmap,  # XDC: loss and last_loss are not pmeaned accross batch dim
    updates, new_opt_state = optimizer.update(grad, opt_state, params)
    new_params = optax.apply_updates(params, to_f32(updates))
    return to_f32(loss), to_f32(last_loss), (new_params, tuple(new_opt_state)), step + 1

In [128]:
def train_step(state, batch):
    params, train_loss_fn = state.params, state.apply_fn

    def microbatch(old_grad, batch):
        val_grad_fn = jax.value_and_grad(train_loss_fn, has_aux=True)
        (loss, last_loss), grad = val_grad_fn(to_bf16(params), batch)
        new_grad = jax.tree_multimap(lambda a, b: a + b, old_grad, grad)
        # gnorm = global_norm(grad)
        return new_grad, (loss, last_loss)#, gnorm)

    if batch['ctx'].shape[0] == 1:
        val_grad_fn = jax.value_and_grad(train_loss_fn, has_aux=True)
        (loss, last_loss), grad = val_grad_fn(to_bf16(params), jax.tree_map(lambda x: x[0], batch))
    else:
        grad, (loss, last_loss, gnorm) = jax.lax.scan(microbatch,
                                               jax.tree_map(lambda x: jnp.zeros_like(x).astype(jnp.bfloat16), params),
                                               batch)
#     grad = jax.lax.pmean(grad, "batch")  # for xmap,  # XDC: loss and last_loss are not pmeaned accross batch dim
    state = state.apply_gradients(grads=grad)
    metrics = {'loss': to_f32(loss), 'last_loss': to_f32(last_loss)}
    return state, metrics

In [135]:
sample = {'obs': x[:, :-1], 'target': x[:, 1:]}
ctx_length = jnp.array([len(sample["obs"][0])] * len(sample["obs"]))
batch = {'ctx': x[:, :-1], 'tgt': x[:, 1:], 'ctx_length': ctx_length}

In [136]:
# loss0 = eval_step(params, (sample['obs'], sample['target'], ctx_length))
eval_metrics = eval_step(state, batch)

In [20]:
p_eval_step = pjit(eval_step, (state_shard[0], (P('dp'), P('dp'), P('dp'))), None)
with mesh: loss = p_eval_step(params, (sample['obs'], sample['target'], ctx_length))

In [139]:
p_eval_step = pjit(eval_step, (state_spec, P('dp')), None)
with mesh: eval_metrics = p_eval_step(state, batch)

In [36]:
eval_step_xmap = xmap(fun=eval_step,
                in_axes=(["shard", ...], (["batch", ...], ["batch", ...], ["batch", ...])),
                out_axes=["batch", ...],
                axis_resources={'shard': 'mp', 'batch': 'dp'})
with mesh: outputs = eval_step_xmap(state[0], (sample['obs'], sample['target'], ctx_length))

  warn("xmap is an experimental feature and probably has bugs!")


In [102]:
sample = {'obs': x[:, :-1][None, ...], 'target': x[:, 1:][None, ...]}
step = np.ones((mp_per_host,)).astype('int32')
batch = {'ctx': x[:, :-1][None, ...], 'tgt': x[:, 1:][None, ...]}
# batch = freeze(batch)  # from dalle-mini freeze batch to pass safely to jax transforms

In [53]:
state = (params, opt_state)

In [129]:
state, train_metrics = train_step(state, batch)



In [130]:
train_metrics

{'loss': DeviceArray(4.151296, dtype=float32),
 'last_loss': DeviceArray(4.151296, dtype=float32)}

In [56]:
p_train_step = pjit(train_step, (state_shard, None, (P(None, 'dp'), P(None, 'dp'))),
                                (None, None, state_shard, None))
with mesh: loss, last_loss, state, step = p_train_step(state, step, (sample['obs'], sample['target']))

In [131]:
p_train_step = pjit(train_step, (state_spec, P(None, 'dp')), (state_spec, None))
with mesh: state, train_metrics = p_train_step(state, batch)

In [53]:
train_step_xmap = xmap(fun=train_step,
                 in_axes=(["shard", ...], ["shard", ...], (["batch", ...], ["batch", ...])),
                 out_axes=(["batch", ...], ["batch", ...], ["shard", ...], ["shard", ...]),
                 donate_argnums=(0,),  # also needed by pjit
                 axis_resources={'shard': 'mp', 'batch': 'dp'})
with mesh: loss, last_loss, state, step = train_step_xmap(state, step, (sample['obs'], sample['target']))

  warn("xmap is an experimental feature and probably has bugs!")


In [55]:
loss, last_loss, step

(ShardedDeviceArray([4.2005105], dtype=float32),
 ShardedDeviceArray([3.0642111], dtype=float32),
 ShardedDeviceArray([2, 2, 2, 2, 2, 2, 2, 2], dtype=int32))

In [8]:
def train_loss(x, y):
    transformer = CausalTransformerShard(config)
    return transformer.loss(x, y)

net = hk.without_apply_rng(hk.transform(train_loss))
params = net.init(jax.random.PRNGKey(42), x, x)
param_shapes = jax.eval_shape(init, jax.random.PRNGKey(42), x)
net.apply(params, x[:, :-1], x[:, 1:])

In [25]:
with jax.experimental.maps.mesh(devices, ('dp', 'mp')): 
    init_xmap = jax.experimental.maps.xmap(fun=init,
                                        in_axes=(["shard", ...],
                                                 ["batch", ...]),
                                        out_axes=["shard", ...],
                                        axis_resources={'shard': 'mp', 'batch': 'dp'})

  warn("xmap is an experimental feature and probably has bugs!")


In [None]:
with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
    out = self.eval_xmap(self.state, sample["obs"], sample["target"], ctx_length)