# MaxText Lingvo fine-tuning

Loads a pre-existing checkpoint of the converged LG dummy model (single
attention layer), and then runs some additional training steps (fine-tuning).

### TODO:
need to create/import all of:
- [x] model: load from checkpoint
- [x] tx: optimizer
- [x] config
- [x] init_rng
- [x] mesh
- [x] checkpoint_manager
- [ ] basic training function
- [ ] short version of this NB that uses just config and calls `train.train_loop`


## Config

In [1]:
my_bucket = "gs://patflick-maxtext-lingvo"
base_output_directory= my_bucket + "/maxtext/dummy/20231120"

# Path to converted checkpoints
checkpoint_path = "gs://mazumdera-test-bucket/maxtext/lg/11032023/1/1xv3-8/checkpoints/"

# Adapt config to our model
extra_config = {
    # Run config
    "run_name": "finetuning-dummy_v1",
    "base_output_directory": base_output_directory,
    "save_period": 100,
    "steps": 1000,

    # Load checkpoint
    "load_from_other_directory": checkpoint_path,
    "load_from_other_directory_step": 85,

    # Model config (this has to match the loaded checkpoint, otherwise ERROR!)
    "base_num_decoder_layers": 1,
    "base_num_heads": 4,
    "head_dim": 96,
    "vocab_size": 50272,
    "per_device_batch_size": 0.5,
    "base_mlp_dim": 2048,
    "base_emb_dim": 512,

    # Dataset loader to use
    "dataset_type": "lg",
    "file_pattern_for_train_data": "gs://yejingxin-us-central2/external/lg/dummy-data/train/*.tfrecords",
    "file_pattern_for_eval_data": "gs://yejingxin-us-central2/external/lg/dummy-data/valid/*tfrecords",

    # Parallelism and KV config
    "dcn_tensor_parallelism": 1,
    "ici_tensor_parallelism": 4,
    "enable_flash_attention": False,
}

import pyconfig
pyconfig.initialize(["", "configs/base.yml"] + [f"{k}={v}" for k,v in extra_config.items()])
config = pyconfig.config
pyconfig._config.keys

OrderedDict([('run_name', 'finetuning-dummy_v1'),
             ('load_parameters_path', ''),
             ('load_from_other_directory',
              'gs://mazumdera-test-bucket/maxtext/lg/11032023/1/1xv3-8/checkpoints/'),
             ('load_from_other_directory_step', 85),
             ('reuse_example_batch', 0),
             ('metrics_file', ''),
             ('gcs_metrics', False),
             ('dtype', dtype(bfloat16)),
             ('int8_training', False),
             ('global_parameter_scale', 1),
             ('base_emb_dim', 512),
             ('base_num_heads', 4),
             ('base_mlp_dim', 2048),
             ('base_num_decoder_layers', 1),
             ('head_dim', 96),
             ('mlp_activations', ['gelu']),
             ('dropout_rate', 0),
             ('logits_via_embedding', True),
             ('remat_policy', 'full'),
             ('scan_layers', True),
             ('param_scan_axis', 1),
             ('enable_flash_attention', False),
             ('reco

In [8]:
#config 
my_bucket = "gs://patflick-maxtext-lingvo"
base_output_directory=my_bucket + "/maxtext/lingvo/20231108/1"

# Aisha's checkpoints
#load_checkpoint_dir="gs://mazumdera-test-bucket/maxtext/lg/10142023/1/1xv3-8/checkpoints/"
load_checkpoint_dir="gs://mazumdera-test-bucket/maxtext/lg/11032023/1/1xv3-8/checkpoints/"
#base_output_directory="base_output_directory=gs://mazumdera-test-bucket/maxtext/lg/11032023/1"

# Train/Eval data
file_pattern_for_train_data="file_pattern_for_train_data=gs://yejingxin-us-central2/external/lg/dummy-data/train/*.tfrecords"
file_pattern_for_eval_data="file_pattern_for_eval_data=gs://yejingxin-us-central2/external/lg/dummy-data/valid/*tfrecords"

base_num_decoder_layers="base_num_decoder_layers=1"
base_num_heads = "base_num_heads=4"
head_nums = "head_dim=96"
dataset_type = "dataset_type=lg"

commandline_args = ["dummy", 
                    "configs/base.yml",
                    "run_name=1xv4-8",
                    "dcn_data_parallelism=1",
                    "save_period=5",
                    # TODO: configure parallelism!
                    "ici_data_parallelism=2",
                    "ici_tensor_parallelism=2",
                    "ici_fsdp_parallelism=1",
                    "steps=20",
                    "enable_profiler=true",
                    "remat_policy=full",
                    "base_emb_dim=512", 
                    base_num_heads,
                    head_nums,
                    "vocab_size=50272",
                    base_num_decoder_layers,
                    "per_device_batch_size=0.5",
                    "enable_profiler=true",
                    "base_mlp_dim=2048", 
                    # File dependencies
                    file_pattern_for_train_data, 
                    file_pattern_for_eval_data,
                    "base_output_directory=" + base_output_directory,
                    "load_from_other_directory=" + load_checkpoint_dir,
                    "load_from_other_directory_step=50",
                    dataset_type,
                    "max_predict_length=512",
                    #"jax_default_prng_impl=unsafe_rgb"   # required/overwritten by train.train_step. if not set here, will cause failures later
                   ]

import pyconfig
pyconfig.initialize(commandline_args)
config = pyconfig.config
pyconfig._config.keys

OrderedDict([('run_name', '1xv4-8'),
             ('load_parameters_path', ''),
             ('load_from_other_directory',
              'gs://mazumdera-test-bucket/maxtext/lg/11032023/1/1xv3-8/checkpoints/'),
             ('load_from_other_directory_step', 50),
             ('reuse_example_batch', 0),
             ('metrics_file', ''),
             ('gcs_metrics', False),
             ('dtype', dtype(bfloat16)),
             ('int8_training', False),
             ('global_parameter_scale', 1),
             ('base_emb_dim', 512),
             ('base_num_heads', 4),
             ('base_mlp_dim', 2048),
             ('base_num_decoder_layers', 1),
             ('head_dim', 96),
             ('mlp_activations', ['gelu']),
             ('dropout_rate', 0),
             ('logits_via_embedding', True),
             ('remat_policy', 'full'),
             ('scan_layers', True),
             ('param_scan_axis', 1),
             ('enable_flash_attention', False),
             ('record_internal_n

## Loading pre-existing checkpoint

Uses the config to create the optimizer, model, and mesh. Then loads the checkpoint into the model/optimizer state.

In [2]:
import checkpointing
checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
      checkpoint_dir = config.checkpoint_dir,
      enable_checkpointing = True,   # need to be true to allow loading other checkpints
      use_async = config.async_checkpointing,
      save_interval_steps = config.save_period
  )

Creating checkpoint manager...




Checkpoint manager created!


In [3]:
import jax
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

In [4]:
# setup device mesh
import max_utils
from jax.sharding import Mesh

devices_array = max_utils.create_device_mesh(config)
mesh = Mesh(devices_array, config.mesh_axes)
mesh

Devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)] (num_devices: 4)
Decided on mesh: [[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)
   TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0)
   TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0)
   TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]]]


Mesh(device_ids=array([[[0, 2, 1, 3]]]), axis_names=('data', 'fsdp', 'tensor'))

In [5]:
import jax
from jax import random
# Initial PRNG Keys
jax.config.update('jax_default_prng_impl', 'unsafe_rbg')   # need to set here. train.py later sets this and then causes shape mismatch
init_rng, nextrng = random.split(random.PRNGKey(config.init_weights_seed), 2)

In [6]:
# create model
from layers import Transformer
model = Transformer(config, mesh)

In [7]:
import optax
# create optimizer
#TODO(from original notebook): also compare with optax.adafactor
tx = optax.adamw(
       max_utils.create_learning_rate_schedule(config),
       b1=config.adam_b1,
       b2=config.adam_b2,
       eps=config.adam_eps,
       eps_root=config.adam_eps_root,
       weight_decay=config.adam_weight_decay,
     )

In [8]:
# load checkpoint and initialize model
state_read_from_ckpt, state_mesh_annotations_read_from_ckpt = max_utils.setup_initial_state(model, tx, config, init_rng, mesh, checkpoint_manager)


restoring state from gs://mazumdera-test-bucket/maxtext/lg/11032023/1/1xv3-8/checkpoints/ step 85


I0000 00:00:1700696741.957102 1183063 gcs_resource.cc:99] Using default AdmissionQueue with limit 32
I0000 00:00:1700696741.958659 1184610 google_auth_provider.cc:179] Running on GCE, using service account 903354779218-compute@developer.gserviceaccount.com


Print out the loaded transformer model architecture. This currently is a simple fake model with a single self-attention layer.

In [9]:
state = state_read_from_ckpt

type(state.params)
def print_params_shape(params):
    def _print_params_dict_rec(sub_dict, ident=""):
        
        for key, value in sub_dict.items():
            if type(value) == dict:
                print(f"{ident}{value}:")
                _print_params_dict_rec(value, ident + "  ")

            line = key
            if type(key) != str:
                line = str(type(key))
            if type(value) == dict:
                print(ident + line + ":")
                _print_params_dict_rec(value, ident + "  ")
            elif "shape" in dir(value):
                print(ident + line + ": " + str(type(value)) + str(value.shape))
            else:
                print(ident + line + ": " + str(type(value)))
    _print_params_dict_rec(params)
            
print_params_shape(state.params)

decoder:
  decoder:
    mlp:
      wi:
        kernel: <class 'jaxlib.xla_extension.ArrayImpl'>(512, 1, 2048)
      wo:
        kernel: <class 'jaxlib.xla_extension.ArrayImpl'>(2048, 1, 512)
    pre_mlp_layer_norm:
      scale: <class 'jaxlib.xla_extension.ArrayImpl'>(512, 1)
    pre_self_attention_layer_norm:
      scale: <class 'jaxlib.xla_extension.ArrayImpl'>(512, 1)
    relpos_bias:
      rel_embedding: <class 'jaxlib.xla_extension.ArrayImpl'>(4, 1, 32)
    self_attention:
      key:
        kernel: <class 'jaxlib.xla_extension.ArrayImpl'>(512, 1, 4, 96)
      out:
        kernel: <class 'jaxlib.xla_extension.ArrayImpl'>(4, 1, 96, 512)
      query:
        kernel: <class 'jaxlib.xla_extension.ArrayImpl'>(512, 1, 4, 96)
      value:
        kernel: <class 'jaxlib.xla_extension.ArrayImpl'>(512, 1, 4, 96)
  decoder_norm:
    scale: <class 'jaxlib.xla_extension.ArrayImpl'>(512,)
token_embedder:
  embedding: <class 'jaxlib.xla_extension.ArrayImpl'>(50272, 512)


### Create sharding

In [10]:
from jax.sharding import PartitionSpec as P

# Compute sharding by combining checkpoint PartitionSpecs with config's mesh
data_pspec = P(*config.data_sharding)
state_mesh_shardings_read_from_ckpt = jax.tree_map(
  lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations_read_from_ckpt)
data_sharding = jax.tree_map(
  lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)

## Continuing training on checkpoint

Loading input data processor/iterator

In [11]:
# Load training data
from input_pipeline import create_data_iterator_with_tokenizer

data_iterator = create_data_iterator_with_tokenizer(config, mesh)

  from .autonotebook import tqdm as notebook_tqdm


Trying to read training data from path: gs://yejingxin-us-central2/external/lg/dummy-data/train/*.tfrecords
Training dataset has: 500000 entries
Trying to read eval data from path: gs://yejingxin-us-central2/external/lg/dummy-data/valid/*tfrecords
Eval dataset has: 50000 entries


Training loop

In [12]:
import train
import numpy as np
from flax.linen import partitioning as nn_partitioning

# compile train function
p_train_step = jax.jit(train.train_step,
                       in_shardings=(state_mesh_shardings_read_from_ckpt, data_sharding, None),
                       out_shardings=(state_mesh_shardings_read_from_ckpt, None, None),
                       static_argnums=(0,1,),
                       donate_argnums=2)

# run fine-tuning training
batch = None
for step in np.arange(train.get_first_step(state), 1000):
    # load batch
    batch = train.load_next_batch(data_iterator, batch, config)

    # run training step
    with nn_partitioning.axis_rules(config.logical_axis_rules):
      state, metrics, nextrng = p_train_step(
          model, config, state, batch, nextrng
      )

    learning_loss = metrics['scalar']['learning/loss']
    print("train step", step, "loss:", learning_loss)



Found 4 devices.
train step 86 loss: 12.052555
train step 87 loss: 11.998739
train step 88 loss: 11.984926
train step 89 loss: 11.960432
train step 90 loss: 11.870367
train step 91 loss: 11.914023
train step 92 loss: 11.878504
train step 93 loss: 11.812907
train step 94 loss: 11.832667
train step 95 loss: 11.78319
train step 96 loss: 11.790079
train step 97 loss: 11.736826
train step 98 loss: 11.72094
train step 99 loss: 11.696386
train step 100 loss: 11.693748
train step 101 loss: 11.651544
train step 102 loss: 11.649
train step 103 loss: 11.614431
train step 104 loss: 11.627368
train step 105 loss: 11.602772
train step 106 loss: 11.591671
train step 107 loss: 11.580237
train step 108 loss: 11.540206
train step 109 loss: 11.581579
train step 110 loss: 11.553307
train step 111 loss: 11.579695
train step 112 loss: 11.568599
train step 113 loss: 11.501836
train step 114 loss: 11.530902
train step 115 loss: 11.49118
train step 116 loss: 11.50413
train step 117 loss: 11.501064
train step 1