Skip to content

Commit 988ad5d

Browse files
authored
[AutoParallel] add use custom op test in pipeline mode (#10653)
* add use rms norm ci test in pp mode * fix loss base * fix base
1 parent 062debf commit 988ad5d

File tree

1 file changed

+60
-57
lines changed

1 file changed

+60
-57
lines changed

scripts/distribute/ci_case_auto.sh

Lines changed: 60 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -472,64 +472,67 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2() {
472472
task_name="llama_auto_bs8_dp2mp2pp2"
473473
case_out_dir="output/$task_name"
474474
case_log_dir="output/$task_name""_log"
475-
rm -rf $case_out_dir
476-
rm -rf $case_log_dir
477475

478-
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" --log_dir $case_log_dir run_pretrain_auto.py \
479-
--model_type "llama" \
480-
--model_name_or_path "facebook/llama-7b" \
481-
--tokenizer_name_or_path "facebook/llama-7b" \
482-
--input_dir "./data" \
483-
--output_dir $case_out_dir \
484-
--split 949,50,1 \
485-
--max_seq_length 2048 \
486-
--hidden_size 1024 \
487-
--intermediate_size 3072 \
488-
--num_hidden_layers 8 \
489-
--num_attention_heads 32 \
490-
--per_device_train_batch_size 1 \
491-
--per_device_eval_batch_size 4 \
492-
--gradient_accumulation_steps 4 \
493-
--use_flash_attention 0 \
494-
--use_fused_rms_norm 0 \
495-
--fp16 0 \
496-
--fp16_opt_level "O2" \
497-
--scale_loss 1024 \
498-
--pipeline_parallel_degree 2 \
499-
--tensor_parallel_degree 2 \
500-
--sharding_parallel_degree 1 \
501-
--learning_rate 0.0001 \
502-
--min_learning_rate 0.00001 \
503-
--max_steps 10 \
504-
--save_steps 5000 \
505-
--weight_decay 0.01 \
506-
--warmup_ratio 0.01 \
507-
--logging_steps 1 \
508-
--dataloader_num_workers 1 \
509-
--sharding "" \
510-
--eval_steps 1000000 \
511-
--disable_tqdm true \
512-
--continue_training 0 \
513-
--recompute 0 \
514-
--do_train \
515-
--do_eval \
516-
--device "gpu" \
517-
--data_impl "mmap" \
518-
--enable_auto_parallel 1 \
519-
--to_static 0 \
520-
--max_grad_norm 1.0 \
521-
>>${log_path}/$FUNCNAME 2>&1
522-
loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
523-
ips=-1
524-
mem=-1
525-
echo "result: loss=$loss ips=$ips mem=$mem"
526-
loss_base=9.3513937
527-
if [ $IS_A100 -ne 0 ];then
528-
loss_base=9.39356422
529-
fi
530-
ips_base=-1
531-
mem_base=-1
532-
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
476+
for use_fused_rms_norm in "1" "0"; do
477+
rm -rf $case_out_dir
478+
rm -rf $case_log_dir
479+
480+
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" --log_dir $case_log_dir run_pretrain_auto.py \
481+
--model_type "llama" \
482+
--model_name_or_path "facebook/llama-7b" \
483+
--tokenizer_name_or_path "facebook/llama-7b" \
484+
--input_dir "./data" \
485+
--output_dir $case_out_dir \
486+
--split 949,50,1 \
487+
--max_seq_length 2048 \
488+
--hidden_size 1024 \
489+
--intermediate_size 3072 \
490+
--num_hidden_layers 8 \
491+
--num_attention_heads 32 \
492+
--per_device_train_batch_size 1 \
493+
--per_device_eval_batch_size 4 \
494+
--gradient_accumulation_steps 4 \
495+
--use_flash_attention 0 \
496+
--use_fused_rms_norm ${use_fused_rms_norm} \
497+
--fp16 0 \
498+
--fp16_opt_level "O2" \
499+
--scale_loss 1024 \
500+
--pipeline_parallel_degree 2 \
501+
--tensor_parallel_degree 2 \
502+
--sharding_parallel_degree 1 \
503+
--learning_rate 0.0001 \
504+
--min_learning_rate 0.00001 \
505+
--max_steps 10 \
506+
--save_steps 5000 \
507+
--weight_decay 0.01 \
508+
--warmup_ratio 0.01 \
509+
--logging_steps 1 \
510+
--dataloader_num_workers 1 \
511+
--sharding "" \
512+
--eval_steps 1000000 \
513+
--disable_tqdm true \
514+
--continue_training 0 \
515+
--recompute 0 \
516+
--do_train \
517+
--do_eval \
518+
--device "gpu" \
519+
--data_impl "mmap" \
520+
--enable_auto_parallel 1 \
521+
--to_static 0 \
522+
--max_grad_norm 1.0 \
523+
>>${log_path}/$FUNCNAME 2>&1
524+
loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
525+
ips=-1
526+
mem=-1
527+
echo "use_fused_rms_norm=$use_fused_rms_norm result: loss=$loss ips=$ips mem=$mem"
528+
loss_base=9.3513937
529+
if [ $IS_A100 -ne 0 ];then
530+
loss_base=9.39356422
531+
fi
532+
ips_base=-1
533+
mem_base=-1
534+
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
535+
done
533536
echo "=========== $FUNCNAME run end ==========="
534537
}
535538

0 commit comments

Comments
 (0)