Skip to content

Commit

Permalink
Merge pull request #151 from shibing624/dev-round
Browse files Browse the repository at this point in the history
update tokenizer for multi round task
  • Loading branch information
shibing624 committed Aug 6, 2023
2 parents b920056 + 70fe182 commit 85d322e
Show file tree
Hide file tree
Showing 5 changed files with 306 additions and 473 deletions.
8 changes: 3 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Supervised Finetuning, Reward Modeling and Reinforcement Learning.

启动服务,命令如下:
```shell
python gradio_demo.py --model_type base_model_type --base_model path_to_llama_hf_dir --lora_model path_to_lora_dir
CUDA_VISIBLE_DEVICES=0 python gradio_demo.py --model_type base_model_type --base_model path_to_llama_hf_dir --lora_model path_to_lora_dir
```

参数说明:
Expand Down Expand Up @@ -137,7 +137,7 @@ baichuan:
训练完成后,现在我们加载训练好的模型,验证模型生成文本的效果。

```shell
python inference.py \
CUDA_VISIBLE_DEVICES=0 python inference.py \
--model_type base_model_type \
--base_model path_to_model_hf_dir \
--tokenizer_path path_to_model_hf_dir \
Expand All @@ -153,16 +153,14 @@ python inference.py \
- `--lora_model {lora_model}`:LoRA解压后文件所在目录,也可使用HF Model Hub模型调用名称。如果已经合并了LoRA权重到预训练模型,则可以不提供此参数
- `--tokenizer_path {tokenizer_path}`:存放对应tokenizer的目录。若不提供此参数,则其默认值与--base_model相同
- `--template_name`:模板名称,如`vicuna``alpaca`等。若不提供此参数,则其默认值是alpaca
- `--interactive`以交互方式启动,以便进行多次单轮问答
- `--interactive`以交互方式启动多轮问答,使用流式推理
- `--data_file {file_name}`:非交互方式启动下,按行读取file_name中的的内容进行预测
- `--predictions_file {file_name}`:非交互式方式下,将预测的结果以json格式写入file_name
- `--resize_emb`:是否调整embedding大小,若不调整,则使用预训练模型的embedding大小,默认不调整
- `--only_cpu`:仅使用CPU进行推理
- `--gpus {gpu_ids}`:指定使用的GPU设备编号,默认为0。如使用多张GPU,以逗号分隔,如0,1,2




#### Inference Examples
[shibing624/ziya-llama-13b-medical-merged](https://huggingface.co/shibing624/ziya-llama-13b-medical-merged) inference examples:

Expand Down
73 changes: 38 additions & 35 deletions gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""
import argparse
import os
from threading import Thread

import gradio as gr
import mdtex2html
Expand All @@ -21,8 +22,9 @@
BloomTokenizerFast,
LlamaTokenizer,
LlamaForCausalLM,
GenerationConfig,
TextIteratorStreamer,
)
from transformers.generation import GenerationConfig

from supervised_finetuning import get_conv_template

Expand All @@ -36,36 +38,34 @@


@torch.inference_mode()
def generate_answer(
def stream_generate_answer(
model,
tokenizer,
prompt,
device,
max_new_tokens=512,
temperature=0.7,
top_p=0.8,
repetition_penalty=1.0,
context_len=2048
):
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
input_ids = tokenizer(prompt).input_ids
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:]
generation_config = dict(
generation_kwargs = dict(
input_ids=torch.as_tensor([input_ids]).to(device),
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
streamer=streamer,
)
generation_output = model.generate(**generation_config)
output_ids = generation_output[0]
output = tokenizer.decode(output_ids, skip_special_tokens=False).strip()
stop_str = tokenizer.eos_token or "</s>"
l_prompt = len(tokenizer.decode(input_ids, skip_special_tokens=False))
pos = output.find(stop_str, l_prompt)
if pos != -1:
output = output[l_prompt:pos]
else:
output = output[l_prompt:]
return output

thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

yield from streamer


def main():
Expand Down Expand Up @@ -114,7 +114,10 @@ def postprocess(self, y):
device_map='auto',
trust_remote_code=True,
)
base_model.generation_config = GenerationConfig.from_pretrained(args.base_model, trust_remote_code=True)
try:
base_model.generation_config = GenerationConfig.from_pretrained(args.base_model, trust_remote_code=True)
except OSError:
print("Failed to load generation config, use default.")
if args.resize_emb:
model_vocab_size = base_model.get_input_embeddings().weight.size(0)
tokenzier_vocab_size = len(tokenizer)
Expand All @@ -139,7 +142,8 @@ def reset_user_input():
def reset_state():
return [], []

conv = get_conv_template(args.template_name)
prompt_template = get_conv_template(args.template_name)
history = []

def predict(
input,
Expand All @@ -152,24 +156,23 @@ def predict(
now_input = input
chatbot.append((input, ""))
history = history or []
conv.append_message(conv.roles[0], now_input)
conv.append_message(conv.roles[1], '')

prompt = conv.get_prompt()
output = generate_answer(
model,
tokenizer,
prompt,
device,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
)
output = output.strip()
conv.messages[-1][-1] = output
history.append((now_input, output))
chatbot[-1] = (now_input, output)
return chatbot, history
history.append([now_input, ''])

prompt = prompt_template.get_prompt(messages=history)
response = ""
for new_text in stream_generate_answer(
model,
tokenizer,
prompt,
device,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
):
response += new_text
new_history = history + [(now_input, response)]
chatbot[-1] = (now_input, response)
yield chatbot, new_history

with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">MedicalGPT</h1>""")
Expand All @@ -196,7 +199,7 @@ def predict(
show_progress=True)
submitBtn.click(reset_user_input, [], [user_input])
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
demo.queue().launch(share=False, inbrowser=True, server_name='0.0.0.0', server_port=8081)
demo.queue().launch(share=False, inbrowser=True, server_name='0.0.0.0', server_port=8082)


if __name__ == '__main__':
Expand Down
132 changes: 37 additions & 95 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
LlamaTokenizer,
LlamaForCausalLM,
TextIteratorStreamer,
GenerationConfig,
)
from transformers.generation import GenerationConfig

from supervised_finetuning import get_conv_template

Expand All @@ -33,47 +33,6 @@
}


class SimpleChatIO:
def prompt_for_input(self, role) -> str:
return input(f"{role}: ")

def prompt_for_output(self, role: str):
print(f"{role}: ", end="", flush=True)


@torch.inference_mode()
def generate_answer(
model,
tokenizer,
prompt,
device,
max_new_tokens=512,
temperature=0.7,
repetition_penalty=1.0,
context_len=2048
):
input_ids = tokenizer(prompt).input_ids
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:]
generation_config = dict(
input_ids=torch.as_tensor([input_ids]).to(device),
max_new_tokens=max_new_tokens,
temperature=temperature,
repetition_penalty=repetition_penalty,
)
generation_output = model.generate(**generation_config)
output_ids = generation_output[0]
output = tokenizer.decode(output_ids, skip_special_tokens=False).strip()
stop_str = tokenizer.eos_token or "</s>"
l_prompt = len(tokenizer.decode(input_ids, skip_special_tokens=False))
pos = output.find(stop_str, l_prompt)
if pos != -1:
output = output[l_prompt:pos]
else:
output = output[l_prompt:]
return output


@torch.inference_mode()
def stream_generate_answer(
model,
Expand All @@ -86,7 +45,7 @@ def stream_generate_answer(
repetition_penalty=1.0,
context_len=2048
):
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=False)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
input_ids = tokenizer(prompt).input_ids
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:]
Expand Down Expand Up @@ -133,7 +92,6 @@ def main():
parser.add_argument('--interactive', action='store_true', help="run in the instruction mode (single-turn)")
parser.add_argument('--predictions_file', default='./predictions.json', type=str)
parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings')
parser.add_argument('--use_stream', action='store_true', help='Whether to use stream generation')
parser.add_argument('--gpus', default="0", type=str)
parser.add_argument('--only_cpu', action='store_true', help='only use CPU for inference')
args = parser.parse_args()
Expand All @@ -159,8 +117,10 @@ def main():
device_map='auto',
trust_remote_code=True,
)
base_model.generation_config = GenerationConfig.from_pretrained(args.base_model, trust_remote_code=True)

try:
base_model.generation_config = GenerationConfig.from_pretrained(args.base_model, trust_remote_code=True)
except OSError:
print("Failed to load generation config, use default.")
if args.resize_emb:
model_vocab_size = base_model.get_input_embeddings().weight.size(0)
tokenzier_vocab_size = len(tokenizer)
Expand Down Expand Up @@ -189,78 +149,60 @@ def main():
for example in examples[:10]:
print(example)

chatio = SimpleChatIO()

# Chat
def new_chat():
return get_conv_template(args.template_name)
prompt_template = get_conv_template(args.template_name)

if args.interactive:
print("Start inference with interactive mode. command: `clear`, `exit`")
conv = new_chat()
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
history = []
while True:
try:
inp = chatio.prompt_for_input(conv.roles[0])
except EOFError:
inp = ""
query = input(f"{prompt_template.roles[0]}: ")
except UnicodeDecodeError:
print("UnicodeDecodeError, please try again.")
print("Detected decoding error at the inputs, please try again.")
continue
if inp == "":
except Exception:
raise
if query == "":
print("Please input text, try again.")
continue
if inp == "exit":
if query.strip() == "exit":
print("exit...")
break
if inp == "clear":
if query.strip() == "clear":
history = []
print("history cleared.")
conv = new_chat()
continue

conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], '')
print(f"{prompt_template.roles[1]}: ", end="", flush=True)

prompt = conv.get_prompt()
chatio.prompt_for_output(conv.roles[1])
if args.use_stream:
response = stream_generate_answer(
model,
tokenizer,
prompt,
device,
do_print=True,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
repetition_penalty=args.repetition_penalty
)
else:
response = generate_answer(
model,
tokenizer,
prompt,
device,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
repetition_penalty=args.repetition_penalty
)
print(response.strip(), flush=True)
# NOTE: strip is important to align with the training data.
conv.messages[-1][-1] = response.strip()
# print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
history.append([query, ''])
prompt = prompt_template.get_prompt(messages=history)
response = stream_generate_answer(
model,
tokenizer,
prompt,
device,
do_print=True,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
repetition_penalty=args.repetition_penalty
)
if history:
history[-1][-1] = response.strip()
else:
print("Start inference.")
results = []
for index, example in enumerate(examples):
conv = new_chat()
conv.append_message(conv.roles[0], example)
conv.append_message(conv.roles[1], '')

prompt = conv.get_prompt()
response = generate_answer(
# Single turn inference
history = [[example, '']]
prompt = prompt_template.get_prompt(messages=history)
response = stream_generate_answer(
model,
tokenizer,
prompt,
device,
do_print=False,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
repetition_penalty=args.repetition_penalty
Expand Down
1 change: 0 additions & 1 deletion run_training_pipeline.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,6 @@
" --save_total_limit 3 \\\n",
" --gradient_accumulation_steps 1 \\\n",
" --preprocessing_num_workers 1 \\\n",
" --model_max_length 512 \\\n",
" --output_dir outputs-sft-v1 \\\n",
" --overwrite_output_dir \\\n",
" --ddp_timeout 30000 \\\n",
Expand Down
Loading

0 comments on commit 85d322e

Please sign in to comment.