Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
263c286
Separate AsyncLLMServer
zhuohan123 May 20, 2023
18f8097
rename fastapi frontend
zhuohan123 May 20, 2023
fb73a1b
small fix
zhuohan123 May 20, 2023
b990239
[WIP] add WIP openai frontend
zhuohan123 May 20, 2023
00c84e2
fix async_llm_server
zhuohan123 May 20, 2023
1c71b88
Basic support for OpenAI Completion API
zhuohan123 May 21, 2023
5415281
Merge branch 'main' into openai-server
zhuohan123 May 21, 2023
63c5d3c
Implement finsh_reason
zhuohan123 May 21, 2023
8321b47
support bestof and stop
zhuohan123 May 21, 2023
5c82790
Support non-streaming requests
zhuohan123 May 22, 2023
0e12ecb
Support logprobs
zhuohan123 May 23, 2023
aa9e83c
Fix streaming corner case.
zhuohan123 May 23, 2023
abb1bf1
Merge branch 'main' into openai-server
zhuohan123 May 23, 2023
8e14b2e
Optimize file locations
zhuohan123 May 23, 2023
788d070
Fix some review comments
zhuohan123 May 23, 2023
6cc6118
Fix client
zhuohan123 May 23, 2023
205b7ed
Fix review comments
zhuohan123 May 23, 2023
489e55e
Fix
zhuohan123 May 23, 2023
ee59d78
Fix other examples.
zhuohan123 May 23, 2023
dd03d97
Remove currently unused chat completion protocols
zhuohan123 May 23, 2023
2cca826
add served_model_name
zhuohan123 May 23, 2023
9fd49e5
Fix some review comments
zhuohan123 May 24, 2023
02f46cd
Use number based request ids
zhuohan123 May 24, 2023
ad1db5e
better benchmark scripts
zhuohan123 May 24, 2023
9609d5c
Merge branch 'main' into async-server-performance
zhuohan123 May 24, 2023
ddfd2e1
fix benchmark script arguments
zhuohan123 May 24, 2023
b935221
Select async server mode
zhuohan123 Jun 2, 2023
7b95efa
Support aborting requests
zhuohan123 Jun 2, 2023
c1296ac
Support aborting request for openai frontend
zhuohan123 Jun 3, 2023
6c76625
Merge branch 'main' into async-server-performance
zhuohan123 Jun 3, 2023
075ed4e
Small fixes
zhuohan123 Jun 3, 2023
8808179
Fix review comments
zhuohan123 Jun 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions benchmarks/benchmark_async_llm_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import argparse
import json
import threading
import time

import requests


def main(args: argparse.Namespace):
prompts = [f"Tell me a story with more than {''.join([str(i+1)] * 5)} words"
for i in range(args.n_threads)]

headers = {"User-Agent": "CacheFlow Benchmark Client"}
ploads = [{
"prompt": p,
"max_tokens": args.max_tokens,
"temperature": 0.0,
"ignore_eos": True,
} for p in prompts]

def send_request(results, i):
response = requests.post(args.api_url, headers=headers,
json=ploads[i], stream=True)
results[i] = response

# use args.n_threads to prompt the backend
tik = time.time()
threads = []
results = [None] * args.n_threads
for i in range(args.n_threads):
t = threading.Thread(target=send_request, args=(results, i))
t.start()
threads.append(t)

for t in threads:
t.join()

print(f"Time (POST): {time.time() - tik} s")
n_words = 0

for i, response in enumerate(results):
k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"))
response_new_words = json.loads(k[-2].decode("utf-8"))["text"][0]
n_words += len(response_new_words.split(" ")) - len(prompts[i].split(" "))

time_seconds = time.time() - tik
print(f"Time (total): {time_seconds:.3f}s to finish, n_threads: {args.n_threads}, "
f"throughput: {n_words / time_seconds} words/s.")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--api-url", type=str, default="http://localhost:8001/generate")
parser.add_argument("--max-tokens", type=int, default=128)
parser.add_argument("--n-threads", type=int, default=128)
args = parser.parse_args()

main(args)
6 changes: 3 additions & 3 deletions cacheflow/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,15 @@ def __init__(
self,
pipeline_parallel_size: int,
tensor_parallel_size: int,
use_ray: bool,
worker_use_ray: bool,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
self.use_ray = use_ray
self.worker_use_ray = worker_use_ray

self.world_size = pipeline_parallel_size * tensor_parallel_size
if self.world_size > 1:
self.use_ray = True
self.worker_use_ray = True
self._verify_args()

def _verify_args(self) -> None:
Expand Down
9 changes: 6 additions & 3 deletions cacheflow/core/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _get_physical_blocks(self, seq_group: SequenceGroup) -> List[PhysicalTokenBl
# the sequences in the same group.
blocks: Set[PhysicalTokenBlock] = set()
for seq in seq_group.get_seqs():
if SequenceStatus.is_finished(seq.status):
if seq.is_finished():
continue
block_table = self.block_tables[seq.seq_id]
for block in block_table:
Expand All @@ -169,7 +169,7 @@ def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
# CPU block -> GPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs():
if SequenceStatus.is_finished(seq.status):
if seq.is_finished():
continue
new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id]
Expand Down Expand Up @@ -200,7 +200,7 @@ def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
# GPU block -> CPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs():
if SequenceStatus.is_finished(seq.status):
if seq.is_finished():
continue
new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id]
Expand Down Expand Up @@ -231,6 +231,9 @@ def _free_block_table(self, block_table: BlockTable) -> None:
self.cpu_allocator.free(block)

def free(self, seq: Sequence) -> None:
if seq.seq_id not in self.block_tables:
# Already freed or haven't been scheduled yet.
return
block_table = self.block_tables[seq.seq_id]
self._free_block_table(block_table)
del self.block_tables[seq.seq_id]
Expand Down
14 changes: 13 additions & 1 deletion cacheflow/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

logger = init_logger(__name__)

_LOGGING_INTERVAL_SEC = 10
_LOGGING_INTERVAL_SEC = 5


class PreemptionMode(enum.Enum):
Expand Down Expand Up @@ -84,6 +84,18 @@ def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue.
self.waiting.append(seq_group)

def abort_seq_group(self, request_id: str) -> None:
for state_queue in [self.waiting, self.running, self.swapped]:
for seq_group in state_queue:
if seq_group.request_id == request_id:
# Remove the sequence group from the state queue.
state_queue.remove(seq_group)
for seq in seq_group.seqs:
if seq.is_finished():
continue
self.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
return

def has_unfinished_seqs(self) -> bool:
return self.waiting or self.running or self.swapped

Expand Down
30 changes: 23 additions & 7 deletions cacheflow/entrypoints/openai/openai_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from typing import AsyncGenerator, Dict, List, Optional

import fastapi
from fastapi import BackgroundTasks, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn

from cacheflow.outputs import RequestOutput
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.async_llm_server import AsyncLLMServer
from cacheflow.server.tokenizer_utils import get_tokenizer
from cacheflow.logger import init_logger
Expand All @@ -33,6 +34,7 @@
UsageInfo,
)

TIMEOUT_KEEP_ALIVE = 5 # seconds

logger = init_logger(__name__)
served_model = None
Expand Down Expand Up @@ -93,7 +95,8 @@ def create_logprobs(token_ids: List[int],


@app.post("/v1/completions")
async def create_completion(request: CompletionRequest):
async def create_completion(raw_request: Request):
request = CompletionRequest(**await raw_request.json())
logger.info(f"Received completion request: {request}")

error_check_ret = await check_model(request)
Expand Down Expand Up @@ -139,14 +142,17 @@ async def create_completion(request: CompletionRequest):
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))

result_generator = server.generate(prompt, sampling_params,
request_id=request_id)
request_id)

# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
stream = (request.stream and
(request.best_of is None or request.n == request.best_of) and
not request.use_beam_search)

async def abort_request() -> None:
await server.abort(request_id)

def create_stream_response_json(index: int,
text: str,
logprobs: Optional[LogProbs] = None,
Expand Down Expand Up @@ -203,12 +209,21 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:

# Streaming response
if stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream")
media_type="text/event-stream",
background=background_tasks)

# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await server.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
assert final_res is not None
choices = []
Expand Down Expand Up @@ -276,7 +291,7 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
help="The model name used in the API. If not specified, "
"the model name will be the same as the "
"huggingface name.")
parser = ServerArgs.add_cli_args(parser)
parser = AsyncServerArgs.add_cli_args(parser)
args = parser.parse_args()

app.add_middleware(
Expand All @@ -291,10 +306,11 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:

served_model = args.served_model_name or args.model

server_args = ServerArgs.from_cli_args(args)
server_args = AsyncServerArgs.from_cli_args(args)
server = AsyncLLMServer.from_server_args(server_args)

# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(args.model)

uvicorn.run(app, host=args.host, port=args.port, log_level="info")
uvicorn.run(app, host=args.host, port=args.port, log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
25 changes: 17 additions & 8 deletions cacheflow/entrypoints/simple_fastapi_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
import json
from typing import AsyncGenerator

from fastapi import FastAPI, Request
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import StreamingResponse
import uvicorn

from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.async_llm_server import AsyncLLMServer
from cacheflow.server.ray_utils import initialize_cluster
from cacheflow.utils import random_uuid

TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
app = FastAPI()

Expand All @@ -20,7 +21,8 @@ async def generate_stream(request: Request) -> StreamingResponse:
request_dict = await request.json()
prompt = request_dict.pop("prompt")
sampling_params = SamplingParams(**request_dict)
results_generator = server.generate(prompt, sampling_params)
request_id = random_uuid()
results_generator = server.generate(prompt, sampling_params, request_id)

async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
Expand All @@ -35,17 +37,24 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
}
yield (json.dumps(ret) + "\0").encode("utf-8")

return StreamingResponse(stream_results())
async def abort_request() -> None:
await server.abort(request_id)

background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(stream_results(), background=background_tasks)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001)
parser = ServerArgs.add_cli_args(parser)
parser = AsyncServerArgs.add_cli_args(parser)
args = parser.parse_args()

server_args = ServerArgs.from_cli_args(args)
server_args = AsyncServerArgs.from_cli_args(args)
server = AsyncLLMServer.from_server_args(server_args)

uvicorn.run(app, host=args.host, port=args.port, log_level="info")
uvicorn.run(app, host=args.host, port=args.port, log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
10 changes: 9 additions & 1 deletion cacheflow/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ class SequenceStatus(enum.Enum):
SWAPPED = enum.auto()
FINISHED_STOPPED = enum.auto()
FINISHED_LENGTH_CAPPED = enum.auto()
FINISHED_ABORTED = enum.auto()

@staticmethod
def is_finished(status: "SequenceStatus") -> bool:
return status in [
SequenceStatus.FINISHED_STOPPED,
SequenceStatus.FINISHED_LENGTH_CAPPED,
SequenceStatus.FINISHED_ABORTED,
]

@staticmethod
Expand All @@ -26,10 +28,13 @@ def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
finish_reason = "stop"
elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
finish_reason = "length"
elif status == SequenceStatus.FINISHED_ABORTED:
finish_reason = "abort"
else:
finish_reason = None
return finish_reason


class SequenceData:

def __init__(
Expand Down Expand Up @@ -137,6 +142,9 @@ def get_output_token_ids(self) -> List[int]:
def get_cumulative_logprob(self) -> float:
return self.data.cumulative_logprob

def is_finished(self) -> bool:
return SequenceStatus.is_finished(self.status)

def fork(self, child_seq: 'Sequence') -> None:
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
Expand Down Expand Up @@ -182,7 +190,7 @@ def find(self, seq_id: int) -> Sequence:
raise ValueError(f'Sequence {seq_id} not found.')

def is_finished(self) -> bool:
return all(SequenceStatus.is_finished(seq.status) for seq in self.seqs)
return all(seq.is_finished() for seq in self.seqs)

def __repr__(self) -> str:
return (f"SequenceGroup(request_id={self.request_id}, "
Expand Down
Loading