Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions clients/python/llmengine/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ async def acreate(
temperature: float = 0.2,
stop_sequences: Optional[List[str]] = None,
return_token_log_probs: Optional[bool] = False,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
timeout: int = COMPLETION_TIMEOUT,
stream: bool = False,
) -> Union[CompletionSyncResponse, AsyncIterable[CompletionStreamResponse]]:
Expand Down Expand Up @@ -72,6 +76,26 @@ async def acreate(
Whether to return the log probabilities of generated tokens.
When True, the response will include a list of tokens and their log probabilities.

presence_penalty (Optional[float]):
Only supported in vllm, lightllm
Penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
https://platform.openai.com/docs/guides/gpt/parameter-details
Range: [0.0, 2.0]. Higher values encourage the model to use new tokens.

frequency_penalty (Optional[float]):
Only supported in vllm, lightllm
Penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
https://platform.openai.com/docs/guides/gpt/parameter-details
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels a bit weird to be linking to OpenAI docs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm didn't link sources for them, but google search result says presence_penalty and frequency_penalty are from OpenAI, and the range matches too, so I linked it here

Range: [0.0, 2.0]. Higher values encourage the model to use new tokens.

top_k (Optional[int]):
Integer that controls the number of top tokens to consider.
Range: [1, infinity). -1 means consider all tokens.

top_p (Optional[float]):
Float that controls the cumulative probability of the top tokens to consider.
Range: (0.0, 1.0]. 1.0 means consider all tokens.

timeout (int):
Timeout in seconds. This is the maximum amount of time you are willing to wait for a response.

Expand Down Expand Up @@ -164,6 +188,10 @@ async def _acreate_stream(
temperature=temperature,
stop_sequences=stop_sequences,
return_token_log_probs=return_token_log_probs,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
top_k=top_k,
top_p=top_p,
timeout=timeout,
)

Expand All @@ -184,6 +212,10 @@ async def _acreate_sync(**kwargs) -> CompletionSyncResponse:
temperature=temperature,
stop_sequences=stop_sequences,
return_token_log_probs=return_token_log_probs,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
top_k=top_k,
top_p=top_p,
)

@classmethod
Expand All @@ -195,6 +227,10 @@ def create(
temperature: float = 0.2,
stop_sequences: Optional[List[str]] = None,
return_token_log_probs: Optional[bool] = False,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
timeout: int = COMPLETION_TIMEOUT,
stream: bool = False,
) -> Union[CompletionSyncResponse, Iterator[CompletionStreamResponse]]:
Expand Down Expand Up @@ -235,6 +271,26 @@ def create(
Whether to return the log probabilities of generated tokens.
When True, the response will include a list of tokens and their log probabilities.

presence_penalty (Optional[float]):
Only supported in vllm, lightllm
Penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
https://platform.openai.com/docs/guides/gpt/parameter-details
Range: [0.0, 2.0]. Higher values encourage the model to use new tokens.

frequency_penalty (Optional[float]):
Only supported in vllm, lightllm
Penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
https://platform.openai.com/docs/guides/gpt/parameter-details
Range: [0.0, 2.0]. Higher values encourage the model to use new tokens.

top_k (Optional[int]):
Integer that controls the number of top tokens to consider.
Range: [1, infinity). -1 means consider all tokens.

top_p (Optional[float]):
Float that controls the cumulative probability of the top tokens to consider.
Range: (0.0, 1.0]. 1.0 means consider all tokens.

timeout (int):
Timeout in seconds. This is the maximum amount of time you are willing to wait for a response.

Expand Down Expand Up @@ -317,6 +373,10 @@ def _create_stream(**kwargs):
temperature=temperature,
stop_sequences=stop_sequences,
return_token_log_probs=return_token_log_probs,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
top_k=top_k,
top_p=top_p,
)

else:
Expand All @@ -326,6 +386,10 @@ def _create_stream(**kwargs):
temperature=temperature,
stop_sequences=stop_sequences,
return_token_log_probs=return_token_log_probs,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
top_k=top_k,
top_p=top_p,
).dict()
response = cls.post_sync(
resource_name=f"v1/llm/completions-sync?model_endpoint_name={model}",
Expand Down
8 changes: 8 additions & 0 deletions clients/python/llmengine/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,10 @@ class CompletionSyncV1Request(BaseModel):
temperature: float = Field(..., ge=0.0)
stop_sequences: Optional[List[str]] = Field(default=None)
return_token_log_probs: Optional[bool] = Field(default=False)
presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
top_k: Optional[int] = Field(default=None, ge=-1)
top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0)


class TokenOutput(BaseModel):
Expand Down Expand Up @@ -330,6 +334,10 @@ class CompletionStreamV1Request(BaseModel):
temperature: float = Field(..., ge=0.0)
stop_sequences: Optional[List[str]] = Field(default=None)
return_token_log_probs: Optional[bool] = Field(default=False)
presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
top_k: Optional[int] = Field(default=None, ge=-1)
top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0)


class CompletionStreamOutput(BaseModel):
Expand Down
40 changes: 38 additions & 2 deletions model-engine/model_engine_server/common/dtos/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class CompletionSyncV1Request(BaseModel):

prompt: str
max_new_tokens: int
temperature: float = Field(ge=0, le=1)
temperature: float = Field(ge=0.0, le=1.0)
"""
Temperature of the sampling. Setting to 0 equals to greedy sampling.
"""
Expand All @@ -116,6 +116,24 @@ class CompletionSyncV1Request(BaseModel):
"""
Whether to return the log probabilities of the tokens.
"""
presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
"""
Only supported in vllm, lightllm
Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty
"""
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
"""
Only supported in vllm, lightllm
Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty
"""
top_k: Optional[int] = Field(default=None, ge=-1)
"""
Controls the number of top tokens to consider. -1 means consider all tokens.
"""
top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0)
"""
Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens.
"""


class TokenOutput(BaseModel):
Expand Down Expand Up @@ -145,7 +163,7 @@ class CompletionStreamV1Request(BaseModel):

prompt: str
max_new_tokens: int
temperature: float = Field(ge=0, le=1)
temperature: float = Field(ge=0.0, le=1.0)
"""
Temperature of the sampling. Setting to 0 equals to greedy sampling.
"""
Expand All @@ -157,6 +175,24 @@ class CompletionStreamV1Request(BaseModel):
"""
Whether to return the log probabilities of the tokens. Only affects behavior for text-generation-inference models
"""
presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
"""
Only supported in vllm, lightllm
Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty
"""
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
"""
Only supported in vllm, lightllm
Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty
"""
top_k: Optional[int] = Field(default=None, ge=-1)
"""
Controls the number of top tokens to consider. -1 means consider all tokens.
"""
top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0)
"""
Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens.
"""


class CompletionStreamOutput(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import math
import os
from dataclasses import asdict
from typing import Any, AsyncIterable, Dict, List, Optional
from typing import Any, AsyncIterable, Dict, List, Optional, Union
from uuid import uuid4

from model_engine_server.common.config import hmi_config
Expand Down Expand Up @@ -839,6 +839,54 @@ def deepspeed_result_to_tokens(result: Dict[str, Any]) -> List[TokenOutput]:
return tokens


def validate_and_update_completion_params(
inference_framework: LLMInferenceFramework,
request: Union[CompletionSyncV1Request, CompletionStreamV1Request],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw if the type checker is still giving you trouble maybe https://docs.python.org/3.8/library/typing.html#user-defined-generic-types would help? at least this feels like the "proper" way to do things to me

Copy link
Contributor Author

@francesy-scale francesy-scale Sep 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was able to pass the check if I call the function like this

new_request = validate_and_update_completion_params(endpoint_content.inference_framework, request)
assert isinstance(new_request, CompletionSyncV1Request)
request = new_request

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would bias against asserts in production code. let's convert to an if-statement that throws a ValueError

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does something like new_request: CompletionSyncV1Request = validate... also work? basically another way of telling the type checker that you know it'll be a CompletionSyncV1Request

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't work... I think the problem is if the return type is union, will have to do some type narrowing https://mypy.readthedocs.io/en/stable/type_narrowing.html

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use generics https://mypy.readthedocs.io/en/stable/generics.html

from typing import TypeVar, Sequence

T = TypeVar('T')

def validate_and_update_completion_params(
    inference_framework: LLMInferenceFramework,
    request: T,
):
...

) -> Union[CompletionSyncV1Request, CompletionStreamV1Request]:
# top_k, top_p
if inference_framework in [
LLMInferenceFramework.TEXT_GENERATION_INFERENCE,
LLMInferenceFramework.VLLM,
LLMInferenceFramework.LIGHTLLM,
]:
if request.temperature == 0:
if request.top_k not in [-1, None] or request.top_p not in [1.0, None]:
raise ObjectHasInvalidValueException(
"top_k and top_p can't be enabled when temperature is 0."
)
if request.top_k == 0:
raise ObjectHasInvalidValueException(
"top_k needs to be strictly positive, or set it to be -1 / None to disable top_k."
)
if inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE:
request.top_k = None if request.top_k == -1 else request.top_k
request.top_p = None if request.top_p == 1.0 else request.top_p
if inference_framework in [LLMInferenceFramework.VLLM, LLMInferenceFramework.LIGHTLLM]:
request.top_k = -1 if request.top_k is None else request.top_k
request.top_p = 1.0 if request.top_p is None else request.top_p
else:
if request.top_k or request.top_p:
raise ObjectHasInvalidValueException(
"top_k and top_p are only supported in text-generation-inference, vllm, lightllm."
)

# presence_penalty, frequency_penalty
if inference_framework in [LLMInferenceFramework.VLLM, LLMInferenceFramework.LIGHTLLM]:
request.presence_penalty = (
0.0 if request.presence_penalty is None else request.presence_penalty
)
request.frequency_penalty = (
0.0 if request.frequency_penalty is None else request.frequency_penalty
)
else:
if request.presence_penalty or request.frequency_penalty:
raise ObjectHasInvalidValueException(
"presence_penalty and frequency_penalty are only supported in vllm, lightllm."
)

return request


class CompletionSyncV1UseCase:
"""
Use case for running a prompt completion on an LLM endpoint.
Expand Down Expand Up @@ -983,6 +1031,15 @@ async def execute(
endpoint_id=model_endpoint.record.id
)
endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint)
validated_request = validate_and_update_completion_params(
endpoint_content.inference_framework, request
)
if not isinstance(validated_request, CompletionSyncV1Request):
raise ValueError(
f"request has type {validated_request.__class__.__name__}, expected type CompletionSyncV1Request"
)
request = validated_request

if endpoint_content.inference_framework == LLMInferenceFramework.DEEPSPEED:
args: Any = {
"prompts": [request.prompt],
Expand Down Expand Up @@ -1036,6 +1093,10 @@ async def execute(
if request.temperature > 0:
tgi_args["parameters"]["temperature"] = request.temperature
tgi_args["parameters"]["do_sample"] = True
tgi_args["parameters"]["top_k"] = request.top_k
tgi_args["parameters"]["top_p"] = request.top_p
else:
tgi_args["parameters"]["do_sample"] = False

inference_request = SyncEndpointPredictV1Request(
args=tgi_args,
Expand Down Expand Up @@ -1064,10 +1125,15 @@ async def execute(
vllm_args: Any = {
"prompt": request.prompt,
"max_tokens": request.max_new_tokens,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
}
if request.stop_sequences is not None:
vllm_args["stop"] = request.stop_sequences
vllm_args["temperature"] = request.temperature
if request.temperature > 0:
vllm_args["top_k"] = request.top_k
vllm_args["top_p"] = request.top_p
if request.return_token_log_probs:
vllm_args["logprobs"] = 1

Expand Down Expand Up @@ -1098,12 +1164,16 @@ async def execute(
"inputs": request.prompt,
"parameters": {
"max_new_tokens": request.max_new_tokens,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
},
}
# TODO: implement stop sequences
if request.temperature > 0:
lightllm_args["parameters"]["temperature"] = request.temperature
lightllm_args["parameters"]["do_sample"] = True
lightllm_args["top_k"] = request.top_k
lightllm_args["top_p"] = request.top_p
else:
lightllm_args["parameters"]["do_sample"] = False
if request.return_token_log_probs:
Expand Down Expand Up @@ -1172,6 +1242,7 @@ async def execute(

request_id = str(uuid4())
add_trace_request_id(request_id)

model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints(
owner=user.team_id, name=model_endpoint_name, order_by=None
)
Expand Down Expand Up @@ -1209,6 +1280,14 @@ async def execute(
)

model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint)
validated_request = validate_and_update_completion_params(
model_content.inference_framework, request
)
if not isinstance(validated_request, CompletionStreamV1Request):
raise ValueError(
f"request has type {validated_request.__class__.__name__}, expected type CompletionStreamV1Request"
)
request = validated_request

args: Any = None
if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED:
Expand Down Expand Up @@ -1237,14 +1316,23 @@ async def execute(
if request.temperature > 0:
args["parameters"]["temperature"] = request.temperature
args["parameters"]["do_sample"] = True
args["parameters"]["top_k"] = request.top_k
args["parameters"]["top_p"] = request.top_p
else:
args["parameters"]["do_sample"] = False
elif model_content.inference_framework == LLMInferenceFramework.VLLM:
args = {
"prompt": request.prompt,
"max_tokens": request.max_new_tokens,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
}
if request.stop_sequences is not None:
args["stop"] = request.stop_sequences
args["temperature"] = request.temperature
if request.temperature > 0:
args["top_k"] = request.top_k
args["top_p"] = request.top_p
if request.return_token_log_probs:
args["logprobs"] = 1
args["stream"] = True
Expand All @@ -1253,12 +1341,16 @@ async def execute(
"inputs": request.prompt,
"parameters": {
"max_new_tokens": request.max_new_tokens,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
},
}
# TODO: stop sequences
if request.temperature > 0:
args["parameters"]["temperature"] = request.temperature
args["parameters"]["do_sample"] = True
args["parameters"]["top_k"] = request.top_k
args["parameters"]["top_p"] = request.top_p
else:
args["parameters"]["do_sample"] = False
if request.return_token_log_probs:
Expand Down