Skip to content

launch v2 bundles API #91

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 59 additions & 19 deletions docs/concepts/model_bundles.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,22 @@ are created by packaging a model up into a deployable format.

## Creating Model Bundles

There are three methods for creating model bundles:
There are four methods for creating model bundles:
[`create_model_bundle_from_callable_v2`](/api/client/#launch.client.LaunchClient.create_model_bundle_from_callable_v2),
[`create_model_bundle_from_dirs_v2`](/api/client/#launch.client.LaunchClient.create_model_bundle_from_dirs_v2),
[`create_model_bundle_from_runnable_image_v2`](/api/client/#launch.client.LaunchClient.create_model_bundle_from_runnable_image_v2),
and
[`create_model_bundle_from_runnable_image_v2`](/api/client/#launch.client.LaunchClient.create_model_bundle_from_runnable_image_v2).
[`create_model_bundle_from_triton_enhanced_runnable_image_v2`](/api/client/#launch.client.LaunchClient.create_model_bundle_from_triton_enhanced_runnable_image_v2).

The first directly pickles a user-specified `load_predict_fn`, a function which
loads the model and returns a `predict_fn`, a function which takes in a request.
The second takes in directories containing a `load_predict_fn` and the
module path to the `load_predict_fn`.
The third takes a Docker image and a command that starts a process listening for
requests at port 5005 using HTTP and exposes `POST /predict` and
`GET /healthz` endpoints.
`GET /readyz` endpoints.
The fourth is a variant of the third that also starts an instance of the NVidia
Triton framework for efficient model serving.

Each of these modes of creating a model bundle is called a "Flavor".

Expand All @@ -43,6 +47,12 @@ Each of these modes of creating a model bundle is called a "Flavor".
* You are comfortable with building a web server and Docker image to serve your model.


A `TritonEnhancedRunnableImageFlavor` (a runnable image variant) is good if:

* You want to use a `RunnableImageFlavor`
* You also want to use [NVidia's `tritonserver`](https://developer.nvidia.com/nvidia-triton-inference-server) to accelerate model inference


=== "Creating From Callables"
```py
import os
Expand Down Expand Up @@ -178,36 +188,66 @@ Each of these modes of creating a model bundle is called a "Flavor".

BUNDLE_PARAMS = {
"model_bundle_name": "test-bundle",
"load_model_fn": my_load_model_fn,
"load_predict_fn": my_load_predict_fn,
"request_schema": MyRequestSchema,
"response_schema": MyResponseSchema,
"repository": "launch_rearch",
"tag": "12b9131c5a1489c76592cddd186962cce965f0f6-cpu",
"repository": "...",
"tag": "...",
"command": [
"dumb-init",
"--",
"ddtrace-run",
"run-service",
"--config",
"/install/launch_rearch/config/service--user_defined_code.yaml",
"--concurrency",
"1",
"--http",
"production",
"--port",
"5005",
...
],
"env": {
"TEST_KEY": "test_value",
},
"readiness_initial_delay_seconds": 30,
}

client = LaunchClient(api_key=os.getenv("LAUNCH_API_KEY"))
client.create_model_bundle_from_runnable_image_v2(**BUNDLE_PARAMS)
```


=== "Creating From a Triton Enhanced Runnable Image"
```py
import os
from pydantic import BaseModel
from launch import LaunchClient


class MyRequestSchema(BaseModel):
x: int
y: str

class MyResponseSchema(BaseModel):
__root__: int


BUNDLE_PARAMS = {
"model_bundle_name": "test-bundle",
"request_schema": MyRequestSchema,
"response_schema": MyResponseSchema,
"repository": "...",
"tag": "...",
"command": [
...
],
"env": {
"TEST_KEY": "test_value",
},
"readiness_initial_delay_seconds": 30,
"triton_model_repository": "...",
"triton_model_replicas": {"": ""},
"triton_num_cpu": 4.0,
"triton_commit_tag": "",
"triton_storage": "",
"triton_memory": "",
"triton_readiness_initial_delay_seconds": 300,
}

client = LaunchClient(api_key=os.getenv("LAUNCH_API_KEY"))
client.create_model_bundle_from_triton_enhanced_runnable_image_v2(**BUNDLE_PARAMS)
```


## Configuring Model Bundles

The `app_config` field of a model bundle is a dictionary that can be used to
Expand Down
20 changes: 9 additions & 11 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,20 @@ def my_load_predict_fn(model):
return returns_model_of_x_plus_len_of_y


def my_model(x):
return x * 2

ENV_PARAMS = {
"framework_type": "pytorch",
"pytorch_image_tag": "1.7.1-cuda11.0-cudnn8-runtime",
}
def my_load_model_fn():
def my_model(x):
return x * 2

return my_model

BUNDLE_PARAMS = {
"model_bundle_name": "test-bundle",
"model": my_model,
"load_predict_fn": my_load_predict_fn,
"env_params": ENV_PARAMS,
"requirements": ["pytest==7.2.1", "numpy"], # list your requirements here
"load_model_fn": my_load_model_fn,
"request_schema": MyRequestSchema,
"response_schema": MyResponseSchema,
"requirements": ["pytest==7.2.1", "numpy"], # list your requirements here
"pytorch_image_tag": "1.7.1-cuda11.0-cudnn8-runtime",
}

ENDPOINT_PARAMS = {
Expand Down Expand Up @@ -81,7 +79,7 @@ def predict_on_endpoint(request: MyRequestSchema) -> MyResponseSchema:

client = LaunchClient(api_key=os.getenv("LAUNCH_API_KEY"))

client.create_model_bundle(**BUNDLE_PARAMS)
client.create_model_bundle_from_callable_v2(**BUNDLE_PARAMS)
endpoint = client.create_model_endpoint(**ENDPOINT_PARAMS)

request = MyRequestSchema(x=5, y="hello")
Expand Down
68 changes: 7 additions & 61 deletions launch/api_client/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,9 +416,7 @@ def __init__(
self.content = content

def _serialize_json(
self,
in_data: typing.Union[None, int, float, str, bool, dict, list],
eliminate_whitespace: bool = False,
self, in_data: typing.Union[None, int, float, str, bool, dict, list], eliminate_whitespace: bool = False
) -> str:
if eliminate_whitespace:
return json.dumps(in_data, separators=self._json_encoder.compact_separators)
Expand Down Expand Up @@ -483,19 +481,7 @@ def __serialize_simple(
def serialize(
self,
in_data: typing.Union[
Schema,
Decimal,
int,
float,
str,
date,
datetime,
None,
bool,
list,
tuple,
dict,
frozendict.frozendict,
Schema, Decimal, int, float, str, date, datetime, None, bool, list, tuple, dict, frozendict.frozendict
],
) -> typing.Dict[str, str]:
if self.schema:
Expand Down Expand Up @@ -611,19 +597,7 @@ def get_prefix_separator_iterator(self) -> typing.Optional[PrefixSeparatorIterat
def serialize(
self,
in_data: typing.Union[
Schema,
Decimal,
int,
float,
str,
date,
datetime,
None,
bool,
list,
tuple,
dict,
frozendict.frozendict,
Schema, Decimal, int, float, str, date, datetime, None, bool, list, tuple, dict, frozendict.frozendict
],
prefix_separator_iterator: typing.Optional[PrefixSeparatorIterator] = None,
) -> typing.Dict[str, str]:
Expand Down Expand Up @@ -691,19 +665,7 @@ def __init__(
def serialize(
self,
in_data: typing.Union[
Schema,
Decimal,
int,
float,
str,
date,
datetime,
None,
bool,
list,
tuple,
dict,
frozendict.frozendict,
Schema, Decimal, int, float, str, date, datetime, None, bool, list, tuple, dict, frozendict.frozendict
],
) -> typing.Dict[str, str]:
if self.schema:
Expand Down Expand Up @@ -770,19 +732,7 @@ def __to_headers(in_data: typing.Tuple[typing.Tuple[str, str], ...]) -> HTTPHead
def serialize(
self,
in_data: typing.Union[
Schema,
Decimal,
int,
float,
str,
date,
datetime,
None,
bool,
list,
tuple,
dict,
frozendict.frozendict,
Schema, Decimal, int, float, str, date, datetime, None, bool, list, tuple, dict, frozendict.frozendict
],
) -> HTTPHeaderDict:
if self.schema:
Expand Down Expand Up @@ -940,9 +890,7 @@ def __deserialize_application_octet_stream(
return response.data

@staticmethod
def __deserialize_multipart_form_data(
response: urllib3.HTTPResponse,
) -> typing.Dict[str, typing.Any]:
def __deserialize_multipart_form_data(response: urllib3.HTTPResponse) -> typing.Dict[str, typing.Any]:
msg = email.message_from_bytes(response.data)
return {
part.get_param("name", header="Content-Disposition"): part.get_payload(decode=True).decode(
Expand Down Expand Up @@ -1295,9 +1243,7 @@ def _verify_typed_dict_inputs_oapg(
if required_keys_with_unset_values:
raise ApiValueError(
"{} contains invalid unset values for {} required keys: {}".format(
cls.__name__,
len(required_keys_with_unset_values),
required_keys_with_unset_values,
cls.__name__, len(required_keys_with_unset_values), required_keys_with_unset_values
)
)

Expand Down
3 changes: 2 additions & 1 deletion launch/api_client/model/batch_job_serialization_format.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ from datetime import date, datetime # noqa: F401

import frozendict # noqa: F401
import typing_extensions # noqa: F401
from launch_client import schemas # noqa: F401

from launch.api_client import schemas # noqa: F401

class BatchJobSerializationFormat(schemas.EnumBase, schemas.StrSchema):
"""NOTE: This class is auto generated by OpenAPI Generator.
Expand Down
3 changes: 2 additions & 1 deletion launch/api_client/model/batch_job_status.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ from datetime import date, datetime # noqa: F401

import frozendict # noqa: F401
import typing_extensions # noqa: F401
from launch_client import schemas # noqa: F401

from launch.api_client import schemas # noqa: F401

class BatchJobStatus(schemas.EnumBase, schemas.StrSchema):
"""NOTE: This class is auto generated by OpenAPI Generator.
Expand Down
7 changes: 4 additions & 3 deletions launch/api_client/model/callback_auth.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ from datetime import date, datetime # noqa: F401

import frozendict # noqa: F401
import typing_extensions # noqa: F401
from launch_client import schemas # noqa: F401

from launch.api_client import schemas # noqa: F401

class CallbackAuth(
schemas.ComposedSchema,
Expand Down Expand Up @@ -100,5 +101,5 @@ class CallbackAuth(
**kwargs,
)

from launch_client.model.callback_basic_auth import CallbackBasicAuth
from launch_client.model.callbackm_tls_auth import CallbackmTLSAuth
from launch.api_client.model.callback_basic_auth import CallbackBasicAuth
from launch.api_client.model.callbackm_tls_auth import CallbackmTLSAuth
3 changes: 2 additions & 1 deletion launch/api_client/model/callback_basic_auth.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ from datetime import date, datetime # noqa: F401

import frozendict # noqa: F401
import typing_extensions # noqa: F401
from launch_client import schemas # noqa: F401

from launch.api_client import schemas # noqa: F401

class CallbackBasicAuth(schemas.DictSchema):
"""NOTE: This class is auto generated by OpenAPI Generator.
Expand Down
3 changes: 2 additions & 1 deletion launch/api_client/model/callbackm_tls_auth.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ from datetime import date, datetime # noqa: F401

import frozendict # noqa: F401
import typing_extensions # noqa: F401
from launch_client import schemas # noqa: F401

from launch.api_client import schemas # noqa: F401

class CallbackmTLSAuth(schemas.DictSchema):
"""NOTE: This class is auto generated by OpenAPI Generator.
Expand Down
3 changes: 2 additions & 1 deletion launch/api_client/model/clone_model_bundle_v1_request.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ from datetime import date, datetime # noqa: F401

import frozendict # noqa: F401
import typing_extensions # noqa: F401
from launch_client import schemas # noqa: F401

from launch.api_client import schemas # noqa: F401

class CloneModelBundleV1Request(schemas.DictSchema):
"""NOTE: This class is auto generated by OpenAPI Generator.
Expand Down
3 changes: 2 additions & 1 deletion launch/api_client/model/clone_model_bundle_v2_request.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ from datetime import date, datetime # noqa: F401

import frozendict # noqa: F401
import typing_extensions # noqa: F401
from launch_client import schemas # noqa: F401

from launch.api_client import schemas # noqa: F401

class CloneModelBundleV2Request(schemas.DictSchema):
"""NOTE: This class is auto generated by OpenAPI Generator.
Expand Down
9 changes: 5 additions & 4 deletions launch/api_client/model/cloudpickle_artifact_flavor.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ from datetime import date, datetime # noqa: F401

import frozendict # noqa: F401
import typing_extensions # noqa: F401
from launch_client import schemas # noqa: F401

from launch.api_client import schemas # noqa: F401

class CloudpickleArtifactFlavor(schemas.DictSchema):
"""NOTE: This class is auto generated by OpenAPI Generator.
Expand Down Expand Up @@ -310,6 +311,6 @@ class CloudpickleArtifactFlavor(schemas.DictSchema):
**kwargs,
)

from launch_client.model.custom_framework import CustomFramework
from launch_client.model.pytorch_framework import PytorchFramework
from launch_client.model.tensorflow_framework import TensorflowFramework
from launch.api_client.model.custom_framework import CustomFramework
from launch.api_client.model.pytorch_framework import PytorchFramework
from launch.api_client.model.tensorflow_framework import TensorflowFramework
3 changes: 2 additions & 1 deletion launch/api_client/model/create_async_task_v1_response.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ from datetime import date, datetime # noqa: F401

import frozendict # noqa: F401
import typing_extensions # noqa: F401
from launch_client import schemas # noqa: F401

from launch.api_client import schemas # noqa: F401

class CreateAsyncTaskV1Response(schemas.DictSchema):
"""NOTE: This class is auto generated by OpenAPI Generator.
Expand Down
Loading