Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Training Guidelines #150

Merged
merged 16 commits into from
Aug 24, 2023
Merged
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@
问题4:为什么不对模型做全量预训练而是用LoRA?
问题5:二代模型支不支持某些支持一代LLaMA的工具?
问题6:Chinese-Alpaca-2是Llama-2-Chat训练得到的吗?
问题7:为什么24G显存微调chinese-alpaca-2-7b OOM?
```


Expand Down
1 change: 1 addition & 0 deletions README_EN.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ Question 3: Do you accept third-party Pull Requests?
Question 4: Why not perform full pre-training but use LoRA instead?
Question 5: Does Llama-2 series support tools that support the first-gen LLaMA?
Question 6: Is Chinese-Alpaca-2 trained from Llama-2-Chat?
Question 7: Why does training with 24GB VRAM lead to an OOM error when fine-tuning chinese-alpaca-2-7b?
```

For specific questions and answers, please refer to the project >>> [📚 GitHub Wiki](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/faq_en)
Expand Down
5 changes: 3 additions & 2 deletions scripts/inference/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
from attn_and_long_ctx_patches import apply_attention_patch, apply_ntk_scaling_patch
apply_attention_patch(use_memory_efficient_attention=True)
if not args.only_cpu:
apply_attention_patch(use_memory_efficient_attention=True)
apply_ntk_scaling_patch(args.alpha)

# Set CUDA devices if available
Expand Down Expand Up @@ -192,7 +193,7 @@ def setup():
args.lora_model,
torch_dtype=load_type,
device_map='auto',
)
).half()
else:
model = base_model

Expand Down
5 changes: 3 additions & 2 deletions scripts/inference/inference_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
from attn_and_long_ctx_patches import apply_attention_patch, apply_ntk_scaling_patch
apply_attention_patch(use_memory_efficient_attention=True)
if not args.only_cpu:
apply_attention_patch(use_memory_efficient_attention=True)
apply_ntk_scaling_patch(args.alpha)

if args.use_vllm:
Expand Down Expand Up @@ -131,7 +132,7 @@ def generate_prompt(instruction, system_prompt=DEFAULT_SYSTEM_PROMPT):
base_model.resize_token_embeddings(tokenizer_vocab_size)
if args.lora_model is not None:
print("loading peft model")
model = PeftModel.from_pretrained(base_model, args.lora_model,torch_dtype=load_type,device_map='auto',)
model = PeftModel.from_pretrained(base_model, args.lora_model,torch_dtype=load_type,device_map='auto',).half()
else:
model = base_model

Expand Down
2 changes: 1 addition & 1 deletion scripts/training/build_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def tokenization(examples):

if data_cache_dir is None:
data_cache_dir = str(os.path.dirname(file))
cache_path = os.path.join(data_cache_dir,os.path.basename(file).split('.')[0])
cache_path = os.path.join(data_cache_dir,os.path.basename(file).split('.')[0]+f"_{max_seq_length}")
os.makedirs(cache_path, exist_ok=True)
try:
processed_dataset = datasets.load_from_disk(cache_path)
Expand Down
45 changes: 21 additions & 24 deletions scripts/training/run_clm_pt_with_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,13 +467,13 @@ def group_texts(examples):
for idx, file in enumerate(files):
data_file = os.path.join(path, file)
filename = ''.join(file.split(".")[:-1])
cache_path = os.path.join(data_args.data_cache_dir, filename)
cache_path = os.path.join(data_args.data_cache_dir, filename+f"_{block_size}")
os.makedirs(cache_path, exist_ok=True)
try:
processed_dataset = datasets.load_from_disk(cache_path, keep_in_memory=False)
logger.info(f'training datasets-{filename} has been loaded from disk')
except Exception:
cache_dir = os.path.join(data_args.data_cache_dir, filename+"_text")
cache_dir = os.path.join(data_args.data_cache_dir, filename+f"_text_{block_size}")
os.makedirs(cache_dir, exist_ok=True)
raw_dataset = load_dataset("text", data_files=data_file, cache_dir=cache_dir, keep_in_memory=False)
logger.info(f"{file} has been loaded")
Expand Down Expand Up @@ -503,7 +503,6 @@ def group_texts(examples):
else:
assert lm_datasets.features.type == processed_dataset["train"].features.type
lm_datasets = concatenate_datasets([lm_datasets, processed_dataset["train"]])

lm_datasets = lm_datasets.train_test_split(test_size = data_args.validation_split_percentage)

if training_args.do_train:
Expand All @@ -522,26 +521,24 @@ def group_texts(examples):
logger.info(f"Num eval_samples {len(eval_dataset)}")
logger.info("Evaluation example:")
logger.info(tokenizer.decode(eval_dataset[0]['input_ids']))
if model_args.model_name_or_path:
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
model = LlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True
)
else:
model = AutoModelForCausalLM.from_config(config)
n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
device_map = {"":int(os.environ.get("LOCAL_RANK") or 0)}
model = LlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
device_map=device_map
)
model.config.use_cache = False

model_vocab_size = model.get_output_embeddings().weight.size(0)
tokenizer_vocab_size = len(tokenizer)
Expand All @@ -555,7 +552,7 @@ def group_texts(examples):

if training_args.peft_path is not None:
logger.info("Peft from pre-trained model")
model = PeftModel.from_pretrained(model, training_args.peft_path)
model = PeftModel.from_pretrained(model, training_args.peft_path, device_map=device_map)
else:
logger.info("Init new peft model")
target_modules = training_args.trainable.split(',')
Expand Down
43 changes: 20 additions & 23 deletions scripts/training/run_clm_sft_with_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
from peft import LoraConfig, TaskType, get_peft_model, PeftModel, get_peft_model_state_dict
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR

IGNORE_INDEX = -100

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

Expand Down Expand Up @@ -295,7 +294,7 @@ def main():
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)

if (len(tokenizer))!=55296:
if (len(tokenizer)) != 55296:
raise ValueError(f"The vocab size of the tokenizer should be 55296, but found {len(tokenizer)}.\n"
"Please use Chinese-LLaMA-2 tokenizer.")

Expand Down Expand Up @@ -331,26 +330,24 @@ def main():
logger.info("Evaluation example:")
logger.info(tokenizer.decode(eval_dataset[0]['input_ids']))

if model_args.model_name_or_path:
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
model = LlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True
)
else:
model = AutoModelForCausalLM.from_config(config)
n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
device_map = {"":int(os.environ.get("LOCAL_RANK") or 0)}
model = LlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
device_map=device_map
)
model.config.use_cache = False

model_vocab_size = model.get_input_embeddings().weight.shape[0]
logger.info(f"Model vocab size: {model_vocab_size}")
Expand All @@ -361,7 +358,7 @@ def main():

if training_args.peft_path is not None:
logger.info("Peft from pre-trained model")
model = PeftModel.from_pretrained(model, training_args.peft_path)
model = PeftModel.from_pretrained(model, training_args.peft_path, device_map=device_map)
else:
logger.info("Init new peft model")
target_modules = training_args.trainable.split(',')
Expand Down
15 changes: 7 additions & 8 deletions scripts/training/run_pt.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# 运行脚本前请仔细阅读wiki(https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/pt_scripts_zh)
# Read the wiki(https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/pt_scripts_zh) carefully before running the script
lr=2e-4
lora_rank=64
lora_alpha=128
Expand All @@ -6,12 +8,12 @@ modules_to_save="embed_tokens,lm_head"
lora_dropout=0.05

pretrained_model=path/to/hf/llama-2/dir
chinese_tokenizer_path=path/to/chinese/llama-2/tokenizer/dir
chinese_tokenizer_path=path/to/chinese-llama-2/tokenizer/dir
dataset_dir=path/to/pt/data/dir
data_cache=temp_data_cache_dir
per_device_train_batch_size=1
per_device_eval_batch_size=1
gradient_accumulation_steps=8
block_size=512
output_dir=output_dir

deepspeed_config_file=ds_zero2_no_offload.json
Expand All @@ -24,7 +26,6 @@ torchrun --nnodes 1 --nproc_per_node 1 run_clm_pt_with_peft.py \
--data_cache_dir ${data_cache} \
--validation_split_percentage 0.001 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--do_train \
--seed $RANDOM \
--fp16 \
Expand All @@ -40,16 +41,14 @@ torchrun --nnodes 1 --nproc_per_node 1 run_clm_pt_with_peft.py \
--save_steps 200 \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--preprocessing_num_workers 8 \
--block_size 1024 \
--block_size ${block_size} \
--output_dir ${output_dir} \
--overwrite_output_dir \
--ddp_timeout 30000 \
--logging_first_step True \
--lora_rank ${lora_rank} \
--lora_alpha ${lora_alpha} \
--trainable ${lora_trainable} \
--modules_to_save ${modules_to_save} \
--lora_dropout ${lora_dropout} \
--torch_dtype float16 \
--gradient_checkpointing \
--ddp_find_unused_parameters False
--modules_to_save ${modules_to_save} \
--torch_dtype float16
18 changes: 8 additions & 10 deletions scripts/training/run_sft.sh
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
# 运行脚本前请仔细阅读wiki(https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/sft_scripts_zh)
# Read the wiki(https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/sft_scripts_zh) carefully before running the script
lr=1e-4
lora_rank=64
lora_alpha=128
lora_trainable="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj"
modules_to_save="embed_tokens,lm_head"
lora_dropout=0.05

pretrained_model=path/to/hf/llama-2/or/merged/llama-2/dir/or/model_id
chinese_tokenizer_path=path/to/chinese/llama-2/tokenizer/dir
pretrained_model=path/to/hf/llama-2/or/chinese-llama-2/dir/or/model_id
chinese_tokenizer_path=path/to/chinese-llama-2/tokenizer/dir
dataset_dir=path/to/sft/data/dir
per_device_train_batch_size=1
per_device_eval_batch_size=1
gradient_accumulation_steps=8
max_seq_length=512
output_dir=output_dir
peft_model=path/to/peft/model/dir
validation_file=validation_file_name

deepspeed_config_file=ds_zero2_no_offload.json
Expand All @@ -22,7 +24,6 @@ torchrun --nnodes 1 --nproc_per_node 1 run_clm_sft_with_peft.py \
--model_name_or_path ${pretrained_model} \
--tokenizer_name_or_path ${chinese_tokenizer_path} \
--dataset_dir ${dataset_dir} \
--validation_split_percentage 0.001 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--do_train \
Expand All @@ -43,18 +44,15 @@ torchrun --nnodes 1 --nproc_per_node 1 run_clm_sft_with_peft.py \
--save_steps 200 \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--preprocessing_num_workers 8 \
--max_seq_length 1024 \
--max_seq_length ${max_seq_length} \
--output_dir ${output_dir} \
--overwrite_output_dir \
--ddp_timeout 30000 \
--logging_first_step True \
--lora_rank ${lora_rank} \
--lora_alpha ${lora_alpha} \
--trainable ${lora_trainable} \
--modules_to_save ${modules_to_save} \
--lora_dropout ${lora_dropout} \
--modules_to_save ${modules_to_save} \
--torch_dtype float16 \
--validation_file ${validation_file} \
--peft_path ${peft_model} \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please explain more on the modifications?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

peft_path与lora相关训练参数互斥,因此以设置lora可训练参数为例。

--gradient_checkpointing \
--ddp_find_unused_parameters False
--validation_file ${validation_file}