-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Frontend] Bad words sampling parameter #9717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
DarkLight1337
merged 49 commits into
vllm-project:main
from
compressa-ai:feature/bad_words_ids2
Oct 26, 2024
Merged
Changes from all commits
Commits
Show all changes
49 commits
Select commit
Hold shift + click to select a range
feabb97
add bad words logits processor
Alvant f80b7d6
rename processor, add docstring
Alvant a152e66
add test for bad words
Alvant 9f21847
add tests (versions with and without vllm runner)
Alvant c577c47
keep only vllm runner tests
Alvant b372c17
Merge branch 'main' into feature/bad_words_ids
Alvant e6bdadb
run yapf and ruff
Alvant a69329f
run yapf and ruff again
Alvant b9f8a5d
search for bad sequence among all generateds
Alvant ecefbf4
run yapf and ruff
Alvant 05671cd
fix test for two bad tokens
Alvant 44bd494
fix unused imports and vars
Alvant 0fc7974
fix style
Alvant e1a18f1
fix format
Alvant 5ab80a4
refine test for two token word
Alvant f838046
bad_words_ids -> bad_words (sync engine)
Alvant 32f2e20
fix bad words and test
Alvant f93c4d7
fix for llama tokenizer, add llama in test
Alvant 16c2dd4
run yapf and ruff
Alvant ad0d61c
add comment about two models
Alvant 3389770
Merge pull request #1 from compressa-ai/feature/bad_words
Alvant 9b1a1ac
fix style
Alvant 924ed79
Merge pull request #2 from compressa-ai/feature/bad_words
Alvant ea4e02a
clarify comment about prefixes
Alvant be5d5c3
Merge branch 'main' into feature/bad_words_ids
Alvant bd86123
move logits stuff to logits file
Alvant 308e76a
run yapf and ruff, fix import order
Alvant aae2ac5
clarify add prefix logic
Alvant 9724dbc
fix is match ckeck for different type sequences
Alvant 1f0938c
add process params to async engine
Alvant f2673a5
change type to tuple in bad words ids processor
Alvant 8a6e88b
fix type for logits process to pass checks
Alvant 2f2ea06
Merge branch 'main' into feature/bad_words_ids
Alvant 7c0c60c
fix code style
Alvant 136937b
init bad words as empty list, add in repr
Alvant cdfce02
fix bad words post init
Alvant 3478d6d
fix code len style
Alvant 0187b3d
fix yapf style
Alvant 8eeb5ad
remove async preproc params
Alvant 4de1e23
move bad words creation logic to build_logits_processors
Alvant f0fbadd
fix style
Alvant b185e5b
unify logit processor creation logic (add getter for bad words)
Alvant 0232f72
fix import for openai logits
Alvant e37fa23
fix style (ruff or other)
Alvant eed58a7
move all bad words logits processor creation to separate file
Alvant 4c6f54d
fix style (one of them)
Alvant 64a3b80
handle case of mistral tokenizer in bad words logits
Alvant 7edca9e
simplify bad words code a bit
Alvant 1828554
Merge branch 'main' into feature/bad_words_ids2
Alvant File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,185 @@ | ||
| """Make sure bad_words works. | ||
|
|
||
| Run `pytest tests/samplers/test_no_bad_words.py`. | ||
|
|
||
| """ | ||
| from typing import List, Optional | ||
|
|
||
| from transformers import AutoTokenizer | ||
|
|
||
| from vllm import LLM, SamplingParams | ||
|
|
||
|
|
||
| def _generate( | ||
| model: LLM, | ||
| prompt: str, | ||
| num_prompt_tokens: int, | ||
| temperature: float = 0, | ||
| bad_words: Optional[List[str]] = None, | ||
| ) -> List[int]: | ||
| sampling_params = SamplingParams( | ||
| temperature=temperature, | ||
| bad_words=bad_words, | ||
| ) | ||
|
|
||
| # [([output_token_ids, ], [output_text, ]), ] | ||
| output = model.generate([prompt], sampling_params=sampling_params) | ||
|
|
||
| output_token_ids = output[0][0][0][num_prompt_tokens:] | ||
| # [0] first (and only) request output | ||
| # [0] token_ids (not text) | ||
| # [0] first (and only) output completion | ||
|
|
||
| return output_token_ids | ||
|
|
||
|
|
||
| class TestOneTokenBadWord: | ||
| MODEL = "TheBloke/Llama-2-7B-fp16" | ||
|
|
||
| PROMPT = "Hi! How are" | ||
| TARGET_TOKEN = "you" | ||
|
|
||
| def setup_method(self, method): | ||
| self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL, | ||
| add_prefix_space=True) | ||
|
|
||
| self.num_prompt_tokens = len(self._encode(self.PROMPT)) | ||
| self.target_token_id = self._encode(self.TARGET_TOKEN, | ||
| add_special_tokens=False)[0] | ||
|
|
||
| def test_one_token_bad_word(self, vllm_runner): | ||
| with vllm_runner(self.MODEL) as llm: | ||
| output_token_ids = self._generate(llm) | ||
| assert output_token_ids[0] == self.target_token_id | ||
|
|
||
| output_token_ids = self._generate(llm, | ||
| bad_words=[self.TARGET_TOKEN]) | ||
| assert self.target_token_id not in output_token_ids | ||
|
|
||
| def _generate(self, | ||
| model: LLM, | ||
| bad_words: Optional[List[str]] = None) -> List[int]: | ||
| return _generate( | ||
| model=model, | ||
| prompt=self.PROMPT, | ||
| num_prompt_tokens=self.num_prompt_tokens, | ||
| bad_words=bad_words, | ||
| ) | ||
|
|
||
| def _encode(self, | ||
| prompt: str, | ||
| add_special_tokens: bool = True) -> List[int]: | ||
| return self.tokenizer(prompt, | ||
| add_special_tokens=add_special_tokens).input_ids | ||
|
|
||
|
|
||
| class TestTwoTokenBadWord: | ||
| # Another model (with a different tokenizer behaviour) | ||
| MODEL = "openai-community/gpt2" | ||
|
|
||
| PROMPT = "How old are you? I am 10" | ||
| TARGET_TOKEN1 = "years" | ||
| TARGET_TOKEN2 = "old" | ||
| NEIGHBOUR_TOKEN2 = "older" | ||
|
|
||
| def setup_method(self, method): | ||
| self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL, | ||
| add_prefix_space=True) | ||
|
|
||
| self.num_prompt_tokens = len(self._encode(self.PROMPT)) | ||
| self.target_token_id1 = self._encode(self.TARGET_TOKEN1, | ||
| add_special_tokens=False)[0] | ||
| self.target_token_id2 = self._encode(self.TARGET_TOKEN2, | ||
| add_special_tokens=False)[0] | ||
| self.neighbour_token_id2 = self._encode(self.NEIGHBOUR_TOKEN2, | ||
| add_special_tokens=False)[0] | ||
|
|
||
| def test_two_token_bad_word(self, vllm_runner): | ||
| with vllm_runner(self.MODEL) as llm: | ||
| output_token_ids = self._generate(llm) | ||
| assert output_token_ids[:2] == [ | ||
| self.target_token_id1, self.target_token_id2 | ||
| ] | ||
|
|
||
| output_token_ids = self._generate(llm, | ||
| bad_words=[self.TARGET_TOKEN1]) | ||
| assert self.target_token_id1 not in output_token_ids | ||
|
|
||
| output_token_ids = self._generate(llm, | ||
| bad_words=[self.TARGET_TOKEN2]) | ||
| assert output_token_ids[0] == self.target_token_id1 | ||
| assert self.target_token_id2 not in output_token_ids | ||
|
|
||
| output_token_ids = self._generate( | ||
| llm, bad_words=[f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}']) | ||
| assert output_token_ids[0] == self.target_token_id1 | ||
| assert output_token_ids[:2] != [ | ||
| self.target_token_id1, self.target_token_id2 | ||
| ] | ||
| assert not self._contains( | ||
| output_token_ids, | ||
| [self.target_token_id1, self.target_token_id2]) | ||
| # Model dependent behaviour | ||
| assert output_token_ids[:2] == [ | ||
| self.target_token_id1, self.neighbour_token_id2 | ||
| ] | ||
|
|
||
| output_token_ids = self._generate( | ||
| llm, | ||
| bad_words=[ | ||
| f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}', | ||
| f'{self.TARGET_TOKEN1} {self.NEIGHBOUR_TOKEN2}' | ||
| ]) | ||
| assert output_token_ids[0] == self.target_token_id1 | ||
| assert output_token_ids[:2] != [ | ||
| self.target_token_id1, self.target_token_id2 | ||
| ] | ||
| assert not self._contains( | ||
| output_token_ids, | ||
| [self.target_token_id1, self.target_token_id2]) | ||
| assert output_token_ids[:2] != [ | ||
| self.target_token_id1, self.neighbour_token_id2 | ||
| ] | ||
| assert not self._contains( | ||
| output_token_ids, | ||
| [self.target_token_id1, self.neighbour_token_id2]) | ||
| assert ((self.target_token_id2 in output_token_ids) | ||
| or (self.neighbour_token_id2 in output_token_ids)) | ||
|
|
||
| def _generate(self, | ||
| model: LLM, | ||
| bad_words: Optional[List[str]] = None) -> List[int]: | ||
| return _generate( | ||
| model=model, | ||
| prompt=self.PROMPT, | ||
| num_prompt_tokens=self.num_prompt_tokens, | ||
| bad_words=bad_words, | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def _contains(sequence: List[int], subsequence: List[int]) -> bool: | ||
| searched = False | ||
|
|
||
| for start in range(len(sequence)): | ||
| end = start + len(subsequence) | ||
| current_subsequence = sequence[start:end] | ||
|
|
||
| if len(current_subsequence) < len(subsequence): | ||
| continue | ||
|
|
||
| searched = True | ||
|
|
||
| assert len(current_subsequence) == len(subsequence) | ||
|
|
||
| if current_subsequence == subsequence: | ||
| return True | ||
|
|
||
| assert searched, "All subsequences did not match in length..." | ||
|
|
||
| return False | ||
|
|
||
| def _encode(self, | ||
| prompt: str, | ||
| add_special_tokens: bool = True) -> List[int]: | ||
| return self.tokenizer(prompt, | ||
| add_special_tokens=add_special_tokens).input_ids |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| from typing import Callable, List, Tuple, Union | ||
|
|
||
| import torch | ||
|
|
||
| from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer | ||
|
|
||
| LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor], | ||
| Callable[[List[int], List[int], torch.Tensor], | ||
| torch.Tensor]] | ||
| """LogitsProcessor is a function that takes a list | ||
| of previously generated tokens, the logits tensor | ||
| for the next token and, optionally, prompt tokens as a | ||
| first argument, and returns a modified tensor of logits | ||
| to sample from.""" | ||
|
|
||
|
|
||
| def get_bad_words_logits_processors( | ||
| bad_words: List[str], | ||
| tokenizer: AnyTokenizer) -> List[LogitsProcessor]: | ||
DarkLight1337 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| bad_words_ids: List[List[int]] = list() | ||
|
|
||
| for bad_word in bad_words: | ||
| # To prohibit words both at the beginning | ||
| # and in the middle of text | ||
| # (related to add_prefix_space tokenizer parameter) | ||
| for add_prefix_space in [False, True]: | ||
| prefix = " " if add_prefix_space else "" | ||
| prompt = prefix + bad_word.lstrip() | ||
|
|
||
| if isinstance(tokenizer, MistralTokenizer): | ||
| # Mistral tokenizers should not add special tokens | ||
| prompt_token_ids = tokenizer.encode(prompt=prompt) | ||
| else: | ||
| prompt_token_ids = tokenizer.encode(text=prompt, | ||
| add_special_tokens=False) | ||
|
|
||
| # If no space at the beginning | ||
| # or if prefix space produces a new word token | ||
| if (not add_prefix_space) or ( | ||
| add_prefix_space | ||
| and prompt_token_ids[0] != bad_words_ids[-1][0] | ||
| and len(prompt_token_ids) == len(bad_words_ids[-1])): | ||
| bad_words_ids.append(prompt_token_ids) | ||
|
|
||
| return [NoBadWordsLogitsProcessor(bad_words_ids=bad_words_ids)] | ||
|
|
||
|
|
||
| class NoBadWordsLogitsProcessor: | ||
| _SMALLEST_LOGIT = float("-inf") | ||
| _NEUTRAL_LOGIT = 0.0 | ||
|
|
||
| def __init__(self, bad_words_ids: List[List[int]]): | ||
| self.bad_words_ids = bad_words_ids | ||
| self.word_bias: torch.FloatTensor = None | ||
|
|
||
| def __call__( | ||
| self, | ||
| past_tokens_ids: Union[List[int], Tuple[int]], | ||
| logits: torch.FloatTensor, | ||
| ) -> torch.Tensor: | ||
| if self.word_bias is None: | ||
| self._init_word_bias(logits=logits) | ||
|
|
||
| last_token_bias = torch.zeros_like(logits) | ||
|
|
||
| for bad_word_ids in self.bad_words_ids: | ||
| if len(bad_word_ids) == 1: # 1-token words already processed | ||
| continue | ||
|
|
||
| if len(bad_word_ids) > len(past_tokens_ids) + 1: | ||
DarkLight1337 marked this conversation as resolved.
Show resolved
Hide resolved
DarkLight1337 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| continue | ||
|
|
||
| prefix_length = len(bad_word_ids) - 1 | ||
| last_token_id = bad_word_ids[-1] | ||
| actual_prefix = past_tokens_ids[-prefix_length:] | ||
| expected_prefix = bad_word_ids[:prefix_length] | ||
|
|
||
| assert len(actual_prefix) == len(expected_prefix) | ||
|
|
||
| is_match = tuple(actual_prefix) == tuple(expected_prefix) | ||
| last_token_bias[last_token_id] += (self._SMALLEST_LOGIT if is_match | ||
| else self._NEUTRAL_LOGIT) | ||
|
|
||
| logits = logits + self.word_bias + last_token_bias | ||
|
|
||
| return logits | ||
|
|
||
| def _init_word_bias(self, logits: torch.FloatTensor) -> None: | ||
| # Code based on NoBadWordsLogitsProcessor and SequenceBiasLogitsProcessor # noqa: E501 | ||
| # from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py | ||
|
|
||
| vocab_size = logits.shape[-1] | ||
|
|
||
| self._check_token_ids_bounds(vocab_size=vocab_size) | ||
|
|
||
| self.word_bias = torch.zeros((vocab_size, ), | ||
| dtype=torch.float, | ||
| device=logits.device) | ||
|
|
||
| for bad_word_ids in self.bad_words_ids: | ||
| if len(bad_word_ids) == 1: | ||
| bad_word_id = bad_word_ids[-1] | ||
| self.word_bias[bad_word_id] = self._SMALLEST_LOGIT | ||
|
|
||
| def _check_token_ids_bounds(self, vocab_size: int) -> None: | ||
| invalid_token_ids = [] | ||
|
|
||
| for bad_word_ids in self.bad_words_ids: | ||
| for token_id in bad_word_ids: | ||
| if token_id < 0 or token_id >= vocab_size: | ||
| invalid_token_ids.append(token_id) | ||
|
|
||
| if len(invalid_token_ids) > 0: | ||
| raise ValueError( | ||
| f"The model vocabulary size is {vocab_size}," | ||
| f" but the following tokens" | ||
| f" were specified as bad: {invalid_token_ids}." | ||
| f" All token id values should be integers satisfying:" | ||
| f" 0 <= token_id < {vocab_size}.") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.