From b56800d8067485ba8541f9d91fec08cc7f01cea8 Mon Sep 17 00:00:00 2001
From: "li.yunhao"
Date: Wed, 2 Aug 2023 00:06:17 +0800
Subject: [PATCH 1/8] add stream openai api support
---
scripts/openai_server_demo/README.md | 2 +-
.../openai_server_demo/openai_api_protocol.py | 21 +-
.../openai_server_demo/openai_api_server.py | 228 ++++++++++++++----
3 files changed, 195 insertions(+), 56 deletions(-)
diff --git a/scripts/openai_server_demo/README.md b/scripts/openai_server_demo/README.md
index aa6f2be..475f87c 100644
--- a/scripts/openai_server_demo/README.md
+++ b/scripts/openai_server_demo/README.md
@@ -8,7 +8,7 @@
安装依赖
``` shell
-pip install fastapi uvicorn shortuuid
+pip install fastapi uvicorn shortuuid sse_starlette
```
启动脚本
diff --git a/scripts/openai_server_demo/openai_api_protocol.py b/scripts/openai_server_demo/openai_api_protocol.py
index 7775f25..fee6ebe 100644
--- a/scripts/openai_server_demo/openai_api_protocol.py
+++ b/scripts/openai_server_demo/openai_api_protocol.py
@@ -1,10 +1,11 @@
-from typing import Optional, List, Dict, Any, Union
+from typing import Optional, List, Dict, Any, Union, Literal
import time
import shortuuid
from pydantic import BaseModel, Field
+
class ChatCompletionRequest(BaseModel):
model: str = "chinese-llama-alpaca-2"
messages: Union[str, List[Dict[str, str]]]
@@ -26,17 +27,30 @@ class ChatMessage(BaseModel):
content: str
+class DeltaMessage(BaseModel):
+ role: Optional[Literal["user", "assistant", "system"]] = None
+ content: Optional[str] = None
+
+
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
+class ChatCompletionResponseStreamChoice(BaseModel):
+ index: int
+ delta: DeltaMessage
+ finish_reason: Optional[Literal["stop", "length"]]
+
+
class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str = "chinese-llama-alpaca-2"
- choices: List[ChatCompletionResponseChoice]
+ choices: List[
+ Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]
+ ]
class EmbeddingsRequest(BaseModel):
@@ -76,6 +90,5 @@ class CompletionResponse(BaseModel):
id: Optional[str] = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}")
object: Optional[str] = "text_completion"
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
- model: Optional[str] = 'chinese-llama-alpaca-2'
+ model: Optional[str] = "chinese-llama-alpaca-2"
choices: List[CompletionResponseChoice]
-
diff --git a/scripts/openai_server_demo/openai_api_server.py b/scripts/openai_server_demo/openai_api_server.py
index 2f78280..9428877 100644
--- a/scripts/openai_server_demo/openai_api_server.py
+++ b/scripts/openai_server_demo/openai_api_server.py
@@ -2,15 +2,32 @@
import os
from fastapi import FastAPI
import uvicorn
+from threading import Thread
+from sse_starlette.sse import EventSourceResponse
+
parser = argparse.ArgumentParser()
-parser.add_argument('--base_model', default=None, type=str, required=True)
-parser.add_argument('--lora_model', default=None, type=str,help="If None, perform inference on the base model")
-parser.add_argument('--tokenizer_path',default=None,type=str)
-parser.add_argument('--gpus', default="0", type=str)
-parser.add_argument('--load_in_8bit',action='store_true', help='Load the model in 8bit mode')
-parser.add_argument('--only_cpu',action='store_true',help='Only use CPU for inference')
-parser.add_argument('--alpha',type=str,default="1.0", help="The scaling factor of NTK method, can be a float or 'auto'. ")
+parser.add_argument("--base_model", default=None, type=str, required=True)
+parser.add_argument(
+ "--lora_model",
+ default=None,
+ type=str,
+ help="If None, perform inference on the base model",
+)
+parser.add_argument("--tokenizer_path", default=None, type=str)
+parser.add_argument("--gpus", default="0", type=str)
+parser.add_argument(
+ "--load_in_8bit", action="store_true", help="Load the model in 8bit mode"
+)
+parser.add_argument(
+ "--only_cpu", action="store_true", help="Only use CPU for inference"
+)
+parser.add_argument(
+ "--alpha",
+ type=str,
+ default="1.0",
+ help="The scaling factor of NTK method, can be a float or 'auto'. ",
+)
args = parser.parse_args()
load_in_8bit = args.load_in_8bit
if args.only_cpu is True:
@@ -19,13 +36,20 @@
import torch
import torch.nn.functional as F
-from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig
+from transformers import (
+ LlamaForCausalLM,
+ LlamaTokenizer,
+ GenerationConfig,
+ TextIteratorStreamer,
+)
from peft import PeftModel
import sys
+
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
from attn_and_long_ctx_patches import apply_attention_patch, apply_ntk_scaling_patch
+
apply_attention_patch(use_memory_efficient_attention=True)
apply_ntk_scaling_patch(args.alpha)
@@ -39,13 +63,15 @@
CompletionResponseChoice,
EmbeddingsRequest,
EmbeddingsResponse,
+ ChatCompletionResponseStreamChoice,
+ DeltaMessage,
)
load_type = torch.float16
if torch.cuda.is_available():
device = torch.device(0)
else:
- device = torch.device('cpu')
+ device = torch.device("cpu")
if args.tokenizer_path is None:
args.tokenizer_path = args.lora_model
if args.lora_model is None:
@@ -57,23 +83,28 @@
load_in_8bit=load_in_8bit,
torch_dtype=load_type,
low_cpu_mem_usage=True,
- device_map='auto' if not args.only_cpu else None,
- )
+ device_map="auto" if not args.only_cpu else None,
+)
model_vocab_size = base_model.get_input_embeddings().weight.size(0)
tokenzier_vocab_size = len(tokenizer)
print(f"Vocab of the base model: {model_vocab_size}")
print(f"Vocab of the tokenizer: {tokenzier_vocab_size}")
-if model_vocab_size!=tokenzier_vocab_size:
+if model_vocab_size != tokenzier_vocab_size:
print("Resize model embeddings to fit tokenizer")
base_model.resize_token_embeddings(tokenzier_vocab_size)
if args.lora_model is not None:
print("loading peft model")
- model = PeftModel.from_pretrained(base_model, args.lora_model,torch_dtype=load_type,device_map='auto',)
+ model = PeftModel.from_pretrained(
+ base_model,
+ args.lora_model,
+ torch_dtype=load_type,
+ device_map="auto",
+ )
else:
model = base_model
-if device==torch.device('cpu'):
+if device == torch.device("cpu"):
model.float()
model.eval()
@@ -81,25 +112,28 @@
DEFAULT_SYSTEM_PROMPT = """你是一个乐于助人的助手。"""
TEMPLATE_WITH_SYSTEM_PROMPT = (
- "[INST] <>\n"
- "{system_prompt}\n"
- "<>\n\n"
- "{instruction} [/INST]"
+ "[INST] <>\n" "{system_prompt}\n" "<>\n\n" "{instruction} [/INST]"
)
TEMPLATE_WITHOUT_SYSTEM_PROMPT = "[INST] {instruction} [/INST]"
-def generate_prompt(instruction, response="", with_system_prompt=True, system_prompt=None):
+
+def generate_prompt(
+ instruction, response="", with_system_prompt=True, system_prompt=None
+):
if with_system_prompt is True:
if system_prompt is None:
system_prompt = DEFAULT_SYSTEM_PROMPT
- prompt = TEMPLATE_WITH_SYSTEM_PROMPT.format_map({'instruction': instruction,'system_prompt': system_prompt})
+ prompt = TEMPLATE_WITH_SYSTEM_PROMPT.format_map(
+ {"instruction": instruction, "system_prompt": system_prompt}
+ )
else:
- prompt = TEMPLATE_WITHOUT_SYSTEM_PROMPT.format_map({'instruction': instruction})
- if len(response)>0:
+ prompt = TEMPLATE_WITHOUT_SYSTEM_PROMPT.format_map({"instruction": instruction})
+ if len(response) > 0:
prompt += " " + response
return prompt
+
def generate_completion_prompt(instruction: str):
"""Generate prompt for completion"""
return generate_prompt(instruction, response="", with_system_prompt=True)
@@ -110,23 +144,26 @@ def generate_chat_prompt(messages: list):
system_msg = None
for msg in messages:
- if msg.role == 'system':
+ if msg.role == "system":
system_msg = msg.content
prompt = ""
is_first_user_content = True
for msg in messages:
- if msg.role == 'system':
+ if msg.role == "system":
continue
- if msg.role == 'user':
+ if msg.role == "user":
if is_first_user_content is True:
- prompt += generate_prompt(msg.content, with_system_prompt=True, system_prompt=system_msg)
+ prompt += generate_prompt(
+ msg.content, with_system_prompt=True, system_prompt=system_msg
+ )
is_first_user_content = False
else:
- prompt += ''+generate_prompt(msg.content, with_system_prompt=False)
- if msg.role == 'assistant':
- prompt += f" {msg.content}"+""
+ prompt += "" + generate_prompt(msg.content, with_system_prompt=False)
+ if msg.role == "assistant":
+ prompt += f" {msg.content}" + ""
return prompt
+
def predict(
input,
max_new_tokens=128,
@@ -160,7 +197,7 @@ def predict(
generation_config.return_dict_in_generate = True
generation_config.output_scores = False
generation_config.max_new_tokens = max_new_tokens
- generation_config.repetition_penalty=float(repetition_penalty)
+ generation_config.repetition_penalty = float(repetition_penalty)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
@@ -171,17 +208,84 @@ def predict(
output = output.split("[/INST]")[-1].strip()
return output
+
+def stream_predict(
+ input,
+ max_new_tokens=128,
+ top_p=0.75,
+ temperature=0.1,
+ top_k=40,
+ num_beams=4,
+ repetition_penalty=1.0,
+ do_sample=True,
+ model_id="chinese-llama-alpaca-2",
+ **kwargs,
+):
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=0, delta=DeltaMessage(role="assistant"), finish_reason=None
+ )
+ chunk = ChatCompletionResponse(
+ model=model_id,
+ choices=[choice_data],
+ object="chat.completion.chunk",
+ )
+ yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
+
+ if isinstance(input, str):
+ prompt = generate_completion_prompt(input)
+ else:
+ prompt = generate_chat_prompt(input)
+ inputs = tokenizer(prompt, return_tensors="pt")
+ input_ids = inputs["input_ids"].to(device)
+ generation_config = GenerationConfig(
+ temperature=temperature,
+ top_p=top_p,
+ top_k=top_k,
+ num_beams=num_beams,
+ do_sample=do_sample,
+ **kwargs,
+ )
+
+ streamer = TextIteratorStreamer(tokenizer)
+ generation_kwargs = dict(
+ streamer=streamer,
+ input_ids=input_ids,
+ generation_config=generation_config,
+ return_dict_in_generate=True,
+ output_scores=False,
+ max_new_tokens=max_new_tokens,
+ repetition_penalty=float(repetition_penalty),
+ )
+ Thread(target=model.generate, kwargs=generation_kwargs).start()
+ for new_text in streamer:
+ if new_text.startswith("") or new_text.startswith("[/INST]"):
+ continue
+ if new_text.endswith(""):
+ new_text = new_text.split("")[0]
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=0, delta=DeltaMessage(content=new_text), finish_reason=None
+ )
+ chunk = ChatCompletionResponse(
+ model=model_id, choices=[choice_data], object="chat.completion.chunk"
+ )
+ yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=0, delta=DeltaMessage(), finish_reason="stop"
+ )
+ chunk = ChatCompletionResponse(
+ model=model_id, choices=[choice_data], object="chat.completion.chunk"
+ )
+ yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
+ yield "[DONE]"
+
+
def get_embedding(input):
"""Get embedding main function"""
with torch.no_grad():
- encoding = tokenizer(
- input, padding=True, return_tensors="pt"
- )
+ encoding = tokenizer(input, padding=True, return_tensors="pt")
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
- model_output = model(
- input_ids, attention_mask, output_hidden_states=True
- )
+ model_output = model(input_ids, attention_mask, output_hidden_states=True)
data = model_output.hidden_states[-1]
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
masked_embeddings = data * mask
@@ -192,16 +296,30 @@ def get_embedding(input):
ret = normalized_embeddings.squeeze(0).tolist()
return ret
+
app = FastAPI()
+
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
"""Creates a completion for the chat message"""
msgs = request.messages
if isinstance(msgs, str):
- msgs = [ChatMessage(role='user',content=msgs)]
+ msgs = [ChatMessage(role="user", content=msgs)]
else:
- msgs = [ChatMessage(role=x['role'],content=x['message']) for x in msgs]
+ msgs = [ChatMessage(role=x["role"], content=x["message"]) for x in msgs]
+ if request.stream:
+ generate = stream_predict(
+ input=msgs,
+ max_new_tokens=request.max_tokens,
+ top_p=request.top_p,
+ top_k=request.top_k,
+ temperature=request.temperature,
+ num_beams=request.num_beams,
+ repetition_penalty=request.repetition_penalty,
+ do_sample=request.do_sample,
+ )
+ return EventSourceResponse(generate, media_type="text/event-stream")
output = predict(
input=msgs,
max_new_tokens=request.max_tokens,
@@ -212,9 +330,16 @@ async def create_chat_completion(request: ChatCompletionRequest):
repetition_penalty=request.repetition_penalty,
do_sample=request.do_sample,
)
- choices = [ChatCompletionResponseChoice(index = i, message = msg) for i, msg in enumerate(msgs)]
- choices += [ChatCompletionResponseChoice(index = len(choices), message = ChatMessage(role='assistant',content=output))]
- return ChatCompletionResponse(choices = choices)
+ choices = [
+ ChatCompletionResponseChoice(index=i, message=msg) for i, msg in enumerate(msgs)
+ ]
+ choices += [
+ ChatCompletionResponseChoice(
+ index=len(choices), message=ChatMessage(role="assistant", content=output)
+ )
+ ]
+ return ChatCompletionResponse(choices=choices)
+
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest):
@@ -229,23 +354,24 @@ async def create_completion(request: CompletionRequest):
repetition_penalty=request.repetition_penalty,
do_sample=request.do_sample,
)
- choices = [CompletionResponseChoice(index = 0, text = output)]
- return CompletionResponse(choices = choices)
+ choices = [CompletionResponseChoice(index=0, text=output)]
+ return CompletionResponse(choices=choices)
+
@app.post("/v1/embeddings")
async def create_embeddings(request: EmbeddingsRequest):
"""Creates text embedding"""
embedding = get_embedding(request.input)
- data = [{
- "object": "embedding",
- "embedding": embedding,
- "index": 0
- }]
+ data = [{"object": "embedding", "embedding": embedding, "index": 0}]
return EmbeddingsResponse(data=data)
if __name__ == "__main__":
log_config = uvicorn.config.LOGGING_CONFIG
- log_config["formatters"]["access"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s"
- log_config["formatters"]["default"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s"
- uvicorn.run(app, host='0.0.0.0', port=19327, workers=1, log_config=log_config)
+ log_config["formatters"]["access"][
+ "fmt"
+ ] = "%(asctime)s - %(levelname)s - %(message)s"
+ log_config["formatters"]["default"][
+ "fmt"
+ ] = "%(asctime)s - %(levelname)s - %(message)s"
+ uvicorn.run(app, host="0.0.0.0", port=19327, workers=1, log_config=log_config)
From 9efdb97ea376ef30e927023c432b80c56851bb83 Mon Sep 17 00:00:00 2001
From: Ziqing Yang
Date: Thu, 3 Aug 2023 09:05:52 +0800
Subject: [PATCH 2/8] fix spelling error
---
scripts/openai_server_demo/openai_api_server.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/scripts/openai_server_demo/openai_api_server.py b/scripts/openai_server_demo/openai_api_server.py
index 1a1be9f..fb916bd 100644
--- a/scripts/openai_server_demo/openai_api_server.py
+++ b/scripts/openai_server_demo/openai_api_server.py
@@ -84,8 +84,8 @@
model_vocab_size = base_model.get_input_embeddings().weight.size(0)
tokenizer_vocab_size = len(tokenizer)
print(f"Vocab of the base model: {model_vocab_size}")
-print(f"Vocab of the tokenizer: {tokenzier_vocab_size}")
-if model_vocab_size != tokenzier_vocab_size:
+print(f"Vocab of the tokenizer: {tokenizer_vocab_size}")
+if model_vocab_size != tokenizer_vocab_size:
print("Resize model embeddings to fit tokenizer")
base_model.resize_token_embeddings(tokenizer_vocab_size)
if args.lora_model is not None:
From 720581544851ee3b0fe7e3268fbd34ed15b56663 Mon Sep 17 00:00:00 2001
From: "li.yunhao"
Date: Thu, 3 Aug 2023 14:29:31 +0800
Subject: [PATCH 3/8] Add openai stream api docs.
---
scripts/openai_server_demo/README.md | 10 ++++++----
scripts/openai_server_demo/openai_api_server.py | 10 +++++++++-
2 files changed, 15 insertions(+), 5 deletions(-)
diff --git a/scripts/openai_server_demo/README.md b/scripts/openai_server_demo/README.md
index 475f87c..732aa37 100644
--- a/scripts/openai_server_demo/README.md
+++ b/scripts/openai_server_demo/README.md
@@ -137,7 +137,7 @@ curl http://localhost:19327/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [
- {"role": "user","message": "给我讲一些有关杭州的故事吧"}
+ {"role": "user","content": "给我讲一些有关杭州的故事吧"}
],
"repetition_penalty": 1.0
}'
@@ -179,9 +179,9 @@ curl http://localhost:19327/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [
- {"role": "user","message": "给我讲一些有关杭州的故事吧"},
- {"role": "assistant","message": "好的,请问您对杭州有什么特别的偏好吗?"},
- {"role": "user","message": "我比较喜欢和西湖,可以给我讲一下西湖吗"}
+ {"role": "user","content": "给我讲一些有关杭州的故事吧"},
+ {"role": "assistant","content": "好的,请问您对杭州有什么特别的偏好吗?"},
+ {"role": "user","content": "我比较喜欢和西湖,可以给我讲一下西湖吗"}
],
"repetition_penalty": 1.0
}'
@@ -246,6 +246,8 @@ json返回体:
`do_sample`: 启用随机采样策略。默认为true。
+`stream`: OpenAI格式的流式返回。默认为false,设置为true时,会按照OpenAI的格式流式返回数据,可以作为任意基于ChatGPT的应用的后端。
+
### 文本嵌入向量(text embedding)
文本嵌入向量有很多作用,包括但不限于基于大型文档问答、总结一本书中的内容、为大语言模型找到与当前用户输入最相近的记忆等等。
diff --git a/scripts/openai_server_demo/openai_api_server.py b/scripts/openai_server_demo/openai_api_server.py
index fb916bd..fd182cc 100644
--- a/scripts/openai_server_demo/openai_api_server.py
+++ b/scripts/openai_server_demo/openai_api_server.py
@@ -1,6 +1,7 @@
import argparse
import os
from fastapi import FastAPI
+from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from threading import Thread
from sse_starlette.sse import EventSourceResponse
@@ -294,6 +295,13 @@ def get_embedding(input):
app = FastAPI()
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
@@ -302,7 +310,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
if isinstance(msgs, str):
msgs = [ChatMessage(role="user", content=msgs)]
else:
- msgs = [ChatMessage(role=x["role"], content=x["message"]) for x in msgs]
+ msgs = [ChatMessage(role=x["role"], content=x["content"]) for x in msgs]
if request.stream:
generate = stream_predict(
input=msgs,
From 30cb8a5f50559974f0b58104b78874a888e185bc Mon Sep 17 00:00:00 2001
From: "li.yunhao"
Date: Fri, 4 Aug 2023 00:58:34 +0800
Subject: [PATCH 4/8] fix stream api split
---
scripts/openai_server_demo/openai_api_server.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/scripts/openai_server_demo/openai_api_server.py b/scripts/openai_server_demo/openai_api_server.py
index fd182cc..ab36fed 100644
--- a/scripts/openai_server_demo/openai_api_server.py
+++ b/scripts/openai_server_demo/openai_api_server.py
@@ -74,6 +74,7 @@
args.base_model,
torch_dtype=load_type,
low_cpu_mem_usage=True,
+ offload_folder="offload", # TODO
device_map='auto' if not args.only_cpu else None,
quantization_config=BitsAndBytesConfig(
load_in_4bit=args.load_in_4bit,
@@ -254,8 +255,10 @@ def stream_predict(
)
Thread(target=model.generate, kwargs=generation_kwargs).start()
for new_text in streamer:
- if new_text.startswith("") or new_text.startswith("[/INST]"):
+ if new_text.startswith(""):
continue
+ if new_text.startswith("[/INST]"):
+ new_text = new_text.split("[/INST]")[-1]
if new_text.endswith(""):
new_text = new_text.split("")[0]
choice_data = ChatCompletionResponseStreamChoice(
From ef0e1c11758b0df9a822392627d509f546868c23 Mon Sep 17 00:00:00 2001
From: "li.yunhao"
Date: Fri, 4 Aug 2023 01:09:16 +0800
Subject: [PATCH 5/8] fix stream api bug
---
scripts/openai_server_demo/openai_api_server.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/scripts/openai_server_demo/openai_api_server.py b/scripts/openai_server_demo/openai_api_server.py
index ab36fed..73ae747 100644
--- a/scripts/openai_server_demo/openai_api_server.py
+++ b/scripts/openai_server_demo/openai_api_server.py
@@ -74,7 +74,6 @@
args.base_model,
torch_dtype=load_type,
low_cpu_mem_usage=True,
- offload_folder="offload", # TODO
device_map='auto' if not args.only_cpu else None,
quantization_config=BitsAndBytesConfig(
load_in_4bit=args.load_in_4bit,
From 0db43d87019ffc404765c15835e88a96615170a0 Mon Sep 17 00:00:00 2001
From: "li.yunhao"
Date: Fri, 4 Aug 2023 11:31:17 +0800
Subject: [PATCH 6/8] fix api extra token
---
scripts/openai_server_demo/openai_api_server.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/scripts/openai_server_demo/openai_api_server.py b/scripts/openai_server_demo/openai_api_server.py
index 73ae747..05e3673 100644
--- a/scripts/openai_server_demo/openai_api_server.py
+++ b/scripts/openai_server_demo/openai_api_server.py
@@ -257,7 +257,7 @@ def stream_predict(
if new_text.startswith(""):
continue
if new_text.startswith("[/INST]"):
- new_text = new_text.split("[/INST]")[-1]
+ new_text = new_text.split("[/INST]")[-1].strip()
if new_text.endswith(""):
new_text = new_text.split("")[0]
choice_data = ChatCompletionResponseStreamChoice(
From 01fc1dfdc1899d7f7e8e6897fd77504707d8fafd Mon Sep 17 00:00:00 2001
From: "li.yunhao"
Date: Fri, 4 Aug 2023 21:25:46 +0800
Subject: [PATCH 7/8] fix openai api repetition of the last word
---
scripts/openai_server_demo/openai_api_server.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/scripts/openai_server_demo/openai_api_server.py b/scripts/openai_server_demo/openai_api_server.py
index 05e3673..e353a06 100644
--- a/scripts/openai_server_demo/openai_api_server.py
+++ b/scripts/openai_server_demo/openai_api_server.py
@@ -254,12 +254,13 @@ def stream_predict(
)
Thread(target=model.generate, kwargs=generation_kwargs).start()
for new_text in streamer:
+ new_text = new_text.strip()
if new_text.startswith(""):
continue
if new_text.startswith("[/INST]"):
- new_text = new_text.split("[/INST]")[-1].strip()
+ new_text = new_text.split("[/INST]")[-1]
if new_text.endswith(""):
- new_text = new_text.split("")[0]
+ new_text = new_text.split("")[0][1:]
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(content=new_text), finish_reason=None
)
From ebbe80e6540305a4aa6c1d96a5f864d524f6af18 Mon Sep 17 00:00:00 2001
From: "li.yunhao"
Date: Mon, 7 Aug 2023 18:24:54 +0800
Subject: [PATCH 8/8] fix api first toekn missed
---
scripts/openai_server_demo/openai_api_server.py | 9 +--------
1 file changed, 1 insertion(+), 8 deletions(-)
diff --git a/scripts/openai_server_demo/openai_api_server.py b/scripts/openai_server_demo/openai_api_server.py
index e353a06..687ac4d 100644
--- a/scripts/openai_server_demo/openai_api_server.py
+++ b/scripts/openai_server_demo/openai_api_server.py
@@ -242,7 +242,7 @@ def stream_predict(
**kwargs,
)
- streamer = TextIteratorStreamer(tokenizer)
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
streamer=streamer,
input_ids=input_ids,
@@ -254,13 +254,6 @@ def stream_predict(
)
Thread(target=model.generate, kwargs=generation_kwargs).start()
for new_text in streamer:
- new_text = new_text.strip()
- if new_text.startswith(""):
- continue
- if new_text.startswith("[/INST]"):
- new_text = new_text.split("[/INST]")[-1]
- if new_text.endswith(""):
- new_text = new_text.split("")[0][1:]
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(content=new_text), finish_reason=None
)