Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #142

Merged
merged 6 commits into from
Aug 2, 2023
Merged

Dev #142

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions build_domain_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--in_file', default='data/pretrain/tianlongbabu.txt', type=str)
parser.add_argument('--in_file', default='data/pretrain/fever.txt', type=str)
parser.add_argument('--domain_sp_model_name', default='domain_sp', type=str)
parser.add_argument('--max_sentence_length', default=16384, type=int)
parser.add_argument('--pad_id', default=3, type=int)
parser.add_argument('--vocab_size', default=10000, type=int)
parser.add_argument('--vocab_size', default=2236, type=int)
parser.add_argument('--model_type', default="BPE", type=str)

args = parser.parse_args()
Expand Down Expand Up @@ -47,7 +47,7 @@ def main():
sp.load(model_file)

# encode: text => id
print(sp.encode_as_pieces('慕容复来到河边,this is a test'))
print(sp.encode_as_pieces('潜伏性感染又称潜在性感染。慕容复来到河边,this is a test'))
print(sp.encode_as_ids('this is a test'))

# decode: id => text
Expand Down
1,000 changes: 1,000 additions & 0 deletions data/finetune/medical_sft_1K_format.jsonl

Large diffs are not rendered by default.

996 changes: 996 additions & 0 deletions data/pretrain/fever.txt

Large diffs are not rendered by default.

38 changes: 20 additions & 18 deletions merge_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,18 @@ def load_jieba_vocab(jieba_vocab_file):

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--llama_tokenizer_dir', default=None, type=str, required=True)
parser.add_argument('--base_tokenizer_dir', default=None, type=str, required=True)
parser.add_argument('--domain_sp_model_file', default='./domain_sp.model', type=str)
parser.add_argument('--baichuan_vocab_file', default="data/vocab/baichuan_vocab.txt", type=str)
parser.add_argument('--add_jieba', action='store_true', help='Whether to add jieba vocab.')
parser.add_argument('--jieba_word_freq_file', default='data/vocab/word_freq.txt', type=str)
parser.add_argument('--jieba_word_size', default=20000, type=int)

args = parser.parse_args()
print(args)

# load
llama_tokenizer = LlamaTokenizer.from_pretrained(args.llama_tokenizer_dir)
llama_tokenizer = LlamaTokenizer.from_pretrained(args.base_tokenizer_dir)
chinese_sp_model = spm.SentencePieceProcessor()
chinese_sp_model.Load(args.domain_sp_model_file)

Expand Down Expand Up @@ -100,21 +101,22 @@ def main():
added_set.add(piece)
print(f"[add baichuan tokens]New model pieces: {len(llama_spm.pieces)}")

word_freqs = load_jieba_vocab(args.jieba_word_freq_file)
top_words = word_freqs[:args.jieba_word_size]
print('jieba top10 freq words:', top_words[:10])
jieba_vocab_set = set([i[0] for i in top_words if i])
print('jieba_vocab_set size:', len(jieba_vocab_set))
print('jieba_vocab head:', list(jieba_vocab_set)[:3])
for p in jieba_vocab_set:
piece = p
if piece not in llama_spm_tokens_set and piece not in added_set:
# print('jieba picec', piece)
new_p = sp_pb2_model.ModelProto().SentencePiece()
new_p.piece = piece
new_p.score = 0
llama_spm.pieces.append(new_p)
print(f"[add jieba tokens]New model pieces: {len(llama_spm.pieces)}")
if args.add_jieba:
word_freqs = load_jieba_vocab(args.jieba_word_freq_file)
top_words = word_freqs[:args.jieba_word_size]
print('jieba top10 freq words:', top_words[:10])
jieba_vocab_set = set([i[0] for i in top_words if i])
print('jieba_vocab_set size:', len(jieba_vocab_set))
print('jieba_vocab head:', list(jieba_vocab_set)[:3])
for p in jieba_vocab_set:
piece = p
if piece not in llama_spm_tokens_set and piece not in added_set:
# print('jieba picec', piece)
new_p = sp_pb2_model.ModelProto().SentencePiece()
new_p.piece = piece
new_p.score = 0
llama_spm.pieces.append(new_p)
print(f"[add jieba tokens]New model pieces: {len(llama_spm.pieces)}")

# Save
output_sp_dir = 'merged_tokenizer_sp'
Expand All @@ -128,7 +130,7 @@ def main():
print(f"Chinese-LLaMA tokenizer has been saved to {output_hf_dir}")

# Test
llama_tokenizer = LlamaTokenizer.from_pretrained(args.llama_tokenizer_dir)
llama_tokenizer = LlamaTokenizer.from_pretrained(args.base_tokenizer_dir)
chinese_llama_tokenizer = LlamaTokenizer.from_pretrained(output_hf_dir)
print(chinese_llama_tokenizer.all_special_tokens)
print(chinese_llama_tokenizer.all_special_ids)
Expand Down
141 changes: 78 additions & 63 deletions pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,36 +359,10 @@ def main():
# Set seed before initializing model.
set_seed(training_args.seed)

# Load pretrained model and tokenizer
# Load tokenizer
if not model_args.model_type:
raise ValueError("Please specify a model_type, e.g. llama, chatglm, bloom, etc.")
config_class, model_class, tokenizer_class = MODEL_CLASSES[model_args.model_type]
if model_args.model_type and model_args.model_name_or_path:
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
model_args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0}

config = config_class.from_pretrained(
model_args.model_name_or_path,
torch_dtype=torch_dtype,
trust_remote_code=model_args.trust_remote_code,
cache_dir=model_args.cache_dir
)
model = model_class.from_pretrained(
model_args.model_name_or_path,
config=config,
load_in_8bit=model_args.load_in_8bit,
device_map=model_args.device_map,
trust_remote_code=model_args.trust_remote_code,
)
else:
raise ValueError(f"Error, model_name_or_path is None, Continue PT must be loaded from a pre-trained model")

tokenizer_kwargs = {
"cache_dir": model_args.cache_dir,
Expand All @@ -400,37 +374,6 @@ def main():
tokenizer_name_or_path = model_args.model_name_or_path
tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)

if training_args.use_peft:
if training_args.peft_path is not None:
logger.info(f"Peft from pre-trained model: {training_args.peft_path}")
model = PeftModel.from_pretrained(model, training_args.peft_path, is_trainable=True)
else:
logger.info("Init new peft model")
target_modules = training_args.target_modules.split(',') if training_args.target_modules else None
if target_modules and 'all' in target_modules:
target_modules = find_all_linear_names(model, int4=False, int8=model_args.load_in_8bit)
modules_to_save = training_args.modules_to_save
if modules_to_save is not None:
modules_to_save = modules_to_save.split(',')
logger.info(f"Peft target_modules: {target_modules}")
logger.info(f"Peft lora_rank: {training_args.lora_rank}")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=target_modules,
inference_mode=False,
r=training_args.lora_rank,
lora_alpha=training_args.lora_alpha,
lora_dropout=training_args.lora_dropout,
modules_to_save=modules_to_save)
model = get_peft_model(model, peft_config)
if model_args.load_in_8bit:
model = prepare_model_for_int8_training(model)
model.print_trainable_parameters()
else:
logger.info("Full parameters training")
model = model.float()
print_trainable_parameters(model)

# Preprocessing the datasets.
def tokenize_function(examples):
return tokenizer(examples["text"])
Expand Down Expand Up @@ -505,21 +448,34 @@ def group_texts(examples):
data_files = {}
dataset_args = {}
if data_args.train_file_dir is not None and os.path.exists(data_args.train_file_dir):
train_data_files = glob(f'{data_args.train_file_dir}/**/*.txt', recursive=True)
logger.info(f"train files: {', '.join(train_data_files)}")
train_data_files = glob(f'{data_args.train_file_dir}/**/*.txt', recursive=True) + glob(
f'{data_args.train_file_dir}/**/*.json', recursive=True) + glob(
f'{data_args.train_file_dir}/**/*.jsonl', recursive=True)
logger.info(f"train files: {train_data_files}")
# Train data files must be same type, e.g. all txt or all jsonl
types = [f.split('.')[-1] for f in train_data_files]
if len(set(types)) > 1:
raise ValueError(f"train files must be same type, e.g. all txt or all jsonl, but got {types}")
data_files["train"] = train_data_files
if data_args.validation_file_dir is not None and os.path.exists(data_args.validation_file_dir):
eval_data_files = glob(f'{data_args.validation_file_dir}/**/*.txt', recursive=True)
logger.info(f"eval files: {', '.join(eval_data_files)}")
eval_data_files = glob(f'{data_args.train_file_dir}/**/*.txt', recursive=True) + glob(
f'{data_args.train_file_dir}/**/*.json', recursive=True) + glob(
f'{data_args.train_file_dir}/**/*.jsonl', recursive=True)
logger.info(f"eval files: {eval_data_files}")
data_files["validation"] = eval_data_files
extension = "text"
# Train data files must be same type, e.g. all txt or all jsonl
types = [f.split('.')[-1] for f in eval_data_files]
if len(set(types)) > 1:
raise ValueError(f"train files must be same type, e.g. all txt or all jsonl, but got {types}")
extension = "text" if data_files["train"][0].endswith('txt') else 'json'
dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
raw_datasets = load_dataset(
extension,
data_files=data_files,
cache_dir=model_args.cache_dir,
**dataset_args,
)

# If no validation data is there, validation_split_percentage will be used to divide the dataset.
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
Expand Down Expand Up @@ -600,6 +556,65 @@ def group_texts(examples):
logger.debug("Tokenized eval example:")
logger.debug(tokenizer.decode(eval_dataset[0]['input_ids']))

# Load model
if model_args.model_type and model_args.model_name_or_path:
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
model_args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0}

config = config_class.from_pretrained(
model_args.model_name_or_path,
torch_dtype=torch_dtype,
trust_remote_code=model_args.trust_remote_code,
cache_dir=model_args.cache_dir
)
model = model_class.from_pretrained(
model_args.model_name_or_path,
config=config,
load_in_8bit=model_args.load_in_8bit,
device_map=model_args.device_map,
trust_remote_code=model_args.trust_remote_code,
)
else:
raise ValueError(f"Error, model_name_or_path is None, Continue PT must be loaded from a pre-trained model")

if training_args.use_peft:
if training_args.peft_path is not None:
logger.info(f"Peft from pre-trained model: {training_args.peft_path}")
model = PeftModel.from_pretrained(model, training_args.peft_path, is_trainable=True)
else:
logger.info("Init new peft model")
target_modules = training_args.target_modules.split(',') if training_args.target_modules else None
if target_modules and 'all' in target_modules:
target_modules = find_all_linear_names(model, int4=False, int8=model_args.load_in_8bit)
modules_to_save = training_args.modules_to_save
if modules_to_save is not None:
modules_to_save = modules_to_save.split(',')
logger.info(f"Peft target_modules: {target_modules}")
logger.info(f"Peft lora_rank: {training_args.lora_rank}")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=target_modules,
inference_mode=False,
r=training_args.lora_rank,
lora_alpha=training_args.lora_alpha,
lora_dropout=training_args.lora_dropout,
modules_to_save=modules_to_save)
model = get_peft_model(model, peft_config)
if model_args.load_in_8bit:
model = prepare_model_for_int8_training(model)
model.print_trainable_parameters()
else:
logger.info("Full parameters training")
model = model.float()
print_trainable_parameters(model)

# Initialize our Trainer
if training_args.gradient_checkpointing:
model.gradient_checkpointing_enable()
Expand Down
2 changes: 1 addition & 1 deletion run_sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 supervised_finetuning.py \
--save_total_limit 3 \
--gradient_accumulation_steps 1 \
--preprocessing_num_workers 1 \
--model_max_length 512 \
--model_max_length 534 \
--output_dir outputs-sft-v1 \
--overwrite_output_dir \
--ddp_timeout 30000 \
Expand Down
29 changes: 15 additions & 14 deletions supervised_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ def main():
if not tokenizer_name_or_path:
tokenizer_name_or_path = model_args.model_name_or_path
tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
# tokenizer.padding_side = "right" # set padding side to the right, equal to label's -100 padding side
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = 0 # set as the <unk> token

Expand Down Expand Up @@ -720,48 +721,43 @@ def preprocess_function(examples):
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
add_special_tokens=False
).input_ids
targets = input_ids.clone()

# Mask targets. Only compute loss on the assistant outputs.
sep = conv.sep + conv.roles[1] + ": "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())

turns = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
cur_len = 0
for i, turn in enumerate(turns):
if turn == "":
break
turn_len = len(tokenizer(turn).input_ids)

turn_len = len(tokenizer(turn, add_special_tokens=False).input_ids) + 1 # 1 is </s> token at the end
parts = turn.split(sep)
if len(parts) != 2:
break
parts[0] += sep
instruction_len = len(tokenizer(parts[0]).input_ids)
if model_args.model_type in ['llama']:
# "-2" is hardcoded for the LLaMA tokenizer to make the offset correct.
instruction_len = instruction_len - 2

instruction_len = len(tokenizer(parts[0], add_special_tokens=False).input_ids) - 1
# Ignore the user instructions
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
cur_len += turn_len

target[cur_len:] = IGNORE_INDEX

if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
logger.warning(f"tokenization mismatch: {cur_len} vs. {total_len}. (ignored)")

return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
)

def filter_empty_labels(example):
"""Remove empty labels dataset."""
return not all(label == IGNORE_INDEX for label in example["labels"])

train_dataset = None
max_train_samples = 0
if training_args.do_train:
Expand All @@ -782,6 +778,7 @@ def preprocess_function(examples):
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset",
)
train_dataset = train_dataset.filter(filter_empty_labels)
logger.debug(f"Num train_samples: {len(train_dataset)}")
logger.debug("Tokenized training example:")
logger.debug(tokenizer.decode(train_dataset[0]['input_ids']))
Expand All @@ -806,6 +803,7 @@ def preprocess_function(examples):
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset",
)
eval_dataset = eval_dataset.filter(filter_empty_labels)
logger.debug(f"Num eval_samples: {len(eval_dataset)}")
logger.debug("Tokenized eval example:")
logger.debug(tokenizer.decode(eval_dataset[0]['input_ids']))
Expand Down Expand Up @@ -904,7 +902,10 @@ def preprocess_function(examples):
logger.info("*** Train ***")
sample = next(iter(trainer.get_train_dataloader()))
logger.debug(f"Train dataloader example: {sample}")
logger.debug(f"Details: \ninput_ids: {list(sample['input_ids'])}, \nlabels: {list(sample['labels'])}")
logger.debug(f"Detail input_ids: {list(sample['input_ids'])[:3]}, \nlabels: {list(sample['labels'])[:3]}")
logger.debug(f"Decode input_ids[0]: {tokenizer.decode(sample['input_ids'][0])}")
replaced_labels = [label if label != IGNORE_INDEX else tokenizer.pad_token_id for label in sample['labels'][0]]
logger.debug(f"Decode labels[0]: {tokenizer.decode(replaced_labels)}")
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
Expand Down