Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unexpected slow down in gradio web demo #707

Merged
merged 3 commits into from
Jul 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 116 additions & 32 deletions scripts/inference/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopPLogitsWarper,
TopKLogitsWarper,
StoppingCriteria,
)
import gradio as gr
import argparse
import os
from queue import Queue
from threading import Thread
import traceback
import gc


# Parse command-line arguments
Expand Down Expand Up @@ -137,6 +137,79 @@ def user(user_message, history):
[[user_message, None]]


class Stream(StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func

def __call__(self, input_ids, scores) -> bool:
if self.callback_func is not None:
self.callback_func(input_ids[0])
return False


class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).

Adapted from: https://stackoverflow.com/a/9969000
"""
def __init__(self, func, kwargs=None, callback=None):
self.mfunc = func
self.c_callback = callback
self.q = Queue()
self.sentinel = object()
self.kwargs = kwargs or {}
self.stop_now = False

def _callback(val):
if self.stop_now:
raise ValueError
self.q.put(val)

def gentask():
try:
ret = self.mfunc(callback=_callback, **self.kwargs)
except ValueError:
pass
except Exception:
traceback.print_exc()

clear_torch_cache()
self.q.put(self.sentinel)
if self.c_callback:
self.c_callback(ret)

self.thread = Thread(target=gentask)
self.thread.start()

def __iter__(self):
return self

def __next__(self):
obj = self.q.get(True, None)
if obj is self.sentinel:
raise StopIteration
else:
return obj

def __del__(self):
clear_torch_cache()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_now = True
clear_torch_cache()


def clear_torch_cache():
gc.collect()
if torch.cuda.device_count() > 0:
torch.cuda.empty_cache()


# Perform prediction based on the user input and history
@torch.no_grad()
def predict(
Expand All @@ -161,33 +234,44 @@ def predict(
prompt = generate_prompt(input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
original_size = len(input_ids[0])
logits_processor = LogitsProcessorList([
TemperatureLogitsWarper(temperature=temperature),
RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty)),
TopPLogitsWarper(top_p=top_p),
TopKLogitsWarper(top_k=top_k)
])
eos_token_id = tokenizer.eos_token_id
while True:
logits = model(input_ids).logits
logits = logits[:, -1, :]
logits = logits_processor(input_ids, logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1) \
if do_sample else torch.argmax(probs).unsqueeze(0).unsqueeze(0)
if next_token_id == eos_token_id:
break
tokens_previous = tokenizer.decode(
input_ids[0], skip_special_tokens=True)
input_ids = torch.cat((input_ids, next_token_id), dim=1)
tokens = tokenizer.decode(input_ids[0], skip_special_tokens=True)
new_tokens = tokens[len(tokens_previous) :]
history[-1][1] += new_tokens
yield history
input_ids = torch.cat((input_ids, next_token_id), dim=1)
if len(input_ids[0]) >= original_size + max_new_tokens:
break

generate_params = {
'input_ids': input_ids,
'max_new_tokens': max_new_tokens,
'top_p': top_p,
'temperature': temperature,
'top_k': top_k,
'do_sample': do_sample,
'repetition_penalty': repetition_penalty,
}

def generate_with_callback(callback=None, **kwargs):
if 'stopping_criteria' in kwargs:
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
else:
kwargs['stopping_criteria'] = [Stream(callback_func=callback)]
clear_torch_cache()
with torch.no_grad():
model.generate(**kwargs)

def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs, callback=None)

with generate_with_streaming(**generate_params) as generator:
for output in generator:
next_token_ids = output[len(input_ids[0]):]
if next_token_ids[0] == tokenizer.eos_token_id:
break
new_tokens = tokenizer.decode(
next_token_ids, skip_special_tokens=True)
if isinstance(tokenizer, LlamaTokenizer) and len(next_token_ids) > 0:
if tokenizer.convert_ids_to_tokens(int(next_token_ids[0])).startswith('▁'):
new_tokens = ' ' + new_tokens

history[-1][1] = new_tokens
yield history
if len(next_token_ids) >= max_new_tokens:
break


# Call the setup function to initialize the components
Expand Down