diff --git a/integration_tests/rest_api_utils.py b/integration_tests/rest_api_utils.py index fb0dd7c3..604b8744 100644 --- a/integration_tests/rest_api_utils.py +++ b/integration_tests/rest_api_utils.py @@ -224,6 +224,11 @@ def my_model(**keyword_args): } +@retry(stop=stop_after_attempt(300), wait=wait_fixed(2)) +def ensure_launch_gateway_healthy(): + assert requests.get(f"{BASE_PATH}/healthz").status_code == 200 + + def create_model_bundle( create_model_bundle_request: Dict[str, Any], user_id: str, version: str ) -> Dict[str, Any]: @@ -735,7 +740,7 @@ def ensure_inference_task_response_is_correct(response: Dict[str, Any], return_p # Wait up to 30 seconds for the tasks to be returned. @retry( - stop=stop_after_attempt(30), wait=wait_fixed(1), retry=retry_if_exception_type(AssertionError) + stop=stop_after_attempt(10), wait=wait_fixed(1), retry=retry_if_exception_type(AssertionError) ) def ensure_all_async_tasks_success(task_ids: List[str], user_id: str, return_pickled: bool): responses = asyncio.run(get_async_tasks(task_ids, user_id)) diff --git a/integration_tests/test_bundles.py b/integration_tests/test_bundles.py index 5d38d80f..cb8a45e7 100644 --- a/integration_tests/test_bundles.py +++ b/integration_tests/test_bundles.py @@ -1,4 +1,5 @@ import pytest +from tenacity import retry, stop_after_attempt, wait_fixed from .rest_api_utils import ( CREATE_MODEL_BUNDLE_REQUEST_RUNNABLE_IMAGE, @@ -6,12 +7,15 @@ USER_ID_0, USER_ID_1, create_model_bundle, + ensure_launch_gateway_healthy, get_latest_model_bundle, ) @pytest.fixture(scope="session") +@retry(stop=stop_after_attempt(10), wait=wait_fixed(30)) def model_bundles(): + ensure_launch_gateway_healthy() for user in [USER_ID_0, USER_ID_1]: for create_bundle_request in [ CREATE_MODEL_BUNDLE_REQUEST_SIMPLE, diff --git a/integration_tests/test_endpoints.py b/integration_tests/test_endpoints.py index 2af5a257..ad40a2d9 100644 --- a/integration_tests/test_endpoints.py +++ b/integration_tests/test_endpoints.py @@ -2,6 +2,7 @@ import time import pytest +from tenacity import RetryError, retry, retry_if_exception_type, stop_after_attempt, wait_fixed from .rest_api_utils import ( CREATE_ASYNC_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE, @@ -41,6 +42,23 @@ def delete_endpoints(capsys): print("Endpoint deletion failed") +@retry(stop=stop_after_attempt(3), wait=wait_fixed(10), retry=retry_if_exception_type(RetryError)) +def ensure_async_inference_works(user, create_endpoint_request, inference_payload, return_pickled): + print( + f"Sending async tasks to {create_endpoint_request['name']} for user {user}, {inference_payload=}, {return_pickled=} ..." + ) + task_ids = asyncio.run( + create_async_tasks( + create_endpoint_request["name"], + [inference_payload] * 3, + user, + ) + ) + print("Retrieving async task results...") + ensure_nonzero_available_workers(create_endpoint_request["name"], user) + ensure_all_async_tasks_success(task_ids, user, return_pickled) + + @pytest.mark.parametrize( "create_endpoint_request,update_endpoint_request,inference_requests", [ @@ -89,22 +107,12 @@ def test_async_model_endpoint( == update_endpoint_request["max_workers"] ) - time.sleep(10) + time.sleep(20) for inference_payload, return_pickled in inference_requests: - print( - f"Sending async tasks to {create_endpoint_request['name']} for user {user}, {inference_payload=}, {return_pickled=} ..." - ) - task_ids = asyncio.run( - create_async_tasks( - create_endpoint_request["name"], - [inference_payload] * 3, - user, - ) + ensure_async_inference_works( + user, create_endpoint_request, inference_payload, return_pickled ) - print("Retrieving async task results...") - ensure_nonzero_available_workers(create_endpoint_request["name"], user) - ensure_all_async_tasks_success(task_ids, user, return_pickled) finally: delete_model_endpoint(create_endpoint_request["name"], user)