From 6ad0864dc6b1610d416e2b514d4fd165ae544da3 Mon Sep 17 00:00:00 2001 From: WanMok <16273544+wanmok@users.noreply.github.com> Date: Tue, 8 Aug 2023 02:17:38 -0700 Subject: [PATCH 1/5] Supports prompt_token_ids in the OpenAI completion API. (#2) Supports prompt_token_ids in the OpenAI completion API. --- vllm/entrypoints/openai/api_server.py | 25 ++++++++++++++++++------- vllm/entrypoints/openai/protocol.py | 6 +++--- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 8acea787c1b7..11772f5786f6 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -376,15 +376,24 @@ async def create_completion(raw_request: Request): model_name = request.model request_id = f"cmpl-{random_uuid()}" + + use_token_ids = False if isinstance(request.prompt, list): if len(request.prompt) == 0: return create_error_response(HTTPStatus.BAD_REQUEST, "please provide at least one prompt") - if len(request.prompt) > 1: - return create_error_response( - HTTPStatus.BAD_REQUEST, - "multiple prompts in a batch is not currently supported") - prompt = request.prompt[0] + first_element = request.prompt[0] + if isinstance(first_element, int): + use_token_ids = True + prompt = request.prompt + elif isinstance(first_element, str) or isinstance(first_element, list): + # TODO: handles multiple prompt case in list[list[int]] + if len(request.prompt) > 1: + return create_error_response( + HTTPStatus.BAD_REQUEST, + "multiple prompts in a batch is not currently supported") + use_token_ids = not isinstance(first_element, str) + prompt = request.prompt[0] else: prompt = request.prompt @@ -411,8 +420,10 @@ async def create_completion(raw_request: Request): except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) - result_generator = engine.generate(prompt, sampling_params, request_id, - token_ids) + if use_token_ids: + result_generator = engine.generate(None, sampling_params, request_id, prompt_token_ids=prompt) + else: + result_generator = engine.generate(prompt, sampling_params, request_id) # Similar to the OpenAI API, when n != best_of, we do not stream the # results. In addition, we do not stream the results when use beam search. diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index c63e7a2964fc..32da081e9f64 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -74,7 +74,8 @@ class ChatCompletionRequest(BaseModel): class CompletionRequest(BaseModel): model: str - prompt: Union[str, List[str]] + # a string, array of strings, array of tokens, or array of token arrays + prompt: Union[List[int], List[List[int]], str, List[str]] suffix: Optional[str] = None max_tokens: Optional[int] = 16 temperature: Optional[float] = 1.0 @@ -99,8 +100,7 @@ class LogProbs(BaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list) - top_logprobs: List[Optional[Dict[str, - float]]] = Field(default_factory=list) + top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) class CompletionResponseChoice(BaseModel): From 81bd18a24383aaf1e1010549c23b735c28ad901e Mon Sep 17 00:00:00 2001 From: wanmok Date: Wed, 9 Aug 2023 01:39:06 -0700 Subject: [PATCH 2/5] Rebased on the main to adopt `check_length` --- vllm/entrypoints/openai/api_server.py | 30 ++++++++++++++++++++------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 11772f5786f6..9ead3baeaf43 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -3,18 +3,18 @@ import argparse import asyncio -from http import HTTPStatus import json import time -from typing import AsyncGenerator, Dict, List, Optional -from packaging import version +from http import HTTPStatus +from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union import fastapi +import uvicorn from fastapi import BackgroundTasks, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse -import uvicorn +from packaging import version from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -115,8 +115,19 @@ async def get_gen_prompt(request) -> str: return prompt -async def check_length(request, prompt): - input_ids = tokenizer(prompt).input_ids +async def check_length( + request: Union[ChatCompletionRequest, CompletionRequest], + prompt: Optional[str] = None, + prompt_ids: Optional[List[int]] = None +) -> Tuple[List[int], Optional[JSONResponse]]: + assert ( + not (prompt is None and prompt_ids is None) + and not (prompt is not None and prompt_ids is not None) + ), "Either prompt or prompt_ids should be provided." + if prompt_ids is not None: + input_ids = prompt_ids + else: + input_ids = tokenizer(prompt).input_ids token_num = len(input_ids) if token_num + request.max_tokens > max_model_len: @@ -191,7 +202,7 @@ async def create_chat_completion(raw_request: Request): "logit_bias is not currently supported") prompt = await get_gen_prompt(request) - token_ids, error_check_ret = await check_length(request, prompt) + token_ids, error_check_ret = await check_length(request, prompt=prompt) if error_check_ret is not None: return error_check_ret @@ -397,7 +408,10 @@ async def create_completion(raw_request: Request): else: prompt = request.prompt - token_ids, error_check_ret = await check_length(request, prompt) + if use_token_ids: + token_ids, error_check_ret = await check_length(request, prompt_ids=prompt) + else: + token_ids, error_check_ret = await check_length(request, prompt=prompt) if error_check_ret is not None: return error_check_ret From cd49cad574cd1709df9baf64061148ac631f8f78 Mon Sep 17 00:00:00 2001 From: wanmok Date: Wed, 9 Aug 2023 13:57:52 -0700 Subject: [PATCH 3/5] Reformatted to pass pylint. --- vllm/entrypoints/openai/api_server.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 9ead3baeaf43..89dfec55d844 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -397,7 +397,7 @@ async def create_completion(raw_request: Request): if isinstance(first_element, int): use_token_ids = True prompt = request.prompt - elif isinstance(first_element, str) or isinstance(first_element, list): + elif isinstance(first_element, (str, list)): # TODO: handles multiple prompt case in list[list[int]] if len(request.prompt) > 1: return create_error_response( @@ -409,9 +409,13 @@ async def create_completion(raw_request: Request): prompt = request.prompt if use_token_ids: - token_ids, error_check_ret = await check_length(request, prompt_ids=prompt) + token_ids, error_check_ret = await check_length( + request, prompt_ids=prompt + ) else: - token_ids, error_check_ret = await check_length(request, prompt=prompt) + token_ids, error_check_ret = await check_length( + request, prompt=prompt + ) if error_check_ret is not None: return error_check_ret @@ -435,7 +439,9 @@ async def create_completion(raw_request: Request): return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) if use_token_ids: - result_generator = engine.generate(None, sampling_params, request_id, prompt_token_ids=prompt) + result_generator = engine.generate( + None, sampling_params, request_id, prompt_token_ids=prompt + ) else: result_generator = engine.generate(prompt, sampling_params, request_id) From 5ff4191e1d43edcaa66eee3a480e4cd17fa24e58 Mon Sep 17 00:00:00 2001 From: wanmok Date: Fri, 11 Aug 2023 11:48:03 -0700 Subject: [PATCH 4/5] Reformatted to pass pylint. --- vllm/entrypoints/openai/api_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 89dfec55d844..44a9bed4ec87 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -409,7 +409,7 @@ async def create_completion(raw_request: Request): prompt = request.prompt if use_token_ids: - token_ids, error_check_ret = await check_length( + _, error_check_ret = await check_length( request, prompt_ids=prompt ) else: From 55d39e302fea917f8099f9437674b20538ea281e Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 11 Aug 2023 19:13:20 +0000 Subject: [PATCH 5/5] fix deleted token_ids and fix format --- vllm/entrypoints/openai/api_server.py | 31 ++++++++++++--------------- vllm/entrypoints/openai/protocol.py | 3 ++- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 44a9bed4ec87..97d097e60f31 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -116,14 +116,13 @@ async def get_gen_prompt(request) -> str: async def check_length( - request: Union[ChatCompletionRequest, CompletionRequest], - prompt: Optional[str] = None, - prompt_ids: Optional[List[int]] = None + request: Union[ChatCompletionRequest, CompletionRequest], + prompt: Optional[str] = None, + prompt_ids: Optional[List[int]] = None ) -> Tuple[List[int], Optional[JSONResponse]]: - assert ( - not (prompt is None and prompt_ids is None) - and not (prompt is not None and prompt_ids is not None) - ), "Either prompt or prompt_ids should be provided." + assert (not (prompt is None and prompt_ids is None) + and not (prompt is not None and prompt_ids is not None) + ), "Either prompt or prompt_ids should be provided." if prompt_ids is not None: input_ids = prompt_ids else: @@ -409,13 +408,9 @@ async def create_completion(raw_request: Request): prompt = request.prompt if use_token_ids: - _, error_check_ret = await check_length( - request, prompt_ids=prompt - ) + _, error_check_ret = await check_length(request, prompt_ids=prompt) else: - token_ids, error_check_ret = await check_length( - request, prompt=prompt - ) + token_ids, error_check_ret = await check_length(request, prompt=prompt) if error_check_ret is not None: return error_check_ret @@ -439,11 +434,13 @@ async def create_completion(raw_request: Request): return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) if use_token_ids: - result_generator = engine.generate( - None, sampling_params, request_id, prompt_token_ids=prompt - ) + result_generator = engine.generate(None, + sampling_params, + request_id, + prompt_token_ids=prompt) else: - result_generator = engine.generate(prompt, sampling_params, request_id) + result_generator = engine.generate(prompt, sampling_params, request_id, + token_ids) # Similar to the OpenAI API, when n != best_of, we do not stream the # results. In addition, we do not stream the results when use beam search. diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 32da081e9f64..701f704234ad 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -100,7 +100,8 @@ class LogProbs(BaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list) - top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) + top_logprobs: List[Optional[Dict[str, + float]]] = Field(default_factory=list) class CompletionResponseChoice(BaseModel):