Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions clients/python/llmengine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
ListFilesResponse,
ListFineTunesResponse,
ListLLMEndpointsResponse,
ModelDownloadRequest,
ModelDownloadResponse,
UploadFileResponse,
)
from llmengine.file import File
Expand All @@ -51,6 +53,8 @@
"CreateFineTuneResponse",
"DeleteFileResponse",
"DeleteLLMEndpointResponse",
"ModelDownloadRequest",
"ModelDownloadResponse",
"GetFileContentResponse",
"File",
"FineTune",
Expand Down
22 changes: 22 additions & 0 deletions clients/python/llmengine/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,28 @@ class GetFineTuneEventsResponse(BaseModel):
events: List[LLMFineTuneEvent] = Field(..., description="List of fine-tuning events.")


class ModelDownloadRequest(BaseModel):
"""
Request object for downloading a model.
"""

model_name: str = Field(..., description="Name of the model to download.")
download_format: Optional[str] = Field(
default="hugging_face",
description="Desired return format for downloaded model weights (default=hugging_face).",
)


class ModelDownloadResponse(BaseModel):
"""
Response object for downloading a model.
"""

urls: Dict[str, str] = Field(
..., description="Dictionary of (file_name, url) pairs to download the model from."
)


class UploadFileResponse(BaseModel):
"""Response object for uploading a file."""

Expand Down
50 changes: 50 additions & 0 deletions clients/python/llmengine/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
ListLLMEndpointsResponse,
LLMInferenceFramework,
LLMSource,
ModelDownloadRequest,
ModelDownloadResponse,
ModelEndpointType,
PostInferenceHooks,
Quantization,
Expand Down Expand Up @@ -366,3 +368,51 @@ def delete(cls, model: str) -> DeleteLLMEndpointResponse:
"""
response = cls._delete(f"v1/llm/model-endpoints/{model}", timeout=DEFAULT_TIMEOUT)
return DeleteLLMEndpointResponse.parse_obj(response)

@classmethod
def download(
cls,
model_name: str,
download_format: str = "hugging_face",
) -> ModelDownloadResponse:
"""
Download a fine-tuned model.

This API can be used to download the resulting model from a fine-tuning job.
It takes the `model_name` and `download_format` as parameter and returns a
response object which contains a list of urls associated with the fine-tuned model.
The user can then download these urls to obtain the fine-tuned model. If called
on a nonexistent model, an error will be thrown.

Args:
model_name (`str`):
name of the fine-tuned model
download_format (`str`):
download format requested (default=hugging_face)
Returns:
DownloadModelResponse: an object that contains a dictionary of filenames, urls from which to download the model weights.
The urls are presigned urls that grant temporary access and expire after an hour.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think this is OK for now, but not sure if this is technically true in all contexts, e.g. self-hosting.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This make sense, I'm actually not exactly sure what the behavior would be in a self-hosted context. We can think about this more as it becomes relevant in the future?


=== "Downloading model in Python"
```python
from llmengine import Model

response = Model.download("llama-2-7b.suffix.2023-07-18-12-00-00", download_format="hugging_face")
print(response.json())
```

=== "Response in JSON"
```json
{
"urls": {"my_model_file": 'https://url-to-my-model-weights'}
}
```
"""

request = ModelDownloadRequest(model_name=model_name, download_format=download_format)
response = cls.post_sync(
resource_name="v1/llm/model-endpoints/download",
data=request.dict(),
timeout=DEFAULT_TIMEOUT,
)
return ModelDownloadResponse.parse_obj(response)
4 changes: 4 additions & 0 deletions docs/api/data_types.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@

::: llmengine.DeleteLLMEndpointResponse

::: llmengine.ModelDownloadRequest

::: llmengine.ModelDownloadResponse

::: llmengine.UploadFileResponse

::: llmengine.GetFileResponse
Expand Down
1 change: 1 addition & 0 deletions docs/api/python_client.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
- get
- list
- delete
- download

::: llmengine.File
selection:
Expand Down
Loading