From 263c2863439c159b636992d18cdb72e3da889c78 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 20 May 2023 13:35:00 -0700 Subject: [PATCH 01/23] Separate AsyncLLMServer --- cacheflow/entrypoints/fastapi_server.py | 97 +----------------------- cacheflow/server/async_llm_server.py | 99 +++++++++++++++++++++++++ cacheflow/utils.py | 7 +- examples/simple_server.py | 4 +- playground/http_client.py | 20 ----- playground/streaming_fastapi_worker.py | 40 ---------- 6 files changed, 109 insertions(+), 158 deletions(-) create mode 100644 cacheflow/server/async_llm_server.py delete mode 100644 playground/http_client.py delete mode 100644 playground/streaming_fastapi_worker.py diff --git a/cacheflow/entrypoints/fastapi_server.py b/cacheflow/entrypoints/fastapi_server.py index ed8afda1b79d..64ac196646ff 100644 --- a/cacheflow/entrypoints/fastapi_server.py +++ b/cacheflow/entrypoints/fastapi_server.py @@ -1,111 +1,18 @@ import argparse import asyncio -import json -import time -from typing import Any, Dict -import uuid - from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse -import ray import uvicorn -from cacheflow.outputs import RequestOutput -from cacheflow.sampling_params import SamplingParams from cacheflow.server.arg_utils import ( add_server_arguments, create_server_configs_from_args) -from cacheflow.server.llm_server import LLMServer +from cacheflow.server.async_llm_server import AsyncLLMServer from cacheflow.server.ray_utils import initialize_cluster TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds app = FastAPI() -class FastAPIServer: - - def __init__(self, server_use_ray: bool, *args, **kwargs) -> None: - if server_use_ray: - remote_server_class = ray.remote(num_cpus=0)(LLMServer) - else: - remote_server_class = ray.remote(num_gpus=1)(LLMServer) - self.server = remote_server_class.remote(*args, **kwargs) - - # Request id -> request output. - self.request_outputs: Dict[str, RequestOutput] = {} - # Request id -> event to notify that there is new output. - self.request_events: Dict[str, asyncio.Event] = {} - self.is_server_running = False - - async def server_step(self): - self.is_server_running = True - request_outputs = await self.server.step.remote() - self.is_server_running = False - # Notify the waiting coroutines that there are new outputs ready. - for request_output in request_outputs: - request_id = request_output.request_id - self.request_outputs[request_id] = request_output - self.request_events[request_id].set() - - async def generate(self, request_dict: Dict[str, Any]): - # Preprocess the request. - arrival_time = time.time() - prompt = request_dict.pop("prompt") - sampling_params = SamplingParams(**request_dict) - - # Create an event to notify us that there is new output from the - # cacheflow server. - request_id = str(uuid.uuid4().hex[:8]) - request_event = asyncio.Event() - self.request_events[request_id] = request_event - - # Add the request into the cacheflow server's waiting queue. - await self.server.add_request.remote( - request_id, prompt, sampling_params, arrival_time=arrival_time) - - # The cacheflow server does not have a background loop that keeps - # processing incoming requests. Therefore, we need to keep kicking - # the server to process the requests. - while True: - # Kick the server if the server is not running. - if not self.is_server_running: - await self.server_step() - - # Wait for new output. The group_event will be set in server_step - # when there is new output available for the sequence group. - # Added a timeout to prevent deadlock. - try: - await asyncio.wait_for(request_event.wait(), - timeout=TIMEOUT_TO_PREVENT_DEADLOCK) - except asyncio.TimeoutError: - continue - # Reset the event to wait for the next output. - request_event.clear() - - # Decode and return new outputs. - request_output = self.request_outputs[request_id] - prompt = request_output.prompt - text_outputs = [ - prompt + output.text - for output in request_output.outputs - ] - ret = { - "text": text_outputs, - "error": 0, - } - yield (json.dumps(ret) + "\0").encode("utf-8") - - # Once finished, release the resources of the sequence group. - if request_output.done: - del self.request_outputs[request_id] - del self.request_events[request_id] - # Kick the server if the server is not running. This is to - # prevent that there are still requests in server's waiting - # queue to be executed. - if not self.is_server_running: - await self.server_step() - break - - @app.post("/generate") async def generate_stream(request: Request): request_dict = await request.json() @@ -123,6 +30,6 @@ async def generate_stream(request: Request): parallel_config = server_configs[2] distributed_init_method, stage_devices = initialize_cluster(parallel_config) - server = FastAPIServer( + server = AsyncLLMServer( args.use_ray, *server_configs, distributed_init_method, stage_devices) uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/cacheflow/server/async_llm_server.py b/cacheflow/server/async_llm_server.py new file mode 100644 index 000000000000..655493568481 --- /dev/null +++ b/cacheflow/server/async_llm_server.py @@ -0,0 +1,99 @@ +import asyncio +import json +import time +from typing import Any, Dict +import uuid + +import ray + +from cacheflow.outputs import RequestOutput +from cacheflow.sampling_params import SamplingParams +from cacheflow.server.llm_server import LLMServer +from cacheflow.utils import random_uuid + +TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds + + +class AsyncLLMServer: + + def __init__(self, server_use_ray: bool, *args, **kwargs) -> None: + if server_use_ray: + remote_server_class = ray.remote(num_cpus=0)(LLMServer) + else: + remote_server_class = ray.remote(num_gpus=1)(LLMServer) + self.server = remote_server_class.remote(*args, **kwargs) + + # Request id -> request output. + self.request_outputs: Dict[str, RequestOutput] = {} + # Request id -> event to notify that there is new output. + self.request_events: Dict[str, asyncio.Event] = {} + self.is_server_running = False + + async def server_step(self): + self.is_server_running = True + request_outputs = await self.server.step.remote() + self.is_server_running = False + # Notify the waiting coroutines that there are new outputs ready. + for request_output in request_outputs: + request_id = request_output.request_id + self.request_outputs[request_id] = request_output + self.request_events[request_id].set() + + async def generate(self, request_dict: Dict[str, Any]): + # Preprocess the request. + arrival_time = time.time() + prompt = request_dict.pop("prompt") + sampling_params = SamplingParams(**request_dict) + + # Create an event to notify us that there is new output from the + # cacheflow server. + request_id = random_uuid() + request_event = asyncio.Event() + self.request_events[request_id] = request_event + + # Add the request into the cacheflow server's waiting queue. + await self.server.add_request.remote( + request_id, prompt, sampling_params, arrival_time=arrival_time) + + # The cacheflow server does not have a background loop that keeps + # processing incoming requests. Therefore, we need to keep kicking + # the server to process the requests. + while True: + # Kick the server if the server is not running. + if not self.is_server_running: + await self.server_step() + + # Wait for new output. The group_event will be set in server_step + # when there is new output available for the sequence group. + # Added a timeout to prevent deadlock. + try: + await asyncio.wait_for(request_event.wait(), + timeout=TIMEOUT_TO_PREVENT_DEADLOCK) + except asyncio.TimeoutError: + continue + # Reset the event to wait for the next output. + request_event.clear() + + # Decode and return new outputs. + request_output = self.request_outputs[request_id] + prompt = request_output.prompt + text_outputs = [ + prompt + output.text + for output in request_output.outputs + ] + ret = { + "text": text_outputs, + "error": 0, + } + yield (json.dumps(ret) + "\0").encode("utf-8") + + # Once finished, release the resources of the sequence group. + if request_output.done: + del self.request_outputs[request_id] + del self.request_events[request_id] + # Kick the server if the server is not running. This is to + # prevent that there are still requests in server's waiting + # queue to be executed. + if not self.is_server_running: + await self.server_step() + break diff --git a/cacheflow/utils.py b/cacheflow/utils.py index e586707cd876..a28bacfe4cf2 100644 --- a/cacheflow/utils.py +++ b/cacheflow/utils.py @@ -1,6 +1,7 @@ import enum - import psutil +import uuid + import torch @@ -29,3 +30,7 @@ def get_gpu_memory(gpu: int = 0) -> int: def get_cpu_memory() -> int: return psutil.virtual_memory().total + + +def random_uuid() -> str: + return str(uuid.uuid4().hex) diff --git a/examples/simple_server.py b/examples/simple_server.py index 627148557fc8..2ab04a837efb 100644 --- a/examples/simple_server.py +++ b/examples/simple_server.py @@ -3,7 +3,7 @@ from cacheflow import (add_server_arguments, initialize_server_from_args, SamplingParams) - +from cacheflow.utils import random_uuid def main(args: argparse.Namespace): # Initialize the server. @@ -25,7 +25,7 @@ def main(args: argparse.Namespace): # To test iteration-level scheduling, we add one request at each step. if test_prompts: prompt, sampling_params = test_prompts.pop(0) - request_id = str(uuid.uuid4().hex[:8]) + request_id = random_uuid() server.add_request(request_id, prompt, sampling_params) request_outputs = server.step() diff --git a/playground/http_client.py b/playground/http_client.py deleted file mode 100644 index ac13ac62c4b6..000000000000 --- a/playground/http_client.py +++ /dev/null @@ -1,20 +0,0 @@ -import requests -import json - -def http_bot(): - prompt = "How are you? I'm fine." - - headers = {"User-Agent": "Test Client"} - pload = { - "prompt": prompt, - } - response = requests.post("http://localhost:10002", headers=headers, json=pload, stream=True) - - for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): - if chunk: - data = json.loads(chunk.decode("utf-8")) - output = data["text"] - yield output - -for h in http_bot(): - print(h, end="", flush=True) \ No newline at end of file diff --git a/playground/streaming_fastapi_worker.py b/playground/streaming_fastapi_worker.py deleted file mode 100644 index 8ab087d109e6..000000000000 --- a/playground/streaming_fastapi_worker.py +++ /dev/null @@ -1,40 +0,0 @@ -import argparse -import asyncio -import time -from typing import Union -import json - -from fastapi import FastAPI, Request -from fastapi.responses import StreamingResponse -import uvicorn - - -app = FastAPI() - - -async def text_streamer(args): - context = args["prompt"] - words = context.split(" ") - for word in words: - await asyncio.sleep(1) - print("word:", word) - ret = { - "text": word + " ", - "error": 0, - } - yield (json.dumps(ret) + "\0").encode("utf-8") - - -@app.post("/") -async def read_root(request: Request): - args = await request.json() - return StreamingResponse(text_streamer(args)) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default="localhost") - parser.add_argument("--port", type=int, default=10002) - args = parser.parse_args() - - uvicorn.run(app, host=args.host, port=args.port, log_level="info") From 18f809761020626e7e8cd49d118f4a474506cf46 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 20 May 2023 13:39:25 -0700 Subject: [PATCH 02/23] rename fastapi frontend --- .../{fastapi_server.py => simple_fastapi_frontend.py} | 1 - cacheflow/server/async_llm_server.py | 1 - 2 files changed, 2 deletions(-) rename cacheflow/entrypoints/{fastapi_server.py => simple_fastapi_frontend.py} (98%) diff --git a/cacheflow/entrypoints/fastapi_server.py b/cacheflow/entrypoints/simple_fastapi_frontend.py similarity index 98% rename from cacheflow/entrypoints/fastapi_server.py rename to cacheflow/entrypoints/simple_fastapi_frontend.py index 64ac196646ff..44f4ee5f2f64 100644 --- a/cacheflow/entrypoints/fastapi_server.py +++ b/cacheflow/entrypoints/simple_fastapi_frontend.py @@ -1,5 +1,4 @@ import argparse -import asyncio from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse import uvicorn diff --git a/cacheflow/server/async_llm_server.py b/cacheflow/server/async_llm_server.py index 655493568481..71a1905dffb6 100644 --- a/cacheflow/server/async_llm_server.py +++ b/cacheflow/server/async_llm_server.py @@ -2,7 +2,6 @@ import json import time from typing import Any, Dict -import uuid import ray From fb73a1baf297564331956c4acb1489cc010a074d Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 20 May 2023 16:43:33 -0700 Subject: [PATCH 03/23] small fix --- examples/simple_server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/simple_server.py b/examples/simple_server.py index 2ab04a837efb..fb1fec905ba5 100644 --- a/examples/simple_server.py +++ b/examples/simple_server.py @@ -1,5 +1,4 @@ import argparse -import uuid from cacheflow import (add_server_arguments, initialize_server_from_args, SamplingParams) From b990239ce1bc81c60c692714fb9fe682ff894225 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 20 May 2023 16:44:00 -0700 Subject: [PATCH 04/23] [WIP] add WIP openai frontend --- .../entrypoints/openai/openai_frontend.py | 244 ++++++++++++++++++ cacheflow/entrypoints/openai/protocol.py | 196 ++++++++++++++ 2 files changed, 440 insertions(+) create mode 100644 cacheflow/entrypoints/openai/openai_frontend.py create mode 100644 cacheflow/entrypoints/openai/protocol.py diff --git a/cacheflow/entrypoints/openai/openai_frontend.py b/cacheflow/entrypoints/openai/openai_frontend.py new file mode 100644 index 000000000000..3da1cfc06062 --- /dev/null +++ b/cacheflow/entrypoints/openai/openai_frontend.py @@ -0,0 +1,244 @@ +import asyncio +import argparse +import asyncio +import json +from typing import Generator, Optional, Union, Dict, List, Any +from http import HTTPStatus + +import fastapi +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn + +from cacheflow.server.arg_utils import ( + add_server_arguments, create_server_configs_from_args) +from cacheflow.server.async_llm_server import AsyncLLMServer +from cacheflow.server.ray_utils import initialize_cluster +from cacheflow.logger import init_logger +from cacheflow.sampling_params import SamplingParams +from cacheflow.utils import random_uuid + + +from cacheflow.entrypoints.openai.protocol import ( + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + CompletionResponseStreamChoice, + CompletionStreamResponse, + ErrorCode, + ErrorResponse, + ModelCard, + ModelList, + ModelPermission, + UsageInfo, +) + + +logger = init_logger(__name__) +served_model = None +app = fastapi.FastAPI() + + +def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse: + return JSONResponse( + ErrorResponse(message=message).dict(), + status_code=status_code.value + ) + + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request, exc): + return create_error_response(HTTPStatus.BAD_REQUEST, str(exc)) + + +async def check_model(request) -> Optional[JSONResponse]: + if request.model == served_model: + return + ret = create_error_response( + ErrorCode.INVALID_MODEL, + f"The model `{request.model}` does not exist.", + ) + return ret + +@app.get("/v1/models") +async def show_available_models(): + controller_address = app_settings.controller_address + async with httpx.AsyncClient() as client: + ret = await client.post(controller_address + "/refresh_all_workers") + ret = await client.post(controller_address + "/list_models") + models = ret.json()["models"] + models.sort() + # TODO: return real model permission details + model_cards = [] + for m in models: + model_cards.append(ModelCard(id=m, root=m, permission=[ModelPermission()])) + return ModelList(data=model_cards) + +@app.post("/v1/completions") +async def create_completion(request: CompletionRequest): + try: + sampling_params = SamplingParams( + n=request.n, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + max_tokens=request.max_tokens, + logprobs=request.logprobs if request.logprobs else 0, + use_beam_search=request.use_beam_search, + # TODO: support stop, best_of, and logit_bias + ) + except ValueError as e: + return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) + + if request.stream: + generator = generate_completion_stream_generator(payload, request.n) + return StreamingResponse(generator, media_type="text/event-stream") + else: + text_completions = [] + for i in range(request.n): + content = asyncio.create_task(generate_completion(payload)) + text_completions.append(content) + + try: + all_tasks = await asyncio.gather(*text_completions) + except Exception as e: + return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) + + choices = [] + usage = UsageInfo() + for i, content in enumerate(all_tasks): + if content["error_code"] != 0: + return create_error_response(content["error_code"], content["text"]) + choices.append( + CompletionResponseChoice( + index=i, + text=content["text"], + logprobs=content.get("logprobs", None), + finish_reason=content.get("finish_reason", "stop"), + ) + ) + task_usage = UsageInfo.parse_obj(content["usage"]) + for usage_key, usage_value in task_usage.dict().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + return CompletionResponse( + model=request.model, choices=choices, usage=UsageInfo.parse_obj(usage) + ) + + +async def generate_completion_stream_generator(payload: Dict[str, Any], n: int): + model_name = payload["model"] + id = f"cmpl-{random_uuid()}" + finish_stream_events = [] + for i in range(n): + previous_text = "" + async for content in generate_completion_stream(payload): + if content["error_code"] != 0: + yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + return + decoded_unicode = content["text"].replace("\ufffd", "") + delta_text = decoded_unicode[len(previous_text) :] + previous_text = decoded_unicode + + choice_data = CompletionResponseStreamChoice( + index=i, + text=delta_text, + logprobs=content.get("logprobs", None), + finish_reason=content.get("finish_reason", None), + ) + chunk = CompletionStreamResponse( + id=id, object="text_completion", choices=[choice_data], model=model_name + ) + if len(delta_text) == 0: + if content.get("finish_reason", None) is not None: + finish_stream_events.append(chunk) + continue + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + # There is not "content" field in the last delta message, so exclude_none to exclude field "content". + for finish_chunk in finish_stream_events: + yield f"data: {finish_chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + + +async def generate_completion_stream(payload: Dict[str, Any]): + controller_address = app_settings.controller_address + async with httpx.AsyncClient() as client: + worker_addr = await _get_worker_address(payload["model"], client) + + delimiter = b"\0" + async with client.stream( + "POST", + worker_addr + "/worker_generate_completion_stream", + headers=headers, + json=payload, + timeout=WORKER_API_TIMEOUT, + ) as response: + # content = await response.aread() + async for raw_chunk in response.aiter_raw(): + for chunk in raw_chunk.split(delimiter): + if not chunk: + continue + data = json.loads(chunk.decode()) + yield data + + +async def generate_completion(payload: Dict[str, Any]): + controller_address = app_settings.controller_address + async with httpx.AsyncClient() as client: + worker_addr = await _get_worker_address(payload["model"], client) + + response = await client.post( + worker_addr + "/worker_generate_completion", + headers=headers, + json=payload, + timeout=WORKER_API_TIMEOUT, + ) + completion = response.json() + return completion + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="CacheFlow OpenAI-Compatible RESTful API server." + ) + parser.add_argument("--host", type=str, default="localhost", help="host name") + parser.add_argument("--port", type=int, default=8000, help="port number") + parser.add_argument( + "--allow-credentials", action="store_true", help="allow credentials" + ) + parser.add_argument( + "--allowed-origins", type=json.loads, default=["*"], help="allowed origins" + ) + parser.add_argument( + "--allowed-methods", type=json.loads, default=["*"], help="allowed methods" + ) + parser.add_argument( + "--allowed-headers", type=json.loads, default=["*"], help="allowed headers" + ) + parser = add_server_arguments(parser) + args = parser.parse_args() + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + + logger.info(f"args: {args}") + + served_model = args.model + + server_configs = create_server_configs_from_args(args) + parallel_config = server_configs[2] + distributed_init_method, stage_devices = initialize_cluster(parallel_config) + + server = AsyncLLMServer( + args.use_ray, *server_configs, distributed_init_method, stage_devices) + + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/cacheflow/entrypoints/openai/protocol.py b/cacheflow/entrypoints/openai/protocol.py new file mode 100644 index 000000000000..ebe171b1b518 --- /dev/null +++ b/cacheflow/entrypoints/openai/protocol.py @@ -0,0 +1,196 @@ +from typing import Literal, Optional, List, Dict, Any, Union +from enum import IntEnum + +import time + +from pydantic import BaseModel, Field +from cacheflow.utils import random_uuid + + +class ErrorResponse(BaseModel): + object: str = "error" + message: str + type: str + param: Optional[str] = None + code: Optional[str] = None + + +class ModelPermission(BaseModel): + id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") + object: str = "model_permission" + created: int = Field(default_factory=lambda: int(time.time())) + allow_create_engine: bool = False + allow_sampling: bool = True + allow_logprobs: bool = True + allow_search_indices: bool = True + allow_view: bool = True + allow_fine_tuning: bool = False + organization: str = "*" + group: Optional[str] = None + is_blocking: str = False + + +class ModelCard(BaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "fastchat" + root: Optional[str] = None + parent: Optional[str] = None + permission: List[ModelPermission] = [] + + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = [] + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[Dict[str, str]] + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 + n: Optional[int] = 1 + max_tokens: Optional[int] = None + stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + user: Optional[str] = None + + +class ChatMessage(BaseModel): + role: str + content: str + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: Optional[Literal["stop", "length"]] + + +class ChatCompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: str = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: UsageInfo + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[Literal["stop", "length"]] + + +class ChatCompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: str = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] + + +class EmbeddingsRequest(BaseModel): + model: str + input: str + user: Optional[str] = None + + +class EmbeddingsResponse(BaseModel): + object: str = "list" + data: List[Dict[str, Any]] + model: str + usage: UsageInfo + + +class CompletionRequest(BaseModel): + model: str + prompt: str + suffix: Optional[str] = None + max_tokens: Optional[int] = 16 + temperature: Optional[float] = 1.0 + top_p: Optional[float] = 1.0 + n: Optional[int] = 1 + stream: Optional[bool] = False + logprobs: Optional[int] = None + echo: Optional[bool] = False + stop: Optional[Union[str, List[str]]] = None + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + best_of: Optional[int] = 1 + logit_bias: Optional[Dict[str, float]] = None + user: Optional[str] = None + # Additional parameters supported by cacheflow + use_beam_search: Optional[bool] = False + + +class CompletionResponseChoice(BaseModel): + index: int + text: str + logprobs: Optional[int] = None + finish_reason: Optional[Literal["stop", "length"]] + + +class CompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseChoice] + usage: UsageInfo + + +class CompletionResponseStreamChoice(BaseModel): + index: int + text: str + logprobs: Optional[float] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseStreamChoice] + + +class ErrorCode(IntEnum): + """ + https://platform.openai.com/docs/guides/error-codes/api-errors + """ + + VALIDATION_TYPE_ERROR = 40001 + + INVALID_AUTH_KEY = 40101 + INCORRECT_AUTH_KEY = 40102 + NO_PERMISSION = 40103 + + INVALID_MODEL = 40301 + PARAM_OUT_OF_RANGE = 40302 + CONTEXT_OVERFLOW = 40303 + + RATE_LIMIT = 42901 + QUOTA_EXCEEDED = 42902 + ENGINE_OVERLOADED = 42903 + + INTERNAL_ERROR = 50001 + CUDA_OUT_OF_MEMORY = 50002 + GRADIO_REQUEST_ERROR = 50003 + GRADIO_STREAM_UNKNOWN_ERROR = 50004 + CONTROLLER_NO_WORKER = 50005 + CONTROLLER_WORKER_TIMEOUT = 50006 From 00c84e2d6e874a9371af49d79c11c2def26eacec Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 20 May 2023 23:51:31 +0000 Subject: [PATCH 05/23] fix async_llm_server --- .../entrypoints/simple_fastapi_frontend.py | 17 ++++++++++++++++- cacheflow/server/async_llm_server.py | 12 +----------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/cacheflow/entrypoints/simple_fastapi_frontend.py b/cacheflow/entrypoints/simple_fastapi_frontend.py index 44f4ee5f2f64..7c0618a70cdf 100644 --- a/cacheflow/entrypoints/simple_fastapi_frontend.py +++ b/cacheflow/entrypoints/simple_fastapi_frontend.py @@ -1,4 +1,5 @@ import argparse +import json from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse import uvicorn @@ -15,7 +16,21 @@ @app.post("/generate") async def generate_stream(request: Request): request_dict = await request.json() - return StreamingResponse(server.generate(request_dict)) + + async def stream_results(): + async for request_output in server.generate(request_dict): + prompt = request_output.prompt + text_outputs = [ + prompt + output.text + for output in request_output.outputs + ] + ret = { + "text": text_outputs, + "error": 0, + } + yield (json.dumps(ret) + "\0").encode("utf-8") + + return StreamingResponse(stream_results()) if __name__ == "__main__": diff --git a/cacheflow/server/async_llm_server.py b/cacheflow/server/async_llm_server.py index 71a1905dffb6..c1d905f44d0c 100644 --- a/cacheflow/server/async_llm_server.py +++ b/cacheflow/server/async_llm_server.py @@ -1,5 +1,4 @@ import asyncio -import json import time from typing import Any, Dict @@ -75,16 +74,7 @@ async def generate(self, request_dict: Dict[str, Any]): # Decode and return new outputs. request_output = self.request_outputs[request_id] - prompt = request_output.prompt - text_outputs = [ - prompt + output.text - for output in request_output.outputs - ] - ret = { - "text": text_outputs, - "error": 0, - } - yield (json.dumps(ret) + "\0").encode("utf-8") + yield request_output # Once finished, release the resources of the sequence group. if request_output.done: From 1c71b88fc3b2f0ee0c7f664582236630e4404a8c Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sun, 21 May 2023 17:47:23 +0000 Subject: [PATCH 06/23] Basic support for OpenAI Completion API --- .../entrypoints/openai/openai_frontend.py | 167 ++++++------------ cacheflow/entrypoints/openai/protocol.py | 1 + .../entrypoints/simple_fastapi_frontend.py | 6 +- cacheflow/server/async_llm_server.py | 10 +- examples/openai_client.py | 22 +++ test_cli_client.py | 2 +- 6 files changed, 86 insertions(+), 122 deletions(-) create mode 100644 examples/openai_client.py diff --git a/cacheflow/entrypoints/openai/openai_frontend.py b/cacheflow/entrypoints/openai/openai_frontend.py index 3da1cfc06062..6ab3287f7f1e 100644 --- a/cacheflow/entrypoints/openai/openai_frontend.py +++ b/cacheflow/entrypoints/openai/openai_frontend.py @@ -2,6 +2,7 @@ import argparse import asyncio import json +import time from typing import Generator, Optional, Union, Dict, List, Any from http import HTTPStatus @@ -56,27 +57,26 @@ async def check_model(request) -> Optional[JSONResponse]: if request.model == served_model: return ret = create_error_response( - ErrorCode.INVALID_MODEL, + HTTPStatus.NOT_FOUND, f"The model `{request.model}` does not exist.", ) return ret @app.get("/v1/models") async def show_available_models(): - controller_address = app_settings.controller_address - async with httpx.AsyncClient() as client: - ret = await client.post(controller_address + "/refresh_all_workers") - ret = await client.post(controller_address + "/list_models") - models = ret.json()["models"] - models.sort() - # TODO: return real model permission details - model_cards = [] - for m in models: - model_cards.append(ModelCard(id=m, root=m, permission=[ModelPermission()])) + model_cards = [ModelCard(id=served_model, root=served_model, permission=[ModelPermission()])] return ModelList(data=model_cards) @app.post("/v1/completions") async def create_completion(request: CompletionRequest): + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + + model_name = request.model + request_id = f"cmpl-{random_uuid()}" + prompt = request.prompt + created_time = int(time.time()) try: sampling_params = SamplingParams( n=request.n, @@ -88,117 +88,54 @@ async def create_completion(request: CompletionRequest): max_tokens=request.max_tokens, logprobs=request.logprobs if request.logprobs else 0, use_beam_search=request.use_beam_search, - # TODO: support stop, best_of, and logit_bias + # TODO(zhuohan): support stop, best_of, and logit_bias ) except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) + result_generator = server.generate(prompt, sampling_params, + request_id=request_id) + + async def generate_completion_stream_generator(): + previous_texts = [""] * request.n + async for res in result_generator: + print("res:", res) + for i, output in enumerate(res.outputs): + delta_text = output.text[len(previous_texts[i]):] + previous_texts[i] = output.text + choice_data = CompletionResponseStreamChoice( + index=i, + text=delta_text, + logprobs=None, + finish_reason=None, + ) + response = CompletionStreamResponse( + id=request_id, + created=created_time, + model = model_name, + choices=[choice_data], + ) + yield f"data: {response.json(exclude_unset=True, ensure_ascii=False)}\n\n" + if res.done: + choice_data = CompletionResponseStreamChoice( + index=i, + text="", + logprobs=None, + finish_reason="stop", + ) + response = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[choice_data], + ) + yield f"data: {response.json(exclude_none=True, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" if request.stream: - generator = generate_completion_stream_generator(payload, request.n) + generator = generate_completion_stream_generator() return StreamingResponse(generator, media_type="text/event-stream") else: - text_completions = [] - for i in range(request.n): - content = asyncio.create_task(generate_completion(payload)) - text_completions.append(content) - - try: - all_tasks = await asyncio.gather(*text_completions) - except Exception as e: - return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) - - choices = [] - usage = UsageInfo() - for i, content in enumerate(all_tasks): - if content["error_code"] != 0: - return create_error_response(content["error_code"], content["text"]) - choices.append( - CompletionResponseChoice( - index=i, - text=content["text"], - logprobs=content.get("logprobs", None), - finish_reason=content.get("finish_reason", "stop"), - ) - ) - task_usage = UsageInfo.parse_obj(content["usage"]) - for usage_key, usage_value in task_usage.dict().items(): - setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) - - return CompletionResponse( - model=request.model, choices=choices, usage=UsageInfo.parse_obj(usage) - ) - - -async def generate_completion_stream_generator(payload: Dict[str, Any], n: int): - model_name = payload["model"] - id = f"cmpl-{random_uuid()}" - finish_stream_events = [] - for i in range(n): - previous_text = "" - async for content in generate_completion_stream(payload): - if content["error_code"] != 0: - yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n" - yield "data: [DONE]\n\n" - return - decoded_unicode = content["text"].replace("\ufffd", "") - delta_text = decoded_unicode[len(previous_text) :] - previous_text = decoded_unicode - - choice_data = CompletionResponseStreamChoice( - index=i, - text=delta_text, - logprobs=content.get("logprobs", None), - finish_reason=content.get("finish_reason", None), - ) - chunk = CompletionStreamResponse( - id=id, object="text_completion", choices=[choice_data], model=model_name - ) - if len(delta_text) == 0: - if content.get("finish_reason", None) is not None: - finish_stream_events.append(chunk) - continue - yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" - # There is not "content" field in the last delta message, so exclude_none to exclude field "content". - for finish_chunk in finish_stream_events: - yield f"data: {finish_chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" - yield "data: [DONE]\n\n" - - -async def generate_completion_stream(payload: Dict[str, Any]): - controller_address = app_settings.controller_address - async with httpx.AsyncClient() as client: - worker_addr = await _get_worker_address(payload["model"], client) - - delimiter = b"\0" - async with client.stream( - "POST", - worker_addr + "/worker_generate_completion_stream", - headers=headers, - json=payload, - timeout=WORKER_API_TIMEOUT, - ) as response: - # content = await response.aread() - async for raw_chunk in response.aiter_raw(): - for chunk in raw_chunk.split(delimiter): - if not chunk: - continue - data = json.loads(chunk.decode()) - yield data - - -async def generate_completion(payload: Dict[str, Any]): - controller_address = app_settings.controller_address - async with httpx.AsyncClient() as client: - worker_addr = await _get_worker_address(payload["model"], client) - - response = await client.post( - worker_addr + "/worker_generate_completion", - headers=headers, - json=payload, - timeout=WORKER_API_TIMEOUT, - ) - completion = response.json() - return completion + raise NotImplementedError("Not implemented yet.") if __name__ == "__main__": diff --git a/cacheflow/entrypoints/openai/protocol.py b/cacheflow/entrypoints/openai/protocol.py index ebe171b1b518..0c8e43c76c2c 100644 --- a/cacheflow/entrypoints/openai/protocol.py +++ b/cacheflow/entrypoints/openai/protocol.py @@ -124,6 +124,7 @@ class CompletionRequest(BaseModel): max_tokens: Optional[int] = 16 temperature: Optional[float] = 1.0 top_p: Optional[float] = 1.0 + top_k: Optional[int] = -1 n: Optional[int] = 1 stream: Optional[bool] = False logprobs: Optional[int] = None diff --git a/cacheflow/entrypoints/simple_fastapi_frontend.py b/cacheflow/entrypoints/simple_fastapi_frontend.py index 7c0618a70cdf..bd6d7e05e299 100644 --- a/cacheflow/entrypoints/simple_fastapi_frontend.py +++ b/cacheflow/entrypoints/simple_fastapi_frontend.py @@ -4,6 +4,7 @@ from fastapi.responses import StreamingResponse import uvicorn +from cacheflow.sampling_params import SamplingParams from cacheflow.server.arg_utils import ( add_server_arguments, create_server_configs_from_args) from cacheflow.server.async_llm_server import AsyncLLMServer @@ -16,9 +17,12 @@ @app.post("/generate") async def generate_stream(request: Request): request_dict = await request.json() + prompt = request_dict.pop("prompt") + sampling_params = SamplingParams(**request_dict) + results_generator = server.generate(prompt, sampling_params) async def stream_results(): - async for request_output in server.generate(request_dict): + async for request_output in results_generator: prompt = request_output.prompt text_outputs = [ prompt + output.text diff --git a/cacheflow/server/async_llm_server.py b/cacheflow/server/async_llm_server.py index c1d905f44d0c..6bd1fb283f31 100644 --- a/cacheflow/server/async_llm_server.py +++ b/cacheflow/server/async_llm_server.py @@ -1,6 +1,6 @@ import asyncio import time -from typing import Any, Dict +from typing import Optional, Dict import ray @@ -37,15 +37,15 @@ async def server_step(self): self.request_outputs[request_id] = request_output self.request_events[request_id].set() - async def generate(self, request_dict: Dict[str, Any]): + async def generate(self, prompt: str, sampling_params: SamplingParams, + request_id: Optional[str] = None) -> RequestOutput: # Preprocess the request. arrival_time = time.time() - prompt = request_dict.pop("prompt") - sampling_params = SamplingParams(**request_dict) # Create an event to notify us that there is new output from the # cacheflow server. - request_id = random_uuid() + if request_id is None: + request_id = random_uuid() request_event = asyncio.Event() self.request_events[request_id] = request_event diff --git a/examples/openai_client.py b/examples/openai_client.py new file mode 100644 index 000000000000..100883dd1dc0 --- /dev/null +++ b/examples/openai_client.py @@ -0,0 +1,22 @@ +import openai +openai.api_key = "EMPTY" +openai.api_base = "http://localhost:8000/v1" +model = "facebook/opt-125m" + +# create a chat completion + +stream = True + +completion = openai.Completion.create( + model=model, prompt="A robot may not injure a human being", echo=False, n=2, + stream=stream) + +print("completion:", completion) + +# print the chat completion +if stream: + for c in completion: + print(c) +else: + print("completion:", completion) + diff --git a/test_cli_client.py b/test_cli_client.py index 217f8088645a..b93f950df356 100644 --- a/test_cli_client.py +++ b/test_cli_client.py @@ -11,7 +11,7 @@ def http_request(): "use_beam_search": True, "temperature": 0.0, } - response = requests.post("http://localhost:10002/generate", headers=headers, json=pload, stream=True) + response = requests.post("http://localhost:10003/generate", headers=headers, json=pload, stream=True) for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): if chunk: From 63c5d3cc25fcb15d3e6c7fa97b287af44cc98cf7 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sun, 21 May 2023 23:36:53 +0000 Subject: [PATCH 07/23] Implement finsh_reason --- cacheflow/core/block_manager.py | 6 ++-- cacheflow/core/scheduler.py | 3 +- .../entrypoints/openai/openai_frontend.py | 9 +++--- cacheflow/model_executor/layers/sampler.py | 6 ++-- cacheflow/outputs.py | 30 ++++++++++++------- cacheflow/sampling_params.py | 4 +-- cacheflow/sequence.py | 13 ++++++-- cacheflow/server/async_llm_server.py | 2 +- cacheflow/server/llm_server.py | 7 +++-- examples/openai_client.py | 2 +- examples/simple_server.py | 2 +- test_cli_client.py | 2 +- 12 files changed, 54 insertions(+), 32 deletions(-) diff --git a/cacheflow/core/block_manager.py b/cacheflow/core/block_manager.py index 8f1295bcc672..07129b65b226 100644 --- a/cacheflow/core/block_manager.py +++ b/cacheflow/core/block_manager.py @@ -148,7 +148,7 @@ def _get_physical_blocks(self, seq_group: SequenceGroup) -> List[PhysicalTokenBl # the sequences in the same group. blocks: Set[PhysicalTokenBlock] = set() for seq in seq_group.get_seqs(): - if seq.status == SequenceStatus.FINISHED: + if SequenceStatus.is_finished(seq.status): continue block_table = self.block_tables[seq.seq_id] for block in block_table: @@ -169,7 +169,7 @@ def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: # CPU block -> GPU block. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(): - if seq.status == SequenceStatus.FINISHED: + if SequenceStatus.is_finished(seq.status): continue new_block_table: BlockTable = [] block_table = self.block_tables[seq.seq_id] @@ -200,7 +200,7 @@ def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: # GPU block -> CPU block. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(): - if seq.status == SequenceStatus.FINISHED: + if SequenceStatus.is_finished(seq.status): continue new_block_table: BlockTable = [] block_table = self.block_tables[seq.seq_id] diff --git a/cacheflow/core/scheduler.py b/cacheflow/core/scheduler.py index 08f855730dd6..496aaa842c4d 100644 --- a/cacheflow/core/scheduler.py +++ b/cacheflow/core/scheduler.py @@ -292,10 +292,11 @@ def update( # Append a new token to the sequence. output = seq_outputs[seq.seq_id] seq.append_token(output.output_token, output.logprobs) + # Return a shallow copy of the running queue to prevent the queue + # from being modified by the caller. return self.running.copy() def free_seq(self, seq: Sequence) -> None: - seq.status = SequenceStatus.FINISHED self.block_manager.free(seq) def free_finished_seq_groups(self) -> None: diff --git a/cacheflow/entrypoints/openai/openai_frontend.py b/cacheflow/entrypoints/openai/openai_frontend.py index 6ab3287f7f1e..81c159a6a932 100644 --- a/cacheflow/entrypoints/openai/openai_frontend.py +++ b/cacheflow/entrypoints/openai/openai_frontend.py @@ -64,7 +64,8 @@ async def check_model(request) -> Optional[JSONResponse]: @app.get("/v1/models") async def show_available_models(): - model_cards = [ModelCard(id=served_model, root=served_model, permission=[ModelPermission()])] + model_cards = [ModelCard(id=served_model, root=served_model, + permission=[ModelPermission()])] return ModelList(data=model_cards) @app.post("/v1/completions") @@ -116,12 +117,12 @@ async def generate_completion_stream_generator(): choices=[choice_data], ) yield f"data: {response.json(exclude_unset=True, ensure_ascii=False)}\n\n" - if res.done: + if output.finish_reason is not None: choice_data = CompletionResponseStreamChoice( index=i, text="", logprobs=None, - finish_reason="stop", + finish_reason=output.finish_reason, ) response = CompletionStreamResponse( id=request_id, @@ -129,7 +130,7 @@ async def generate_completion_stream_generator(): model=model_name, choices=[choice_data], ) - yield f"data: {response.json(exclude_none=True, ensure_ascii=False)}\n\n" + yield f"data: {response.json(exclude_unset=True, ensure_ascii=False)}\n\n" yield "data: [DONE]\n\n" if request.stream: generator = generate_completion_stream_generator() diff --git a/cacheflow/model_executor/layers/sampler.py b/cacheflow/model_executor/layers/sampler.py index 1c3187c05410..baf62ba71075 100644 --- a/cacheflow/model_executor/layers/sampler.py +++ b/cacheflow/model_executor/layers/sampler.py @@ -1,5 +1,5 @@ """A layer that samples the next tokens from the model's outputs.""" -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Optional import numpy as np import torch @@ -258,9 +258,9 @@ def _apply_top_p_top_k( def _get_topk_logprobs( logprobs: torch.Tensor, - num_logprobs: int, + num_logprobs: Optional[int], ) -> Dict[int, float]: - if num_logprobs == 0: + if num_logprobs is None or num_logprobs == 0: return {} topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs) diff --git a/cacheflow/outputs.py b/cacheflow/outputs.py index 0b4dcabe0b35..becb838f2599 100644 --- a/cacheflow/outputs.py +++ b/cacheflow/outputs.py @@ -1,6 +1,6 @@ -from typing import Dict, List +from typing import Dict, List, Optional -from cacheflow.sequence import SequenceGroup +from cacheflow.sequence import SequenceGroup, SequenceStatus class CompletionOutput: @@ -12,19 +12,22 @@ def __init__( token_ids: List[int], cumulative_logprob: float, logprobs: List[Dict[int, float]], + finish_reason: Optional[str] = None, ) -> None: self.index = index self.text = text self.token_ids = token_ids self.cumulative_logprob = cumulative_logprob self.logprobs = logprobs + self.finish_reason = finish_reason def __repr__(self) -> str: return (f"CompletionOutput(index={self.index}, " f"text={self.text!r}, " f"token_ids={self.token_ids}, " f"cumulative_logprob={self.cumulative_logprob}, " - f"logprobs={self.logprobs})") + f"logprobs={self.logprobs}," + f"finish_reason={self.finish_reason})") class RequestOutput: @@ -35,13 +38,11 @@ def __init__( prompt: str, prompt_token_ids: List[int], outputs: List[CompletionOutput], - done: bool = False, ) -> None: self.request_id = request_id self.prompt = prompt self.prompt_token_ids = prompt_token_ids self.outputs = outputs - self.done = done @staticmethod def from_seq_group(seq_group: SequenceGroup) -> "RequestOutput": @@ -57,25 +58,34 @@ def from_seq_group(seq_group: SequenceGroup) -> "RequestOutput": outputs: List[CompletionOutput] = [] for seq in top_n_seqs: logprobs = seq.output_logprobs - if seq_group.sampling_params.logprobs == 0: + if seq_group.sampling_params.logprobs is None: # NOTE: We need to take care of this case because the sequence # always has the logprobs of the sampled tokens even if the # logprobs are not requested. logprobs = {} + if seq.status == SequenceStatus.FINISHED_STOPPED: + finish_reason = "stop" + elif seq.status == SequenceStatus.FINISHED_LENGTH_CAPPED: + finish_reason = "length" + else: + finish_reason = None output = CompletionOutput(seqs.index(seq), seq.output_text, seq.get_output_token_ids(), - seq.get_cumulative_logprob(), logprobs) + seq.get_cumulative_logprob(), logprobs, + finish_reason) outputs.append(output) # Every sequence in the sequence group should have the same prompt. prompt = top_n_seqs[0].prompt prompt_token_ids = top_n_seqs[0].data.prompt_token_ids return RequestOutput(seq_group.request_id, prompt, prompt_token_ids, - outputs, seq_group.is_finished()) + outputs) def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " f"prompt_token_ids={self.prompt_token_ids}, " - f"outputs={self.outputs}, " - f"done={self.done})") + f"outputs={self.outputs})") + + def finished(self) -> bool: + return all(output.finish_reason is not None for output in self.outputs) \ No newline at end of file diff --git a/cacheflow/sampling_params.py b/cacheflow/sampling_params.py index 0ce772a9e9ff..031eb820583f 100644 --- a/cacheflow/sampling_params.py +++ b/cacheflow/sampling_params.py @@ -53,7 +53,7 @@ def __init__( stop: Union[str, List[str]] = [], ignore_eos: bool = False, max_tokens: int = 16, - logprobs: int = 0, + logprobs: Optional[int] = None, ) -> None: self.n = n self.best_of = best_of if best_of is not None else n @@ -98,7 +98,7 @@ def _verify_args(self) -> None: if self.max_tokens < 1: raise ValueError( f"max_tokens must be at least 1, got {self.max_tokens}.") - if self.logprobs < 0: + if self.logprobs is not None and self.logprobs < 0: raise ValueError( f"logprobs must be non-negative, got {self.logprobs}.") diff --git a/cacheflow/sequence.py b/cacheflow/sequence.py index b2c19fae9c96..8a467c56b2c2 100644 --- a/cacheflow/sequence.py +++ b/cacheflow/sequence.py @@ -10,7 +10,15 @@ class SequenceStatus(enum.Enum): WAITING = enum.auto() RUNNING = enum.auto() SWAPPED = enum.auto() - FINISHED = enum.auto() + FINISHED_STOPPED = enum.auto() + FINISHED_LENGTH_CAPPED = enum.auto() + + @staticmethod + def is_finished(status: "SequenceStatus") -> bool: + return status in [ + SequenceStatus.FINISHED_STOPPED, + SequenceStatus.FINISHED_LENGTH_CAPPED, + ] class SequenceData: @@ -20,7 +28,6 @@ def __init__( prompt_token_ids: List[int], ) -> None: self.prompt_token_ids = prompt_token_ids - self.output_token_ids: List[int] = [] self.cumulative_logprob = 0.0 @@ -160,7 +167,7 @@ def find(self, seq_id: int) -> Sequence: raise ValueError(f'Sequence {seq_id} not found.') def is_finished(self) -> bool: - return all(seq.status == SequenceStatus.FINISHED for seq in self.seqs) + return all(SequenceStatus.is_finished(seq.status) for seq in self.seqs) def __repr__(self) -> str: return (f"SequenceGroup(request_id={self.request_id}, " diff --git a/cacheflow/server/async_llm_server.py b/cacheflow/server/async_llm_server.py index 6bd1fb283f31..dca22483da3f 100644 --- a/cacheflow/server/async_llm_server.py +++ b/cacheflow/server/async_llm_server.py @@ -77,7 +77,7 @@ async def generate(self, prompt: str, sampling_params: SamplingParams, yield request_output # Once finished, release the resources of the sequence group. - if request_output.done: + if request_output.finished(): del self.request_outputs[request_id] del self.request_events[request_id] # Kick the server if the server is not running. This is to diff --git a/cacheflow/server/llm_server.py b/cacheflow/server/llm_server.py index 4cc4a228accf..3313984cbb47 100644 --- a/cacheflow/server/llm_server.py +++ b/cacheflow/server/llm_server.py @@ -195,6 +195,7 @@ def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None: # Truncate the output text so that the stop string is # not included in the output. seq.output_text = seq.output_text[:-len(stop_str)] + seq.status = SequenceStatus.FINISHED_STOPPED self.scheduler.free_seq(seq) stopped = True break @@ -203,11 +204,13 @@ def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None: # Check if the sequence has reached max_tokens. if seq.get_output_len() == sampling_params.max_tokens: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED self.scheduler.free_seq(seq) continue # Check if the sequence has generated the EOS token. if not sampling_params.ignore_eos: if seq.get_last_token_id() == self.tokenizer.eos_token_id: + seq.status = SequenceStatus.FINISHED_STOPPED self.scheduler.free_seq(seq) continue @@ -223,10 +226,10 @@ def _run_workers( executor = getattr(worker, method) if self.parallel_config.use_ray: executor = executor.remote - + output = executor(*args, **kwargs) all_outputs.append(output) - + if self.parallel_config.use_ray: all_outputs = ray.get(all_outputs) diff --git a/examples/openai_client.py b/examples/openai_client.py index 100883dd1dc0..77ff1bf568c7 100644 --- a/examples/openai_client.py +++ b/examples/openai_client.py @@ -8,7 +8,7 @@ stream = True completion = openai.Completion.create( - model=model, prompt="A robot may not injure a human being", echo=False, n=2, + model=model, prompt="A robot may not injure a human being", echo=False, n=2, logprobs=3, stream=stream) print("completion:", completion) diff --git a/examples/simple_server.py b/examples/simple_server.py index a2f498b8eeb2..f70cd8903b2d 100644 --- a/examples/simple_server.py +++ b/examples/simple_server.py @@ -29,7 +29,7 @@ def main(args: argparse.Namespace): request_outputs = server.step() for request_output in request_outputs: - if request_output.done: + if request_output.finished(): print(request_output) if not (server.has_unfinished_requests() or test_prompts): diff --git a/test_cli_client.py b/test_cli_client.py index b93f950df356..217f8088645a 100644 --- a/test_cli_client.py +++ b/test_cli_client.py @@ -11,7 +11,7 @@ def http_request(): "use_beam_search": True, "temperature": 0.0, } - response = requests.post("http://localhost:10003/generate", headers=headers, json=pload, stream=True) + response = requests.post("http://localhost:10002/generate", headers=headers, json=pload, stream=True) for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): if chunk: From 8321b47339b7d887d4361485452665f884128203 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sun, 21 May 2023 23:57:18 +0000 Subject: [PATCH 08/23] support bestof and stop --- cacheflow/entrypoints/openai/openai_frontend.py | 8 +++++--- cacheflow/entrypoints/openai/protocol.py | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/cacheflow/entrypoints/openai/openai_frontend.py b/cacheflow/entrypoints/openai/openai_frontend.py index 81c159a6a932..610e45e5303a 100644 --- a/cacheflow/entrypoints/openai/openai_frontend.py +++ b/cacheflow/entrypoints/openai/openai_frontend.py @@ -81,15 +81,18 @@ async def create_completion(request: CompletionRequest): try: sampling_params = SamplingParams( n=request.n, + best_of=request.best_of, presence_penalty=request.presence_penalty, frequency_penalty=request.frequency_penalty, temperature=request.temperature, top_p=request.top_p, top_k=request.top_k, + stop=request.stop, + ignore_eos=request.ignore_eos, max_tokens=request.max_tokens, - logprobs=request.logprobs if request.logprobs else 0, + logprobs=request.logprobs, use_beam_search=request.use_beam_search, - # TODO(zhuohan): support stop, best_of, and logit_bias + # TODO(zhuohan): support logit_bias ) except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) @@ -100,7 +103,6 @@ async def create_completion(request: CompletionRequest): async def generate_completion_stream_generator(): previous_texts = [""] * request.n async for res in result_generator: - print("res:", res) for i, output in enumerate(res.outputs): delta_text = output.text[len(previous_texts[i]):] previous_texts[i] = output.text diff --git a/cacheflow/entrypoints/openai/protocol.py b/cacheflow/entrypoints/openai/protocol.py index 0c8e43c76c2c..51a3c5dde37d 100644 --- a/cacheflow/entrypoints/openai/protocol.py +++ b/cacheflow/entrypoints/openai/protocol.py @@ -136,6 +136,7 @@ class CompletionRequest(BaseModel): logit_bias: Optional[Dict[str, float]] = None user: Optional[str] = None # Additional parameters supported by cacheflow + ignore_eos: Optional[bool] = False use_beam_search: Optional[bool] = False From 5c82790227cbfb2c9ad25998f1f0858fa75635dc Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 22 May 2023 00:40:00 +0000 Subject: [PATCH 09/23] Support non-streaming requests --- .../entrypoints/openai/openai_frontend.py | 91 +++++++++++++------ cacheflow/entrypoints/openai/protocol.py | 44 +-------- examples/openai_client.py | 7 +- 3 files changed, 69 insertions(+), 73 deletions(-) diff --git a/cacheflow/entrypoints/openai/openai_frontend.py b/cacheflow/entrypoints/openai/openai_frontend.py index 610e45e5303a..43ab446e71da 100644 --- a/cacheflow/entrypoints/openai/openai_frontend.py +++ b/cacheflow/entrypoints/openai/openai_frontend.py @@ -27,7 +27,6 @@ CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, - ErrorCode, ErrorResponse, ModelCard, ModelList, @@ -41,9 +40,10 @@ app = fastapi.FastAPI() -def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse: +def create_error_response(status_code: HTTPStatus, + message: str) -> JSONResponse: return JSONResponse( - ErrorResponse(message=message).dict(), + ErrorResponse(message=message, type="invalid_request_error").dict(), status_code=status_code.value ) @@ -100,45 +100,84 @@ async def create_completion(request: CompletionRequest): result_generator = server.generate(prompt, sampling_params, request_id=request_id) + # Similar to the OpenAI API, when n != best_of, we do not stream the + # results. In addition, we do not stream the results when use beam search. + stream = (request.stream and request.n == request.best_of + and not request.use_beam_search) + + def create_stream_response_json(index: int, text: str, + log_probs: Optional[Dict] = None, + finish_reason: Optional[str] = None): + choice_data = CompletionResponseStreamChoice( + index=index, + text=text, + logprobs=log_probs, + finish_reason=finish_reason, + ) + response = CompletionStreamResponse( + id=request_id, + created=created_time, + model = model_name, + choices=[choice_data], + ) + response_json = response.json(exclude_unset=True, ensure_ascii=False) + return response_json + async def generate_completion_stream_generator(): previous_texts = [""] * request.n async for res in result_generator: - for i, output in enumerate(res.outputs): + for output in res.outputs: + i = output.index delta_text = output.text[len(previous_texts[i]):] previous_texts[i] = output.text - choice_data = CompletionResponseStreamChoice( + response_json = create_stream_response_json( index=i, text=delta_text, - logprobs=None, - finish_reason=None, ) - response = CompletionStreamResponse( - id=request_id, - created=created_time, - model = model_name, - choices=[choice_data], - ) - yield f"data: {response.json(exclude_unset=True, ensure_ascii=False)}\n\n" + yield f"data: {response_json}\n\n" if output.finish_reason is not None: - choice_data = CompletionResponseStreamChoice( + response_json = create_stream_response_json( index=i, text="", - logprobs=None, finish_reason=output.finish_reason, ) - response = CompletionStreamResponse( - id=request_id, - created=created_time, - model=model_name, - choices=[choice_data], - ) - yield f"data: {response.json(exclude_unset=True, ensure_ascii=False)}\n\n" + yield f"data: {response_json}\n\n" yield "data: [DONE]\n\n" - if request.stream: + + if stream: generator = generate_completion_stream_generator() return StreamingResponse(generator, media_type="text/event-stream") - else: - raise NotImplementedError("Not implemented yet.") + + # Non-streaming response + final_res = None + async for res in result_generator: + final_res = res + assert final_res is not None + choices = [] + for output in final_res.outputs: + choice_data = CompletionResponseChoice( + index=output.index, + text=output.text, + logprobs=None, + finish_reason=output.finish_reason, + ) + choices.append(choice_data) + num_prompt_tokens = len(final_res.prompt_token_ids) + num_generated_tokens = sum(len(output.token_ids) + for output in final_res.outputs) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + response = CompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + return response if __name__ == "__main__": diff --git a/cacheflow/entrypoints/openai/protocol.py b/cacheflow/entrypoints/openai/protocol.py index 51a3c5dde37d..d76407cdcf36 100644 --- a/cacheflow/entrypoints/openai/protocol.py +++ b/cacheflow/entrypoints/openai/protocol.py @@ -104,19 +104,6 @@ class ChatCompletionStreamResponse(BaseModel): choices: List[ChatCompletionResponseStreamChoice] -class EmbeddingsRequest(BaseModel): - model: str - input: str - user: Optional[str] = None - - -class EmbeddingsResponse(BaseModel): - object: str = "list" - data: List[Dict[str, Any]] - model: str - usage: UsageInfo - - class CompletionRequest(BaseModel): model: str prompt: str @@ -129,10 +116,10 @@ class CompletionRequest(BaseModel): stream: Optional[bool] = False logprobs: Optional[int] = None echo: Optional[bool] = False - stop: Optional[Union[str, List[str]]] = None + stop: Optional[Union[str, List[str]]] = [] presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 - best_of: Optional[int] = 1 + best_of: Optional[int] = None logit_bias: Optional[Dict[str, float]] = None user: Optional[str] = None # Additional parameters supported by cacheflow @@ -169,30 +156,3 @@ class CompletionStreamResponse(BaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: List[CompletionResponseStreamChoice] - - -class ErrorCode(IntEnum): - """ - https://platform.openai.com/docs/guides/error-codes/api-errors - """ - - VALIDATION_TYPE_ERROR = 40001 - - INVALID_AUTH_KEY = 40101 - INCORRECT_AUTH_KEY = 40102 - NO_PERMISSION = 40103 - - INVALID_MODEL = 40301 - PARAM_OUT_OF_RANGE = 40302 - CONTEXT_OVERFLOW = 40303 - - RATE_LIMIT = 42901 - QUOTA_EXCEEDED = 42902 - ENGINE_OVERLOADED = 42903 - - INTERNAL_ERROR = 50001 - CUDA_OUT_OF_MEMORY = 50002 - GRADIO_REQUEST_ERROR = 50003 - GRADIO_STREAM_UNKNOWN_ERROR = 50004 - CONTROLLER_NO_WORKER = 50005 - CONTROLLER_WORKER_TIMEOUT = 50006 diff --git a/examples/openai_client.py b/examples/openai_client.py index 77ff1bf568c7..e13a1df3c12f 100644 --- a/examples/openai_client.py +++ b/examples/openai_client.py @@ -5,18 +5,15 @@ # create a chat completion -stream = True +stream = False completion = openai.Completion.create( - model=model, prompt="A robot may not injure a human being", echo=False, n=2, logprobs=3, + model=model, prompt="A robot may not injure a human being", echo=False, n=2, stream=stream) -print("completion:", completion) - # print the chat completion if stream: for c in completion: print(c) else: print("completion:", completion) - From 0e12ecb7cf963fc0a1ab44e068724d12e883f7d7 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 23 May 2023 00:17:01 +0000 Subject: [PATCH 10/23] Support logprobs --- .../entrypoints/openai/openai_frontend.py | 96 +++++++++++++++---- cacheflow/entrypoints/openai/protocol.py | 27 ++++-- examples/openai_client.py | 10 +- 3 files changed, 104 insertions(+), 29 deletions(-) diff --git a/cacheflow/entrypoints/openai/openai_frontend.py b/cacheflow/entrypoints/openai/openai_frontend.py index 43ab446e71da..8f61fa009f91 100644 --- a/cacheflow/entrypoints/openai/openai_frontend.py +++ b/cacheflow/entrypoints/openai/openai_frontend.py @@ -1,9 +1,9 @@ -import asyncio +# Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py + import argparse -import asyncio import json import time -from typing import Generator, Optional, Union, Dict, List, Any +from typing import Optional, Dict, List, AsyncGenerator from http import HTTPStatus import fastapi @@ -12,16 +12,17 @@ from fastapi.responses import StreamingResponse, JSONResponse import uvicorn +from cacheflow.outputs import RequestOutput from cacheflow.server.arg_utils import ( add_server_arguments, create_server_configs_from_args) from cacheflow.server.async_llm_server import AsyncLLMServer from cacheflow.server.ray_utils import initialize_cluster +from cacheflow.server.tokenizer_utils import get_tokenizer from cacheflow.logger import init_logger from cacheflow.sampling_params import SamplingParams from cacheflow.utils import random_uuid - - from cacheflow.entrypoints.openai.protocol import ( + LogProbs, CompletionRequest, CompletionResponse, CompletionResponseChoice, @@ -62,18 +63,58 @@ async def check_model(request) -> Optional[JSONResponse]: ) return ret + @app.get("/v1/models") async def show_available_models(): model_cards = [ModelCard(id=served_model, root=served_model, permission=[ModelPermission()])] return ModelList(data=model_cards) + +def create_logprobs(token_ids: List[int], + id_logprobs: List[Dict[int, float]], + initial_text_offset: int = 0) -> LogProbs: + logprobs = LogProbs() + last_token_len = 0 + for token_id, id_logprob in zip(token_ids, id_logprobs): + token = tokenizer.convert_ids_to_tokens(token_id) + logprobs.tokens.append(token) + logprobs.token_logprobs.append(id_logprob[token_id]) + if len(logprobs.text_offset) == 0: + logprobs.text_offset.append(initial_text_offset) + else: + logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len) + last_token_len = len(token) + + logprobs.top_logprobs.append( + {tokenizer.convert_ids_to_tokens(i): p + for i, p in id_logprob.items()}) + return logprobs + + @app.post("/v1/completions") async def create_completion(request: CompletionRequest): + logger.info(f"Received completion request: {request}") + error_check_ret = await check_model(request) if error_check_ret is not None: return error_check_ret + if request.echo: + # We do not support echo since the cacheflow server does not + # currently support getting the logprobs of prompt tokens. + return create_error_response(HTTPStatus.BAD_REQUEST, + "echo is not currently supported") + + if request.suffix is not None: + return create_error_response(HTTPStatus.BAD_REQUEST, + "suffix is not currently supported") + + if request.logit_bias is not None: + # TODO: support logit_bias in cacheflow server. + return create_error_response(HTTPStatus.BAD_REQUEST, + "logit_bias is not currently supported") + model_name = request.model request_id = f"cmpl-{random_uuid()}" prompt = request.prompt @@ -92,7 +133,6 @@ async def create_completion(request: CompletionRequest): max_tokens=request.max_tokens, logprobs=request.logprobs, use_beam_search=request.use_beam_search, - # TODO(zhuohan): support logit_bias ) except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) @@ -102,43 +142,59 @@ async def create_completion(request: CompletionRequest): # Similar to the OpenAI API, when n != best_of, we do not stream the # results. In addition, we do not stream the results when use beam search. - stream = (request.stream and request.n == request.best_of - and not request.use_beam_search) + stream = (request.stream and + (request.best_of is None or request.n == request.best_of) and + not request.use_beam_search) - def create_stream_response_json(index: int, text: str, - log_probs: Optional[Dict] = None, - finish_reason: Optional[str] = None): + def create_stream_response_json(index: int, + text: str, + logprobs: Optional[LogProbs] = None, + finish_reason: Optional[str] = None) -> str: choice_data = CompletionResponseStreamChoice( index=index, text=text, - logprobs=log_probs, + logprobs=logprobs, finish_reason=finish_reason, ) response = CompletionStreamResponse( id=request_id, created=created_time, - model = model_name, + model=model_name, choices=[choice_data], ) - response_json = response.json(exclude_unset=True, ensure_ascii=False) + response_json = response.json(ensure_ascii=False) + return response_json - async def generate_completion_stream_generator(): + async def generate_completion_stream_generator() -> AsyncGenerator[str, None]: previous_texts = [""] * request.n + previous_num_tokens = [0] * request.n async for res in result_generator: + res: RequestOutput for output in res.outputs: i = output.index delta_text = output.text[len(previous_texts[i]):] + if request.logprobs is not None: + logprobs = create_logprobs( + output.token_ids[previous_num_tokens[i]:], + output.logprobs[previous_num_tokens[i]:], + len(previous_texts[i])) + else: + logprobs = None previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) response_json = create_stream_response_json( index=i, text=delta_text, + logprobs=logprobs, ) yield f"data: {response_json}\n\n" if output.finish_reason is not None: + logprobs = LogProbs() if request.logprobs is not None else None response_json = create_stream_response_json( index=i, text="", + logprobs=logprobs, finish_reason=output.finish_reason, ) yield f"data: {response_json}\n\n" @@ -155,13 +211,18 @@ async def generate_completion_stream_generator(): assert final_res is not None choices = [] for output in final_res.outputs: + if request.logprobs is not None: + logprobs = create_logprobs(output.token_ids, output.logprobs) + else: + logprobs = None choice_data = CompletionResponseChoice( index=output.index, text=output.text, - logprobs=None, + logprobs=logprobs, finish_reason=output.finish_reason, ) choices.append(choice_data) + num_prompt_tokens = len(final_res.prompt_token_ids) num_generated_tokens = sum(len(output.token_ids) for output in final_res.outputs) @@ -217,6 +278,9 @@ async def generate_completion_stream_generator(): parallel_config = server_configs[2] distributed_init_method, stage_devices = initialize_cluster(parallel_config) + # A separate tokenizer to map token IDs to strings. + tokenizer = get_tokenizer(args.model) + server = AsyncLLMServer( args.use_ray, *server_configs, distributed_init_method, stage_devices) diff --git a/cacheflow/entrypoints/openai/protocol.py b/cacheflow/entrypoints/openai/protocol.py index d76407cdcf36..c493ef28c1b4 100644 --- a/cacheflow/entrypoints/openai/protocol.py +++ b/cacheflow/entrypoints/openai/protocol.py @@ -1,5 +1,5 @@ -from typing import Literal, Optional, List, Dict, Any, Union -from enum import IntEnum +# Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/protocol/openai_api_protocol.py +from typing import Literal, Optional, List, Dict, Union import time @@ -22,7 +22,7 @@ class ModelPermission(BaseModel): allow_create_engine: bool = False allow_sampling: bool = True allow_logprobs: bool = True - allow_search_indices: bool = True + allow_search_indices: bool = False allow_view: bool = True allow_fine_tuning: bool = False organization: str = "*" @@ -34,15 +34,15 @@ class ModelCard(BaseModel): id: str object: str = "model" created: int = Field(default_factory=lambda: int(time.time())) - owned_by: str = "fastchat" + owned_by: str = "cacheflow" root: Optional[str] = None parent: Optional[str] = None - permission: List[ModelPermission] = [] + permission: List[ModelPermission] = Field(default_factory=list) class ModelList(BaseModel): object: str = "list" - data: List[ModelCard] = [] + data: List[ModelCard] = Field(default_factory=list) class UsageInfo(BaseModel): @@ -116,7 +116,7 @@ class CompletionRequest(BaseModel): stream: Optional[bool] = False logprobs: Optional[int] = None echo: Optional[bool] = False - stop: Optional[Union[str, List[str]]] = [] + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 best_of: Optional[int] = None @@ -127,11 +127,18 @@ class CompletionRequest(BaseModel): use_beam_search: Optional[bool] = False +class LogProbs(BaseModel): + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[Optional[float]] = Field(default_factory=list) + tokens: List[str] = Field(default_factory=list) + top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) + + class CompletionResponseChoice(BaseModel): index: int text: str - logprobs: Optional[int] = None - finish_reason: Optional[Literal["stop", "length"]] + logprobs: Optional[LogProbs] = None + finish_reason: Optional[Literal["stop", "length"]] = None class CompletionResponse(BaseModel): @@ -146,7 +153,7 @@ class CompletionResponse(BaseModel): class CompletionResponseStreamChoice(BaseModel): index: int text: str - logprobs: Optional[float] = None + logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length"]] = None diff --git a/examples/openai_client.py b/examples/openai_client.py index e13a1df3c12f..385d773781da 100644 --- a/examples/openai_client.py +++ b/examples/openai_client.py @@ -3,15 +3,19 @@ openai.api_base = "http://localhost:8000/v1" model = "facebook/opt-125m" -# create a chat completion +# list models +models = openai.Model.list() +print(models) + +# create a completion stream = False completion = openai.Completion.create( model=model, prompt="A robot may not injure a human being", echo=False, n=2, - stream=stream) + stream=stream, logprobs=3) -# print the chat completion +# print the completion if stream: for c in completion: print(c) From aa9e83cd4b9edf6a439cd7ef99f87a352b78f15a Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 23 May 2023 05:34:50 +0000 Subject: [PATCH 11/23] Fix streaming corner case. --- .../entrypoints/openai/openai_frontend.py | 23 +++++++++++++++---- examples/openai_client.py | 4 ++-- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/cacheflow/entrypoints/openai/openai_frontend.py b/cacheflow/entrypoints/openai/openai_frontend.py index 8f61fa009f91..5ee931b0cf00 100644 --- a/cacheflow/entrypoints/openai/openai_frontend.py +++ b/cacheflow/entrypoints/openai/openai_frontend.py @@ -66,6 +66,7 @@ async def check_model(request) -> Optional[JSONResponse]: @app.get("/v1/models") async def show_available_models(): + """Show available models. Right now we only have one model.""" model_cards = [ModelCard(id=served_model, root=served_model, permission=[ModelPermission()])] return ModelList(data=model_cards) @@ -74,6 +75,7 @@ async def show_available_models(): def create_logprobs(token_ids: List[int], id_logprobs: List[Dict[int, float]], initial_text_offset: int = 0) -> LogProbs: + """Create OpenAI-style logprobs.""" logprobs = LogProbs() last_token_len = 0 for token_id, id_logprob in zip(token_ids, id_logprobs): @@ -107,6 +109,7 @@ async def create_completion(request: CompletionRequest): "echo is not currently supported") if request.suffix is not None: + # The language models we currently support do not support suffix. return create_error_response(HTTPStatus.BAD_REQUEST, "suffix is not currently supported") @@ -166,7 +169,7 @@ def create_stream_response_json(index: int, return response_json - async def generate_completion_stream_generator() -> AsyncGenerator[str, None]: + async def completion_stream_generator() -> AsyncGenerator[str, None]: previous_texts = [""] * request.n previous_num_tokens = [0] * request.n async for res in result_generator: @@ -200,12 +203,13 @@ async def generate_completion_stream_generator() -> AsyncGenerator[str, None]: yield f"data: {response_json}\n\n" yield "data: [DONE]\n\n" + # Streaming response if stream: - generator = generate_completion_stream_generator() - return StreamingResponse(generator, media_type="text/event-stream") + return StreamingResponse(completion_stream_generator(), + media_type="text/event-stream") # Non-streaming response - final_res = None + final_res: RequestOutput = None async for res in result_generator: final_res = res assert final_res is not None @@ -238,6 +242,17 @@ async def generate_completion_stream_generator() -> AsyncGenerator[str, None]: choices=choices, usage=usage, ) + + if request.stream: + # When user requests streaming but we don't stream, we still need to + # return a streaming response with a single event. + response_json = response.json(ensure_ascii=False) + async def fake_stream_generator() -> AsyncGenerator[str, None]: + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + return StreamingResponse(fake_stream_generator(), + media_type="text/event-stream") + return response diff --git a/examples/openai_client.py b/examples/openai_client.py index 385d773781da..28f91ef22f4a 100644 --- a/examples/openai_client.py +++ b/examples/openai_client.py @@ -9,11 +9,11 @@ # create a completion -stream = False +stream = True completion = openai.Completion.create( model=model, prompt="A robot may not injure a human being", echo=False, n=2, - stream=stream, logprobs=3) + best_of=3, stream=stream, logprobs=3) # print the completion if stream: From 8e14b2e91b6b73fa26571cf004f32e83e5290325 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 23 May 2023 05:54:07 +0000 Subject: [PATCH 12/23] Optimize file locations --- gradio_webserver.py => examples/gradio_webserver.py | 0 .../simple_fastapi_frontend_client.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename gradio_webserver.py => examples/gradio_webserver.py (100%) rename test_cli_client.py => examples/simple_fastapi_frontend_client.py (92%) diff --git a/gradio_webserver.py b/examples/gradio_webserver.py similarity index 100% rename from gradio_webserver.py rename to examples/gradio_webserver.py diff --git a/test_cli_client.py b/examples/simple_fastapi_frontend_client.py similarity index 92% rename from test_cli_client.py rename to examples/simple_fastapi_frontend_client.py index 217f8088645a..f43296f49f58 100644 --- a/test_cli_client.py +++ b/examples/simple_fastapi_frontend_client.py @@ -2,7 +2,7 @@ import json def http_request(): - prompt = "Ion Stoica is a" + prompt = "A robot may not injure a human being" headers = {"User-Agent": "Test Client"} pload = { From 788d070fa8053a3404510478ec91f5dd3cce5382 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 23 May 2023 19:46:37 +0000 Subject: [PATCH 13/23] Fix some review comments --- .../entrypoints/simple_fastapi_frontend.py | 8 ++-- examples/simple_fastapi_client.py | 48 +++++++++++++++++++ examples/simple_fastapi_frontend_client.py | 23 --------- 3 files changed, 53 insertions(+), 26 deletions(-) create mode 100644 examples/simple_fastapi_client.py delete mode 100644 examples/simple_fastapi_frontend_client.py diff --git a/cacheflow/entrypoints/simple_fastapi_frontend.py b/cacheflow/entrypoints/simple_fastapi_frontend.py index f2ea9845d205..e7e1357f3849 100644 --- a/cacheflow/entrypoints/simple_fastapi_frontend.py +++ b/cacheflow/entrypoints/simple_fastapi_frontend.py @@ -1,5 +1,7 @@ import argparse import json +from typing import AsyncGenerator + from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse import uvicorn @@ -14,13 +16,13 @@ @app.post("/generate") -async def generate_stream(request: Request): +async def generate_stream(request: Request) -> StreamingResponse: request_dict = await request.json() prompt = request_dict.pop("prompt") sampling_params = SamplingParams(**request_dict) results_generator = server.generate(prompt, sampling_params) - async def stream_results(): + async def stream_results() -> AsyncGenerator[bytes, None]: async for request_output in results_generator: prompt = request_output.prompt text_outputs = [ @@ -39,7 +41,7 @@ async def stream_results(): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") - parser.add_argument("--port", type=int, default=10002) + parser.add_argument("--port", type=int, default=8001) parser = ServerArgs.add_cli_args(parser) args = parser.parse_args() diff --git a/examples/simple_fastapi_client.py b/examples/simple_fastapi_client.py new file mode 100644 index 000000000000..e9bc3030ef3d --- /dev/null +++ b/examples/simple_fastapi_client.py @@ -0,0 +1,48 @@ +import argparse +import requests +import json + + +def clear_line(n=1): + LINE_UP = '\033[1A' + LINE_CLEAR = '\x1b[2K' + for i in range(n): + print(LINE_UP, end=LINE_CLEAR, flush=True) + + +def http_request(prompt: str, api_url: str, n: int = 1): + headers = {"User-Agent": "Test Client"} + pload = { + "prompt": prompt, + "n": n, + "use_beam_search": True, + "temperature": 0.0, + } + response = requests.post(api_url, headers=headers, json=pload, stream=True) + + for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode("utf-8")) + output = data["text"] + yield output + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8001) + parser.add_argument("--n", type=int, default=4) + parser.add_argument("--prompt", type=str, default="A robot may not injure a human being") + args = parser.parse_args() + prompt = args.prompt + api_url = f"http://{args.host}:{args.port}/generate" + n = args.n + + print(f"Prompt: {prompt}\n", flush=True) + num_printed_lines = 0 + for h in http_request(prompt, api_url, n): + clear_line(num_printed_lines) + num_printed_lines = 0 + for i, line in enumerate(h): + num_printed_lines += line.count("\n") + 1 + print(f"Beam candidate {i:2}: {line}", flush=True) \ No newline at end of file diff --git a/examples/simple_fastapi_frontend_client.py b/examples/simple_fastapi_frontend_client.py deleted file mode 100644 index f43296f49f58..000000000000 --- a/examples/simple_fastapi_frontend_client.py +++ /dev/null @@ -1,23 +0,0 @@ -import requests -import json - -def http_request(): - prompt = "A robot may not injure a human being" - - headers = {"User-Agent": "Test Client"} - pload = { - "prompt": prompt, - "n": 4, - "use_beam_search": True, - "temperature": 0.0, - } - response = requests.post("http://localhost:10002/generate", headers=headers, json=pload, stream=True) - - for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): - if chunk: - data = json.loads(chunk.decode("utf-8")) - output = data["text"] - yield output - -for h in http_request(): - print(h, flush=True) From 6cc6118edc055f5bc4d054efd9f386fc2e4311a8 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 23 May 2023 20:01:22 +0000 Subject: [PATCH 14/23] Fix client --- examples/simple_fastapi_client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/simple_fastapi_client.py b/examples/simple_fastapi_client.py index e9bc3030ef3d..d7d9d355a6f2 100644 --- a/examples/simple_fastapi_client.py +++ b/examples/simple_fastapi_client.py @@ -2,7 +2,6 @@ import requests import json - def clear_line(n=1): LINE_UP = '\033[1A' LINE_CLEAR = '\x1b[2K' @@ -17,6 +16,7 @@ def http_request(prompt: str, api_url: str, n: int = 1): "n": n, "use_beam_search": True, "temperature": 0.0, + "max_tokens": 16, } response = requests.post(api_url, headers=headers, json=pload, stream=True) @@ -32,7 +32,7 @@ def http_request(prompt: str, api_url: str, n: int = 1): parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=8001) parser.add_argument("--n", type=int, default=4) - parser.add_argument("--prompt", type=str, default="A robot may not injure a human being") + parser.add_argument("--prompt", type=str, default="San Francisco is a") args = parser.parse_args() prompt = args.prompt api_url = f"http://{args.host}:{args.port}/generate" @@ -44,5 +44,5 @@ def http_request(prompt: str, api_url: str, n: int = 1): clear_line(num_printed_lines) num_printed_lines = 0 for i, line in enumerate(h): - num_printed_lines += line.count("\n") + 1 - print(f"Beam candidate {i:2}: {line}", flush=True) \ No newline at end of file + num_printed_lines += 1 + print(f"Beam candidate {i}: {line}", flush=True) From 205b7edaaa19b9b73a24a31a9bd981c59e61912d Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 23 May 2023 20:25:53 +0000 Subject: [PATCH 15/23] Fix review comments --- cacheflow/core/scheduler.py | 3 ++- cacheflow/outputs.py | 14 ++++++-------- cacheflow/sequence.py | 11 ++++++++++- cacheflow/server/llm_server.py | 12 ++++++------ cacheflow/utils.py | 2 +- examples/openai_client.py | 1 - 6 files changed, 25 insertions(+), 18 deletions(-) diff --git a/cacheflow/core/scheduler.py b/cacheflow/core/scheduler.py index 496aaa842c4d..c4526cf96337 100644 --- a/cacheflow/core/scheduler.py +++ b/cacheflow/core/scheduler.py @@ -296,7 +296,8 @@ def update( # from being modified by the caller. return self.running.copy() - def free_seq(self, seq: Sequence) -> None: + def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None: + seq.status = finish_status self.block_manager.free(seq) def free_finished_seq_groups(self) -> None: diff --git a/cacheflow/outputs.py b/cacheflow/outputs.py index ba8c783c0b66..50ee9bbb8dcf 100644 --- a/cacheflow/outputs.py +++ b/cacheflow/outputs.py @@ -21,6 +21,9 @@ def __init__( self.logprobs = logprobs self.finish_reason = finish_reason + def finished(self) -> bool: + return self.finish_reason is not None + def __repr__(self) -> str: return (f"CompletionOutput(index={self.index}, " f"text={self.text!r}, " @@ -63,16 +66,11 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": # always has the logprobs of the sampled tokens even if the # logprobs are not requested. logprobs = {} - if seq.status == SequenceStatus.FINISHED_STOPPED: - finish_reason = "stop" - elif seq.status == SequenceStatus.FINISHED_LENGTH_CAPPED: - finish_reason = "length" - else: - finish_reason = None + finshed_reason = SequenceStatus.get_finished_reason(seq.status) output = CompletionOutput(seqs.index(seq), seq.output_text, seq.get_output_token_ids(), seq.get_cumulative_logprob(), logprobs, - finish_reason) + finshed_reason) outputs.append(output) # Every sequence in the sequence group should have the same prompt. @@ -87,4 +85,4 @@ def __repr__(self) -> str: f"outputs={self.outputs})") def finished(self) -> bool: - return all(output.finish_reason is not None for output in self.outputs) \ No newline at end of file + return all(output.finished() for output in self.outputs) diff --git a/cacheflow/sequence.py b/cacheflow/sequence.py index 8a467c56b2c2..3eb52285cb7e 100644 --- a/cacheflow/sequence.py +++ b/cacheflow/sequence.py @@ -1,6 +1,6 @@ import copy import enum -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from cacheflow.block import LogicalTokenBlock from cacheflow.sampling_params import SamplingParams @@ -20,6 +20,15 @@ def is_finished(status: "SequenceStatus") -> bool: SequenceStatus.FINISHED_LENGTH_CAPPED, ] + @staticmethod + def get_finished_reason(status: "SequenceStatus") -> Union[str, None]: + if status == SequenceStatus.FINISHED_STOPPED: + finish_reason = "stop" + elif status == SequenceStatus.FINISHED_LENGTH_CAPPED: + finish_reason = "length" + else: + finish_reason = None + return finish_reason class SequenceData: diff --git a/cacheflow/server/llm_server.py b/cacheflow/server/llm_server.py index d5d81e2c248e..cd1e5e93f2af 100644 --- a/cacheflow/server/llm_server.py +++ b/cacheflow/server/llm_server.py @@ -209,8 +209,8 @@ def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None: # Truncate the output text so that the stop string is # not included in the output. seq.output_text = seq.output_text[:-len(stop_str)] - seq.status = SequenceStatus.FINISHED_STOPPED - self.scheduler.free_seq(seq) + self.scheduler.free_seq(seq, + SequenceStatus.FINISHED_STOPPED) stopped = True break if stopped: @@ -218,14 +218,14 @@ def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None: # Check if the sequence has reached max_tokens. if seq.get_output_len() == sampling_params.max_tokens: - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - self.scheduler.free_seq(seq) + self.scheduler.free_seq( + seq, SequenceStatus.FINISHED_LENGTH_CAPPED) continue # Check if the sequence has generated the EOS token. if not sampling_params.ignore_eos: if seq.get_last_token_id() == self.tokenizer.eos_token_id: - seq.status = SequenceStatus.FINISHED_STOPPED - self.scheduler.free_seq(seq) + self.scheduler.free_seq(seq, + SequenceStatus.FINISHED_STOPPED) continue def _run_workers( diff --git a/cacheflow/utils.py b/cacheflow/utils.py index a28bacfe4cf2..59ce9dfa6010 100644 --- a/cacheflow/utils.py +++ b/cacheflow/utils.py @@ -1,7 +1,7 @@ import enum -import psutil import uuid +import psutil import torch diff --git a/examples/openai_client.py b/examples/openai_client.py index 28f91ef22f4a..9e711a8a0899 100644 --- a/examples/openai_client.py +++ b/examples/openai_client.py @@ -10,7 +10,6 @@ # create a completion stream = True - completion = openai.Completion.create( model=model, prompt="A robot may not injure a human being", echo=False, n=2, best_of=3, stream=stream, logprobs=3) From 489e55e77115a4e0484268400511c917de4c876a Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 23 May 2023 20:26:29 +0000 Subject: [PATCH 16/23] Fix --- cacheflow/server/async_llm_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cacheflow/server/async_llm_server.py b/cacheflow/server/async_llm_server.py index aefef326d979..8755b023b897 100644 --- a/cacheflow/server/async_llm_server.py +++ b/cacheflow/server/async_llm_server.py @@ -1,14 +1,14 @@ import asyncio import time -from typing import Optional, Dict +from typing import Dict, Optional import ray from cacheflow.outputs import RequestOutput from cacheflow.sampling_params import SamplingParams from cacheflow.server.arg_utils import ServerArgs -from cacheflow.server.ray_utils import initialize_cluster from cacheflow.server.llm_server import LLMServer +from cacheflow.server.ray_utils import initialize_cluster from cacheflow.utils import random_uuid TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds From ee59d787f1897a906bae212f62897c6f2b9dbced Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 23 May 2023 23:16:53 +0000 Subject: [PATCH 17/23] Fix other examples. --- cacheflow/entrypoints/openai/openai_frontend.py | 1 - examples/gradio_webserver.py | 11 +++++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/cacheflow/entrypoints/openai/openai_frontend.py b/cacheflow/entrypoints/openai/openai_frontend.py index 14aefcfa4898..210778a86f20 100644 --- a/cacheflow/entrypoints/openai/openai_frontend.py +++ b/cacheflow/entrypoints/openai/openai_frontend.py @@ -15,7 +15,6 @@ from cacheflow.outputs import RequestOutput from cacheflow.server.arg_utils import ServerArgs from cacheflow.server.async_llm_server import AsyncLLMServer -from cacheflow.server.ray_utils import initialize_cluster from cacheflow.server.tokenizer_utils import get_tokenizer from cacheflow.logger import init_logger from cacheflow.sampling_params import SamplingParams diff --git a/examples/gradio_webserver.py b/examples/gradio_webserver.py index d819ecab0c62..2f768c2714ae 100644 --- a/examples/gradio_webserver.py +++ b/examples/gradio_webserver.py @@ -1,6 +1,5 @@ import argparse import json -import time import gradio as gr import requests @@ -24,9 +23,9 @@ def http_bot(prompt): def build_demo(): with gr.Blocks() as demo: gr.Markdown( - "# Cacheflow demo\n" + "# Cacheflow text completion demo\n" ) - inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")# .style(container=False) + inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER") outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model") inputbox.submit(http_bot, [inputbox], [outputbox]) return demo @@ -35,9 +34,9 @@ def build_demo(): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") - parser.add_argument("--port", type=int, default=10003) - parser.add_argument("--model-url", type=str, default="http://localhost:10002/generate") + parser.add_argument("--port", type=int, default=8002) + parser.add_argument("--model-url", type=str, default="http://localhost:8001/generate") args = parser.parse_args() demo = build_demo() - demo.queue(concurrency_count=100).launch(server_name=args.host, server_port=args.port) \ No newline at end of file + demo.queue(concurrency_count=100).launch(server_name=args.host, server_port=args.port, share=True) \ No newline at end of file From dd03d97a0a70b848bd0e96f352e1db991858faba Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 23 May 2023 23:43:36 +0000 Subject: [PATCH 18/23] Remove currently unused chat completion protocols --- cacheflow/entrypoints/openai/protocol.py | 39 ------------------------ 1 file changed, 39 deletions(-) diff --git a/cacheflow/entrypoints/openai/protocol.py b/cacheflow/entrypoints/openai/protocol.py index c493ef28c1b4..f765e09ef7fd 100644 --- a/cacheflow/entrypoints/openai/protocol.py +++ b/cacheflow/entrypoints/openai/protocol.py @@ -65,45 +65,6 @@ class ChatCompletionRequest(BaseModel): user: Optional[str] = None -class ChatMessage(BaseModel): - role: str - content: str - - -class ChatCompletionResponseChoice(BaseModel): - index: int - message: ChatMessage - finish_reason: Optional[Literal["stop", "length"]] - - -class ChatCompletionResponse(BaseModel): - id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") - object: str = "chat.completion" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - choices: List[ChatCompletionResponseChoice] - usage: UsageInfo - - -class DeltaMessage(BaseModel): - role: Optional[str] = None - content: Optional[str] = None - - -class ChatCompletionResponseStreamChoice(BaseModel): - index: int - delta: DeltaMessage - finish_reason: Optional[Literal["stop", "length"]] - - -class ChatCompletionStreamResponse(BaseModel): - id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") - object: str = "chat.completion.chunk" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - choices: List[ChatCompletionResponseStreamChoice] - - class CompletionRequest(BaseModel): model: str prompt: str From 2cca8266538f2bbc5959cc7e83a0df53859f6230 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 23 May 2023 23:54:36 +0000 Subject: [PATCH 19/23] add served_model_name --- cacheflow/entrypoints/openai/openai_frontend.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cacheflow/entrypoints/openai/openai_frontend.py b/cacheflow/entrypoints/openai/openai_frontend.py index 210778a86f20..b90e0d98276d 100644 --- a/cacheflow/entrypoints/openai/openai_frontend.py +++ b/cacheflow/entrypoints/openai/openai_frontend.py @@ -272,6 +272,10 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: parser.add_argument( "--allowed-headers", type=json.loads, default=["*"], help="allowed headers" ) + parser.add_argument("--served-model-name", type=str, default=None, + help="The model name used in the API. If not specified, " + "the model name will be the same as the " + "huggingface name.") parser = ServerArgs.add_cli_args(parser) args = parser.parse_args() @@ -285,7 +289,7 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: logger.info(f"args: {args}") - served_model = args.model + served_model = args.served_model_name or args.model server_args = ServerArgs.from_cli_args(args) server = AsyncLLMServer.from_server_args(server_args) From 9fd49e5926608f7b5b7d437c3b86ba9d84a96edb Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Wed, 24 May 2023 03:04:59 +0000 Subject: [PATCH 20/23] Fix some review comments --- benchmark/benchmark_async_llm_server.py | 58 +++++++++++++++++++ .../entrypoints/openai/openai_frontend.py | 6 +- cacheflow/entrypoints/openai/protocol.py | 2 +- 3 files changed, 62 insertions(+), 4 deletions(-) create mode 100644 benchmark/benchmark_async_llm_server.py diff --git a/benchmark/benchmark_async_llm_server.py b/benchmark/benchmark_async_llm_server.py new file mode 100644 index 000000000000..a950c16e705b --- /dev/null +++ b/benchmark/benchmark_async_llm_server.py @@ -0,0 +1,58 @@ +import argparse +import json + +import requests +import threading +import time + +def main(args: argparse.Namespace): + prompt = [f"Tell me a story with more than {''.join([str(i+1)] * 5)} words" + for i in range(args.n_thread)] + + headers = {"User-Agent": "CacheFlow Benchmark Client"} + ploads = [{ + "prompt": prompt[i], + "max_new_tokens": args.max_new_tokens, + "temperature": 0.0, + "ignore_eos": True, + } for i in range(len(prompt))] + + def send_request(results, i): + response = requests.post(args.api_url, headers=headers, + json=ploads[i], stream=True) + results[i] = response + + # use args.n_threads to prompt the backend + tik = time.time() + threads = [] + results = [None] * args.n_thread + for i in range(args.n_thread): + t = threading.Thread(target=send_request, args=(results, i)) + t.start() + threads.append(t) + + for t in threads: + t.join() + + print(f"Time (POST): {time.time() - tik} s") + n_words = 0 + + # if streaming: + for i, response in enumerate(results): + k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0")) + response_new_words = json.loads(k[-2].decode("utf-8"))["text"] + n_words += len(response_new_words.split(" ")) - len(prompt[i].split(" ")) + + time_seconds = time.time() - tik + print(f"Time (total): {time_seconds} to finish, n threads: {args.n_thread}, " + f"throughput: {n_words / time_seconds} words/s.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--api-url", type=str, default="http://localhost:8001") + parser.add_argument("--max-new-tokens", type=int, default=2048) + parser.add_argument("--n-thread", type=int, default=2) + args = parser.parse_args() + + main(args) \ No newline at end of file diff --git a/cacheflow/entrypoints/openai/openai_frontend.py b/cacheflow/entrypoints/openai/openai_frontend.py index b90e0d98276d..1218a5ae57c7 100644 --- a/cacheflow/entrypoints/openai/openai_frontend.py +++ b/cacheflow/entrypoints/openai/openai_frontend.py @@ -1,10 +1,10 @@ -# Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py +# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py import argparse +from http import HTTPStatus import json import time -from typing import Optional, Dict, List, AsyncGenerator -from http import HTTPStatus +from typing import AsyncGenerator, Dict, List, Optional import fastapi from fastapi.exceptions import RequestValidationError diff --git a/cacheflow/entrypoints/openai/protocol.py b/cacheflow/entrypoints/openai/protocol.py index f765e09ef7fd..8035a01d16fa 100644 --- a/cacheflow/entrypoints/openai/protocol.py +++ b/cacheflow/entrypoints/openai/protocol.py @@ -1,4 +1,4 @@ -# Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/protocol/openai_api_protocol.py +# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py from typing import Literal, Optional, List, Dict, Union import time From 02f46cd9117aa8476be4f3ce3c7dd4e8c13ade1d Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Wed, 24 May 2023 03:39:07 +0000 Subject: [PATCH 21/23] Use number based request ids --- examples/simple_server.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/simple_server.py b/examples/simple_server.py index e9be1cf6b0d6..8d5fcaf8e472 100644 --- a/examples/simple_server.py +++ b/examples/simple_server.py @@ -1,7 +1,6 @@ import argparse from cacheflow import ServerArgs, LLMServer, SamplingParams -from cacheflow.utils import random_uuid def main(args: argparse.Namespace): @@ -20,13 +19,15 @@ def main(args: argparse.Namespace): SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)), ] + request_id = 0 + # Run the server. while True: # To test iteration-level scheduling, we add one request at each step. if test_prompts: prompt, sampling_params = test_prompts.pop(0) - request_id = random_uuid() - server.add_request(request_id, prompt, sampling_params) + server.add_request(str(request_id), prompt, sampling_params) + request_id += 1 request_outputs = server.step() for request_output in request_outputs: From 7d9a9c6af596efb47258f9c8f59ab2a52daf4838 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 23 May 2023 20:40:29 -0700 Subject: [PATCH 22/23] Delete benchmark_async_llm_server.py --- benchmark/benchmark_async_llm_server.py | 58 ------------------------- 1 file changed, 58 deletions(-) delete mode 100644 benchmark/benchmark_async_llm_server.py diff --git a/benchmark/benchmark_async_llm_server.py b/benchmark/benchmark_async_llm_server.py deleted file mode 100644 index a950c16e705b..000000000000 --- a/benchmark/benchmark_async_llm_server.py +++ /dev/null @@ -1,58 +0,0 @@ -import argparse -import json - -import requests -import threading -import time - -def main(args: argparse.Namespace): - prompt = [f"Tell me a story with more than {''.join([str(i+1)] * 5)} words" - for i in range(args.n_thread)] - - headers = {"User-Agent": "CacheFlow Benchmark Client"} - ploads = [{ - "prompt": prompt[i], - "max_new_tokens": args.max_new_tokens, - "temperature": 0.0, - "ignore_eos": True, - } for i in range(len(prompt))] - - def send_request(results, i): - response = requests.post(args.api_url, headers=headers, - json=ploads[i], stream=True) - results[i] = response - - # use args.n_threads to prompt the backend - tik = time.time() - threads = [] - results = [None] * args.n_thread - for i in range(args.n_thread): - t = threading.Thread(target=send_request, args=(results, i)) - t.start() - threads.append(t) - - for t in threads: - t.join() - - print(f"Time (POST): {time.time() - tik} s") - n_words = 0 - - # if streaming: - for i, response in enumerate(results): - k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0")) - response_new_words = json.loads(k[-2].decode("utf-8"))["text"] - n_words += len(response_new_words.split(" ")) - len(prompt[i].split(" ")) - - time_seconds = time.time() - tik - print(f"Time (total): {time_seconds} to finish, n threads: {args.n_thread}, " - f"throughput: {n_words / time_seconds} words/s.") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--api-url", type=str, default="http://localhost:8001") - parser.add_argument("--max-new-tokens", type=int, default=2048) - parser.add_argument("--n-thread", type=int, default=2) - args = parser.parse_args() - - main(args) \ No newline at end of file From 83df8a0de42f5c6228dcd1efce385f715d9b79df Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Wed, 24 May 2023 04:38:59 +0000 Subject: [PATCH 23/23] Address review comments --- cacheflow/entrypoints/openai/openai_frontend.py | 2 +- cacheflow/entrypoints/openai/protocol.py | 6 +++--- examples/gradio_webserver.py | 4 +++- requirements.txt | 1 + 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/cacheflow/entrypoints/openai/openai_frontend.py b/cacheflow/entrypoints/openai/openai_frontend.py index 1218a5ae57c7..4d32390bade1 100644 --- a/cacheflow/entrypoints/openai/openai_frontend.py +++ b/cacheflow/entrypoints/openai/openai_frontend.py @@ -20,13 +20,13 @@ from cacheflow.sampling_params import SamplingParams from cacheflow.utils import random_uuid from cacheflow.entrypoints.openai.protocol import ( - LogProbs, CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse, + LogProbs, ModelCard, ModelList, ModelPermission, diff --git a/cacheflow/entrypoints/openai/protocol.py b/cacheflow/entrypoints/openai/protocol.py index 8035a01d16fa..61ad60c24bb1 100644 --- a/cacheflow/entrypoints/openai/protocol.py +++ b/cacheflow/entrypoints/openai/protocol.py @@ -1,9 +1,9 @@ # Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py -from typing import Literal, Optional, List, Dict, Union - import time +from typing import Dict, List, Literal, Optional, Union from pydantic import BaseModel, Field + from cacheflow.utils import random_uuid @@ -72,7 +72,6 @@ class CompletionRequest(BaseModel): max_tokens: Optional[int] = 16 temperature: Optional[float] = 1.0 top_p: Optional[float] = 1.0 - top_k: Optional[int] = -1 n: Optional[int] = 1 stream: Optional[bool] = False logprobs: Optional[int] = None @@ -84,6 +83,7 @@ class CompletionRequest(BaseModel): logit_bias: Optional[Dict[str, float]] = None user: Optional[str] = None # Additional parameters supported by cacheflow + top_k: Optional[int] = -1 ignore_eos: Optional[bool] = False use_beam_search: Optional[bool] = False diff --git a/examples/gradio_webserver.py b/examples/gradio_webserver.py index 2f768c2714ae..e4a80c39ac4c 100644 --- a/examples/gradio_webserver.py +++ b/examples/gradio_webserver.py @@ -39,4 +39,6 @@ def build_demo(): args = parser.parse_args() demo = build_demo() - demo.queue(concurrency_count=100).launch(server_name=args.host, server_port=args.port, share=True) \ No newline at end of file + demo.queue(concurrency_count=100).launch(server_name=args.host, + server_port=args.port, + share=True) diff --git a/requirements.txt b/requirements.txt index bcb79da5213a..e84873eddd21 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ transformers >= 4.28.0 # Required for LLaMA. xformers >= 0.0.19 fastapi uvicorn +pydantic # Required for OpenAI server.