-
Notifications
You must be signed in to change notification settings - Fork 67
Add repetition_penalty, top_k, top_p to Completion #295
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b9f8766
1847fa2
e5de486
3d845e2
ffd9341
63d174c
aa64673
4b1b4fb
c541555
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,7 +8,7 @@ | |
| import math | ||
| import os | ||
| from dataclasses import asdict | ||
| from typing import Any, AsyncIterable, Dict, List, Optional | ||
| from typing import Any, AsyncIterable, Dict, List, Optional, Union | ||
| from uuid import uuid4 | ||
|
|
||
| from model_engine_server.common.config import hmi_config | ||
|
|
@@ -839,6 +839,54 @@ def deepspeed_result_to_tokens(result: Dict[str, Any]) -> List[TokenOutput]: | |
| return tokens | ||
|
|
||
|
|
||
| def validate_and_update_completion_params( | ||
| inference_framework: LLMInferenceFramework, | ||
| request: Union[CompletionSyncV1Request, CompletionStreamV1Request], | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. btw if the type checker is still giving you trouble maybe https://docs.python.org/3.8/library/typing.html#user-defined-generic-types would help? at least this feels like the "proper" way to do things to me
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was able to pass the check if I call the function like this
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: would bias against
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does something like
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't work... I think the problem is if the return type is union, will have to do some type narrowing https://mypy.readthedocs.io/en/stable/type_narrowing.html
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can use generics https://mypy.readthedocs.io/en/stable/generics.html |
||
| ) -> Union[CompletionSyncV1Request, CompletionStreamV1Request]: | ||
| # top_k, top_p | ||
| if inference_framework in [ | ||
| LLMInferenceFramework.TEXT_GENERATION_INFERENCE, | ||
| LLMInferenceFramework.VLLM, | ||
| LLMInferenceFramework.LIGHTLLM, | ||
| ]: | ||
| if request.temperature == 0: | ||
| if request.top_k not in [-1, None] or request.top_p not in [1.0, None]: | ||
| raise ObjectHasInvalidValueException( | ||
| "top_k and top_p can't be enabled when temperature is 0." | ||
| ) | ||
| if request.top_k == 0: | ||
| raise ObjectHasInvalidValueException( | ||
| "top_k needs to be strictly positive, or set it to be -1 / None to disable top_k." | ||
| ) | ||
| if inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: | ||
| request.top_k = None if request.top_k == -1 else request.top_k | ||
| request.top_p = None if request.top_p == 1.0 else request.top_p | ||
| if inference_framework in [LLMInferenceFramework.VLLM, LLMInferenceFramework.LIGHTLLM]: | ||
| request.top_k = -1 if request.top_k is None else request.top_k | ||
| request.top_p = 1.0 if request.top_p is None else request.top_p | ||
| else: | ||
| if request.top_k or request.top_p: | ||
| raise ObjectHasInvalidValueException( | ||
| "top_k and top_p are only supported in text-generation-inference, vllm, lightllm." | ||
| ) | ||
|
|
||
| # presence_penalty, frequency_penalty | ||
| if inference_framework in [LLMInferenceFramework.VLLM, LLMInferenceFramework.LIGHTLLM]: | ||
| request.presence_penalty = ( | ||
| 0.0 if request.presence_penalty is None else request.presence_penalty | ||
| ) | ||
| request.frequency_penalty = ( | ||
| 0.0 if request.frequency_penalty is None else request.frequency_penalty | ||
| ) | ||
| else: | ||
| if request.presence_penalty or request.frequency_penalty: | ||
| raise ObjectHasInvalidValueException( | ||
| "presence_penalty and frequency_penalty are only supported in vllm, lightllm." | ||
| ) | ||
|
|
||
| return request | ||
|
|
||
|
|
||
| class CompletionSyncV1UseCase: | ||
| """ | ||
| Use case for running a prompt completion on an LLM endpoint. | ||
|
|
@@ -983,6 +1031,15 @@ async def execute( | |
| endpoint_id=model_endpoint.record.id | ||
| ) | ||
| endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) | ||
| validated_request = validate_and_update_completion_params( | ||
| endpoint_content.inference_framework, request | ||
| ) | ||
| if not isinstance(validated_request, CompletionSyncV1Request): | ||
| raise ValueError( | ||
| f"request has type {validated_request.__class__.__name__}, expected type CompletionSyncV1Request" | ||
| ) | ||
| request = validated_request | ||
|
|
||
| if endpoint_content.inference_framework == LLMInferenceFramework.DEEPSPEED: | ||
| args: Any = { | ||
| "prompts": [request.prompt], | ||
|
|
@@ -1036,6 +1093,10 @@ async def execute( | |
| if request.temperature > 0: | ||
| tgi_args["parameters"]["temperature"] = request.temperature | ||
| tgi_args["parameters"]["do_sample"] = True | ||
| tgi_args["parameters"]["top_k"] = request.top_k | ||
| tgi_args["parameters"]["top_p"] = request.top_p | ||
| else: | ||
| tgi_args["parameters"]["do_sample"] = False | ||
|
|
||
| inference_request = SyncEndpointPredictV1Request( | ||
| args=tgi_args, | ||
|
|
@@ -1064,10 +1125,15 @@ async def execute( | |
| vllm_args: Any = { | ||
| "prompt": request.prompt, | ||
| "max_tokens": request.max_new_tokens, | ||
| "presence_penalty": request.presence_penalty, | ||
| "frequency_penalty": request.frequency_penalty, | ||
| } | ||
| if request.stop_sequences is not None: | ||
| vllm_args["stop"] = request.stop_sequences | ||
| vllm_args["temperature"] = request.temperature | ||
| if request.temperature > 0: | ||
| vllm_args["top_k"] = request.top_k | ||
| vllm_args["top_p"] = request.top_p | ||
| if request.return_token_log_probs: | ||
| vllm_args["logprobs"] = 1 | ||
|
|
||
|
|
@@ -1098,12 +1164,16 @@ async def execute( | |
| "inputs": request.prompt, | ||
| "parameters": { | ||
| "max_new_tokens": request.max_new_tokens, | ||
| "presence_penalty": request.presence_penalty, | ||
| "frequency_penalty": request.frequency_penalty, | ||
| }, | ||
| } | ||
| # TODO: implement stop sequences | ||
| if request.temperature > 0: | ||
| lightllm_args["parameters"]["temperature"] = request.temperature | ||
| lightllm_args["parameters"]["do_sample"] = True | ||
| lightllm_args["top_k"] = request.top_k | ||
| lightllm_args["top_p"] = request.top_p | ||
| else: | ||
| lightllm_args["parameters"]["do_sample"] = False | ||
| if request.return_token_log_probs: | ||
|
|
@@ -1172,6 +1242,7 @@ async def execute( | |
|
|
||
| request_id = str(uuid4()) | ||
| add_trace_request_id(request_id) | ||
|
|
||
| model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( | ||
| owner=user.team_id, name=model_endpoint_name, order_by=None | ||
| ) | ||
|
|
@@ -1209,6 +1280,14 @@ async def execute( | |
| ) | ||
|
|
||
| model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) | ||
| validated_request = validate_and_update_completion_params( | ||
| model_content.inference_framework, request | ||
| ) | ||
| if not isinstance(validated_request, CompletionStreamV1Request): | ||
| raise ValueError( | ||
| f"request has type {validated_request.__class__.__name__}, expected type CompletionStreamV1Request" | ||
| ) | ||
| request = validated_request | ||
|
|
||
| args: Any = None | ||
| if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED: | ||
|
|
@@ -1237,14 +1316,23 @@ async def execute( | |
| if request.temperature > 0: | ||
| args["parameters"]["temperature"] = request.temperature | ||
| args["parameters"]["do_sample"] = True | ||
| args["parameters"]["top_k"] = request.top_k | ||
| args["parameters"]["top_p"] = request.top_p | ||
| else: | ||
| args["parameters"]["do_sample"] = False | ||
| elif model_content.inference_framework == LLMInferenceFramework.VLLM: | ||
| args = { | ||
| "prompt": request.prompt, | ||
| "max_tokens": request.max_new_tokens, | ||
| "presence_penalty": request.presence_penalty, | ||
| "frequency_penalty": request.frequency_penalty, | ||
| } | ||
| if request.stop_sequences is not None: | ||
| args["stop"] = request.stop_sequences | ||
| args["temperature"] = request.temperature | ||
| if request.temperature > 0: | ||
| args["top_k"] = request.top_k | ||
| args["top_p"] = request.top_p | ||
| if request.return_token_log_probs: | ||
| args["logprobs"] = 1 | ||
| args["stream"] = True | ||
|
|
@@ -1253,12 +1341,16 @@ async def execute( | |
| "inputs": request.prompt, | ||
| "parameters": { | ||
| "max_new_tokens": request.max_new_tokens, | ||
| "presence_penalty": request.presence_penalty, | ||
| "frequency_penalty": request.frequency_penalty, | ||
| }, | ||
| } | ||
| # TODO: stop sequences | ||
| if request.temperature > 0: | ||
| args["parameters"]["temperature"] = request.temperature | ||
| args["parameters"]["do_sample"] = True | ||
| args["parameters"]["top_k"] = request.top_k | ||
| args["parameters"]["top_p"] = request.top_p | ||
| else: | ||
| args["parameters"]["do_sample"] = False | ||
| if request.return_token_log_probs: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feels a bit weird to be linking to OpenAI docs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vllm didn't link sources for them, but google search result says presence_penalty and frequency_penalty are from OpenAI, and the range matches too, so I linked it here