diff --git a/README.md b/README.md index 94343d7..c513235 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,15 @@ python pretrained/GPT-NeoX-20B/prepare.py The weights for this model will be in the `pretrained/GPT-NeoX-20B/EleutherAI_gpt-neox-20b`. +In case you want to fine-tune other gpt-neox models, e.g. [the Pythia model suite](https://huggingface.co/models?sort=downloads&search=pythia), you can specify the HF model name, for example: + +```shell +python pretrained/GPT-NeoX-20B/prepare.py --model-name EleutherAI/pythia-6.9b-deduped +``` + +And the weights for this model will be in the `pretrained/GPT-NeoX-20B/EleutherAI_pythia-6.9b-deduped`. + + # Training and Finetuning ## (Optional) 8bit Adam @@ -113,21 +122,38 @@ As the training loop runs, checkpoints are saved to the `model_ckpts` directory Please see [the training README](training/README.md) for more details about customizing the training run. +The `training/finetune_Pythia-Chat-Base-7B.sh` script is another example to fine-tune a 7B pythia (gpt-neox) model. The script launches 8 processes with a pipeline-parallel degree of 4 and a data-parallel degree of 2. + # Converting Weights to Huggingface Format Before you can use this model to perform inference, it must be converted to the Huggingface format. Run this command from the root of the repo to do so. ```shell -mkdir huggingface_models \ - && python tools/convert_to_hf_gptneox.py \ +mkdir huggingface_models \ + && python tools/convert_to_hf_gptneox.py \ --ckpt-path model_ckpts/GPT-Neo-XT-Chat-Base-20B/checkpoint_100 \ --save-path huggingface_models/GPT-NeoXT-Chat-Base-20B \ --n-stages 8 \ - --n-layer-per-stage 6 + --n-layer-per-stage 6 \ + --fp16 ``` +where the `--fp16` flag will load and store models in fp16. Make sure to replace `model_ckpts/GPT-Neo-XT-Chat-Base-20B/checkpoint_100` with the latest checkpoint in the `model_ckpts/GPT-Neo-XT-Chat-Base-20B` directory. +If you need to convert ckpts of other gpt-neox variants, make sure to specify the correct config name for your variant. +For example, if you want to convert a checkpoint fine-tuned from `EleutherAI/pythia-6.9b-deduped`, you should indicate this as a config name: +```shell +python tools/convert_to_hf_gptneox.py \ + --config-name EleutherAI/pythia-6.9b-deduped \ + --ckpt-path model_ckpts/Pythia-Chat-Base-7B/checkpoint_100 \ + --save-path huggingface_models/Pythia-Chat-Base-7B \ + --n-stages 4 \ + --n-layer-per-stage 8 \ + --fp16 +``` + + # Inference To help you test the model, we provide a simple test command line test harness to interact with the bot. diff --git a/pretrained/GPT-NeoX-20B/prepare.py b/pretrained/GPT-NeoX-20B/prepare.py index d30381d..a63e123 100644 --- a/pretrained/GPT-NeoX-20B/prepare.py +++ b/pretrained/GPT-NeoX-20B/prepare.py @@ -22,11 +22,11 @@ if not os.path.exists(save_path): os.mkdir(save_path) + print('loading model from HF...') config = AutoConfig.from_pretrained(args.model_name) config.save_pretrained(save_path) tokenizer = AutoTokenizer.from_pretrained(args.model_name) tokenizer.save_pretrained(save_path) - # offload model from memory to disk if offload-dir is specified if args.offload_dir is not None: if not os.path.exists(args.offload_dir): @@ -34,16 +34,23 @@ model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16, device_map="auto", offload_folder=args.offload_dir) else: model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16) + print('loaded model from HF...') + print('converting the embedding layer...') item = {} item['embed_in.weight'] = model.gpt_neox.embed_in.weight torch.save(item, os.path.join(save_path, 'pytorch_embs.pt')) + print('converted the embedding layer.') for i in range(len(model.gpt_neox.layers)): + print(f'converting the {i}-th transformer layer...') torch.save(model.gpt_neox.layers[i].state_dict(), os.path.join(save_path, f'pytorch_{i}.pt')) + print(f'converted the {i}-th transformer layer.') + print('converting the lm_head layer...') item = {} item['embed_out.weight'] = model.embed_out.weight item['final_layer_norm.weight'] = model.gpt_neox.final_layer_norm.weight item['final_layer_norm.bias'] = model.gpt_neox.final_layer_norm.bias torch.save(item, os.path.join(save_path, 'pytorch_lm_head.pt')) + print('converted the lm_head layer.') \ No newline at end of file diff --git a/tools/convert_to_hf_gptneox.py b/tools/convert_to_hf_gptneox.py index 07132ad..751d092 100644 --- a/tools/convert_to_hf_gptneox.py +++ b/tools/convert_to_hf_gptneox.py @@ -56,13 +56,14 @@ def load_decentralized_checkpoint(model, checkpoint_path, n_stages=2, n_layer_pe elif i == n_stages - 1: for j in range(n_layer_per_stage): - if i*n_layer_per_stage + j == 44: - break _tmp = {k[len(f"{j}."):]:v for k,v in checkpoint.items() if k.startswith(f"{j}.")} if len(_tmp) == 0: break # torch.save(_tmp, os.path.join(output_path, f'pytorch_{i*n_layer_per_stage + j}.pt')) model.gpt_neox.layers[i*n_layer_per_stage + j].load_state_dict(_tmp) + if i*n_layer_per_stage + j == len(model.gpt_neox.layers) - 1: + j += 1 + break _tmp = {k[len(f"{j}."):]:v for k,v in checkpoint.items() if k.startswith(f"{j}.")} if len(_tmp) == 0: @@ -88,14 +89,17 @@ def load_decentralized_checkpoint(model, checkpoint_path, n_stages=2, n_layer_pe if __name__ == '__main__': parser = argparse.ArgumentParser(description='Convert HF checkpoints') + parser.add_argument('--config-name', type=str, default='EleutherAI/gpt-neox-20b', + help='config-name') parser.add_argument('--ckpt-path', type=str, default=None, - help='model-name') + help='ckpt-path') parser.add_argument('--save-path', type=str, default=None, - help='model-name') + help='save-path') parser.add_argument('--n-stages', type=int, default=8, help='pipeline group size') parser.add_argument('--n-layer-per-stage', type=int, default=6, help='n layers per GPU device') + parser.add_argument('--fp16', default=False, action='store_true') args = parser.parse_args() assert args.ckpt_path is not None @@ -104,13 +108,26 @@ def load_decentralized_checkpoint(model, checkpoint_path, n_stages=2, n_layer_pe if not os.path.exists(args.save_path): os.mkdir(args.save_path) - config = AutoConfig.from_pretrained('EleutherAI/gpt-neox-20b') - tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b') + print('loading config...') + config = AutoConfig.from_pretrained(args.config_name) + print('loaded config.') + print('loading tokenizer...') + tokenizer = AutoTokenizer.from_pretrained(args.config_name) + print('loaded tokenizer.') + print('creating empty model...') model = create_empty_gptneox(config) + if args.fp16: + model = model.half() + print('created empty model.') + print('loading model ckpt...') load_decentralized_checkpoint( model, args.ckpt_path, n_stages=args.n_stages, n_layer_per_stage=args.n_layer_per_stage, ) + print('loaded model ckpt.') + print('saving HF model...') model.save_pretrained(args.save_path) + print(f'saved HF model to `{args.save_path}`') config.save_pretrained(args.save_path) tokenizer.save_pretrained(args.save_path) + diff --git a/training/finetune_Pythia-Chat-Base-7B.sh b/training/finetune_Pythia-Chat-Base-7B.sh new file mode 100644 index 0000000..92a92e0 --- /dev/null +++ b/training/finetune_Pythia-Chat-Base-7B.sh @@ -0,0 +1,83 @@ +DIR=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) + +netif=lo +export GLOO_SOCKET_IFNAME=${netif} +export NCCL_SOCKET_IFNAME=${netif} +export MODEL_NAME=Pythia-Chat-Base-7B + +export SHOW_DATA=0 + +BASE_MODEL="${DIR}/../pretrained/GPT-NeoX-20B/EleutherAI_pythia-6.9b-deduped/" + +CHECKPOINT_STEPS=100 + +DATASETS="\ +${DIR}/../data/OIG/files/unified_ni.jsonl:0.2,\ +${DIR}/../data/OIG/files/unified_p3.jsonl:0.5,\ +${DIR}/../data/OIG/files/unified_flan.jsonl:0.2,\ +${DIR}/../data/OIG/files/unified_chip2.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_rallio_safety_and_prosocial.jsonl:0.1,\ +${DIR}/../data/OIG/files/unified_soda_dialog.jsonl:0.1,\ +${DIR}/../data/OIG/files/unified_unifiedskg_instructions.jsonl:0.1,\ +${DIR}/../data/OIG/files/unified_merged_code_xp3.jsonl:0.1,\ +${DIR}/../data/OIG/files/unified_oscar_en_sample_dialog.jsonl:0.1,\ +${DIR}/../data/OIG/files/unified_ul2_plus_oscar_en_sample_dialog.jsonl:0.1,\ +${DIR}/../data/OIG/files/unified_multi_news.jsonl:0.05,\ +${DIR}/../data/OIG/files/unified_openai_summarize_tldr.jsonl:0.05,\ +${DIR}/../data/OIG/files/unified_squad_v2.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_nq.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_poetry_instructions.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_sqlv2.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_unnatural_instructions.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_conv_finqa.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_essays.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_plot_screenplay_books_dialog.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_grade_school_math_instructions.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_mathqa_flanv2_kojma_cot.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_joke_explanations.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_cuad.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_abstract_infill.jsonl:0.1,\ +${DIR}/../data/OIG/files/unified_image_prompts_instructions.jsonl:0.01 \ +" + +ARGS="--model-name ${BASE_MODEL} \ +--tokenizer-name ${BASE_MODEL} \ +--project-name together \ +--model-type gptneox \ +--optimizer adam \ +--seed 42 \ +--load-pretrained-model true \ +--task-name \ +"${DATASETS}" \ +--checkpoint-path ${DIR}/../model_ckpts/${MODEL_NAME} \ +--total-steps 20000 --warmup-steps 10 --train-warmup-steps 0 \ +--checkpoint-steps ${CHECKPOINT_STEPS} \ +--lr 1e-5 --seq-length 2048 --batch-size 32 --micro-batch-size 1 --gradient-accumulate-step 1 \ +--dist-url tcp://127.0.0.1:7033 \ +--num-layers 8 --embedding-dim 4096 \ +--world-size 8 --pipeline-group-size 4 --data-group-size 2 \ +--job-id 0 --net-interface ${netif} \ +--fp16 \ +--dp-backend nccl \ +--dp-mode allreduce \ +--pp-mode gpipe --profiling no-profiling" + + +(trap 'kill 0' SIGINT; \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 0 --rank 0 \ + & \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 1 --rank 1 \ + & \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 2 --rank 2 \ + & \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 3 --rank 3 \ + & \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 4 --rank 4 \ + & \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 5 --rank 5 \ + & \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 6 --rank 6 \ + & \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 7 --rank 7 \ + & \ +wait)