Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion integration_tests/rest_api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,11 @@ def my_model(**keyword_args):
}


@retry(stop=stop_after_attempt(300), wait=wait_fixed(2))
def ensure_launch_gateway_healthy():
assert requests.get(f"{BASE_PATH}/healthz").status_code == 200


def create_model_bundle(
create_model_bundle_request: Dict[str, Any], user_id: str, version: str
) -> Dict[str, Any]:
Expand Down Expand Up @@ -735,7 +740,7 @@ def ensure_inference_task_response_is_correct(response: Dict[str, Any], return_p

# Wait up to 30 seconds for the tasks to be returned.
@retry(
stop=stop_after_attempt(30), wait=wait_fixed(1), retry=retry_if_exception_type(AssertionError)
stop=stop_after_attempt(10), wait=wait_fixed(1), retry=retry_if_exception_type(AssertionError)
)
def ensure_all_async_tasks_success(task_ids: List[str], user_id: str, return_pickled: bool):
responses = asyncio.run(get_async_tasks(task_ids, user_id))
Expand Down
4 changes: 4 additions & 0 deletions integration_tests/test_bundles.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import pytest
from tenacity import retry, stop_after_attempt, wait_fixed

from .rest_api_utils import (
CREATE_MODEL_BUNDLE_REQUEST_RUNNABLE_IMAGE,
CREATE_MODEL_BUNDLE_REQUEST_SIMPLE,
USER_ID_0,
USER_ID_1,
create_model_bundle,
ensure_launch_gateway_healthy,
get_latest_model_bundle,
)


@pytest.fixture(scope="session")
@retry(stop=stop_after_attempt(10), wait=wait_fixed(30))
def model_bundles():
ensure_launch_gateway_healthy()
for user in [USER_ID_0, USER_ID_1]:
for create_bundle_request in [
CREATE_MODEL_BUNDLE_REQUEST_SIMPLE,
Expand Down
34 changes: 21 additions & 13 deletions integration_tests/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time

import pytest
from tenacity import RetryError, retry, retry_if_exception_type, stop_after_attempt, wait_fixed

from .rest_api_utils import (
CREATE_ASYNC_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE,
Expand Down Expand Up @@ -41,6 +42,23 @@ def delete_endpoints(capsys):
print("Endpoint deletion failed")


@retry(stop=stop_after_attempt(3), wait=wait_fixed(10), retry=retry_if_exception_type(RetryError))
def ensure_async_inference_works(user, create_endpoint_request, inference_payload, return_pickled):
print(
f"Sending async tasks to {create_endpoint_request['name']} for user {user}, {inference_payload=}, {return_pickled=} ..."
)
task_ids = asyncio.run(
create_async_tasks(
create_endpoint_request["name"],
[inference_payload] * 3,
user,
)
)
print("Retrieving async task results...")
ensure_nonzero_available_workers(create_endpoint_request["name"], user)
ensure_all_async_tasks_success(task_ids, user, return_pickled)


@pytest.mark.parametrize(
"create_endpoint_request,update_endpoint_request,inference_requests",
[
Expand Down Expand Up @@ -89,22 +107,12 @@ def test_async_model_endpoint(
== update_endpoint_request["max_workers"]
)

time.sleep(10)
time.sleep(20)

for inference_payload, return_pickled in inference_requests:
print(
f"Sending async tasks to {create_endpoint_request['name']} for user {user}, {inference_payload=}, {return_pickled=} ..."
)
task_ids = asyncio.run(
create_async_tasks(
create_endpoint_request["name"],
[inference_payload] * 3,
user,
)
ensure_async_inference_works(
user, create_endpoint_request, inference_payload, return_pickled
)
print("Retrieving async task results...")
ensure_nonzero_available_workers(create_endpoint_request["name"], user)
ensure_all_async_tasks_success(task_ids, user, return_pickled)
finally:
delete_model_endpoint(create_endpoint_request["name"], user)

Expand Down