From d515f5885135c04133c4fc4e505e7b63ca2f7fd7 Mon Sep 17 00:00:00 2001 From: Yunmo Chen Date: Tue, 8 Aug 2023 00:46:20 -0700 Subject: [PATCH 1/7] Added support for token ids. --- vllm/entrypoints/openai/api_server.py | 11 +++++++---- vllm/entrypoints/openai/protocol.py | 15 +++++++++++++-- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 81004d3c6783..5e349aa90674 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -3,18 +3,18 @@ import argparse import asyncio -from http import HTTPStatus import json import time +from http import HTTPStatus from typing import AsyncGenerator, Dict, List, Optional -from packaging import version import fastapi +import uvicorn from fastapi import BackgroundTasks, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse -import uvicorn +from packaging import version from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -405,7 +405,10 @@ async def create_completion(raw_request: Request): except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) - result_generator = engine.generate(prompt, sampling_params, request_id) + if request.use_token_ids: + result_generator = engine.generate(None, sampling_params, request_id, prompt_token_ids=prompt) + else: + result_generator = engine.generate(prompt, sampling_params, request_id) # Similar to the OpenAI API, when n != best_of, we do not stream the # results. In addition, we do not stream the results when use beam search. diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index c63e7a2964fc..183472f597af 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1,6 +1,7 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import time +from functools import cached_property from typing import Dict, List, Literal, Optional, Union from pydantic import BaseModel, Field @@ -74,7 +75,8 @@ class ChatCompletionRequest(BaseModel): class CompletionRequest(BaseModel): model: str - prompt: Union[str, List[str]] + # a string, array of strings, array of tokens, or array of token arrays + prompt: Union[str, List[str], List[int], List[List[int]]] suffix: Optional[str] = None max_tokens: Optional[int] = 16 temperature: Optional[float] = 1.0 @@ -94,13 +96,22 @@ class CompletionRequest(BaseModel): ignore_eos: Optional[bool] = False use_beam_search: Optional[bool] = False + @cached_property + def use_token_ids(self) -> bool: + if isinstance(self.prompt, list): + if len(self.prompt) > 0: + if isinstance(self.prompt[0], list) or isinstance(self.prompt[0], int): + return True + return False + return False + class LogProbs(BaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list) top_logprobs: List[Optional[Dict[str, - float]]] = Field(default_factory=list) + float]]] = Field(default_factory=list) class CompletionResponseChoice(BaseModel): From cd041e18407718a078b9e8facee9d85aed74b2d9 Mon Sep 17 00:00:00 2001 From: wanmok Date: Tue, 8 Aug 2023 00:46:20 -0700 Subject: [PATCH 2/7] Added support for token ids. --- vllm/entrypoints/openai/api_server.py | 11 +++++++---- vllm/entrypoints/openai/protocol.py | 15 ++++++++++++--- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 81004d3c6783..5e349aa90674 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -3,18 +3,18 @@ import argparse import asyncio -from http import HTTPStatus import json import time +from http import HTTPStatus from typing import AsyncGenerator, Dict, List, Optional -from packaging import version import fastapi +import uvicorn from fastapi import BackgroundTasks, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse -import uvicorn +from packaging import version from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -405,7 +405,10 @@ async def create_completion(raw_request: Request): except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) - result_generator = engine.generate(prompt, sampling_params, request_id) + if request.use_token_ids: + result_generator = engine.generate(None, sampling_params, request_id, prompt_token_ids=prompt) + else: + result_generator = engine.generate(prompt, sampling_params, request_id) # Similar to the OpenAI API, when n != best_of, we do not stream the # results. In addition, we do not stream the results when use beam search. diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index c63e7a2964fc..78196195e1fb 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1,6 +1,7 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import time +from functools import cached_property from typing import Dict, List, Literal, Optional, Union from pydantic import BaseModel, Field @@ -74,7 +75,8 @@ class ChatCompletionRequest(BaseModel): class CompletionRequest(BaseModel): model: str - prompt: Union[str, List[str]] + # a string, array of strings, array of tokens, or array of token arrays + prompt: Union[str, List[str], List[int], List[List[int]]] suffix: Optional[str] = None max_tokens: Optional[int] = 16 temperature: Optional[float] = 1.0 @@ -94,13 +96,20 @@ class CompletionRequest(BaseModel): ignore_eos: Optional[bool] = False use_beam_search: Optional[bool] = False + def use_token_ids(self) -> bool: + if isinstance(self.prompt, list): + if len(self.prompt) > 0: + if isinstance(self.prompt[0], list) or isinstance(self.prompt[0], int): + return True + return False + return False + class LogProbs(BaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list) - top_logprobs: List[Optional[Dict[str, - float]]] = Field(default_factory=list) + top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) class CompletionResponseChoice(BaseModel): From ee1b3a19b5d88ffdcd83ca56d38a3d27bc13a655 Mon Sep 17 00:00:00 2001 From: wanmok Date: Tue, 8 Aug 2023 02:02:50 -0700 Subject: [PATCH 3/7] Merge remote-tracking branch 'origin/completion-support-ids' into completion-support-ids --- vllm/entrypoints/openai/api_server.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d8b1585d8faf..692bb9ca62a7 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -382,14 +382,13 @@ async def create_completion(raw_request: Request): return create_error_response(HTTPStatus.BAD_REQUEST, "please provide at least one prompt") first_element = request.prompt[0] - if not isinstance(first_element, int) and len(request.prompt) > 1: - return create_error_response( - HTTPStatus.BAD_REQUEST, - "multiple prompts in a batch is not currently supported") - if isinstance(first_element, int): prompt = request.prompt - else: + elif isinstance(first_element, str) or isinstance(first_element, list): + if len(request.prompt) > 1: + return create_error_response( + HTTPStatus.BAD_REQUEST, + "multiple prompts in a batch is not currently supported") prompt = request.prompt[0] else: prompt = request.prompt From 26695cf876d93525f85f16ef1f464d8b5396bf05 Mon Sep 17 00:00:00 2001 From: wanmok Date: Tue, 8 Aug 2023 02:06:11 -0700 Subject: [PATCH 4/7] Merge remote-tracking branch 'origin/completion-support-ids' into completion-support-ids --- vllm/entrypoints/openai/api_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 692bb9ca62a7..f7cbfab89379 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -378,6 +378,7 @@ async def create_completion(raw_request: Request): use_token_ids = request.use_token_ids() if isinstance(request.prompt, list): + print(f'prompt is a list: {request.prompt}') if len(request.prompt) == 0: return create_error_response(HTTPStatus.BAD_REQUEST, "please provide at least one prompt") From b403d550a7c9d1ece339fac99ffb2577dd74f44f Mon Sep 17 00:00:00 2001 From: wanmok Date: Tue, 8 Aug 2023 02:07:21 -0700 Subject: [PATCH 5/7] Merge remote-tracking branch 'origin/completion-support-ids' into completion-support-ids --- vllm/entrypoints/openai/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 78196195e1fb..04e2bd4faa74 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -76,7 +76,7 @@ class ChatCompletionRequest(BaseModel): class CompletionRequest(BaseModel): model: str # a string, array of strings, array of tokens, or array of token arrays - prompt: Union[str, List[str], List[int], List[List[int]]] + prompt: Union[List[int], List[List[int]], str, List[str]] suffix: Optional[str] = None max_tokens: Optional[int] = 16 temperature: Optional[float] = 1.0 From 38d2754cba0b9186d352ab3e8f149865eb0247fa Mon Sep 17 00:00:00 2001 From: wanmok Date: Tue, 8 Aug 2023 02:11:30 -0700 Subject: [PATCH 6/7] Merge remote-tracking branch 'origin/completion-support-ids' into completion-support-ids --- vllm/entrypoints/openai/api_server.py | 6 ++++-- vllm/entrypoints/openai/protocol.py | 9 --------- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index f7cbfab89379..6b6d78cfa7e5 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -375,17 +375,18 @@ async def create_completion(raw_request: Request): model_name = request.model request_id = f"cmpl-{random_uuid()}" - use_token_ids = request.use_token_ids() + use_token_ids = False if isinstance(request.prompt, list): - print(f'prompt is a list: {request.prompt}') if len(request.prompt) == 0: return create_error_response(HTTPStatus.BAD_REQUEST, "please provide at least one prompt") first_element = request.prompt[0] if isinstance(first_element, int): + use_token_ids = True prompt = request.prompt elif isinstance(first_element, str) or isinstance(first_element, list): + # TODO(@wanmok): handles multiple prompt case in list[list[int]] if len(request.prompt) > 1: return create_error_response( HTTPStatus.BAD_REQUEST, @@ -393,6 +394,7 @@ async def create_completion(raw_request: Request): prompt = request.prompt[0] else: prompt = request.prompt + created_time = int(time.time()) try: sampling_params = SamplingParams( diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 04e2bd4faa74..32da081e9f64 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1,7 +1,6 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import time -from functools import cached_property from typing import Dict, List, Literal, Optional, Union from pydantic import BaseModel, Field @@ -96,14 +95,6 @@ class CompletionRequest(BaseModel): ignore_eos: Optional[bool] = False use_beam_search: Optional[bool] = False - def use_token_ids(self) -> bool: - if isinstance(self.prompt, list): - if len(self.prompt) > 0: - if isinstance(self.prompt[0], list) or isinstance(self.prompt[0], int): - return True - return False - return False - class LogProbs(BaseModel): text_offset: List[int] = Field(default_factory=list) From f1cce3cce06534506de4fcf26399e240221faabf Mon Sep 17 00:00:00 2001 From: wanmok Date: Tue, 8 Aug 2023 02:13:00 -0700 Subject: [PATCH 7/7] Merge remote-tracking branch 'origin/completion-support-ids' into completion-support-ids --- vllm/entrypoints/openai/api_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 6b6d78cfa7e5..3206ac757c93 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -391,6 +391,7 @@ async def create_completion(raw_request: Request): return create_error_response( HTTPStatus.BAD_REQUEST, "multiple prompts in a batch is not currently supported") + use_token_ids = not isinstance(first_element, str) prompt = request.prompt[0] else: prompt = request.prompt