From b56800d8067485ba8541f9d91fec08cc7f01cea8 Mon Sep 17 00:00:00 2001 From: "li.yunhao" Date: Wed, 2 Aug 2023 00:06:17 +0800 Subject: [PATCH 1/8] add stream openai api support --- scripts/openai_server_demo/README.md | 2 +- .../openai_server_demo/openai_api_protocol.py | 21 +- .../openai_server_demo/openai_api_server.py | 228 ++++++++++++++---- 3 files changed, 195 insertions(+), 56 deletions(-) diff --git a/scripts/openai_server_demo/README.md b/scripts/openai_server_demo/README.md index aa6f2be..475f87c 100644 --- a/scripts/openai_server_demo/README.md +++ b/scripts/openai_server_demo/README.md @@ -8,7 +8,7 @@ 安装依赖 ``` shell -pip install fastapi uvicorn shortuuid +pip install fastapi uvicorn shortuuid sse_starlette ``` 启动脚本 diff --git a/scripts/openai_server_demo/openai_api_protocol.py b/scripts/openai_server_demo/openai_api_protocol.py index 7775f25..fee6ebe 100644 --- a/scripts/openai_server_demo/openai_api_protocol.py +++ b/scripts/openai_server_demo/openai_api_protocol.py @@ -1,10 +1,11 @@ -from typing import Optional, List, Dict, Any, Union +from typing import Optional, List, Dict, Any, Union, Literal import time import shortuuid from pydantic import BaseModel, Field + class ChatCompletionRequest(BaseModel): model: str = "chinese-llama-alpaca-2" messages: Union[str, List[Dict[str, str]]] @@ -26,17 +27,30 @@ class ChatMessage(BaseModel): content: str +class DeltaMessage(BaseModel): + role: Optional[Literal["user", "assistant", "system"]] = None + content: Optional[str] = None + + class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[Literal["stop", "length"]] + + class ChatCompletionResponse(BaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") object: str = "chat.completion" created: int = Field(default_factory=lambda: int(time.time())) model: str = "chinese-llama-alpaca-2" - choices: List[ChatCompletionResponseChoice] + choices: List[ + Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice] + ] class EmbeddingsRequest(BaseModel): @@ -76,6 +90,5 @@ class CompletionResponse(BaseModel): id: Optional[str] = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") object: Optional[str] = "text_completion" created: Optional[int] = Field(default_factory=lambda: int(time.time())) - model: Optional[str] = 'chinese-llama-alpaca-2' + model: Optional[str] = "chinese-llama-alpaca-2" choices: List[CompletionResponseChoice] - diff --git a/scripts/openai_server_demo/openai_api_server.py b/scripts/openai_server_demo/openai_api_server.py index 2f78280..9428877 100644 --- a/scripts/openai_server_demo/openai_api_server.py +++ b/scripts/openai_server_demo/openai_api_server.py @@ -2,15 +2,32 @@ import os from fastapi import FastAPI import uvicorn +from threading import Thread +from sse_starlette.sse import EventSourceResponse + parser = argparse.ArgumentParser() -parser.add_argument('--base_model', default=None, type=str, required=True) -parser.add_argument('--lora_model', default=None, type=str,help="If None, perform inference on the base model") -parser.add_argument('--tokenizer_path',default=None,type=str) -parser.add_argument('--gpus', default="0", type=str) -parser.add_argument('--load_in_8bit',action='store_true', help='Load the model in 8bit mode') -parser.add_argument('--only_cpu',action='store_true',help='Only use CPU for inference') -parser.add_argument('--alpha',type=str,default="1.0", help="The scaling factor of NTK method, can be a float or 'auto'. ") +parser.add_argument("--base_model", default=None, type=str, required=True) +parser.add_argument( + "--lora_model", + default=None, + type=str, + help="If None, perform inference on the base model", +) +parser.add_argument("--tokenizer_path", default=None, type=str) +parser.add_argument("--gpus", default="0", type=str) +parser.add_argument( + "--load_in_8bit", action="store_true", help="Load the model in 8bit mode" +) +parser.add_argument( + "--only_cpu", action="store_true", help="Only use CPU for inference" +) +parser.add_argument( + "--alpha", + type=str, + default="1.0", + help="The scaling factor of NTK method, can be a float or 'auto'. ", +) args = parser.parse_args() load_in_8bit = args.load_in_8bit if args.only_cpu is True: @@ -19,13 +36,20 @@ import torch import torch.nn.functional as F -from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig +from transformers import ( + LlamaForCausalLM, + LlamaTokenizer, + GenerationConfig, + TextIteratorStreamer, +) from peft import PeftModel import sys + parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(parent_dir) from attn_and_long_ctx_patches import apply_attention_patch, apply_ntk_scaling_patch + apply_attention_patch(use_memory_efficient_attention=True) apply_ntk_scaling_patch(args.alpha) @@ -39,13 +63,15 @@ CompletionResponseChoice, EmbeddingsRequest, EmbeddingsResponse, + ChatCompletionResponseStreamChoice, + DeltaMessage, ) load_type = torch.float16 if torch.cuda.is_available(): device = torch.device(0) else: - device = torch.device('cpu') + device = torch.device("cpu") if args.tokenizer_path is None: args.tokenizer_path = args.lora_model if args.lora_model is None: @@ -57,23 +83,28 @@ load_in_8bit=load_in_8bit, torch_dtype=load_type, low_cpu_mem_usage=True, - device_map='auto' if not args.only_cpu else None, - ) + device_map="auto" if not args.only_cpu else None, +) model_vocab_size = base_model.get_input_embeddings().weight.size(0) tokenzier_vocab_size = len(tokenizer) print(f"Vocab of the base model: {model_vocab_size}") print(f"Vocab of the tokenizer: {tokenzier_vocab_size}") -if model_vocab_size!=tokenzier_vocab_size: +if model_vocab_size != tokenzier_vocab_size: print("Resize model embeddings to fit tokenizer") base_model.resize_token_embeddings(tokenzier_vocab_size) if args.lora_model is not None: print("loading peft model") - model = PeftModel.from_pretrained(base_model, args.lora_model,torch_dtype=load_type,device_map='auto',) + model = PeftModel.from_pretrained( + base_model, + args.lora_model, + torch_dtype=load_type, + device_map="auto", + ) else: model = base_model -if device==torch.device('cpu'): +if device == torch.device("cpu"): model.float() model.eval() @@ -81,25 +112,28 @@ DEFAULT_SYSTEM_PROMPT = """你是一个乐于助人的助手。""" TEMPLATE_WITH_SYSTEM_PROMPT = ( - "[INST] <>\n" - "{system_prompt}\n" - "<>\n\n" - "{instruction} [/INST]" + "[INST] <>\n" "{system_prompt}\n" "<>\n\n" "{instruction} [/INST]" ) TEMPLATE_WITHOUT_SYSTEM_PROMPT = "[INST] {instruction} [/INST]" -def generate_prompt(instruction, response="", with_system_prompt=True, system_prompt=None): + +def generate_prompt( + instruction, response="", with_system_prompt=True, system_prompt=None +): if with_system_prompt is True: if system_prompt is None: system_prompt = DEFAULT_SYSTEM_PROMPT - prompt = TEMPLATE_WITH_SYSTEM_PROMPT.format_map({'instruction': instruction,'system_prompt': system_prompt}) + prompt = TEMPLATE_WITH_SYSTEM_PROMPT.format_map( + {"instruction": instruction, "system_prompt": system_prompt} + ) else: - prompt = TEMPLATE_WITHOUT_SYSTEM_PROMPT.format_map({'instruction': instruction}) - if len(response)>0: + prompt = TEMPLATE_WITHOUT_SYSTEM_PROMPT.format_map({"instruction": instruction}) + if len(response) > 0: prompt += " " + response return prompt + def generate_completion_prompt(instruction: str): """Generate prompt for completion""" return generate_prompt(instruction, response="", with_system_prompt=True) @@ -110,23 +144,26 @@ def generate_chat_prompt(messages: list): system_msg = None for msg in messages: - if msg.role == 'system': + if msg.role == "system": system_msg = msg.content prompt = "" is_first_user_content = True for msg in messages: - if msg.role == 'system': + if msg.role == "system": continue - if msg.role == 'user': + if msg.role == "user": if is_first_user_content is True: - prompt += generate_prompt(msg.content, with_system_prompt=True, system_prompt=system_msg) + prompt += generate_prompt( + msg.content, with_system_prompt=True, system_prompt=system_msg + ) is_first_user_content = False else: - prompt += ''+generate_prompt(msg.content, with_system_prompt=False) - if msg.role == 'assistant': - prompt += f" {msg.content}"+"" + prompt += "" + generate_prompt(msg.content, with_system_prompt=False) + if msg.role == "assistant": + prompt += f" {msg.content}" + "" return prompt + def predict( input, max_new_tokens=128, @@ -160,7 +197,7 @@ def predict( generation_config.return_dict_in_generate = True generation_config.output_scores = False generation_config.max_new_tokens = max_new_tokens - generation_config.repetition_penalty=float(repetition_penalty) + generation_config.repetition_penalty = float(repetition_penalty) with torch.no_grad(): generation_output = model.generate( input_ids=input_ids, @@ -171,17 +208,84 @@ def predict( output = output.split("[/INST]")[-1].strip() return output + +def stream_predict( + input, + max_new_tokens=128, + top_p=0.75, + temperature=0.1, + top_k=40, + num_beams=4, + repetition_penalty=1.0, + do_sample=True, + model_id="chinese-llama-alpaca-2", + **kwargs, +): + choice_data = ChatCompletionResponseStreamChoice( + index=0, delta=DeltaMessage(role="assistant"), finish_reason=None + ) + chunk = ChatCompletionResponse( + model=model_id, + choices=[choice_data], + object="chat.completion.chunk", + ) + yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + + if isinstance(input, str): + prompt = generate_completion_prompt(input) + else: + prompt = generate_chat_prompt(input) + inputs = tokenizer(prompt, return_tensors="pt") + input_ids = inputs["input_ids"].to(device) + generation_config = GenerationConfig( + temperature=temperature, + top_p=top_p, + top_k=top_k, + num_beams=num_beams, + do_sample=do_sample, + **kwargs, + ) + + streamer = TextIteratorStreamer(tokenizer) + generation_kwargs = dict( + streamer=streamer, + input_ids=input_ids, + generation_config=generation_config, + return_dict_in_generate=True, + output_scores=False, + max_new_tokens=max_new_tokens, + repetition_penalty=float(repetition_penalty), + ) + Thread(target=model.generate, kwargs=generation_kwargs).start() + for new_text in streamer: + if new_text.startswith("") or new_text.startswith("[/INST]"): + continue + if new_text.endswith(""): + new_text = new_text.split("")[0] + choice_data = ChatCompletionResponseStreamChoice( + index=0, delta=DeltaMessage(content=new_text), finish_reason=None + ) + chunk = ChatCompletionResponse( + model=model_id, choices=[choice_data], object="chat.completion.chunk" + ) + yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + choice_data = ChatCompletionResponseStreamChoice( + index=0, delta=DeltaMessage(), finish_reason="stop" + ) + chunk = ChatCompletionResponse( + model=model_id, choices=[choice_data], object="chat.completion.chunk" + ) + yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + yield "[DONE]" + + def get_embedding(input): """Get embedding main function""" with torch.no_grad(): - encoding = tokenizer( - input, padding=True, return_tensors="pt" - ) + encoding = tokenizer(input, padding=True, return_tensors="pt") input_ids = encoding["input_ids"].to(device) attention_mask = encoding["attention_mask"].to(device) - model_output = model( - input_ids, attention_mask, output_hidden_states=True - ) + model_output = model(input_ids, attention_mask, output_hidden_states=True) data = model_output.hidden_states[-1] mask = attention_mask.unsqueeze(-1).expand(data.size()).float() masked_embeddings = data * mask @@ -192,16 +296,30 @@ def get_embedding(input): ret = normalized_embeddings.squeeze(0).tolist() return ret + app = FastAPI() + @app.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest): """Creates a completion for the chat message""" msgs = request.messages if isinstance(msgs, str): - msgs = [ChatMessage(role='user',content=msgs)] + msgs = [ChatMessage(role="user", content=msgs)] else: - msgs = [ChatMessage(role=x['role'],content=x['message']) for x in msgs] + msgs = [ChatMessage(role=x["role"], content=x["message"]) for x in msgs] + if request.stream: + generate = stream_predict( + input=msgs, + max_new_tokens=request.max_tokens, + top_p=request.top_p, + top_k=request.top_k, + temperature=request.temperature, + num_beams=request.num_beams, + repetition_penalty=request.repetition_penalty, + do_sample=request.do_sample, + ) + return EventSourceResponse(generate, media_type="text/event-stream") output = predict( input=msgs, max_new_tokens=request.max_tokens, @@ -212,9 +330,16 @@ async def create_chat_completion(request: ChatCompletionRequest): repetition_penalty=request.repetition_penalty, do_sample=request.do_sample, ) - choices = [ChatCompletionResponseChoice(index = i, message = msg) for i, msg in enumerate(msgs)] - choices += [ChatCompletionResponseChoice(index = len(choices), message = ChatMessage(role='assistant',content=output))] - return ChatCompletionResponse(choices = choices) + choices = [ + ChatCompletionResponseChoice(index=i, message=msg) for i, msg in enumerate(msgs) + ] + choices += [ + ChatCompletionResponseChoice( + index=len(choices), message=ChatMessage(role="assistant", content=output) + ) + ] + return ChatCompletionResponse(choices=choices) + @app.post("/v1/completions") async def create_completion(request: CompletionRequest): @@ -229,23 +354,24 @@ async def create_completion(request: CompletionRequest): repetition_penalty=request.repetition_penalty, do_sample=request.do_sample, ) - choices = [CompletionResponseChoice(index = 0, text = output)] - return CompletionResponse(choices = choices) + choices = [CompletionResponseChoice(index=0, text=output)] + return CompletionResponse(choices=choices) + @app.post("/v1/embeddings") async def create_embeddings(request: EmbeddingsRequest): """Creates text embedding""" embedding = get_embedding(request.input) - data = [{ - "object": "embedding", - "embedding": embedding, - "index": 0 - }] + data = [{"object": "embedding", "embedding": embedding, "index": 0}] return EmbeddingsResponse(data=data) if __name__ == "__main__": log_config = uvicorn.config.LOGGING_CONFIG - log_config["formatters"]["access"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s" - log_config["formatters"]["default"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s" - uvicorn.run(app, host='0.0.0.0', port=19327, workers=1, log_config=log_config) + log_config["formatters"]["access"][ + "fmt" + ] = "%(asctime)s - %(levelname)s - %(message)s" + log_config["formatters"]["default"][ + "fmt" + ] = "%(asctime)s - %(levelname)s - %(message)s" + uvicorn.run(app, host="0.0.0.0", port=19327, workers=1, log_config=log_config) From 9efdb97ea376ef30e927023c432b80c56851bb83 Mon Sep 17 00:00:00 2001 From: Ziqing Yang Date: Thu, 3 Aug 2023 09:05:52 +0800 Subject: [PATCH 2/8] fix spelling error --- scripts/openai_server_demo/openai_api_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/openai_server_demo/openai_api_server.py b/scripts/openai_server_demo/openai_api_server.py index 1a1be9f..fb916bd 100644 --- a/scripts/openai_server_demo/openai_api_server.py +++ b/scripts/openai_server_demo/openai_api_server.py @@ -84,8 +84,8 @@ model_vocab_size = base_model.get_input_embeddings().weight.size(0) tokenizer_vocab_size = len(tokenizer) print(f"Vocab of the base model: {model_vocab_size}") -print(f"Vocab of the tokenizer: {tokenzier_vocab_size}") -if model_vocab_size != tokenzier_vocab_size: +print(f"Vocab of the tokenizer: {tokenizer_vocab_size}") +if model_vocab_size != tokenizer_vocab_size: print("Resize model embeddings to fit tokenizer") base_model.resize_token_embeddings(tokenizer_vocab_size) if args.lora_model is not None: From 720581544851ee3b0fe7e3268fbd34ed15b56663 Mon Sep 17 00:00:00 2001 From: "li.yunhao" Date: Thu, 3 Aug 2023 14:29:31 +0800 Subject: [PATCH 3/8] Add openai stream api docs. --- scripts/openai_server_demo/README.md | 10 ++++++---- scripts/openai_server_demo/openai_api_server.py | 10 +++++++++- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/scripts/openai_server_demo/README.md b/scripts/openai_server_demo/README.md index 475f87c..732aa37 100644 --- a/scripts/openai_server_demo/README.md +++ b/scripts/openai_server_demo/README.md @@ -137,7 +137,7 @@ curl http://localhost:19327/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "messages": [ - {"role": "user","message": "给我讲一些有关杭州的故事吧"} + {"role": "user","content": "给我讲一些有关杭州的故事吧"} ], "repetition_penalty": 1.0 }' @@ -179,9 +179,9 @@ curl http://localhost:19327/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "messages": [ - {"role": "user","message": "给我讲一些有关杭州的故事吧"}, - {"role": "assistant","message": "好的,请问您对杭州有什么特别的偏好吗?"}, - {"role": "user","message": "我比较喜欢和西湖,可以给我讲一下西湖吗"} + {"role": "user","content": "给我讲一些有关杭州的故事吧"}, + {"role": "assistant","content": "好的,请问您对杭州有什么特别的偏好吗?"}, + {"role": "user","content": "我比较喜欢和西湖,可以给我讲一下西湖吗"} ], "repetition_penalty": 1.0 }' @@ -246,6 +246,8 @@ json返回体: `do_sample`: 启用随机采样策略。默认为true。 +`stream`: OpenAI格式的流式返回。默认为false,设置为true时,会按照OpenAI的格式流式返回数据,可以作为任意基于ChatGPT的应用的后端。 + ### 文本嵌入向量(text embedding) 文本嵌入向量有很多作用,包括但不限于基于大型文档问答、总结一本书中的内容、为大语言模型找到与当前用户输入最相近的记忆等等。 diff --git a/scripts/openai_server_demo/openai_api_server.py b/scripts/openai_server_demo/openai_api_server.py index fb916bd..fd182cc 100644 --- a/scripts/openai_server_demo/openai_api_server.py +++ b/scripts/openai_server_demo/openai_api_server.py @@ -1,6 +1,7 @@ import argparse import os from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware import uvicorn from threading import Thread from sse_starlette.sse import EventSourceResponse @@ -294,6 +295,13 @@ def get_embedding(input): app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) @app.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest): @@ -302,7 +310,7 @@ async def create_chat_completion(request: ChatCompletionRequest): if isinstance(msgs, str): msgs = [ChatMessage(role="user", content=msgs)] else: - msgs = [ChatMessage(role=x["role"], content=x["message"]) for x in msgs] + msgs = [ChatMessage(role=x["role"], content=x["content"]) for x in msgs] if request.stream: generate = stream_predict( input=msgs, From 30cb8a5f50559974f0b58104b78874a888e185bc Mon Sep 17 00:00:00 2001 From: "li.yunhao" Date: Fri, 4 Aug 2023 00:58:34 +0800 Subject: [PATCH 4/8] fix stream api split --- scripts/openai_server_demo/openai_api_server.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/openai_server_demo/openai_api_server.py b/scripts/openai_server_demo/openai_api_server.py index fd182cc..ab36fed 100644 --- a/scripts/openai_server_demo/openai_api_server.py +++ b/scripts/openai_server_demo/openai_api_server.py @@ -74,6 +74,7 @@ args.base_model, torch_dtype=load_type, low_cpu_mem_usage=True, + offload_folder="offload", # TODO device_map='auto' if not args.only_cpu else None, quantization_config=BitsAndBytesConfig( load_in_4bit=args.load_in_4bit, @@ -254,8 +255,10 @@ def stream_predict( ) Thread(target=model.generate, kwargs=generation_kwargs).start() for new_text in streamer: - if new_text.startswith("") or new_text.startswith("[/INST]"): + if new_text.startswith(""): continue + if new_text.startswith("[/INST]"): + new_text = new_text.split("[/INST]")[-1] if new_text.endswith(""): new_text = new_text.split("")[0] choice_data = ChatCompletionResponseStreamChoice( From ef0e1c11758b0df9a822392627d509f546868c23 Mon Sep 17 00:00:00 2001 From: "li.yunhao" Date: Fri, 4 Aug 2023 01:09:16 +0800 Subject: [PATCH 5/8] fix stream api bug --- scripts/openai_server_demo/openai_api_server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/openai_server_demo/openai_api_server.py b/scripts/openai_server_demo/openai_api_server.py index ab36fed..73ae747 100644 --- a/scripts/openai_server_demo/openai_api_server.py +++ b/scripts/openai_server_demo/openai_api_server.py @@ -74,7 +74,6 @@ args.base_model, torch_dtype=load_type, low_cpu_mem_usage=True, - offload_folder="offload", # TODO device_map='auto' if not args.only_cpu else None, quantization_config=BitsAndBytesConfig( load_in_4bit=args.load_in_4bit, From 0db43d87019ffc404765c15835e88a96615170a0 Mon Sep 17 00:00:00 2001 From: "li.yunhao" Date: Fri, 4 Aug 2023 11:31:17 +0800 Subject: [PATCH 6/8] fix api extra token --- scripts/openai_server_demo/openai_api_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/openai_server_demo/openai_api_server.py b/scripts/openai_server_demo/openai_api_server.py index 73ae747..05e3673 100644 --- a/scripts/openai_server_demo/openai_api_server.py +++ b/scripts/openai_server_demo/openai_api_server.py @@ -257,7 +257,7 @@ def stream_predict( if new_text.startswith(""): continue if new_text.startswith("[/INST]"): - new_text = new_text.split("[/INST]")[-1] + new_text = new_text.split("[/INST]")[-1].strip() if new_text.endswith(""): new_text = new_text.split("")[0] choice_data = ChatCompletionResponseStreamChoice( From 01fc1dfdc1899d7f7e8e6897fd77504707d8fafd Mon Sep 17 00:00:00 2001 From: "li.yunhao" Date: Fri, 4 Aug 2023 21:25:46 +0800 Subject: [PATCH 7/8] fix openai api repetition of the last word --- scripts/openai_server_demo/openai_api_server.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/openai_server_demo/openai_api_server.py b/scripts/openai_server_demo/openai_api_server.py index 05e3673..e353a06 100644 --- a/scripts/openai_server_demo/openai_api_server.py +++ b/scripts/openai_server_demo/openai_api_server.py @@ -254,12 +254,13 @@ def stream_predict( ) Thread(target=model.generate, kwargs=generation_kwargs).start() for new_text in streamer: + new_text = new_text.strip() if new_text.startswith(""): continue if new_text.startswith("[/INST]"): - new_text = new_text.split("[/INST]")[-1].strip() + new_text = new_text.split("[/INST]")[-1] if new_text.endswith(""): - new_text = new_text.split("")[0] + new_text = new_text.split("")[0][1:] choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(content=new_text), finish_reason=None ) From ebbe80e6540305a4aa6c1d96a5f864d524f6af18 Mon Sep 17 00:00:00 2001 From: "li.yunhao" Date: Mon, 7 Aug 2023 18:24:54 +0800 Subject: [PATCH 8/8] fix api first toekn missed --- scripts/openai_server_demo/openai_api_server.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/scripts/openai_server_demo/openai_api_server.py b/scripts/openai_server_demo/openai_api_server.py index e353a06..687ac4d 100644 --- a/scripts/openai_server_demo/openai_api_server.py +++ b/scripts/openai_server_demo/openai_api_server.py @@ -242,7 +242,7 @@ def stream_predict( **kwargs, ) - streamer = TextIteratorStreamer(tokenizer) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( streamer=streamer, input_ids=input_ids, @@ -254,13 +254,6 @@ def stream_predict( ) Thread(target=model.generate, kwargs=generation_kwargs).start() for new_text in streamer: - new_text = new_text.strip() - if new_text.startswith(""): - continue - if new_text.startswith("[/INST]"): - new_text = new_text.split("[/INST]")[-1] - if new_text.endswith(""): - new_text = new_text.split("")[0][1:] choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(content=new_text), finish_reason=None )