-
Notifications
You must be signed in to change notification settings - Fork 834
/
Copy pathchatbot.py
161 lines (133 loc) · 4.53 KB
/
chatbot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved.
"""A simple shell chatbot implemented with lmflow APIs.
"""
import logging
import json
import os
import sys
sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0])))
import warnings
from dataclasses import dataclass, field
from transformers import HfArgumentParser
from typing import Optional
from lmflow.datasets.dataset import Dataset
from lmflow.pipeline.auto_pipeline import AutoPipeline
from lmflow.models.auto_model import AutoModel
from lmflow.args import ModelArguments, DatasetArguments, AutoArguments
logging.disable(logging.ERROR)
warnings.filterwarnings("ignore")
@dataclass
class ChatbotArguments:
prompt_structure: Optional[str] = field(
default="{input_text}",
metadata={
"help": "prompt structure given user's input text"
},
)
end_string: Optional[str] = field(
default="\n\n",
metadata={
"help": "end string mark of the chatbot's output"
},
)
num_token_per_step: int = field(
default=4,
metadata={
"help": "Number of tokens per step for stream inference"
},
)
def main():
pipeline_name = "inferencer"
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)
parser = HfArgumentParser((
ModelArguments,
PipelineArguments,
ChatbotArguments,
))
model_args, pipeline_args, chatbot_args = (
parser.parse_args_into_dataclasses()
)
inferencer_args = pipeline_args
with open (pipeline_args.deepspeed, "r") as f:
ds_config = json.load(f)
model = AutoModel.get_model(
model_args,
tune_strategy='none',
ds_config=ds_config,
device=pipeline_args.device,
use_accelerator=True,
)
# We don't need input data, we will read interactively from stdin
data_args = DatasetArguments(dataset_path=None)
dataset = Dataset(data_args)
inferencer = AutoPipeline.get_pipeline(
pipeline_name=pipeline_name,
model_args=model_args,
data_args=data_args,
pipeline_args=pipeline_args,
)
# Chats
model_name = model_args.model_name_or_path
if model_args.lora_model_path is not None:
model_name += f" + {model_args.lora_model_path}"
guide_message = (
"\n"
f"#############################################################################\n"
f"## A {model_name} chatbot is now chatting with you!\n"
f"#############################################################################\n"
"\n"
)
print(guide_message)
# context = (
# "You are a helpful assistant who follows the given instructions"
# " unconditionally."
# )
context = ""
end_string = chatbot_args.end_string
prompt_structure = chatbot_args.prompt_structure
while True:
input_text = input("User >>> ")
if input_text == "exit":
print("exit...")
break
elif input_text == "reset":
context = ""
print("Chat history cleared")
continue
if not input_text:
input_text = " "
context += prompt_structure.format(input_text=input_text)
context = context[-model.get_max_length():] # Memory of the bot
input_dataset = dataset.from_dict({
"type": "text_only",
"instances": [ { "text": context } ]
})
print("Bot: ", end="")
print_index = 0
token_per_step = chatbot_args.num_token_per_step
for response, flag_break in inferencer.stream_inference(
context=context,
model=model,
max_new_tokens=inferencer_args.max_new_tokens,
token_per_step=token_per_step,
temperature=inferencer_args.temperature,
end_string=end_string,
input_dataset=input_dataset
):
# Prints characters in the buffer
new_print_index = print_index
for char in response[print_index:]:
if end_string is not None and char == end_string[0]:
if new_print_index + len(end_string) >= len(response):
break
new_print_index += 1
print(char, end="", flush=True)
print_index = new_print_index
if flag_break:
break
print("\n", end="")
context += response + "\n"
if __name__ == "__main__":
main()