Skip to content

Commit

Permalink
Merge pull request #707 from GoGoJoestar/main
Browse files Browse the repository at this point in the history
Fix unexpected slow down in gradio web demo
  • Loading branch information
ymcui committed Jul 5, 2023
2 parents 4b3d16f + 25c7a8c commit 4ef9477
Showing 1 changed file with 116 additions and 32 deletions.
148 changes: 116 additions & 32 deletions scripts/inference/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopPLogitsWarper,
TopKLogitsWarper,
StoppingCriteria,
)
import gradio as gr
import argparse
import os
from queue import Queue
from threading import Thread
import traceback
import gc


# Parse command-line arguments
Expand Down Expand Up @@ -137,6 +137,79 @@ def user(user_message, history):
[[user_message, None]]


class Stream(StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func

def __call__(self, input_ids, scores) -> bool:
if self.callback_func is not None:
self.callback_func(input_ids[0])
return False


class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
Adapted from: https://stackoverflow.com/a/9969000
"""
def __init__(self, func, kwargs=None, callback=None):
self.mfunc = func
self.c_callback = callback
self.q = Queue()
self.sentinel = object()
self.kwargs = kwargs or {}
self.stop_now = False

def _callback(val):
if self.stop_now:
raise ValueError
self.q.put(val)

def gentask():
try:
ret = self.mfunc(callback=_callback, **self.kwargs)
except ValueError:
pass
except Exception:
traceback.print_exc()

clear_torch_cache()
self.q.put(self.sentinel)
if self.c_callback:
self.c_callback(ret)

self.thread = Thread(target=gentask)
self.thread.start()

def __iter__(self):
return self

def __next__(self):
obj = self.q.get(True, None)
if obj is self.sentinel:
raise StopIteration
else:
return obj

def __del__(self):
clear_torch_cache()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_now = True
clear_torch_cache()


def clear_torch_cache():
gc.collect()
if torch.cuda.device_count() > 0:
torch.cuda.empty_cache()


# Perform prediction based on the user input and history
@torch.no_grad()
def predict(
Expand All @@ -161,33 +234,44 @@ def predict(
prompt = generate_prompt(input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
original_size = len(input_ids[0])
logits_processor = LogitsProcessorList([
TemperatureLogitsWarper(temperature=temperature),
RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty)),
TopPLogitsWarper(top_p=top_p),
TopKLogitsWarper(top_k=top_k)
])
eos_token_id = tokenizer.eos_token_id
while True:
logits = model(input_ids).logits
logits = logits[:, -1, :]
logits = logits_processor(input_ids, logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1) \
if do_sample else torch.argmax(probs).unsqueeze(0).unsqueeze(0)
if next_token_id == eos_token_id:
break
tokens_previous = tokenizer.decode(
input_ids[0], skip_special_tokens=True)
input_ids = torch.cat((input_ids, next_token_id), dim=1)
tokens = tokenizer.decode(input_ids[0], skip_special_tokens=True)
new_tokens = tokens[len(tokens_previous) :]
history[-1][1] += new_tokens
yield history
input_ids = torch.cat((input_ids, next_token_id), dim=1)
if len(input_ids[0]) >= original_size + max_new_tokens:
break

generate_params = {
'input_ids': input_ids,
'max_new_tokens': max_new_tokens,
'top_p': top_p,
'temperature': temperature,
'top_k': top_k,
'do_sample': do_sample,
'repetition_penalty': repetition_penalty,
}

def generate_with_callback(callback=None, **kwargs):
if 'stopping_criteria' in kwargs:
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
else:
kwargs['stopping_criteria'] = [Stream(callback_func=callback)]
clear_torch_cache()
with torch.no_grad():
model.generate(**kwargs)

def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs, callback=None)

with generate_with_streaming(**generate_params) as generator:
for output in generator:
next_token_ids = output[len(input_ids[0]):]
if next_token_ids[0] == tokenizer.eos_token_id:
break
new_tokens = tokenizer.decode(
next_token_ids, skip_special_tokens=True)
if isinstance(tokenizer, LlamaTokenizer) and len(next_token_ids) > 0:
if tokenizer.convert_ids_to_tokens(int(next_token_ids[0])).startswith('▁'):
new_tokens = ' ' + new_tokens

history[-1][1] = new_tokens
yield history
if len(next_token_ids) >= max_new_tokens:
break


# Call the setup function to initialize the components
Expand Down

0 comments on commit 4ef9477

Please sign in to comment.