diff --git a/scripts/inference/gradio_demo.py b/scripts/inference/gradio_demo.py index c417446..a0d9a47 100644 --- a/scripts/inference/gradio_demo.py +++ b/scripts/inference/gradio_demo.py @@ -2,15 +2,15 @@ from transformers import ( LlamaForCausalLM, LlamaTokenizer, - LogitsProcessorList, - RepetitionPenaltyLogitsProcessor, - TemperatureLogitsWarper, - TopPLogitsWarper, - TopKLogitsWarper, + StoppingCriteria, ) import gradio as gr import argparse import os +from queue import Queue +from threading import Thread +import traceback +import gc # Parse command-line arguments @@ -137,6 +137,79 @@ def user(user_message, history): [[user_message, None]] +class Stream(StoppingCriteria): + def __init__(self, callback_func=None): + self.callback_func = callback_func + + def __call__(self, input_ids, scores) -> bool: + if self.callback_func is not None: + self.callback_func(input_ids[0]) + return False + + +class Iteratorize: + """ + Transforms a function that takes a callback + into a lazy iterator (generator). + + Adapted from: https://stackoverflow.com/a/9969000 + """ + def __init__(self, func, kwargs=None, callback=None): + self.mfunc = func + self.c_callback = callback + self.q = Queue() + self.sentinel = object() + self.kwargs = kwargs or {} + self.stop_now = False + + def _callback(val): + if self.stop_now: + raise ValueError + self.q.put(val) + + def gentask(): + try: + ret = self.mfunc(callback=_callback, **self.kwargs) + except ValueError: + pass + except Exception: + traceback.print_exc() + + clear_torch_cache() + self.q.put(self.sentinel) + if self.c_callback: + self.c_callback(ret) + + self.thread = Thread(target=gentask) + self.thread.start() + + def __iter__(self): + return self + + def __next__(self): + obj = self.q.get(True, None) + if obj is self.sentinel: + raise StopIteration + else: + return obj + + def __del__(self): + clear_torch_cache() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop_now = True + clear_torch_cache() + + +def clear_torch_cache(): + gc.collect() + if torch.cuda.device_count() > 0: + torch.cuda.empty_cache() + + # Perform prediction based on the user input and history @torch.no_grad() def predict( @@ -161,33 +234,44 @@ def predict( prompt = generate_prompt(input) inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to(device) - original_size = len(input_ids[0]) - logits_processor = LogitsProcessorList([ - TemperatureLogitsWarper(temperature=temperature), - RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty)), - TopPLogitsWarper(top_p=top_p), - TopKLogitsWarper(top_k=top_k) - ]) - eos_token_id = tokenizer.eos_token_id - while True: - logits = model(input_ids).logits - logits = logits[:, -1, :] - logits = logits_processor(input_ids, logits) - probs = torch.nn.functional.softmax(logits, dim=-1) - next_token_id = torch.multinomial(probs, num_samples=1) \ - if do_sample else torch.argmax(probs).unsqueeze(0).unsqueeze(0) - if next_token_id == eos_token_id: - break - tokens_previous = tokenizer.decode( - input_ids[0], skip_special_tokens=True) - input_ids = torch.cat((input_ids, next_token_id), dim=1) - tokens = tokenizer.decode(input_ids[0], skip_special_tokens=True) - new_tokens = tokens[len(tokens_previous) :] - history[-1][1] += new_tokens - yield history - input_ids = torch.cat((input_ids, next_token_id), dim=1) - if len(input_ids[0]) >= original_size + max_new_tokens: - break + + generate_params = { + 'input_ids': input_ids, + 'max_new_tokens': max_new_tokens, + 'top_p': top_p, + 'temperature': temperature, + 'top_k': top_k, + 'do_sample': do_sample, + 'repetition_penalty': repetition_penalty, + } + + def generate_with_callback(callback=None, **kwargs): + if 'stopping_criteria' in kwargs: + kwargs['stopping_criteria'].append(Stream(callback_func=callback)) + else: + kwargs['stopping_criteria'] = [Stream(callback_func=callback)] + clear_torch_cache() + with torch.no_grad(): + model.generate(**kwargs) + + def generate_with_streaming(**kwargs): + return Iteratorize(generate_with_callback, kwargs, callback=None) + + with generate_with_streaming(**generate_params) as generator: + for output in generator: + next_token_ids = output[len(input_ids[0]):] + if next_token_ids[0] == tokenizer.eos_token_id: + break + new_tokens = tokenizer.decode( + next_token_ids, skip_special_tokens=True) + if isinstance(tokenizer, LlamaTokenizer) and len(next_token_ids) > 0: + if tokenizer.convert_ids_to_tokens(int(next_token_ids[0])).startswith('▁'): + new_tokens = ' ' + new_tokens + + history[-1][1] = new_tokens + yield history + if len(next_token_ids) >= max_new_tokens: + break # Call the setup function to initialize the components