From 0de765e2de4808e019fc3010f141643b216fb199 Mon Sep 17 00:00:00 2001 From: WanMok <16273544+wanmok@users.noreply.github.com> Date: Tue, 8 Aug 2023 02:15:59 -0700 Subject: [PATCH] Supports prompt_token_ids in the OpenAI completion API. (#1) Supports prompt_token_ids in the OpenAI completion API. --- vllm/entrypoints/openai/api_server.py | 31 +++++++++++++++++++-------- vllm/entrypoints/openai/protocol.py | 6 +++--- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 81004d3c6783..3206ac757c93 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -3,18 +3,18 @@ import argparse import asyncio -from http import HTTPStatus import json import time +from http import HTTPStatus from typing import AsyncGenerator, Dict, List, Optional -from packaging import version import fastapi +import uvicorn from fastapi import BackgroundTasks, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse -import uvicorn +from packaging import version from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -375,17 +375,27 @@ async def create_completion(raw_request: Request): model_name = request.model request_id = f"cmpl-{random_uuid()}" + + use_token_ids = False if isinstance(request.prompt, list): if len(request.prompt) == 0: return create_error_response(HTTPStatus.BAD_REQUEST, "please provide at least one prompt") - if len(request.prompt) > 1: - return create_error_response( - HTTPStatus.BAD_REQUEST, - "multiple prompts in a batch is not currently supported") - prompt = request.prompt[0] + first_element = request.prompt[0] + if isinstance(first_element, int): + use_token_ids = True + prompt = request.prompt + elif isinstance(first_element, str) or isinstance(first_element, list): + # TODO(@wanmok): handles multiple prompt case in list[list[int]] + if len(request.prompt) > 1: + return create_error_response( + HTTPStatus.BAD_REQUEST, + "multiple prompts in a batch is not currently supported") + use_token_ids = not isinstance(first_element, str) + prompt = request.prompt[0] else: prompt = request.prompt + created_time = int(time.time()) try: sampling_params = SamplingParams( @@ -405,7 +415,10 @@ async def create_completion(raw_request: Request): except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) - result_generator = engine.generate(prompt, sampling_params, request_id) + if use_token_ids: + result_generator = engine.generate(None, sampling_params, request_id, prompt_token_ids=prompt) + else: + result_generator = engine.generate(prompt, sampling_params, request_id) # Similar to the OpenAI API, when n != best_of, we do not stream the # results. In addition, we do not stream the results when use beam search. diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index c63e7a2964fc..32da081e9f64 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -74,7 +74,8 @@ class ChatCompletionRequest(BaseModel): class CompletionRequest(BaseModel): model: str - prompt: Union[str, List[str]] + # a string, array of strings, array of tokens, or array of token arrays + prompt: Union[List[int], List[List[int]], str, List[str]] suffix: Optional[str] = None max_tokens: Optional[int] = 16 temperature: Optional[float] = 1.0 @@ -99,8 +100,7 @@ class LogProbs(BaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list) - top_logprobs: List[Optional[Dict[str, - float]]] = Field(default_factory=list) + top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) class CompletionResponseChoice(BaseModel):