Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 46 additions & 7 deletions inference/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,39 @@
import argparse
import conversation as convo
import retrieval.wikipedia as wp
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, StoppingCriteria, StoppingCriteriaList
from accelerate import infer_auto_device_map, init_empty_weights


class StopWordsCriteria(StoppingCriteria):
def __init__(self, tokenizer, stop_words, stream_callback):
self._tokenizer = tokenizer
self._stop_words = stop_words
self._partial_result = ''
self._stream_buffer = ''
self._stream_callback = stream_callback

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
first = not self._partial_result
text = self._tokenizer.decode(input_ids[0, -1])
self._partial_result += text
for stop_word in self._stop_words:
if stop_word in self._partial_result:
return True
if self._stream_callback:
if first:
text = text.lstrip()
# buffer tokens if the partial result ends with a prefix of a stop word, e.g. "<hu"
for stop_word in self._stop_words:
for i in range(1, len(stop_word)):
if self._partial_result.endswith(stop_word[0:i]):
self._stream_buffer += text
return False
self._stream_callback(self._stream_buffer + text)
self._stream_buffer = ''
return False


class ChatModel:
human_id = "<human>"
bot_id = "<bot>"
Expand Down Expand Up @@ -54,7 +83,8 @@ def __init__(self, model_name, gpu_id, max_memory):
)
self._tokenizer = AutoTokenizer.from_pretrained(model_name)

def do_inference(self, prompt, max_new_tokens, do_sample, temperature, top_k):
def do_inference(self, prompt, max_new_tokens, do_sample, temperature, top_k, stream_callback=None):
stop_criteria = StopWordsCriteria(self._tokenizer, [self.human_id], stream_callback)
inputs = (
self._tokenizer(prompt, return_tensors='pt')
.to(self._model.device)
Expand All @@ -65,7 +95,8 @@ def do_inference(self, prompt, max_new_tokens, do_sample, temperature, top_k):
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
pad_token_id=self._tokenizer.eos_token_id
pad_token_id=self._tokenizer.eos_token_id,
stopping_criteria=StoppingCriteriaList([stop_criteria]),
)
output = self._tokenizer.batch_decode(outputs)[0]

Expand All @@ -79,7 +110,7 @@ class OpenChatKitShell(cmd.Cmd):
intro = "Welcome to OpenChatKit shell. Type /help or /? to list commands.\n"
prompt = ">>> "

def __init__(self, gpu_id, model_name_or_path, max_tokens, sample, temperature, top_k, retrieval, max_memory):
def __init__(self, gpu_id, model_name_or_path, max_tokens, sample, temperature, top_k, retrieval, max_memory, do_stream):
super().__init__()
self._gpu_id = int(gpu_id)
self._model_name_or_path = model_name_or_path
Expand All @@ -89,6 +120,7 @@ def __init__(self, gpu_id, model_name_or_path, max_tokens, sample, temperature,
self._top_k = top_k
self._retrieval = retrieval
self._max_memory = max_memory
self._do_stream = do_stream

def preloop(self):
print(f"Loading {self._model_name_or_path} to cuda:{self._gpu_id}...")
Expand Down Expand Up @@ -120,12 +152,13 @@ def do_say(self, arg):
self._max_tokens,
self._sample,
self._temperature,
self._top_k
self._top_k,
lambda x : print(x, end='', flush=True) if self._do_stream else None,
)

self._convo.push_model_response(output)

print(self._convo.get_last_turn())
print("" if self._do_stream else self._convo.get_last_turn())

def do_raw_say(self, arg):
output = self._model.do_inference(
Expand Down Expand Up @@ -183,6 +216,11 @@ def main():
action='store_true',
help='indicates whether to sample'
)
parser.add_argument(
'--no-stream',
action='store_true',
help='indicates whether to stream tokens'
)
parser.add_argument(
'--temperature',
default=0.6,
Expand Down Expand Up @@ -238,7 +276,8 @@ def main():
args.temperature,
args.top_k,
args.retrieval,
max_memory
max_memory,
not args.no_stream,
).cmdloop()


Expand Down