diff --git a/clients/python/llmengine/api_engine.py b/clients/python/llmengine/api_engine.py index 16fd3a72..b80c83eb 100644 --- a/clients/python/llmengine/api_engine.py +++ b/clients/python/llmengine/api_engine.py @@ -44,6 +44,7 @@ def _get(cls, resource_name: str, timeout: int) -> Dict[str, Any]: os.path.join(LLM_ENGINE_BASE_PATH, resource_name), timeout=timeout, headers={"x-api-key": api_key}, + auth=(api_key, ""), ) if response.status_code != 200: raise parse_error(response.status_code, response.content) @@ -60,6 +61,7 @@ def put( json=data, timeout=timeout, headers={"x-api-key": api_key}, + auth=(api_key, ""), ) if response.status_code != 200: raise parse_error(response.status_code, response.content) @@ -73,6 +75,7 @@ def _delete(cls, resource_name: str, timeout: int) -> Dict[str, Any]: os.path.join(LLM_ENGINE_BASE_PATH, resource_name), timeout=timeout, headers={"x-api-key": api_key}, + auth=(api_key, ""), ) if response.status_code != 200: raise parse_error(response.status_code, response.content) @@ -87,6 +90,7 @@ def post_sync(cls, resource_name: str, data: Dict[str, Any], timeout: int) -> Di json=data, timeout=timeout, headers={"x-api-key": api_key}, + auth=(api_key, ""), ) if response.status_code != 200: raise parse_error(response.status_code, response.content) @@ -103,6 +107,7 @@ def post_stream( json=data, timeout=timeout, headers={"x-api-key": api_key}, + auth=(api_key, ""), stream=True, ) if response.status_code != 200: diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index b3fcbf58..3242ad08 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -131,17 +131,27 @@ class CreateLLMEndpointRequest(BaseModel): # LLM specific fields model_name: str source: LLMSource = LLMSource.HUGGING_FACE - inference_framework: LLMInferenceFramework = LLMInferenceFramework.DEEPSPEED + inference_framework: LLMInferenceFramework = LLMInferenceFramework.TEXT_GENERATION_INFERENCE inference_framework_image_tag: str - num_shards: int + num_shards: int = 1 """ - Number of shards to distribute the model onto GPUs. + Number of shards to distribute the model onto GPUs. Only affects behavior for text-generation-inference models + """ + + quantize: Optional[Quantization] = None + """ + Quantization for the LLM. Only affects behavior for text-generation-inference models + """ + + checkpoint_path: Optional[str] = None + """ + Path to the checkpoint to load the model from. Only affects behavior for text-generation-inference models """ # General endpoint fields metadata: Dict[str, Any] # TODO: JSON type post_inference_hooks: Optional[List[str]] - endpoint_type: ModelEndpointType = ModelEndpointType.SYNC + endpoint_type: ModelEndpointType = ModelEndpointType.STREAMING cpus: CpuSpecificationType gpus: int memory: StorageSpecificationType @@ -156,7 +166,10 @@ class CreateLLMEndpointRequest(BaseModel): high_priority: Optional[bool] default_callback_url: Optional[HttpUrl] default_callback_auth: Optional[CallbackAuth] - public_inference: Optional[bool] = True # LLM endpoints are public by default. + public_inference: Optional[bool] = True + """ + Whether the endpoint can be used for inference for all users. LLM endpoints are public by default. + """ class CreateLLMEndpointResponse(BaseModel): diff --git a/clients/python/llmengine/model.py b/clients/python/llmengine/model.py index 1b6c9dba..1c854eba 100644 --- a/clients/python/llmengine/model.py +++ b/clients/python/llmengine/model.py @@ -12,6 +12,7 @@ LLMSource, ModelEndpointType, PostInferenceHooks, + Quantization, ) @@ -28,12 +29,15 @@ class Model(APIEngine): @assert_self_hosted def create( cls, + name: str, # LLM specific fields model: str, inference_framework_image_tag: str, source: LLMSource = LLMSource.HUGGING_FACE, inference_framework: LLMInferenceFramework = LLMInferenceFramework.TEXT_GENERATION_INFERENCE, num_shards: int = 4, + quantize: Optional[Quantization] = None, + checkpoint_path: Optional[str] = None, # General endpoint fields cpus: int = 32, memory: str = "192Gi", @@ -53,8 +57,11 @@ def create( """ Create an LLM model. Note: This feature is only available for self-hosted users. Args: + name (`str`): + Name of the endpoint + model (`str`): - Name of the model + Name of the base model inference_framework_image_tag (`str`): Image tag for the inference framework @@ -68,6 +75,15 @@ def create( num_shards (`int`): Number of shards for the LLM. When bigger than 1, LLM will be sharded to multiple GPUs. Number of GPUs must be larger than num_shards. + Only affects behavior for text-generation-inference models + + quantize (`Optional[Quantization]`): + Quantization for the LLM. Only affects behavior for text-generation-inference models + + checkpoint_path (`Optional[str]`): + Path to the checkpoint for the LLM. For now we only support loading a tar file from AWS S3. + Safetensors are preferred but PyTorch checkpoints are also accepted (model loading will be slower). + Only affects behavior for text-generation-inference models cpus (`int`): Number of cpus each worker should get, e.g. 1, 2, etc. This must be greater @@ -157,12 +173,14 @@ def create( post_inference_hooks_strs.append(hook) request = CreateLLMEndpointRequest( - name=model, + name=name, model_name=model, source=source, inference_framework=inference_framework, inference_framework_image_tag=inference_framework_image_tag, num_shards=num_shards, + quantize=quantize, + checkpoint_path=checkpoint_path, cpus=cpus, endpoint_type=ModelEndpointType(endpoint_type), gpus=gpus,