diff --git a/cacheflow/core/scheduler.py b/cacheflow/core/scheduler.py index 795e576edf89..1085e839c434 100644 --- a/cacheflow/core/scheduler.py +++ b/cacheflow/core/scheduler.py @@ -291,7 +291,7 @@ def update( for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): # Append a new token to the sequence. output = seq_outputs[seq.seq_id] - seq.append_token(output.output_token, output.logprobs) + seq.append_token_id(output.output_token, output.logprobs) return self.running.copy() def free_seq(self, seq: Sequence) -> None: diff --git a/cacheflow/sequence.py b/cacheflow/sequence.py index 61c5091bfb16..02c5970ec681 100644 --- a/cacheflow/sequence.py +++ b/cacheflow/sequence.py @@ -24,7 +24,7 @@ def __init__( self.output_token_ids: List[int] = [] self.cumulative_logprob = 0.0 - def append_token(self, token_id: int, logprob: float) -> None: + def append_token_id(self, token_id: int, logprob: float) -> None: self.output_token_ids.append(token_id) self.cumulative_logprob += logprob @@ -64,6 +64,7 @@ def __init__( self.data = SequenceData(prompt_token_ids) self.output_logprobs: List[Dict[int, float]] = [] + self.output_tokens: List[str] = [] self.output_text = "" self.logical_token_blocks: List[LogicalTokenBlock] = [] @@ -92,11 +93,15 @@ def _append_tokens_to_blocks(self, token_ids: List[int]) -> None: last_block.append_tokens(token_ids[:num_empty_slots]) token_ids = token_ids[num_empty_slots:] - def append_token(self, token_id: int, logprobs: Dict[int, float]) -> None: + def append_token_id( + self, + token_id: int, + logprobs: Dict[int, float], + ) -> None: assert token_id in logprobs self._append_tokens_to_blocks([token_id]) self.output_logprobs.append(logprobs) - self.data.append_token(token_id, logprobs[token_id]) + self.data.append_token_id(token_id, logprobs[token_id]) def get_len(self) -> int: return self.data.get_len() diff --git a/cacheflow/server/llm_server.py b/cacheflow/server/llm_server.py index 5e01bc691ebc..005b99652ef4 100644 --- a/cacheflow/server/llm_server.py +++ b/cacheflow/server/llm_server.py @@ -14,7 +14,8 @@ from cacheflow.sampling_params import SamplingParams from cacheflow.server.arg_utils import ServerArgs from cacheflow.server.ray_utils import initialize_cluster -from cacheflow.server.tokenizer_utils import get_tokenizer +from cacheflow.server.tokenizer_utils import (get_tokenizer, + detokenize_incrementally) from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus from cacheflow.utils import Counter from cacheflow.worker.worker import Worker @@ -184,18 +185,17 @@ def step(self) -> List[RequestOutput]: return request_outputs def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None: - # Batch-decode the sequence outputs. - seqs: List[Sequence] = [] + # Decode the sequence outputs. for seq_group in seq_groups: - seqs.extend(seq_group.get_seqs(status=SequenceStatus.RUNNING)) - output_tokens_per_seq = [] - for seq in seqs: - output_tokens_per_seq.append(seq.get_output_token_ids()) - output_texts = self.tokenizer.batch_decode(output_tokens_per_seq, - skip_special_tokens=True) - # Update the sequences with the output texts. - for seq, output_text in zip(seqs, output_texts): - seq.output_text = output_text + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + new_token, new_output_text = detokenize_incrementally( + self.tokenizer, + seq.output_tokens, + seq.get_last_token_id(), + skip_special_tokens=True, + ) + seq.output_tokens.append(new_token) + seq.output_text = new_output_text def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None: # Stop the sequences. diff --git a/cacheflow/server/tokenizer_utils.py b/cacheflow/server/tokenizer_utils.py index 9dbd67a4cea6..6e12249952bd 100644 --- a/cacheflow/server/tokenizer_utils.py +++ b/cacheflow/server/tokenizer_utils.py @@ -1,8 +1,12 @@ -from typing import Union +from typing import List, Tuple, Union from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) +from cacheflow.logger import init_logger + +logger = init_logger(__name__) + _MODEL_TYPES_WITH_SLOW_TOKENIZER = [ # LLaMA fast tokenizer has a bug related to protobuf. # See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554 @@ -17,5 +21,62 @@ def get_tokenizer( ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: config = AutoConfig.from_pretrained(model_name) if config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER: + if getattr(kwargs, "use_fast", False) == True: + raise ValueError( + f"Cannot use the fast tokenizer for {config.model_type} due to " + "bugs in the fast tokenizer.") + logger.info( + f"Using the slow tokenizer for {config.model_type} due to bugs in " + "the fast tokenizer. This could potentially lead to performance " + "degradation.") kwargs["use_fast"] = False return AutoTokenizer.from_pretrained(model_name, *args, **kwargs) + + +def detokenize_incrementally( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + prev_output_tokens: List[str], + new_token_id: int, + skip_special_tokens: bool, +) -> Tuple[str, str]: + """Detokenizes the new token in conjuction with the previous output tokens. + + NOTE: This function does not update prev_output_tokens. + + Returns: + new_token: The new token as a string. + output_text: The new output text as a string. + """ + new_token = tokenizer.convert_ids_to_tokens( + new_token_id, skip_special_tokens=skip_special_tokens) + output_tokens = prev_output_tokens + [new_token] + + # Convert the tokens to a string. + # Optimization: If the tokenizer does not have `added_tokens_encoder`, + # then we can directly use `convert_tokens_to_string`. + if not getattr(tokenizer, "added_tokens_encoder", {}): + output_text = tokenizer.convert_tokens_to_string(output_tokens) + return new_token, output_text + + # Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 + # NOTE(woosuk): The following code is slow because it runs a for loop over + # the output_tokens. In Python, running a for loop over a list can be slow + # even when the loop body is very simple. + sub_texts = [] + current_sub_text = [] + for token in output_tokens: + if skip_special_tokens and token in tokenizer.all_special_ids: + continue + if token in tokenizer.added_tokens_encoder: + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + current_sub_text = [] + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + output_text = " ".join(sub_texts) + return new_token, output_text