# MaxText Lingvo continued training

Loads a pre-existing checkpoint of the converged LG dummy model (single
attention layer), and then runs some additional training steps (fine-tuning).

## Config

In [4]:
my_bucket = "gs://patflick-maxtext-lingvo"
base_output_directory= my_bucket + "/maxtext/dummy/20231120"

# Path to converted checkpoints
checkpoint_path = "gs://mazumdera-test-bucket/maxtext/lg/11032023/1/1xv3-8/checkpoints/"

# Adapt config to our model
extra_config = {
    # Run config
    "run_name": "finetuning-dummy_v2",
    "base_output_directory": base_output_directory,
    "save_period": 100,
    "steps": 1000,

    # Load checkpoint
    "load_from_other_directory": checkpoint_path,
    "load_from_other_directory_step": 85,

    # Model config (this has to match the loaded checkpoint, otherwise ERROR!)
    "base_num_decoder_layers": 1,
    "base_num_heads": 4,
    "head_dim": 96,
    "vocab_size": 50272,
    "per_device_batch_size": 0.5,
    "base_mlp_dim": 2048,
    "base_emb_dim": 512,

    # Dataset loader to use
    "dataset_type": "lg",
    "file_pattern_for_train_data": "gs://yejingxin-us-central2/external/lg/dummy-data/train/*.tfrecords",
    "file_pattern_for_eval_data": "gs://yejingxin-us-central2/external/lg/dummy-data/valid/*tfrecords",

    # Parallelism and KV config
    "dcn_tensor_parallelism": 1,
    "ici_tensor_parallelism": 4,
    "enable_flash_attention": False,
}

import pyconfig
pyconfig.initialize(["", "configs/base.yml"] + [f"{k}={v}" for k,v in extra_config.items()])
config = pyconfig.config
pyconfig._config.keys

OrderedDict([('run_name', 'finetuning-dummy_v1'),
             ('load_parameters_path', ''),
             ('load_from_other_directory',
              'gs://mazumdera-test-bucket/maxtext/lg/11032023/1/1xv3-8/checkpoints/'),
             ('load_from_other_directory_step', 85),
             ('reuse_example_batch', 0),
             ('metrics_file', ''),
             ('gcs_metrics', False),
             ('dtype', dtype(bfloat16)),
             ('int8_training', False),
             ('global_parameter_scale', 1),
             ('base_emb_dim', 512),
             ('base_num_heads', 4),
             ('base_mlp_dim', 2048),
             ('base_num_decoder_layers', 1),
             ('head_dim', 96),
             ('mlp_activations', ['gelu']),
             ('dropout_rate', 0),
             ('logits_via_embedding', True),
             ('remat_policy', 'full'),
             ('scan_layers', True),
             ('param_scan_axis', 1),
             ('enable_flash_attention', False),
             ('reco

In [5]:
import train
import jax

train.train_loop(config) 
jax.distributed.shutdown()

Creating checkpoint manager...
Checkpoint manager created!
Devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)] (num_devices: 4)
Decided on mesh: [[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)
   TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0)
   TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0)
   TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]]]
Trying to read training data from path: gs://yejingxin-us-central2/external/lg/dummy-data/train/*.tfrecords
Training dataset has: 500000 entries
Trying to read eval data from path: gs://yejingxin-us-central2/external/lg/dummy-data/valid/*tfrecords
Eval dataset has: 50000 entries
restoring state from this run's directory latest step         900
number param