Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement stop strings and best_of #114

Merged
merged 22 commits into from May 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 6 additions & 6 deletions cacheflow/core/block_manager.py
Expand Up @@ -80,15 +80,15 @@ def __init__(
def can_allocate(self, seq_group: SequenceGroup) -> bool:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
seq = seq_group.seqs[0]
seq = seq_group.get_seqs()[0]
num_required_blocks = len(seq.logical_token_blocks)
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
# Use watermark to avoid frequent cache eviction.
return num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks

def allocate(self, seq_group: SequenceGroup) -> None:
# NOTE: Here we assume that all sequences in the group have the same prompt.
seq = seq_group.seqs[0]
seq = seq_group.get_seqs()[0]

# Allocate new physical token blocks that will store the prompt tokens.
block_table: BlockTable = []
Expand All @@ -99,7 +99,7 @@ def allocate(self, seq_group: SequenceGroup) -> None:
block_table.append(block)

# Assign the block table for each sequence.
for seq in seq_group.seqs:
for seq in seq_group.get_seqs():
self.block_tables[seq.seq_id] = block_table.copy()

def can_append_slot(self, seq_group: SequenceGroup) -> bool:
Expand Down Expand Up @@ -147,7 +147,7 @@ def _get_physical_blocks(self, seq_group: SequenceGroup) -> List[PhysicalTokenBl
# NOTE: Here, we assume that the physical blocks are only shared by
# the sequences in the same group.
blocks: Set[PhysicalTokenBlock] = set()
for seq in seq_group.seqs:
for seq in seq_group.get_seqs():
if seq.status == SequenceStatus.FINISHED:
continue
block_table = self.block_tables[seq.seq_id]
Expand All @@ -168,7 +168,7 @@ def can_swap_in(self, seq_group: SequenceGroup) -> bool:
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
# CPU block -> GPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.seqs:
for seq in seq_group.get_seqs():
if seq.status == SequenceStatus.FINISHED:
continue
new_block_table: BlockTable = []
Expand Down Expand Up @@ -199,7 +199,7 @@ def can_swap_out(self, seq_group: SequenceGroup) -> bool:
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
# GPU block -> CPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.seqs:
for seq in seq_group.get_seqs():
if seq.status == SequenceStatus.FINISHED:
continue
new_block_table: BlockTable = []
Expand Down
63 changes: 15 additions & 48 deletions cacheflow/core/scheduler.py
Expand Up @@ -73,8 +73,6 @@ def __init__(
self.waiting: List[SequenceGroup] = []
# Sequence groups in the RUNNING state.
self.running: List[SequenceGroup] = []
# Mapping: request_id -> num_steps.
self.num_steps: Dict[str, int] = {}
# Sequence groups in the SWAPPED state.
self.swapped: List[SequenceGroup] = []

Expand All @@ -84,7 +82,6 @@ def __init__(

def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue.
assert seq_group.request_id not in self.num_steps
self.waiting.append(seq_group)

def has_unfinished_seqs(self) -> bool:
Expand Down Expand Up @@ -178,7 +175,7 @@ def _schedule(self) -> Tuple[SchedulerOutputs, List[int]]:
break

# If the number of batched tokens exceeds the limit, stop.
num_prompt_tokens = seq_group.seqs[0].get_len()
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
if (num_batched_tokens + num_prompt_tokens
> self.scheduler_config.max_num_batched_tokens):
break
Expand Down Expand Up @@ -278,15 +275,8 @@ def update(
) -> List[SequenceGroup]:
# Update the running sequences and free blocks.
for seq_group in self.running:
request_id = seq_group.request_id
self.num_steps[request_id] += 1
stop_token_ids = seq_group.sampling_params.stop_token_ids

# Process beam search results before processing the next tokens.
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue

# Process beam search results before processing the new tokens.
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
output = seq_outputs[seq.seq_id]
if seq.seq_id != output.parent_seq_id:
# The sequence is a fork of the parent sequence (beam search).
Expand All @@ -297,43 +287,27 @@ def update(
parent_seq.fork(seq)
self.block_manager.fork(parent_seq, seq)

# Process the next tokens.
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue

# Process the new tokens.
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
# Append a new token to the sequence.
output = seq_outputs[seq.seq_id]
seq.append_token(output.output_token, output.logprobs)
return self.running.copy()

# Check if the sequence has generated a stop token.
if output.output_token in stop_token_ids:
self._free_seq(seq)
continue
def free_seq(self, seq: Sequence) -> None:
seq.status = SequenceStatus.FINISHED
self.block_manager.free(seq)

# Check if the sequence has reached the maximum number of steps.
max_num_steps = seq_group.sampling_params.max_tokens
if self.num_steps[request_id] == max_num_steps:
self._free_seq(seq)
continue

# Update the running sequences.
updated = self.running.copy()
running: List[SequenceGroup] = []
for seq_group in self.running:
if seq_group.is_finished():
self._free_seq_group(seq_group)
else:
running.append(seq_group)
self.running = running
return updated
def free_finished_seq_groups(self) -> None:
self.running = [
seq_group for seq_group in self.running
if not seq_group.is_finished()
]

def _allocate(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group)
for seq in seq_group.seqs:
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.RUNNING
if seq_group.request_id not in self.num_steps:
self.num_steps[seq_group.request_id] = 0

def _append_slot(
self,
Expand Down Expand Up @@ -403,13 +377,6 @@ def _preempt_by_swap(
self._swap_out(seq_group, blocks_to_swap_out)
self.swapped.append(seq_group)

def _free_seq(self, seq: Sequence) -> None:
seq.status = SequenceStatus.FINISHED
self.block_manager.free(seq)

def _free_seq_group(self, seq_group: SequenceGroup) -> None:
del self.num_steps[seq_group.request_id]

def _swap_in(
self,
seq_group: SequenceGroup,
Expand Down
5 changes: 3 additions & 2 deletions cacheflow/entrypoints/fastapi_server.py
Expand Up @@ -123,6 +123,7 @@ async def generate_stream(request: Request):
parallel_config = server_configs[2]
distributed_init_method, stage_devices = initialize_cluster(parallel_config)

server = FastAPIServer(
args.use_ray, *server_configs, distributed_init_method, stage_devices)
server = FastAPIServer(args.use_ray, *server_configs,
distributed_init_method, stage_devices,
log_stats=not args.disable_log_stats)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
16 changes: 8 additions & 8 deletions cacheflow/model_executor/layers/sampler.py
Expand Up @@ -283,20 +283,20 @@ def _sample_from_prompt(
) -> List[int]:
if sampling_params.use_beam_search:
# Beam search.
beam_width = sampling_params.n
beam_width = sampling_params.best_of
_, next_token_ids = torch.topk(prob, beam_width)
next_token_ids = next_token_ids.tolist()
elif sampling_params.temperature == 0.0:
# Greedy sampling.
assert sampling_params.n == 1
assert sampling_params.best_of == 1
next_token_id = torch.argmax(prob)
next_token_ids = [next_token_id.item()]
else:
# Random sampling.
# Sample n tokens for the prompt.
n = sampling_params.n
# Sample `best_of` tokens for the prompt.
num_seqs = sampling_params.best_of
next_token_ids = torch.multinomial(
prob, num_samples=n, replacement=True)
prob, num_samples=num_seqs, replacement=True)
next_token_ids = next_token_ids.tolist()
return next_token_ids

Expand All @@ -308,7 +308,7 @@ def _sample_from_generation_tokens(
seq_logprobs: List[float],
sampling_params: SamplingParams,
) -> Tuple[List[int], List[int]]:
# NOTE(woosuk): sampling_params.n can be greater than
# NOTE(woosuk): sampling_params.best_of can be greater than
# len(seq_ids) because some sequences in the group might have
# been already terminated.
if sampling_params.use_beam_search:
Expand Down Expand Up @@ -372,7 +372,7 @@ def _sample(
seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts:
# Generate the next tokens for a prompt input.
assert len(seq_ids) == sampling_params.n
assert len(seq_ids) == sampling_params.best_of
prob = probs[idx]
logprob = logprobs[idx]
idx += 1
Expand All @@ -397,7 +397,7 @@ def _sample(

# Sample the next tokens.
seq_logprobs = [
input_metadata.seq_data[seq_id].cumulative_logprobs
input_metadata.seq_data[seq_id].cumulative_logprob
for seq_id in seq_ids]
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
seq_ids, prob, logprob, seq_logprobs, sampling_params)
Expand Down
44 changes: 23 additions & 21 deletions cacheflow/outputs.py
@@ -1,6 +1,4 @@
from typing import Dict, List, Union

from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing import Dict, List

from cacheflow.sequence import SequenceGroup

Expand All @@ -9,20 +7,23 @@ class CompletionOutput:

def __init__(
self,
index: int,
zhuohan123 marked this conversation as resolved.
Show resolved Hide resolved
text: str,
token_ids: List[int],
cumulative_logprobs: float,
cumulative_logprob: float,
logprobs: List[Dict[int, float]],
) -> None:
self.index = index
self.text = text
self.token_ids = token_ids
self.cumulative_logprobs = cumulative_logprobs
self.cumulative_logprob = cumulative_logprob
self.logprobs = logprobs

def __repr__(self) -> str:
return (f"CompletionOutput(output={self.text!r}, "
return (f"CompletionOutput(index={self.index}, "
f"text={self.text!r}, "
f"token_ids={self.token_ids}, "
f"cumulative_logprobs={self.cumulative_logprobs}, "
f"cumulative_logprob={self.cumulative_logprob}, "
f"logprobs={self.logprobs})")


Expand All @@ -43,31 +44,32 @@ def __init__(
self.done = done

@staticmethod
def from_seq_group(
seq_group: SequenceGroup,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
) -> "RequestOutput":
outputs: List[CompletionOutput] = []
def from_seq_group(seq_group: SequenceGroup) -> "RequestOutput":
# Get the top-n sequences.
n = seq_group.sampling_params.n
seqs = seq_group.get_seqs()
for seq in seqs:
output_token_ids = seq.data.output_token_ids
output_str = tokenizer.decode(output_token_ids,
skip_special_tokens=True)
seq_logprobs = seq.data.cumulative_logprobs
assert n <= len(seqs)
sorted_seqs = sorted(
seqs, key=lambda seq: seq.get_cumulative_logprob(), reverse=True)
top_n_seqs = sorted_seqs[:n]

# Create the outputs.
outputs: List[CompletionOutput] = []
for seq in top_n_seqs:
logprobs = seq.output_logprobs
if seq_group.sampling_params.logprobs == 0:
# NOTE: We need to take care of this case because the sequence
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
logprobs = {}
output = CompletionOutput(output_str, output_token_ids,
seq_logprobs, logprobs)
output = CompletionOutput(seqs.index(seq), seq.output_text,
seq.get_output_token_ids(),
seq.get_cumulative_logprob(), logprobs)
outputs.append(output)

# Every sequence in the sequence group should have the same prompt.
prompt = seqs[0].prompt
prompt_token_ids = seqs[0].data.prompt_token_ids
prompt = top_n_seqs[0].prompt
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
return RequestOutput(seq_group.request_id, prompt, prompt_token_ids,
outputs, seq_group.is_finished())

Expand Down