Skip to content

Commit e5858b1

Browse files
authored
[Auto-Parallel] Add acc=2 case for tensor_fusion and overlap in auto-dy (#10665)
* Update ci_case_auto.sh * Update loss_base * Update ci_case_auto.sh * Update ci_case_auto.sh
1 parent 73118f6 commit e5858b1

File tree

1 file changed

+87
-78
lines changed

1 file changed

+87
-78
lines changed

scripts/distribute/ci_case_auto.sh

Lines changed: 87 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -231,86 +231,95 @@ function llama_dygraph_auto_bs4_bf16_SD2() {
231231
export "$f=true"
232232
done
233233
fi
234-
235-
task_name="llama_dygraph_auto_bs4_bf16_SD2_$f"
236-
case_out_dir="output/$task_name"
237-
case_log_dir="output/$task_name""_log"
238-
rm -rf $case_out_dir
239-
rm -rf $case_log_dir
240-
241-
python -u -m paddle.distributed.launch \
242-
--gpus "0,1" \
243-
--log_dir "output/$task_name""_log" \
244-
./run_pretrain_auto.py \
245-
--model_name_or_path "meta-llama/Llama-2-7b" \
246-
--tokenizer_name_or_path "meta-llama/Llama-2-7b" \
247-
--input_dir "./data" \
248-
--output_dir "./output" \
249-
--weight_decay 0.01 \
250-
--warmup_ratio 0.01 \
251-
--max_grad_norm 1.0 \
252-
--learning_rate 3e-05 \
253-
--min_learning_rate 3e-06 \
254-
--max_steps 10 \
255-
--logging_steps 10 \
256-
--eval_steps 1000 \
257-
--save_steps 50000 \
258-
--continue_training 0 \
259-
--do_train true \
260-
--do_eval false \
261-
--do_predict false \
262-
--disable_tqdm true \
263-
--skip_profile_timer true \
264-
--device gpu \
265-
--enable_auto_parallel 1 \
266-
--per_device_train_batch_size 1 \
267-
--gradient_accumulation_steps 1 \
268-
--per_device_eval_batch_size 2 \
269-
--recompute false \
270-
--recompute_use_reentrant true \
271-
--recompute_granularity full \
272-
--pp_recompute_interval 0 \
273-
--bf16 true \
274-
--fp16_opt_level "O2" \
275-
--amp_master_grad true \
276-
--fuse_attention_ffn true \
277-
--fuse_attention_qkv true \
278-
--fused_linear_param_grad_add 1 \
279-
--use_flash_attention true \
280-
--use_fused_rope true \
281-
--use_fused_rms_norm true \
282-
--max_seq_length 4096 \
283-
--sequence_parallel false \
284-
--pipeline_parallel_degree 1 \
285-
--tensor_parallel_degree 1 \
286-
--sharding "stage1" \
287-
--data_parallel_config "enable_allreduce_avg_in_gradinent_scale gradient_sync_after_accumulate" \
288-
--sharding_parallel_config "" \
289-
--to_static 0 \
290-
--amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \
291-
--amp_custom_white_list "lookup_table" "lookup_table_v2" \
292-
--num_hidden_layers 4 \
293-
>>${log_path}/$FUNCNAME 2>&1
294-
loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
295-
ips=-1
296-
mem=-1
297-
echo "result: loss=$loss ips=$ips mem=$mem"
298-
299-
if [ -z "$flag" ]; then
300-
loss_base=9.23502579
301-
elif [ "$flag" = "FLAGS_fuse_allreduce_in_opt" ]; then
302-
loss_base=9.23502579
303-
elif [ "$flag" = "FLAGS_fuse_reducescatter_in_opt" ]; then
304-
loss_base=9.23504105
305-
elif [ "$flag" = "FLAGS_enable_tensor_fusion FLAGS_enable_sharding_overlap" ]; then
306-
loss_base=9.23504868
307-
else
308-
loss_base=-1
234+
acc_steps=(1)
235+
if [ "$flag" = "FLAGS_enable_tensor_fusion FLAGS_enable_sharding_overlap" ]; then
236+
acc_steps=(1 2)
309237
fi
238+
for acc_step in "${acc_steps[@]}"; do
239+
task_name="llama_dygraph_auto_bs4_bf16_SD2_$f"
240+
case_out_dir="output/$task_name"
241+
case_log_dir="output/$task_name""_log"
242+
rm -rf $case_out_dir
243+
rm -rf $case_log_dir
310244

311-
ips_base=-1
312-
mem_base=-1
313-
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
245+
python -u -m paddle.distributed.launch \
246+
--gpus "0,1" \
247+
--log_dir "output/$task_name""_log" \
248+
./run_pretrain_auto.py \
249+
--model_name_or_path "meta-llama/Llama-2-7b" \
250+
--tokenizer_name_or_path "meta-llama/Llama-2-7b" \
251+
--input_dir "./data" \
252+
--output_dir "./output" \
253+
--weight_decay 0.01 \
254+
--warmup_ratio 0.01 \
255+
--max_grad_norm 1.0 \
256+
--learning_rate 3e-05 \
257+
--min_learning_rate 3e-06 \
258+
--max_steps 10 \
259+
--logging_steps 10 \
260+
--eval_steps 1000 \
261+
--save_steps 50000 \
262+
--continue_training 0 \
263+
--do_train true \
264+
--do_eval false \
265+
--do_predict false \
266+
--disable_tqdm true \
267+
--skip_profile_timer true \
268+
--device gpu \
269+
--enable_auto_parallel 1 \
270+
--per_device_train_batch_size 1 \
271+
--gradient_accumulation_steps $acc_step \
272+
--per_device_eval_batch_size 2 \
273+
--recompute false \
274+
--recompute_use_reentrant true \
275+
--recompute_granularity full \
276+
--pp_recompute_interval 0 \
277+
--bf16 true \
278+
--fp16_opt_level "O2" \
279+
--amp_master_grad true \
280+
--fuse_attention_ffn true \
281+
--fuse_attention_qkv true \
282+
--fused_linear_param_grad_add 1 \
283+
--use_flash_attention true \
284+
--use_fused_rope true \
285+
--use_fused_rms_norm true \
286+
--max_seq_length 4096 \
287+
--sequence_parallel false \
288+
--pipeline_parallel_degree 1 \
289+
--tensor_parallel_degree 1 \
290+
--sharding "stage1" \
291+
--data_parallel_config "enable_allreduce_avg_in_gradinent_scale gradient_sync_after_accumulate" \
292+
--sharding_parallel_config "" \
293+
--to_static 0 \
294+
--amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \
295+
--amp_custom_white_list "lookup_table" "lookup_table_v2" \
296+
--num_hidden_layers 4 \
297+
>>${log_path}/$FUNCNAME 2>&1
298+
loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
299+
ips=-1
300+
mem=-1
301+
echo "result: loss=$loss ips=$ips mem=$mem"
302+
303+
if [ -z "$flag" ]; then
304+
loss_base=9.23502579
305+
elif [ "$flag" = "FLAGS_fuse_allreduce_in_opt" ]; then
306+
loss_base=9.23502579
307+
elif [ "$flag" = "FLAGS_fuse_reducescatter_in_opt" ]; then
308+
loss_base=9.23504105
309+
elif [ "$flag" = "FLAGS_enable_tensor_fusion FLAGS_enable_sharding_overlap" ]; then
310+
if [ $acc_step -eq 1 ]; then
311+
loss_base=9.23504868
312+
else
313+
loss_base=9.16484451
314+
fi
315+
else
316+
loss_base=-1
317+
fi
318+
319+
ips_base=-1
320+
mem_base=-1
321+
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
322+
done
314323

315324
if [ -n "$flag" ]; then
316325
for f in $flag; do

0 commit comments

Comments
 (0)