Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions clients/python/llmengine/api_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def _get(cls, resource_name: str, timeout: int) -> Dict[str, Any]:
os.path.join(LLM_ENGINE_BASE_PATH, resource_name),
timeout=timeout,
headers={"x-api-key": api_key},
auth=(api_key, ""),
)
if response.status_code != 200:
raise parse_error(response.status_code, response.content)
Expand All @@ -60,6 +61,7 @@ def put(
json=data,
timeout=timeout,
headers={"x-api-key": api_key},
auth=(api_key, ""),
)
if response.status_code != 200:
raise parse_error(response.status_code, response.content)
Expand All @@ -73,6 +75,7 @@ def _delete(cls, resource_name: str, timeout: int) -> Dict[str, Any]:
os.path.join(LLM_ENGINE_BASE_PATH, resource_name),
timeout=timeout,
headers={"x-api-key": api_key},
auth=(api_key, ""),
)
if response.status_code != 200:
raise parse_error(response.status_code, response.content)
Expand All @@ -87,6 +90,7 @@ def post_sync(cls, resource_name: str, data: Dict[str, Any], timeout: int) -> Di
json=data,
timeout=timeout,
headers={"x-api-key": api_key},
auth=(api_key, ""),
)
if response.status_code != 200:
raise parse_error(response.status_code, response.content)
Expand All @@ -103,6 +107,7 @@ def post_stream(
json=data,
timeout=timeout,
headers={"x-api-key": api_key},
auth=(api_key, ""),
stream=True,
)
if response.status_code != 200:
Expand Down
23 changes: 18 additions & 5 deletions clients/python/llmengine/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,27 @@ class CreateLLMEndpointRequest(BaseModel):
# LLM specific fields
model_name: str
source: LLMSource = LLMSource.HUGGING_FACE
inference_framework: LLMInferenceFramework = LLMInferenceFramework.DEEPSPEED
inference_framework: LLMInferenceFramework = LLMInferenceFramework.TEXT_GENERATION_INFERENCE
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a breaking change? Seems like we're changing default functionality. Might be fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think people could run build deepspeed models right now

inference_framework_image_tag: str
num_shards: int
num_shards: int = 1
"""
Number of shards to distribute the model onto GPUs.
Number of shards to distribute the model onto GPUs. Only affects behavior for text-generation-inference models
"""

quantize: Optional[Quantization] = None
"""
Quantization for the LLM. Only affects behavior for text-generation-inference models
"""

checkpoint_path: Optional[str] = None
"""
Path to the checkpoint to load the model from. Only affects behavior for text-generation-inference models
"""

# General endpoint fields
metadata: Dict[str, Any] # TODO: JSON type
post_inference_hooks: Optional[List[str]]
endpoint_type: ModelEndpointType = ModelEndpointType.SYNC
endpoint_type: ModelEndpointType = ModelEndpointType.STREAMING
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this line up with the default on the server?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, but this lines up with the default endpoint type for TGI

cpus: CpuSpecificationType
gpus: int
memory: StorageSpecificationType
Expand All @@ -156,7 +166,10 @@ class CreateLLMEndpointRequest(BaseModel):
high_priority: Optional[bool]
default_callback_url: Optional[HttpUrl]
default_callback_auth: Optional[CallbackAuth]
public_inference: Optional[bool] = True # LLM endpoints are public by default.
public_inference: Optional[bool] = True
"""
Whether the endpoint can be used for inference for all users. LLM endpoints are public by default.
"""


class CreateLLMEndpointResponse(BaseModel):
Expand Down
22 changes: 20 additions & 2 deletions clients/python/llmengine/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
LLMSource,
ModelEndpointType,
PostInferenceHooks,
Quantization,
)


Expand All @@ -28,12 +29,15 @@ class Model(APIEngine):
@assert_self_hosted
def create(
cls,
name: str,
# LLM specific fields
model: str,
inference_framework_image_tag: str,
source: LLMSource = LLMSource.HUGGING_FACE,
inference_framework: LLMInferenceFramework = LLMInferenceFramework.TEXT_GENERATION_INFERENCE,
num_shards: int = 4,
quantize: Optional[Quantization] = None,
checkpoint_path: Optional[str] = None,
# General endpoint fields
cpus: int = 32,
memory: str = "192Gi",
Expand All @@ -53,8 +57,11 @@ def create(
"""
Create an LLM model. Note: This feature is only available for self-hosted users.
Args:
name (`str`):
Name of the endpoint

model (`str`):
Name of the model
Name of the base model

inference_framework_image_tag (`str`):
Image tag for the inference framework
Expand All @@ -68,6 +75,15 @@ def create(
num_shards (`int`):
Number of shards for the LLM. When bigger than 1, LLM will be sharded
to multiple GPUs. Number of GPUs must be larger than num_shards.
Only affects behavior for text-generation-inference models

quantize (`Optional[Quantization]`):
Quantization for the LLM. Only affects behavior for text-generation-inference models

checkpoint_path (`Optional[str]`):
Path to the checkpoint for the LLM. For now we only support loading a tar file from AWS S3.
Safetensors are preferred but PyTorch checkpoints are also accepted (model loading will be slower).
Only affects behavior for text-generation-inference models

cpus (`int`):
Number of cpus each worker should get, e.g. 1, 2, etc. This must be greater
Expand Down Expand Up @@ -157,12 +173,14 @@ def create(
post_inference_hooks_strs.append(hook)

request = CreateLLMEndpointRequest(
name=model,
name=name,
model_name=model,
source=source,
inference_framework=inference_framework,
inference_framework_image_tag=inference_framework_image_tag,
num_shards=num_shards,
quantize=quantize,
checkpoint_path=checkpoint_path,
cpus=cpus,
endpoint_type=ModelEndpointType(endpoint_type),
gpus=gpus,
Expand Down