diff --git a/inference/bot.py b/inference/bot.py index 12b23c0..aa73af5 100644 --- a/inference/bot.py +++ b/inference/bot.py @@ -11,10 +11,39 @@ import argparse import conversation as convo import retrieval.wikipedia as wp -from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, StoppingCriteria, StoppingCriteriaList from accelerate import infer_auto_device_map, init_empty_weights +class StopWordsCriteria(StoppingCriteria): + def __init__(self, tokenizer, stop_words, stream_callback): + self._tokenizer = tokenizer + self._stop_words = stop_words + self._partial_result = '' + self._stream_buffer = '' + self._stream_callback = stream_callback + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + first = not self._partial_result + text = self._tokenizer.decode(input_ids[0, -1]) + self._partial_result += text + for stop_word in self._stop_words: + if stop_word in self._partial_result: + return True + if self._stream_callback: + if first: + text = text.lstrip() + # buffer tokens if the partial result ends with a prefix of a stop word, e.g. "