Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@

import argparse
import asyncio
from http import HTTPStatus
import json
import time
from http import HTTPStatus
from typing import AsyncGenerator, Dict, List, Optional
from packaging import version

import fastapi
import uvicorn
from fastapi import BackgroundTasks, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
import uvicorn
from packaging import version

from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
Expand Down Expand Up @@ -375,17 +375,27 @@ async def create_completion(raw_request: Request):

model_name = request.model
request_id = f"cmpl-{random_uuid()}"

use_token_ids = False
if isinstance(request.prompt, list):
if len(request.prompt) == 0:
return create_error_response(HTTPStatus.BAD_REQUEST,
"please provide at least one prompt")
if len(request.prompt) > 1:
return create_error_response(
HTTPStatus.BAD_REQUEST,
"multiple prompts in a batch is not currently supported")
prompt = request.prompt[0]
first_element = request.prompt[0]
if isinstance(first_element, int):
use_token_ids = True
prompt = request.prompt
elif isinstance(first_element, str) or isinstance(first_element, list):
# TODO(@wanmok): handles multiple prompt case in list[list[int]]
if len(request.prompt) > 1:
return create_error_response(
HTTPStatus.BAD_REQUEST,
"multiple prompts in a batch is not currently supported")
use_token_ids = not isinstance(first_element, str)
prompt = request.prompt[0]
else:
prompt = request.prompt

created_time = int(time.time())
try:
sampling_params = SamplingParams(
Expand All @@ -405,7 +415,10 @@ async def create_completion(raw_request: Request):
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))

result_generator = engine.generate(prompt, sampling_params, request_id)
if use_token_ids:
result_generator = engine.generate(None, sampling_params, request_id, prompt_token_ids=prompt)
else:
result_generator = engine.generate(prompt, sampling_params, request_id)

# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
Expand Down
6 changes: 3 additions & 3 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ class ChatCompletionRequest(BaseModel):

class CompletionRequest(BaseModel):
model: str
prompt: Union[str, List[str]]
# a string, array of strings, array of tokens, or array of token arrays
prompt: Union[List[int], List[List[int]], str, List[str]]
suffix: Optional[str] = None
max_tokens: Optional[int] = 16
temperature: Optional[float] = 1.0
Expand All @@ -99,8 +100,7 @@ class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str,
float]]] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)


class CompletionResponseChoice(BaseModel):
Expand Down