Skip to content

Commit

Permalink
Merge pull request #142 from shibing624/dev
Browse files Browse the repository at this point in the history
support baichuan, bloom, chatglm2, llama2 SFT training
support multi round dataset finetune
add build domain tokenizer
  • Loading branch information
shibing624 committed Aug 2, 2023
2 parents 0308b61 + 8277609 commit 62d654b
Show file tree
Hide file tree
Showing 7 changed files with 2,113 additions and 99 deletions.
6 changes: 3 additions & 3 deletions build_domain_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--in_file', default='data/pretrain/tianlongbabu.txt', type=str)
parser.add_argument('--in_file', default='data/pretrain/fever.txt', type=str)
parser.add_argument('--domain_sp_model_name', default='domain_sp', type=str)
parser.add_argument('--max_sentence_length', default=16384, type=int)
parser.add_argument('--pad_id', default=3, type=int)
parser.add_argument('--vocab_size', default=10000, type=int)
parser.add_argument('--vocab_size', default=2236, type=int)
parser.add_argument('--model_type', default="BPE", type=str)

args = parser.parse_args()
Expand Down Expand Up @@ -47,7 +47,7 @@ def main():
sp.load(model_file)

# encode: text => id
print(sp.encode_as_pieces('慕容复来到河边,this is a test'))
print(sp.encode_as_pieces('潜伏性感染又称潜在性感染。慕容复来到河边,this is a test'))
print(sp.encode_as_ids('this is a test'))

# decode: id => text
Expand Down
1,000 changes: 1,000 additions & 0 deletions data/finetune/medical_sft_1K_format.jsonl

Large diffs are not rendered by default.

996 changes: 996 additions & 0 deletions data/pretrain/fever.txt

Large diffs are not rendered by default.

38 changes: 20 additions & 18 deletions merge_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,18 @@ def load_jieba_vocab(jieba_vocab_file):

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--llama_tokenizer_dir', default=None, type=str, required=True)
parser.add_argument('--base_tokenizer_dir', default=None, type=str, required=True)
parser.add_argument('--domain_sp_model_file', default='./domain_sp.model', type=str)
parser.add_argument('--baichuan_vocab_file', default="data/vocab/baichuan_vocab.txt", type=str)
parser.add_argument('--add_jieba', action='store_true', help='Whether to add jieba vocab.')
parser.add_argument('--jieba_word_freq_file', default='data/vocab/word_freq.txt', type=str)
parser.add_argument('--jieba_word_size', default=20000, type=int)

args = parser.parse_args()
print(args)

# load
llama_tokenizer = LlamaTokenizer.from_pretrained(args.llama_tokenizer_dir)
llama_tokenizer = LlamaTokenizer.from_pretrained(args.base_tokenizer_dir)
chinese_sp_model = spm.SentencePieceProcessor()
chinese_sp_model.Load(args.domain_sp_model_file)

Expand Down Expand Up @@ -100,21 +101,22 @@ def main():
added_set.add(piece)
print(f"[add baichuan tokens]New model pieces: {len(llama_spm.pieces)}")

word_freqs = load_jieba_vocab(args.jieba_word_freq_file)
top_words = word_freqs[:args.jieba_word_size]
print('jieba top10 freq words:', top_words[:10])
jieba_vocab_set = set([i[0] for i in top_words if i])
print('jieba_vocab_set size:', len(jieba_vocab_set))
print('jieba_vocab head:', list(jieba_vocab_set)[:3])
for p in jieba_vocab_set:
piece = p
if piece not in llama_spm_tokens_set and piece not in added_set:
# print('jieba picec', piece)
new_p = sp_pb2_model.ModelProto().SentencePiece()
new_p.piece = piece
new_p.score = 0
llama_spm.pieces.append(new_p)
print(f"[add jieba tokens]New model pieces: {len(llama_spm.pieces)}")
if args.add_jieba:
word_freqs = load_jieba_vocab(args.jieba_word_freq_file)
top_words = word_freqs[:args.jieba_word_size]
print('jieba top10 freq words:', top_words[:10])
jieba_vocab_set = set([i[0] for i in top_words if i])
print('jieba_vocab_set size:', len(jieba_vocab_set))
print('jieba_vocab head:', list(jieba_vocab_set)[:3])
for p in jieba_vocab_set:
piece = p
if piece not in llama_spm_tokens_set and piece not in added_set:
# print('jieba picec', piece)
new_p = sp_pb2_model.ModelProto().SentencePiece()
new_p.piece = piece
new_p.score = 0
llama_spm.pieces.append(new_p)
print(f"[add jieba tokens]New model pieces: {len(llama_spm.pieces)}")

# Save
output_sp_dir = 'merged_tokenizer_sp'
Expand All @@ -128,7 +130,7 @@ def main():
print(f"Chinese-LLaMA tokenizer has been saved to {output_hf_dir}")

# Test
llama_tokenizer = LlamaTokenizer.from_pretrained(args.llama_tokenizer_dir)
llama_tokenizer = LlamaTokenizer.from_pretrained(args.base_tokenizer_dir)
chinese_llama_tokenizer = LlamaTokenizer.from_pretrained(output_hf_dir)
print(chinese_llama_tokenizer.all_special_tokens)
print(chinese_llama_tokenizer.all_special_ids)
Expand Down
141 changes: 78 additions & 63 deletions pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,36 +359,10 @@ def main():
# Set seed before initializing model.
set_seed(training_args.seed)

# Load pretrained model and tokenizer
# Load tokenizer
if not model_args.model_type:
raise ValueError("Please specify a model_type, e.g. llama, chatglm, bloom, etc.")
config_class, model_class, tokenizer_class = MODEL_CLASSES[model_args.model_type]
if model_args.model_type and model_args.model_name_or_path:
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
model_args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0}

config = config_class.from_pretrained(
model_args.model_name_or_path,
torch_dtype=torch_dtype,
trust_remote_code=model_args.trust_remote_code,
cache_dir=model_args.cache_dir
)
model = model_class.from_pretrained(
model_args.model_name_or_path,
config=config,
load_in_8bit=model_args.load_in_8bit,
device_map=model_args.device_map,
trust_remote_code=model_args.trust_remote_code,
)
else:
raise ValueError(f"Error, model_name_or_path is None, Continue PT must be loaded from a pre-trained model")

tokenizer_kwargs = {
"cache_dir": model_args.cache_dir,
Expand All @@ -400,37 +374,6 @@ def main():
tokenizer_name_or_path = model_args.model_name_or_path
tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)

if training_args.use_peft:
if training_args.peft_path is not None:
logger.info(f"Peft from pre-trained model: {training_args.peft_path}")
model = PeftModel.from_pretrained(model, training_args.peft_path, is_trainable=True)
else:
logger.info("Init new peft model")
target_modules = training_args.target_modules.split(',') if training_args.target_modules else None
if target_modules and 'all' in target_modules:
target_modules = find_all_linear_names(model, int4=False, int8=model_args.load_in_8bit)
modules_to_save = training_args.modules_to_save
if modules_to_save is not None:
modules_to_save = modules_to_save.split(',')
logger.info(f"Peft target_modules: {target_modules}")
logger.info(f"Peft lora_rank: {training_args.lora_rank}")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=target_modules,
inference_mode=False,
r=training_args.lora_rank,
lora_alpha=training_args.lora_alpha,
lora_dropout=training_args.lora_dropout,
modules_to_save=modules_to_save)
model = get_peft_model(model, peft_config)
if model_args.load_in_8bit:
model = prepare_model_for_int8_training(model)
model.print_trainable_parameters()
else:
logger.info("Full parameters training")
model = model.float()
print_trainable_parameters(model)

# Preprocessing the datasets.
def tokenize_function(examples):
return tokenizer(examples["text"])
Expand Down Expand Up @@ -505,21 +448,34 @@ def group_texts(examples):
data_files = {}
dataset_args = {}
if data_args.train_file_dir is not None and os.path.exists(data_args.train_file_dir):
train_data_files = glob(f'{data_args.train_file_dir}/**/*.txt', recursive=True)
logger.info(f"train files: {', '.join(train_data_files)}")
train_data_files = glob(f'{data_args.train_file_dir}/**/*.txt', recursive=True) + glob(
f'{data_args.train_file_dir}/**/*.json', recursive=True) + glob(
f'{data_args.train_file_dir}/**/*.jsonl', recursive=True)
logger.info(f"train files: {train_data_files}")
# Train data files must be same type, e.g. all txt or all jsonl
types = [f.split('.')[-1] for f in train_data_files]
if len(set(types)) > 1:
raise ValueError(f"train files must be same type, e.g. all txt or all jsonl, but got {types}")
data_files["train"] = train_data_files
if data_args.validation_file_dir is not None and os.path.exists(data_args.validation_file_dir):
eval_data_files = glob(f'{data_args.validation_file_dir}/**/*.txt', recursive=True)
logger.info(f"eval files: {', '.join(eval_data_files)}")
eval_data_files = glob(f'{data_args.train_file_dir}/**/*.txt', recursive=True) + glob(
f'{data_args.train_file_dir}/**/*.json', recursive=True) + glob(
f'{data_args.train_file_dir}/**/*.jsonl', recursive=True)
logger.info(f"eval files: {eval_data_files}")
data_files["validation"] = eval_data_files
extension = "text"
# Train data files must be same type, e.g. all txt or all jsonl
types = [f.split('.')[-1] for f in eval_data_files]
if len(set(types)) > 1:
raise ValueError(f"train files must be same type, e.g. all txt or all jsonl, but got {types}")
extension = "text" if data_files["train"][0].endswith('txt') else 'json'
dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
raw_datasets = load_dataset(
extension,
data_files=data_files,
cache_dir=model_args.cache_dir,
**dataset_args,
)

# If no validation data is there, validation_split_percentage will be used to divide the dataset.
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
Expand Down Expand Up @@ -600,6 +556,65 @@ def group_texts(examples):
logger.debug("Tokenized eval example:")
logger.debug(tokenizer.decode(eval_dataset[0]['input_ids']))

# Load model
if model_args.model_type and model_args.model_name_or_path:
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
model_args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0}

config = config_class.from_pretrained(
model_args.model_name_or_path,
torch_dtype=torch_dtype,
trust_remote_code=model_args.trust_remote_code,
cache_dir=model_args.cache_dir
)
model = model_class.from_pretrained(
model_args.model_name_or_path,
config=config,
load_in_8bit=model_args.load_in_8bit,
device_map=model_args.device_map,
trust_remote_code=model_args.trust_remote_code,
)
else:
raise ValueError(f"Error, model_name_or_path is None, Continue PT must be loaded from a pre-trained model")

if training_args.use_peft:
if training_args.peft_path is not None:
logger.info(f"Peft from pre-trained model: {training_args.peft_path}")
model = PeftModel.from_pretrained(model, training_args.peft_path, is_trainable=True)
else:
logger.info("Init new peft model")
target_modules = training_args.target_modules.split(',') if training_args.target_modules else None
if target_modules and 'all' in target_modules:
target_modules = find_all_linear_names(model, int4=False, int8=model_args.load_in_8bit)
modules_to_save = training_args.modules_to_save
if modules_to_save is not None:
modules_to_save = modules_to_save.split(',')
logger.info(f"Peft target_modules: {target_modules}")
logger.info(f"Peft lora_rank: {training_args.lora_rank}")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=target_modules,
inference_mode=False,
r=training_args.lora_rank,
lora_alpha=training_args.lora_alpha,
lora_dropout=training_args.lora_dropout,
modules_to_save=modules_to_save)
model = get_peft_model(model, peft_config)
if model_args.load_in_8bit:
model = prepare_model_for_int8_training(model)
model.print_trainable_parameters()
else:
logger.info("Full parameters training")
model = model.float()
print_trainable_parameters(model)

# Initialize our Trainer
if training_args.gradient_checkpointing:
model.gradient_checkpointing_enable()
Expand Down
2 changes: 1 addition & 1 deletion run_sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 supervised_finetuning.py \
--save_total_limit 3 \
--gradient_accumulation_steps 1 \
--preprocessing_num_workers 1 \
--model_max_length 512 \
--model_max_length 534 \
--output_dir outputs-sft-v1 \
--overwrite_output_dir \
--ddp_timeout 30000 \
Expand Down
29 changes: 15 additions & 14 deletions supervised_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ def main():
if not tokenizer_name_or_path:
tokenizer_name_or_path = model_args.model_name_or_path
tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
# tokenizer.padding_side = "right" # set padding side to the right, equal to label's -100 padding side
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = 0 # set as the <unk> token

Expand Down Expand Up @@ -720,48 +721,43 @@ def preprocess_function(examples):
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
add_special_tokens=False
).input_ids
targets = input_ids.clone()

# Mask targets. Only compute loss on the assistant outputs.
sep = conv.sep + conv.roles[1] + ": "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())

turns = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
cur_len = 0
for i, turn in enumerate(turns):
if turn == "":
break
turn_len = len(tokenizer(turn).input_ids)

turn_len = len(tokenizer(turn, add_special_tokens=False).input_ids) + 1 # 1 is </s> token at the end
parts = turn.split(sep)
if len(parts) != 2:
break
parts[0] += sep
instruction_len = len(tokenizer(parts[0]).input_ids)
if model_args.model_type in ['llama']:
# "-2" is hardcoded for the LLaMA tokenizer to make the offset correct.
instruction_len = instruction_len - 2

instruction_len = len(tokenizer(parts[0], add_special_tokens=False).input_ids) - 1
# Ignore the user instructions
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
cur_len += turn_len

target[cur_len:] = IGNORE_INDEX

if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
logger.warning(f"tokenization mismatch: {cur_len} vs. {total_len}. (ignored)")

return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
)

def filter_empty_labels(example):
"""Remove empty labels dataset."""
return not all(label == IGNORE_INDEX for label in example["labels"])

train_dataset = None
max_train_samples = 0
if training_args.do_train:
Expand All @@ -782,6 +778,7 @@ def preprocess_function(examples):
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset",
)
train_dataset = train_dataset.filter(filter_empty_labels)
logger.debug(f"Num train_samples: {len(train_dataset)}")
logger.debug("Tokenized training example:")
logger.debug(tokenizer.decode(train_dataset[0]['input_ids']))
Expand All @@ -806,6 +803,7 @@ def preprocess_function(examples):
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset",
)
eval_dataset = eval_dataset.filter(filter_empty_labels)
logger.debug(f"Num eval_samples: {len(eval_dataset)}")
logger.debug("Tokenized eval example:")
logger.debug(tokenizer.decode(eval_dataset[0]['input_ids']))
Expand Down Expand Up @@ -904,7 +902,10 @@ def preprocess_function(examples):
logger.info("*** Train ***")
sample = next(iter(trainer.get_train_dataloader()))
logger.debug(f"Train dataloader example: {sample}")
logger.debug(f"Details: \ninput_ids: {list(sample['input_ids'])}, \nlabels: {list(sample['labels'])}")
logger.debug(f"Detail input_ids: {list(sample['input_ids'])[:3]}, \nlabels: {list(sample['labels'])[:3]}")
logger.debug(f"Decode input_ids[0]: {tokenizer.decode(sample['input_ids'][0])}")
replaced_labels = [label if label != IGNORE_INDEX else tokenizer.pad_token_id for label in sample['labels'][0]]
logger.debug(f"Decode labels[0]: {tokenizer.decode(replaced_labels)}")
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
Expand Down

0 comments on commit 62d654b

Please sign in to comment.