diff --git a/README.md b/README.md index 3aa8f48..7b7b083 100644 --- a/README.md +++ b/README.md @@ -245,6 +245,7 @@ 问题4:为什么不对模型做全量预训练而是用LoRA? 问题5:二代模型支不支持某些支持一代LLaMA的工具? 问题6:Chinese-Alpaca-2是Llama-2-Chat训练得到的吗? +问题7:为什么24G显存微调chinese-alpaca-2-7b OOM? ``` diff --git a/README_EN.md b/README_EN.md index 60a3a27..51cee13 100644 --- a/README_EN.md +++ b/README_EN.md @@ -228,6 +228,7 @@ Question 3: Do you accept third-party Pull Requests? Question 4: Why not perform full pre-training but use LoRA instead? Question 5: Does Llama-2 series support tools that support the first-gen LLaMA? Question 6: Is Chinese-Alpaca-2 trained from Llama-2-Chat? +Question 7: Why does training with 24GB VRAM lead to an OOM error when fine-tuning chinese-alpaca-2-7b? ``` For specific questions and answers, please refer to the project >>> [📚 GitHub Wiki](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/faq_en) diff --git a/scripts/inference/gradio_demo.py b/scripts/inference/gradio_demo.py index b64af6e..d98373e 100644 --- a/scripts/inference/gradio_demo.py +++ b/scripts/inference/gradio_demo.py @@ -109,7 +109,8 @@ parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(parent_dir) from attn_and_long_ctx_patches import apply_attention_patch, apply_ntk_scaling_patch -apply_attention_patch(use_memory_efficient_attention=True) +if not args.only_cpu: + apply_attention_patch(use_memory_efficient_attention=True) apply_ntk_scaling_patch(args.alpha) # Set CUDA devices if available @@ -192,7 +193,7 @@ def setup(): args.lora_model, torch_dtype=load_type, device_map='auto', - ) + ).half() else: model = base_model diff --git a/scripts/inference/inference_hf.py b/scripts/inference/inference_hf.py index a1acb7d..a7d7295 100644 --- a/scripts/inference/inference_hf.py +++ b/scripts/inference/inference_hf.py @@ -63,7 +63,8 @@ parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(parent_dir) from attn_and_long_ctx_patches import apply_attention_patch, apply_ntk_scaling_patch -apply_attention_patch(use_memory_efficient_attention=True) +if not args.only_cpu: + apply_attention_patch(use_memory_efficient_attention=True) apply_ntk_scaling_patch(args.alpha) if args.use_vllm: @@ -131,7 +132,7 @@ def generate_prompt(instruction, system_prompt=DEFAULT_SYSTEM_PROMPT): base_model.resize_token_embeddings(tokenizer_vocab_size) if args.lora_model is not None: print("loading peft model") - model = PeftModel.from_pretrained(base_model, args.lora_model,torch_dtype=load_type,device_map='auto',) + model = PeftModel.from_pretrained(base_model, args.lora_model,torch_dtype=load_type,device_map='auto',).half() else: model = base_model diff --git a/scripts/training/build_dataset.py b/scripts/training/build_dataset.py index 953a6d2..9fd1fdb 100644 --- a/scripts/training/build_dataset.py +++ b/scripts/training/build_dataset.py @@ -62,7 +62,7 @@ def tokenization(examples): if data_cache_dir is None: data_cache_dir = str(os.path.dirname(file)) - cache_path = os.path.join(data_cache_dir,os.path.basename(file).split('.')[0]) + cache_path = os.path.join(data_cache_dir,os.path.basename(file).split('.')[0]+f"_{max_seq_length}") os.makedirs(cache_path, exist_ok=True) try: processed_dataset = datasets.load_from_disk(cache_path) diff --git a/scripts/training/run_clm_pt_with_peft.py b/scripts/training/run_clm_pt_with_peft.py index 96f403c..cd36b7a 100644 --- a/scripts/training/run_clm_pt_with_peft.py +++ b/scripts/training/run_clm_pt_with_peft.py @@ -467,13 +467,13 @@ def group_texts(examples): for idx, file in enumerate(files): data_file = os.path.join(path, file) filename = ''.join(file.split(".")[:-1]) - cache_path = os.path.join(data_args.data_cache_dir, filename) + cache_path = os.path.join(data_args.data_cache_dir, filename+f"_{block_size}") os.makedirs(cache_path, exist_ok=True) try: processed_dataset = datasets.load_from_disk(cache_path, keep_in_memory=False) logger.info(f'training datasets-{filename} has been loaded from disk') except Exception: - cache_dir = os.path.join(data_args.data_cache_dir, filename+"_text") + cache_dir = os.path.join(data_args.data_cache_dir, filename+f"_text_{block_size}") os.makedirs(cache_dir, exist_ok=True) raw_dataset = load_dataset("text", data_files=data_file, cache_dir=cache_dir, keep_in_memory=False) logger.info(f"{file} has been loaded") @@ -503,7 +503,6 @@ def group_texts(examples): else: assert lm_datasets.features.type == processed_dataset["train"].features.type lm_datasets = concatenate_datasets([lm_datasets, processed_dataset["train"]]) - lm_datasets = lm_datasets.train_test_split(test_size = data_args.validation_split_percentage) if training_args.do_train: @@ -522,26 +521,24 @@ def group_texts(examples): logger.info(f"Num eval_samples {len(eval_dataset)}") logger.info("Evaluation example:") logger.info(tokenizer.decode(eval_dataset[0]['input_ids'])) - if model_args.model_name_or_path: - torch_dtype = ( - model_args.torch_dtype - if model_args.torch_dtype in ["auto", None] - else getattr(torch, model_args.torch_dtype) - ) - model = LlamaForCausalLM.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - torch_dtype=torch_dtype, - low_cpu_mem_usage=True - ) - else: - model = AutoModelForCausalLM.from_config(config) - n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values()) - logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params") + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) + ) + device_map = {"":int(os.environ.get("LOCAL_RANK") or 0)} + model = LlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + torch_dtype=torch_dtype, + low_cpu_mem_usage=True, + device_map=device_map + ) + model.config.use_cache = False model_vocab_size = model.get_output_embeddings().weight.size(0) tokenizer_vocab_size = len(tokenizer) @@ -555,7 +552,7 @@ def group_texts(examples): if training_args.peft_path is not None: logger.info("Peft from pre-trained model") - model = PeftModel.from_pretrained(model, training_args.peft_path) + model = PeftModel.from_pretrained(model, training_args.peft_path, device_map=device_map) else: logger.info("Init new peft model") target_modules = training_args.trainable.split(',') diff --git a/scripts/training/run_clm_sft_with_peft.py b/scripts/training/run_clm_sft_with_peft.py index 4daf208..fea0879 100644 --- a/scripts/training/run_clm_sft_with_peft.py +++ b/scripts/training/run_clm_sft_with_peft.py @@ -51,7 +51,6 @@ from peft import LoraConfig, TaskType, get_peft_model, PeftModel, get_peft_model_state_dict from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR -IGNORE_INDEX = -100 require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") @@ -295,7 +294,7 @@ def main(): "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) - if (len(tokenizer))!=55296: + if (len(tokenizer)) != 55296: raise ValueError(f"The vocab size of the tokenizer should be 55296, but found {len(tokenizer)}.\n" "Please use Chinese-LLaMA-2 tokenizer.") @@ -331,26 +330,24 @@ def main(): logger.info("Evaluation example:") logger.info(tokenizer.decode(eval_dataset[0]['input_ids'])) - if model_args.model_name_or_path: - torch_dtype = ( - model_args.torch_dtype - if model_args.torch_dtype in ["auto", None] - else getattr(torch, model_args.torch_dtype) - ) - model = LlamaForCausalLM.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - torch_dtype=torch_dtype, - low_cpu_mem_usage=True - ) - else: - model = AutoModelForCausalLM.from_config(config) - n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values()) - logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params") + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) + ) + device_map = {"":int(os.environ.get("LOCAL_RANK") or 0)} + model = LlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + torch_dtype=torch_dtype, + low_cpu_mem_usage=True, + device_map=device_map + ) + model.config.use_cache = False model_vocab_size = model.get_input_embeddings().weight.shape[0] logger.info(f"Model vocab size: {model_vocab_size}") @@ -361,7 +358,7 @@ def main(): if training_args.peft_path is not None: logger.info("Peft from pre-trained model") - model = PeftModel.from_pretrained(model, training_args.peft_path) + model = PeftModel.from_pretrained(model, training_args.peft_path, device_map=device_map) else: logger.info("Init new peft model") target_modules = training_args.trainable.split(',') diff --git a/scripts/training/run_pt.sh b/scripts/training/run_pt.sh index b409eac..663a326 100644 --- a/scripts/training/run_pt.sh +++ b/scripts/training/run_pt.sh @@ -1,3 +1,5 @@ +# 运行脚本前请仔细阅读wiki(https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/pt_scripts_zh) +# Read the wiki(https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/pt_scripts_zh) carefully before running the script lr=2e-4 lora_rank=64 lora_alpha=128 @@ -6,12 +8,12 @@ modules_to_save="embed_tokens,lm_head" lora_dropout=0.05 pretrained_model=path/to/hf/llama-2/dir -chinese_tokenizer_path=path/to/chinese/llama-2/tokenizer/dir +chinese_tokenizer_path=path/to/chinese-llama-2/tokenizer/dir dataset_dir=path/to/pt/data/dir data_cache=temp_data_cache_dir per_device_train_batch_size=1 -per_device_eval_batch_size=1 gradient_accumulation_steps=8 +block_size=512 output_dir=output_dir deepspeed_config_file=ds_zero2_no_offload.json @@ -24,7 +26,6 @@ torchrun --nnodes 1 --nproc_per_node 1 run_clm_pt_with_peft.py \ --data_cache_dir ${data_cache} \ --validation_split_percentage 0.001 \ --per_device_train_batch_size ${per_device_train_batch_size} \ - --per_device_eval_batch_size ${per_device_eval_batch_size} \ --do_train \ --seed $RANDOM \ --fp16 \ @@ -40,7 +41,7 @@ torchrun --nnodes 1 --nproc_per_node 1 run_clm_pt_with_peft.py \ --save_steps 200 \ --gradient_accumulation_steps ${gradient_accumulation_steps} \ --preprocessing_num_workers 8 \ - --block_size 1024 \ + --block_size ${block_size} \ --output_dir ${output_dir} \ --overwrite_output_dir \ --ddp_timeout 30000 \ @@ -48,8 +49,6 @@ torchrun --nnodes 1 --nproc_per_node 1 run_clm_pt_with_peft.py \ --lora_rank ${lora_rank} \ --lora_alpha ${lora_alpha} \ --trainable ${lora_trainable} \ - --modules_to_save ${modules_to_save} \ --lora_dropout ${lora_dropout} \ - --torch_dtype float16 \ - --gradient_checkpointing \ - --ddp_find_unused_parameters False + --modules_to_save ${modules_to_save} \ + --torch_dtype float16 diff --git a/scripts/training/run_sft.sh b/scripts/training/run_sft.sh index 0c31a8b..a74986d 100644 --- a/scripts/training/run_sft.sh +++ b/scripts/training/run_sft.sh @@ -1,3 +1,5 @@ +# 运行脚本前请仔细阅读wiki(https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/sft_scripts_zh) +# Read the wiki(https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/sft_scripts_zh) carefully before running the script lr=1e-4 lora_rank=64 lora_alpha=128 @@ -5,14 +7,14 @@ lora_trainable="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj" modules_to_save="embed_tokens,lm_head" lora_dropout=0.05 -pretrained_model=path/to/hf/llama-2/or/merged/llama-2/dir/or/model_id -chinese_tokenizer_path=path/to/chinese/llama-2/tokenizer/dir +pretrained_model=path/to/hf/llama-2/or/chinese-llama-2/dir/or/model_id +chinese_tokenizer_path=path/to/chinese-llama-2/tokenizer/dir dataset_dir=path/to/sft/data/dir per_device_train_batch_size=1 per_device_eval_batch_size=1 gradient_accumulation_steps=8 +max_seq_length=512 output_dir=output_dir -peft_model=path/to/peft/model/dir validation_file=validation_file_name deepspeed_config_file=ds_zero2_no_offload.json @@ -22,7 +24,6 @@ torchrun --nnodes 1 --nproc_per_node 1 run_clm_sft_with_peft.py \ --model_name_or_path ${pretrained_model} \ --tokenizer_name_or_path ${chinese_tokenizer_path} \ --dataset_dir ${dataset_dir} \ - --validation_split_percentage 0.001 \ --per_device_train_batch_size ${per_device_train_batch_size} \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ --do_train \ @@ -43,7 +44,7 @@ torchrun --nnodes 1 --nproc_per_node 1 run_clm_sft_with_peft.py \ --save_steps 200 \ --gradient_accumulation_steps ${gradient_accumulation_steps} \ --preprocessing_num_workers 8 \ - --max_seq_length 1024 \ + --max_seq_length ${max_seq_length} \ --output_dir ${output_dir} \ --overwrite_output_dir \ --ddp_timeout 30000 \ @@ -51,10 +52,7 @@ torchrun --nnodes 1 --nproc_per_node 1 run_clm_sft_with_peft.py \ --lora_rank ${lora_rank} \ --lora_alpha ${lora_alpha} \ --trainable ${lora_trainable} \ - --modules_to_save ${modules_to_save} \ --lora_dropout ${lora_dropout} \ + --modules_to_save ${modules_to_save} \ --torch_dtype float16 \ - --validation_file ${validation_file} \ - --peft_path ${peft_model} \ - --gradient_checkpointing \ - --ddp_find_unused_parameters False + --validation_file ${validation_file}