From 9bbb0a5e4a32780a05b06dfcda19855f4410a096 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 23 May 2023 00:52:03 +0000 Subject: [PATCH 1/8] Incremental tokenizer --- cacheflow/core/scheduler.py | 2 +- cacheflow/sequence.py | 11 ++++++++--- cacheflow/server/llm_server.py | 15 +++++---------- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/cacheflow/core/scheduler.py b/cacheflow/core/scheduler.py index 08f855730dd6..3491b60b0ec7 100644 --- a/cacheflow/core/scheduler.py +++ b/cacheflow/core/scheduler.py @@ -291,7 +291,7 @@ def update( for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): # Append a new token to the sequence. output = seq_outputs[seq.seq_id] - seq.append_token(output.output_token, output.logprobs) + seq.append_token_id(output.output_token, output.logprobs) return self.running.copy() def free_seq(self, seq: Sequence) -> None: diff --git a/cacheflow/sequence.py b/cacheflow/sequence.py index b2c19fae9c96..c2a6a73898b0 100644 --- a/cacheflow/sequence.py +++ b/cacheflow/sequence.py @@ -24,7 +24,7 @@ def __init__( self.output_token_ids: List[int] = [] self.cumulative_logprob = 0.0 - def append_token(self, token_id: int, logprob: float) -> None: + def append_token_id(self, token_id: int, logprob: float) -> None: self.output_token_ids.append(token_id) self.cumulative_logprob += logprob @@ -64,6 +64,7 @@ def __init__( self.data = SequenceData(prompt_token_ids) self.output_logprobs: List[Dict[int, float]] = [] + self.output_tokens: List[str] = [] self.output_text = "" self.logical_token_blocks: List[LogicalTokenBlock] = [] @@ -92,11 +93,15 @@ def _append_tokens_to_blocks(self, token_ids: List[int]) -> None: last_block.append_tokens(token_ids[:num_empty_slots]) token_ids = token_ids[num_empty_slots:] - def append_token(self, token_id: int, logprobs: Dict[int, float]) -> None: + def append_token_id( + self, + token_id: int, + logprobs: Dict[int, float], + ) -> None: assert token_id in logprobs self._append_tokens_to_blocks([token_id]) self.output_logprobs.append(logprobs) - self.data.append_token(token_id, logprobs[token_id]) + self.data.append_token_id(token_id, logprobs[token_id]) def get_len(self) -> int: return self.data.get_len() diff --git a/cacheflow/server/llm_server.py b/cacheflow/server/llm_server.py index 5e01bc691ebc..16381ce2a38f 100644 --- a/cacheflow/server/llm_server.py +++ b/cacheflow/server/llm_server.py @@ -185,17 +185,12 @@ def step(self) -> List[RequestOutput]: def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None: # Batch-decode the sequence outputs. - seqs: List[Sequence] = [] for seq_group in seq_groups: - seqs.extend(seq_group.get_seqs(status=SequenceStatus.RUNNING)) - output_tokens_per_seq = [] - for seq in seqs: - output_tokens_per_seq.append(seq.get_output_token_ids()) - output_texts = self.tokenizer.batch_decode(output_tokens_per_seq, - skip_special_tokens=True) - # Update the sequences with the output texts. - for seq, output_text in zip(seqs, output_texts): - seq.output_text = output_text + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + token_id = seq.get_last_token_id() + token = self.tokenizer.convert_ids_to_tokens(token_id, skip_special_tokens=True) + seq.output_tokens.append(token) + seq.output_text = self.tokenizer.convert_tokens_to_string(seq.output_tokens) def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None: # Stop the sequences. From b35fa03ea7d7095bcfe5d24ac294000a6b514563 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 23 May 2023 01:38:08 +0000 Subject: [PATCH 2/8] Implement detokenize_incrementally in tokenizer_utils.py --- cacheflow/server/llm_server.py | 15 ++++++---- cacheflow/server/tokenizer_utils.py | 43 ++++++++++++++++++++++++++++- 2 files changed, 52 insertions(+), 6 deletions(-) diff --git a/cacheflow/server/llm_server.py b/cacheflow/server/llm_server.py index 16381ce2a38f..f8ebef25e74e 100644 --- a/cacheflow/server/llm_server.py +++ b/cacheflow/server/llm_server.py @@ -14,7 +14,8 @@ from cacheflow.sampling_params import SamplingParams from cacheflow.server.arg_utils import ServerArgs from cacheflow.server.ray_utils import initialize_cluster -from cacheflow.server.tokenizer_utils import get_tokenizer +from cacheflow.server.tokenizer_utils import (get_tokenizer, + detokenize_incrementally) from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus from cacheflow.utils import Counter from cacheflow.worker.worker import Worker @@ -187,10 +188,14 @@ def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None: # Batch-decode the sequence outputs. for seq_group in seq_groups: for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - token_id = seq.get_last_token_id() - token = self.tokenizer.convert_ids_to_tokens(token_id, skip_special_tokens=True) - seq.output_tokens.append(token) - seq.output_text = self.tokenizer.convert_tokens_to_string(seq.output_tokens) + new_token, new_output_text = detokenize_incrementally( + self.tokenizer, + seq.output_tokens, + seq.get_last_token_id(), + skip_special_tokens=True, + ) + seq.output_tokens.append(new_token) + seq.output_text = new_output_text def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None: # Stop the sequences. diff --git a/cacheflow/server/tokenizer_utils.py b/cacheflow/server/tokenizer_utils.py index 9dbd67a4cea6..b27f9af3c82c 100644 --- a/cacheflow/server/tokenizer_utils.py +++ b/cacheflow/server/tokenizer_utils.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import List, Tuple, Union from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) @@ -19,3 +19,44 @@ def get_tokenizer( if config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER: kwargs["use_fast"] = False return AutoTokenizer.from_pretrained(model_name, *args, **kwargs) + + +def detokenize_incrementally( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + prev_output_tokens: List[str], + new_token_id: int, + skip_special_tokens: bool, +) -> Tuple[str, str]: + """Detokenizes the new token in conjuction with the previous output tokens. + + NOTE: This function does not update prev_output_tokens. + + Returns: + new_token: The new token as a string. + output_text: The new output text as a string. + """ + new_token = tokenizer.convert_ids_to_tokens( + new_token_id, skip_special_tokens=skip_special_tokens) + output_tokens = prev_output_tokens + [new_token] + + # Convert the tokens to a string. + # Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 + sub_texts = [] + current_sub_text = [] + for token in output_tokens: + if skip_special_tokens and token in tokenizer.all_special_ids: + continue + if (hasattr(tokenizer, "added_tokens_encoder") and + token in tokenizer.added_tokens_encoder): + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + current_sub_text = [] + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + output_text = " ".join(sub_texts) + return new_token, output_text From f777a342ac788956b94b036c0d6b85ee395991ee Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 23 May 2023 02:02:28 +0000 Subject: [PATCH 3/8] Optimize --- cacheflow/server/llm_server.py | 2 +- cacheflow/server/tokenizer_utils.py | 24 +++++------------------- 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/cacheflow/server/llm_server.py b/cacheflow/server/llm_server.py index f8ebef25e74e..005b99652ef4 100644 --- a/cacheflow/server/llm_server.py +++ b/cacheflow/server/llm_server.py @@ -185,7 +185,7 @@ def step(self) -> List[RequestOutput]: return request_outputs def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None: - # Batch-decode the sequence outputs. + # Decode the sequence outputs. for seq_group in seq_groups: for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): new_token, new_output_text = detokenize_incrementally( diff --git a/cacheflow/server/tokenizer_utils.py b/cacheflow/server/tokenizer_utils.py index b27f9af3c82c..5931ee6c4b4d 100644 --- a/cacheflow/server/tokenizer_utils.py +++ b/cacheflow/server/tokenizer_utils.py @@ -40,23 +40,9 @@ def detokenize_incrementally( output_tokens = prev_output_tokens + [new_token] # Convert the tokens to a string. - # Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 - sub_texts = [] - current_sub_text = [] - for token in output_tokens: - if skip_special_tokens and token in tokenizer.all_special_ids: - continue - if (hasattr(tokenizer, "added_tokens_encoder") and - token in tokenizer.added_tokens_encoder): - if current_sub_text: - sub_text = tokenizer.convert_tokens_to_string(current_sub_text) - sub_texts.append(sub_text) - current_sub_text = [] - sub_texts.append(token) - else: - current_sub_text.append(token) - if current_sub_text: - sub_text = tokenizer.convert_tokens_to_string(current_sub_text) - sub_texts.append(sub_text) - output_text = " ".join(sub_texts) + # We optimize tokenizer._decode() by assuming that the tokenizer does not + # have added_tokens_encoder. + if hasattr(tokenizer, "added_tokens_encoder"): + assert not tokenizer.added_tokens_encoder + output_text = tokenizer.convert_tokens_to_string(output_tokens) return new_token, output_text From 35ff77391d506219bd03a5bb60e188747a76d8e2 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 23 May 2023 05:50:43 +0000 Subject: [PATCH 4/8] Allow unoptimized path --- cacheflow/server/tokenizer_utils.py | 32 ++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/cacheflow/server/tokenizer_utils.py b/cacheflow/server/tokenizer_utils.py index 5931ee6c4b4d..ecdcbd125e08 100644 --- a/cacheflow/server/tokenizer_utils.py +++ b/cacheflow/server/tokenizer_utils.py @@ -40,9 +40,31 @@ def detokenize_incrementally( output_tokens = prev_output_tokens + [new_token] # Convert the tokens to a string. - # We optimize tokenizer._decode() by assuming that the tokenizer does not - # have added_tokens_encoder. - if hasattr(tokenizer, "added_tokens_encoder"): - assert not tokenizer.added_tokens_encoder - output_text = tokenizer.convert_tokens_to_string(output_tokens) + # Optimization: If the tokenizer does not have `added_tokens_encoder`, + # then we can just use `convert_tokens_to_string`. + if not (hasattr(tokenizer, "added_tokens_encoder") and + tokenizer.added_tokens_encoder): + output_text = tokenizer.convert_tokens_to_string(output_tokens) + return new_token, output_text + + # Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 + # NOTE(woosuk): The following code is slow. + sub_texts = [] + current_sub_text = [] + for token in output_tokens: + if skip_special_tokens and token in tokenizer.all_special_ids: + continue + if (hasattr(tokenizer, "added_tokens_encoder") and + token in tokenizer.added_tokens_encoder): + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + current_sub_text = [] + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + output_text = " ".join(sub_texts) return new_token, output_text From 5161505bc34f6b558d2c02ca02d93ecbaebe8333 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 23 May 2023 05:51:51 +0000 Subject: [PATCH 5/8] Warn the prformance drop when using slow tokenizer --- cacheflow/server/tokenizer_utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/cacheflow/server/tokenizer_utils.py b/cacheflow/server/tokenizer_utils.py index ecdcbd125e08..da435b9ce0bf 100644 --- a/cacheflow/server/tokenizer_utils.py +++ b/cacheflow/server/tokenizer_utils.py @@ -3,6 +3,10 @@ from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) +from cacheflow.logger import init_logger + +logger = init_logger(__name__) + _MODEL_TYPES_WITH_SLOW_TOKENIZER = [ # LLaMA fast tokenizer has a bug related to protobuf. # See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554 @@ -17,6 +21,14 @@ def get_tokenizer( ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: config = AutoConfig.from_pretrained(model_name) if config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER: + if getattr(kwargs, "use_fast", False) == True: + raise ValueError( + f"Cannot use the fast tokenizer for {config.model_type} due to " + "bugs in the fast tokenizer.") + logger.info( + f"Using the slow tokenizer for {config.model_type} due to bugs in " + "the fast tokenizer. This could potentially lead to performance " + "degradation.") kwargs["use_fast"] = False return AutoTokenizer.from_pretrained(model_name, *args, **kwargs) From e5c49c438a946a12d62beeb4077d968875b50eea Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Wed, 24 May 2023 01:06:06 +0000 Subject: [PATCH 6/8] Simplify --- cacheflow/server/tokenizer_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cacheflow/server/tokenizer_utils.py b/cacheflow/server/tokenizer_utils.py index da435b9ce0bf..be35bb151bf7 100644 --- a/cacheflow/server/tokenizer_utils.py +++ b/cacheflow/server/tokenizer_utils.py @@ -54,8 +54,7 @@ def detokenize_incrementally( # Convert the tokens to a string. # Optimization: If the tokenizer does not have `added_tokens_encoder`, # then we can just use `convert_tokens_to_string`. - if not (hasattr(tokenizer, "added_tokens_encoder") and - tokenizer.added_tokens_encoder): + if not getattr(tokenizer, "added_tokens_encoder", {}): output_text = tokenizer.convert_tokens_to_string(output_tokens) return new_token, output_text @@ -66,8 +65,7 @@ def detokenize_incrementally( for token in output_tokens: if skip_special_tokens and token in tokenizer.all_special_ids: continue - if (hasattr(tokenizer, "added_tokens_encoder") and - token in tokenizer.added_tokens_encoder): + if token in tokenizer.added_tokens_encoder: if current_sub_text: sub_text = tokenizer.convert_tokens_to_string(current_sub_text) sub_texts.append(sub_text) From 4722c15e011b762ccf6bd633d55972047b21ba5c Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Wed, 24 May 2023 01:18:22 +0000 Subject: [PATCH 7/8] Add comments --- cacheflow/server/tokenizer_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cacheflow/server/tokenizer_utils.py b/cacheflow/server/tokenizer_utils.py index be35bb151bf7..f060b89a5db0 100644 --- a/cacheflow/server/tokenizer_utils.py +++ b/cacheflow/server/tokenizer_utils.py @@ -53,13 +53,14 @@ def detokenize_incrementally( # Convert the tokens to a string. # Optimization: If the tokenizer does not have `added_tokens_encoder`, - # then we can just use `convert_tokens_to_string`. + # then we can directly use `convert_tokens_to_string`. if not getattr(tokenizer, "added_tokens_encoder", {}): output_text = tokenizer.convert_tokens_to_string(output_tokens) return new_token, output_text # Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 - # NOTE(woosuk): The following code is slow. + # NOTE(woosuk): The following code is slow because it runs a for loop over + # the output_tokens. sub_texts = [] current_sub_text = [] for token in output_tokens: From f8ee1a35e84d3fa79227632686a8e7f01e2eb71d Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Wed, 24 May 2023 01:19:51 +0000 Subject: [PATCH 8/8] Minor --- cacheflow/server/tokenizer_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cacheflow/server/tokenizer_utils.py b/cacheflow/server/tokenizer_utils.py index f060b89a5db0..6e12249952bd 100644 --- a/cacheflow/server/tokenizer_utils.py +++ b/cacheflow/server/tokenizer_utils.py @@ -60,7 +60,8 @@ def detokenize_incrementally( # Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 # NOTE(woosuk): The following code is slow because it runs a for loop over - # the output_tokens. + # the output_tokens. In Python, running a for loop over a list can be slow + # even when the loop body is very simple. sub_texts = [] current_sub_text = [] for token in output_tokens: