diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 743454301838..c8d7164d3b4c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -650,6 +650,9 @@ def _check_stop(self, seq: Sequence, seq.output_text = seq.output_text[:-len(stop_str)] seq.status = SequenceStatus.FINISHED_STOPPED return + if seq.get_last_token_id() in sampling_params.stop_token_ids: + seq.status = SequenceStatus.FINISHED_STOPPED + return # Check if the sequence has reached max_model_len. if seq.get_len() > self.scheduler_config.max_model_len: diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index bc827b344195..1cffeed91ddb 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -217,6 +217,7 @@ async def create_chat_completion(request: ChatCompletionRequest, temperature=request.temperature, top_p=request.top_p, stop=request.stop, + stop_token_ids=request.stop_token_ids, max_tokens=request.max_tokens, best_of=request.best_of, top_k=request.top_k, @@ -425,6 +426,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): top_p=request.top_p, top_k=request.top_k, stop=request.stop, + stop_token_ids=request.stop_token_ids, ignore_eos=request.ignore_eos, max_tokens=request.max_tokens, logprobs=request.logprobs, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 701f704234ad..b45e4146e29f 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -70,6 +70,7 @@ class ChatCompletionRequest(BaseModel): top_k: Optional[int] = -1 ignore_eos: Optional[bool] = False use_beam_search: Optional[bool] = False + stop_token_ids: Optional[List[int]] = Field(default_factory=list) class CompletionRequest(BaseModel): @@ -94,6 +95,7 @@ class CompletionRequest(BaseModel): top_k: Optional[int] = -1 ignore_eos: Optional[bool] = False use_beam_search: Optional[bool] = False + stop_token_ids: Optional[List[int]] = Field(default_factory=list) class LogProbs(BaseModel): diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index d24404b8d3e0..85fb503aa3d8 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -45,6 +45,9 @@ class SamplingParams: (canonical beam search algorithm). stop: List of strings that stop the generation when they are generated. The returned output will not contain the stop strings. + stop_token_ids: List of tokens that stop the generation when they are + generated. The returned output will contain the stop tokens unless + the stop tokens are sepcial tokens. ignore_eos: Whether to ignore the EOS token and continue generating tokens after the EOS token is generated. max_tokens: Maximum number of tokens to generate per output sequence. @@ -64,6 +67,7 @@ def __init__( length_penalty: float = 1.0, early_stopping: Union[bool, str] = False, stop: Union[None, str, List[str]] = None, + stop_token_ids: List[int] = None, ignore_eos: bool = False, max_tokens: int = 16, logprobs: Optional[int] = None, @@ -84,6 +88,10 @@ def __init__( self.stop = [stop] else: self.stop = list(stop) + if stop_token_ids is None: + self.stop_token_ids = [] + else: + self.stop_token_ids = list(stop_token_ids) self.ignore_eos = ignore_eos self.max_tokens = max_tokens self.logprobs = logprobs