Skip to content

Commit

Permalink
Implement stop strings and best_of (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon committed May 21, 2023
1 parent c3442c1 commit f746ced
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 116 deletions.
12 changes: 6 additions & 6 deletions cacheflow/core/block_manager.py
Expand Up @@ -80,15 +80,15 @@ def __init__(
def can_allocate(self, seq_group: SequenceGroup) -> bool:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
seq = seq_group.seqs[0]
seq = seq_group.get_seqs()[0]
num_required_blocks = len(seq.logical_token_blocks)
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
# Use watermark to avoid frequent cache eviction.
return num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks

def allocate(self, seq_group: SequenceGroup) -> None:
# NOTE: Here we assume that all sequences in the group have the same prompt.
seq = seq_group.seqs[0]
seq = seq_group.get_seqs()[0]

# Allocate new physical token blocks that will store the prompt tokens.
block_table: BlockTable = []
Expand All @@ -99,7 +99,7 @@ def allocate(self, seq_group: SequenceGroup) -> None:
block_table.append(block)

# Assign the block table for each sequence.
for seq in seq_group.seqs:
for seq in seq_group.get_seqs():
self.block_tables[seq.seq_id] = block_table.copy()

def can_append_slot(self, seq_group: SequenceGroup) -> bool:
Expand Down Expand Up @@ -147,7 +147,7 @@ def _get_physical_blocks(self, seq_group: SequenceGroup) -> List[PhysicalTokenBl
# NOTE: Here, we assume that the physical blocks are only shared by
# the sequences in the same group.
blocks: Set[PhysicalTokenBlock] = set()
for seq in seq_group.seqs:
for seq in seq_group.get_seqs():
if seq.status == SequenceStatus.FINISHED:
continue
block_table = self.block_tables[seq.seq_id]
Expand All @@ -168,7 +168,7 @@ def can_swap_in(self, seq_group: SequenceGroup) -> bool:
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
# CPU block -> GPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.seqs:
for seq in seq_group.get_seqs():
if seq.status == SequenceStatus.FINISHED:
continue
new_block_table: BlockTable = []
Expand Down Expand Up @@ -199,7 +199,7 @@ def can_swap_out(self, seq_group: SequenceGroup) -> bool:
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
# GPU block -> CPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.seqs:
for seq in seq_group.get_seqs():
if seq.status == SequenceStatus.FINISHED:
continue
new_block_table: BlockTable = []
Expand Down
63 changes: 15 additions & 48 deletions cacheflow/core/scheduler.py
Expand Up @@ -73,8 +73,6 @@ def __init__(
self.waiting: List[SequenceGroup] = []
# Sequence groups in the RUNNING state.
self.running: List[SequenceGroup] = []
# Mapping: request_id -> num_steps.
self.num_steps: Dict[str, int] = {}
# Sequence groups in the SWAPPED state.
self.swapped: List[SequenceGroup] = []

Expand All @@ -84,7 +82,6 @@ def __init__(

def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue.
assert seq_group.request_id not in self.num_steps
self.waiting.append(seq_group)

def has_unfinished_seqs(self) -> bool:
Expand Down Expand Up @@ -178,7 +175,7 @@ def _schedule(self) -> Tuple[SchedulerOutputs, List[int]]:
break

# If the number of batched tokens exceeds the limit, stop.
num_prompt_tokens = seq_group.seqs[0].get_len()
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
if (num_batched_tokens + num_prompt_tokens
> self.scheduler_config.max_num_batched_tokens):
break
Expand Down Expand Up @@ -278,15 +275,8 @@ def update(
) -> List[SequenceGroup]:
# Update the running sequences and free blocks.
for seq_group in self.running:
request_id = seq_group.request_id
self.num_steps[request_id] += 1
stop_token_ids = seq_group.sampling_params.stop_token_ids

# Process beam search results before processing the next tokens.
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue

# Process beam search results before processing the new tokens.
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
output = seq_outputs[seq.seq_id]
if seq.seq_id != output.parent_seq_id:
# The sequence is a fork of the parent sequence (beam search).
Expand All @@ -297,43 +287,27 @@ def update(
parent_seq.fork(seq)
self.block_manager.fork(parent_seq, seq)

# Process the next tokens.
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue

# Process the new tokens.
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
# Append a new token to the sequence.
output = seq_outputs[seq.seq_id]
seq.append_token(output.output_token, output.logprobs)
return self.running.copy()

# Check if the sequence has generated a stop token.
if output.output_token in stop_token_ids:
self._free_seq(seq)
continue
def free_seq(self, seq: Sequence) -> None:
seq.status = SequenceStatus.FINISHED
self.block_manager.free(seq)

# Check if the sequence has reached the maximum number of steps.
max_num_steps = seq_group.sampling_params.max_tokens
if self.num_steps[request_id] == max_num_steps:
self._free_seq(seq)
continue

# Update the running sequences.
updated = self.running.copy()
running: List[SequenceGroup] = []
for seq_group in self.running:
if seq_group.is_finished():
self._free_seq_group(seq_group)
else:
running.append(seq_group)
self.running = running
return updated
def free_finished_seq_groups(self) -> None:
self.running = [
seq_group for seq_group in self.running
if not seq_group.is_finished()
]

def _allocate(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group)
for seq in seq_group.seqs:
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.RUNNING
if seq_group.request_id not in self.num_steps:
self.num_steps[seq_group.request_id] = 0

def _append_slot(
self,
Expand Down Expand Up @@ -403,13 +377,6 @@ def _preempt_by_swap(
self._swap_out(seq_group, blocks_to_swap_out)
self.swapped.append(seq_group)

def _free_seq(self, seq: Sequence) -> None:
seq.status = SequenceStatus.FINISHED
self.block_manager.free(seq)

def _free_seq_group(self, seq_group: SequenceGroup) -> None:
del self.num_steps[seq_group.request_id]

def _swap_in(
self,
seq_group: SequenceGroup,
Expand Down
5 changes: 3 additions & 2 deletions cacheflow/entrypoints/fastapi_server.py
Expand Up @@ -123,6 +123,7 @@ async def generate_stream(request: Request):
parallel_config = server_configs[2]
distributed_init_method, stage_devices = initialize_cluster(parallel_config)

server = FastAPIServer(
args.use_ray, *server_configs, distributed_init_method, stage_devices)
server = FastAPIServer(args.use_ray, *server_configs,
distributed_init_method, stage_devices,
log_stats=not args.disable_log_stats)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
16 changes: 8 additions & 8 deletions cacheflow/model_executor/layers/sampler.py
Expand Up @@ -283,20 +283,20 @@ def _sample_from_prompt(
) -> List[int]:
if sampling_params.use_beam_search:
# Beam search.
beam_width = sampling_params.n
beam_width = sampling_params.best_of
_, next_token_ids = torch.topk(prob, beam_width)
next_token_ids = next_token_ids.tolist()
elif sampling_params.temperature == 0.0:
# Greedy sampling.
assert sampling_params.n == 1
assert sampling_params.best_of == 1
next_token_id = torch.argmax(prob)
next_token_ids = [next_token_id.item()]
else:
# Random sampling.
# Sample n tokens for the prompt.
n = sampling_params.n
# Sample `best_of` tokens for the prompt.
num_seqs = sampling_params.best_of
next_token_ids = torch.multinomial(
prob, num_samples=n, replacement=True)
prob, num_samples=num_seqs, replacement=True)
next_token_ids = next_token_ids.tolist()
return next_token_ids

Expand All @@ -308,7 +308,7 @@ def _sample_from_generation_tokens(
seq_logprobs: List[float],
sampling_params: SamplingParams,
) -> Tuple[List[int], List[int]]:
# NOTE(woosuk): sampling_params.n can be greater than
# NOTE(woosuk): sampling_params.best_of can be greater than
# len(seq_ids) because some sequences in the group might have
# been already terminated.
if sampling_params.use_beam_search:
Expand Down Expand Up @@ -372,7 +372,7 @@ def _sample(
seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts:
# Generate the next tokens for a prompt input.
assert len(seq_ids) == sampling_params.n
assert len(seq_ids) == sampling_params.best_of
prob = probs[idx]
logprob = logprobs[idx]
idx += 1
Expand All @@ -397,7 +397,7 @@ def _sample(

# Sample the next tokens.
seq_logprobs = [
input_metadata.seq_data[seq_id].cumulative_logprobs
input_metadata.seq_data[seq_id].cumulative_logprob
for seq_id in seq_ids]
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
seq_ids, prob, logprob, seq_logprobs, sampling_params)
Expand Down
44 changes: 23 additions & 21 deletions cacheflow/outputs.py
@@ -1,6 +1,4 @@
from typing import Dict, List, Union

from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing import Dict, List

from cacheflow.sequence import SequenceGroup

Expand All @@ -9,20 +7,23 @@ class CompletionOutput:

def __init__(
self,
index: int,
text: str,
token_ids: List[int],
cumulative_logprobs: float,
cumulative_logprob: float,
logprobs: List[Dict[int, float]],
) -> None:
self.index = index
self.text = text
self.token_ids = token_ids
self.cumulative_logprobs = cumulative_logprobs
self.cumulative_logprob = cumulative_logprob
self.logprobs = logprobs

def __repr__(self) -> str:
return (f"CompletionOutput(output={self.text!r}, "
return (f"CompletionOutput(index={self.index}, "
f"text={self.text!r}, "
f"token_ids={self.token_ids}, "
f"cumulative_logprobs={self.cumulative_logprobs}, "
f"cumulative_logprob={self.cumulative_logprob}, "
f"logprobs={self.logprobs})")


Expand All @@ -43,31 +44,32 @@ def __init__(
self.done = done

@staticmethod
def from_seq_group(
seq_group: SequenceGroup,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
) -> "RequestOutput":
outputs: List[CompletionOutput] = []
def from_seq_group(seq_group: SequenceGroup) -> "RequestOutput":
# Get the top-n sequences.
n = seq_group.sampling_params.n
seqs = seq_group.get_seqs()
for seq in seqs:
output_token_ids = seq.data.output_token_ids
output_str = tokenizer.decode(output_token_ids,
skip_special_tokens=True)
seq_logprobs = seq.data.cumulative_logprobs
assert n <= len(seqs)
sorted_seqs = sorted(
seqs, key=lambda seq: seq.get_cumulative_logprob(), reverse=True)
top_n_seqs = sorted_seqs[:n]

# Create the outputs.
outputs: List[CompletionOutput] = []
for seq in top_n_seqs:
logprobs = seq.output_logprobs
if seq_group.sampling_params.logprobs == 0:
# NOTE: We need to take care of this case because the sequence
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
logprobs = {}
output = CompletionOutput(output_str, output_token_ids,
seq_logprobs, logprobs)
output = CompletionOutput(seqs.index(seq), seq.output_text,
seq.get_output_token_ids(),
seq.get_cumulative_logprob(), logprobs)
outputs.append(output)

# Every sequence in the sequence group should have the same prompt.
prompt = seqs[0].prompt
prompt_token_ids = seqs[0].data.prompt_token_ids
prompt = top_n_seqs[0].prompt
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
return RequestOutput(seq_group.request_id, prompt, prompt_token_ids,
outputs, seq_group.is_finished())

Expand Down

0 comments on commit f746ced

Please sign in to comment.