From 98907d466fdb2a600c8ad230b42aa660917db532 Mon Sep 17 00:00:00 2001 From: GoGoJoestar Date: Mon, 3 Jul 2023 18:13:14 +0800 Subject: [PATCH 1/3] fix output speed in gradio demo --- scripts/inference/gradio_demo.py | 148 ++++++++++++++++++++++++------- 1 file changed, 117 insertions(+), 31 deletions(-) diff --git a/scripts/inference/gradio_demo.py b/scripts/inference/gradio_demo.py index c417446..27d40c4 100644 --- a/scripts/inference/gradio_demo.py +++ b/scripts/inference/gradio_demo.py @@ -2,15 +2,15 @@ from transformers import ( LlamaForCausalLM, LlamaTokenizer, - LogitsProcessorList, - RepetitionPenaltyLogitsProcessor, - TemperatureLogitsWarper, - TopPLogitsWarper, - TopKLogitsWarper, + StoppingCriteria, ) import gradio as gr import argparse import os +from queue import Queue +from threading import Thread +import traceback +import gc # Parse command-line arguments @@ -137,6 +137,80 @@ def user(user_message, history): [[user_message, None]] +class Stream(StoppingCriteria): + def __init__(self, callback_func=None): + self.callback_func = callback_func + + def __call__(self, input_ids, scores) -> bool: + if self.callback_func is not None: + self.callback_func(input_ids[0]) + return False + + +class Iteratorize: + """ + Transforms a function that takes a callback + into a lazy iterator (generator). + + Adapted from: https://stackoverflow.com/a/9969000 + """ + def __init__(self, func, kwargs=None, callback=None): + self.mfunc = func + self.c_callback = callback + self.q = Queue() + self.sentinel = object() + self.kwargs = kwargs or {} + self.stop_now = False + + def _callback(val): + if self.stop_now: + raise ValueError + self.q.put(val) + + def gentask(): + try: + ret = self.mfunc(callback=_callback, **self.kwargs) + except ValueError: + pass + except: + traceback.print_exc() + pass + + clear_torch_cache() + self.q.put(self.sentinel) + if self.c_callback: + self.c_callback(ret) + + self.thread = Thread(target=gentask) + self.thread.start() + + def __iter__(self): + return self + + def __next__(self): + obj = self.q.get(True, None) + if obj is self.sentinel: + raise StopIteration + else: + return obj + + def __del__(self): + clear_torch_cache() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop_now = True + clear_torch_cache() + + +def clear_torch_cache(): + gc.collect() + if torch.cuda.device_count() > 0: + torch.cuda.empty_cache() + + # Perform prediction based on the user input and history @torch.no_grad() def predict( @@ -162,32 +236,44 @@ def predict( inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to(device) original_size = len(input_ids[0]) - logits_processor = LogitsProcessorList([ - TemperatureLogitsWarper(temperature=temperature), - RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty)), - TopPLogitsWarper(top_p=top_p), - TopKLogitsWarper(top_k=top_k) - ]) - eos_token_id = tokenizer.eos_token_id - while True: - logits = model(input_ids).logits - logits = logits[:, -1, :] - logits = logits_processor(input_ids, logits) - probs = torch.nn.functional.softmax(logits, dim=-1) - next_token_id = torch.multinomial(probs, num_samples=1) \ - if do_sample else torch.argmax(probs).unsqueeze(0).unsqueeze(0) - if next_token_id == eos_token_id: - break - tokens_previous = tokenizer.decode( - input_ids[0], skip_special_tokens=True) - input_ids = torch.cat((input_ids, next_token_id), dim=1) - tokens = tokenizer.decode(input_ids[0], skip_special_tokens=True) - new_tokens = tokens[len(tokens_previous) :] - history[-1][1] += new_tokens - yield history - input_ids = torch.cat((input_ids, next_token_id), dim=1) - if len(input_ids[0]) >= original_size + max_new_tokens: - break + + generate_params = { + 'input_ids': input_ids, + 'max_new_tokens': max_new_tokens, + 'top_p': top_p, + 'temperature': temperature, + 'top_k': top_k, + 'do_sample': do_sample, + 'repetition_penalty': repetition_penalty, + } + + def generate_with_callback(callback=None, **kwargs): + if 'stopping_criteria' in kwargs: + kwargs['stopping_criteria'].append(Stream(callback_func=callback)) + else: + kwargs['stopping_criteria'] = [Stream(callback_func=callback)] + clear_torch_cache() + with torch.no_grad(): + model.generate(**kwargs) + + def generate_with_streaming(**kwargs): + return Iteratorize(generate_with_callback, kwargs, callback=None) + + with generate_with_streaming(**generate_params) as generator: + for output in generator: + next_token_ids = output[len(input_ids[0]):] + if next_token_ids[0] == tokenizer.eos_token_id: + break + new_tokens = tokenizer.decode( + next_token_ids, skip_special_tokens=True) + if type(tokenizer) is LlamaTokenizer and len(next_token_ids) > 0: + if tokenizer.convert_ids_to_tokens(int(next_token_ids[0])).startswith('▁'): + new_tokens = ' ' + new_tokens + + history[-1][1] = new_tokens + yield history + if len(next_token_ids) >= max_new_tokens: + break # Call the setup function to initialize the components From 58d1020ffb634828c0f657355697fbb52ab5ba49 Mon Sep 17 00:00:00 2001 From: GoGoJoestar Date: Tue, 4 Jul 2023 09:21:34 +0800 Subject: [PATCH 2/3] fix Codacy issues --- scripts/inference/gradio_demo.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/scripts/inference/gradio_demo.py b/scripts/inference/gradio_demo.py index 27d40c4..4476d46 100644 --- a/scripts/inference/gradio_demo.py +++ b/scripts/inference/gradio_demo.py @@ -172,9 +172,8 @@ def gentask(): ret = self.mfunc(callback=_callback, **self.kwargs) except ValueError: pass - except: + except Exception as e: traceback.print_exc() - pass clear_torch_cache() self.q.put(self.sentinel) @@ -235,7 +234,6 @@ def predict( prompt = generate_prompt(input) inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to(device) - original_size = len(input_ids[0]) generate_params = { 'input_ids': input_ids, @@ -266,7 +264,7 @@ def generate_with_streaming(**kwargs): break new_tokens = tokenizer.decode( next_token_ids, skip_special_tokens=True) - if type(tokenizer) is LlamaTokenizer and len(next_token_ids) > 0: + if isinstance(tokenizer, LlamaTokenizer) and len(next_token_ids) > 0: if tokenizer.convert_ids_to_tokens(int(next_token_ids[0])).startswith('▁'): new_tokens = ' ' + new_tokens From 25c7a8ce6149da7edf0cd34d56b0ab3b10523fc1 Mon Sep 17 00:00:00 2001 From: GoGoJoestar Date: Tue, 4 Jul 2023 09:32:19 +0800 Subject: [PATCH 3/3] fix Codacy issues --- scripts/inference/gradio_demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/inference/gradio_demo.py b/scripts/inference/gradio_demo.py index 4476d46..a0d9a47 100644 --- a/scripts/inference/gradio_demo.py +++ b/scripts/inference/gradio_demo.py @@ -172,7 +172,7 @@ def gentask(): ret = self.mfunc(callback=_callback, **self.kwargs) except ValueError: pass - except Exception as e: + except Exception: traceback.print_exc() clear_torch_cache()