In [1]:
import jax
import flax
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

# Converting to Orbax Sharded Checkpoint

## Resources

- https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#multi-host-multi-process-checkpointing
- https://orbax.readthedocs.io/en/latest/#checkpointing

In [2]:
from EasyLM.checkpoint import StreamingCheckpointer
import orbax.checkpoint

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
params = StreamingCheckpointer.load_trainstate_checkpoint(
        'params::/home/supermdguy/open_llama_3b_easylm/open_llama_3b_easylm', disallow_trainstate=True
    )

In [10]:
params = params[1]['params']

In [19]:
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
checkpointer.save('/home/supermdguy/open_llama_3b_orbax/', params)

## Load Sharded Checkpoint

In [9]:
from functools import partial

import numpy as np
import mlxu

import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
import optax
from transformers import GenerationConfig, FlaxLogitsProcessorList

from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.serving import LMServer
from EasyLM.jax_utils import (
    JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, tree_apply,
    set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
    with_sharding_constraint, FlaxTemperatureLogitsWarper
)
from EasyLM.models.llama.llama_model import LLaMAConfig, FlaxLLaMAForCausalLM

In [11]:
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
params = checkpointer.restore('/home/supermdguy/open_llama_3b_orbax/')

In [20]:
model_ps = match_partition_rules(LLaMAConfig.get_partition_rules(), params)
shard_fns, _ = make_shard_and_gather_fns(
    model_ps, get_float_dtype_by_name('bf16')
)

In [21]:
mesh = LLaMAConfig.get_jax_mesh('1,-1,1')
with mesh:
    params = tree_apply(shard_fns, params)

In [26]:
shardings = jax.tree_map(lambda x: x.sharding, params)
restore_args = orbax.checkpoint.checkpoint_utils.construct_restore_args(
    params, shardings
)

In [30]:
shardings

{'lm_head': {'kernel': NamedSharding(mesh={'dp': 1, 'fsdp': 8, 'mp': 1}, spec=PartitionSpec('fsdp', 'mp'))},
 'transformer': {'h': {'0': {'attention': {'wk': {'kernel': NamedSharding(mesh={'dp': 1, 'fsdp': 8, 'mp': 1}, spec=PartitionSpec('fsdp', 'mp'))},
     'wo': {'kernel': NamedSharding(mesh={'dp': 1, 'fsdp': 8, 'mp': 1}, spec=PartitionSpec('mp', 'fsdp'))},
     'wq': {'kernel': NamedSharding(mesh={'dp': 1, 'fsdp': 8, 'mp': 1}, spec=PartitionSpec('fsdp', 'mp'))},
     'wv': {'kernel': NamedSharding(mesh={'dp': 1, 'fsdp': 8, 'mp': 1}, spec=PartitionSpec('fsdp', 'mp'))}},
    'attention_norm': {'kernel': NamedSharding(mesh={'dp': 1, 'fsdp': 8, 'mp': 1}, spec=PartitionSpec(None,))},
    'feed_forward': {'w1': {'kernel': NamedSharding(mesh={'dp': 1, 'fsdp': 8, 'mp': 1}, spec=PartitionSpec('fsdp', 'mp'))},
     'w2': {'kernel': NamedSharding(mesh={'dp': 1, 'fsdp': 8, 'mp': 1}, spec=PartitionSpec('mp', 'fsdp'))},
     'w3': {'kernel': NamedSharding(mesh={'dp': 1, 'fsdp': 8, 'mp': 1}, spec

In [31]:
shardings

{'lm_head': {'kernel': NamedSharding(mesh={'dp': 1, 'fsdp': 8, 'mp': 1}, spec=PartitionSpec('fsdp', 'mp'))},
 'transformer': {'h': {'0': {'attention': {'wk': {'kernel': NamedSharding(mesh={'dp': 1, 'fsdp': 8, 'mp': 1}, spec=PartitionSpec('fsdp', 'mp'))},
     'wo': {'kernel': NamedSharding(mesh={'dp': 1, 'fsdp': 8, 'mp': 1}, spec=PartitionSpec('mp', 'fsdp'))},
     'wq': {'kernel': NamedSharding(mesh={'dp': 1, 'fsdp': 8, 'mp': 1}, spec=PartitionSpec('fsdp', 'mp'))},
     'wv': {'kernel': NamedSharding(mesh={'dp': 1, 'fsdp': 8, 'mp': 1}, spec=PartitionSpec('fsdp', 'mp'))}},
    'attention_norm': {'kernel': NamedSharding(mesh={'dp': 1, 'fsdp': 8, 'mp': 1}, spec=PartitionSpec(None,))},
    'feed_forward': {'w1': {'kernel': NamedSharding(mesh={'dp': 1, 'fsdp': 8, 'mp': 1}, spec=PartitionSpec('fsdp', 'mp'))},
     'w2': {'kernel': NamedSharding(mesh={'dp': 1, 'fsdp': 8, 'mp': 1}, spec=PartitionSpec('mp', 'fsdp'))},
     'w3': {'kernel': NamedSharding(mesh={'dp': 1, 'fsdp': 8, 'mp': 1}, spec

In [28]:
import pickle