diff --git a/src/zenml/constants.py b/src/zenml/constants.py index a4263c07a3d..c3f8460b884 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -52,6 +52,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int: ENV_ZENML_ENABLE_RICH_TRACEBACK = "ZENML_ENABLE_RICH_TRACEBACK" ENV_ZENML_DEFAULT_STORE_TYPE = "ZENML_DEFAULT_STORE_TYPE" ENV_ZENML_PROFILE_NAME = "ZENML_PROFILE_NAME" +ENV_ZENML_PROFILE_CONFIGURATION = "ZENML_PROFILE_CONFIGURATION" # Logging variables IS_DEBUG_ENV: bool = handle_bool_env_var(ENV_ZENML_DEBUG, default=False) @@ -126,6 +127,8 @@ def handle_int_env_var(var: str, default: int = 0) -> int: # Services DEFAULT_SERVICE_START_STOP_TIMEOUT = 10 +ZEN_SERVICE_ENTRYPOINT = "zenml.zen_service.zen_service_api:app" +ZEN_SERVICE_IP = "127.0.0.1" # API Endpoint paths: IS_EMPTY = "/empty" diff --git a/src/zenml/zen_service/zen_service.py b/src/zenml/zen_service/zen_service.py index 733b3f914be..342e1f8bec6 100644 --- a/src/zenml/zen_service/zen_service.py +++ b/src/zenml/zen_service/zen_service.py @@ -5,7 +5,11 @@ from pydantic import Field from zenml.config.profile_config import ProfileConfiguration -from zenml.constants import ENV_ZENML_PROFILE_NAME +from zenml.constants import ( + ENV_ZENML_PROFILE_NAME, + ZEN_SERVICE_ENTRYPOINT, + ZEN_SERVICE_IP, +) from zenml.enums import StoreType from zenml.logger import get_logger from zenml.repository import Repository @@ -140,8 +144,8 @@ def run(self) -> None: try: uvicorn.run( - "zenml.zen_service.zen_service_api:app", - host="127.0.0.1", + ZEN_SERVICE_ENTRYPOINT, + host=ZEN_SERVICE_IP, port=self.endpoint.status.port, log_level="info", ) diff --git a/src/zenml/zen_service/zen_service_api.py b/src/zenml/zen_service/zen_service_api.py index 15a5749c456..1459444ddc9 100644 --- a/src/zenml/zen_service/zen_service_api.py +++ b/src/zenml/zen_service/zen_service_api.py @@ -20,6 +20,7 @@ from zenml.config.global_config import GlobalConfiguration from zenml.config.profile_config import ProfileConfiguration from zenml.constants import ( + ENV_ZENML_PROFILE_CONFIGURATION, ENV_ZENML_PROFILE_NAME, IS_EMPTY, STACK_COMPONENTS, @@ -32,7 +33,7 @@ from zenml.stack_stores import BaseStackStore from zenml.stack_stores.models import StackComponentWrapper, StackWrapper -profile_configuration_json = os.environ.get("ZENML_PROFILE_CONFIGURATION") +profile_configuration_json = os.environ.get(ENV_ZENML_PROFILE_CONFIGURATION) profile_name = os.environ.get(ENV_ZENML_PROFILE_NAME) # Hopefully profile configuration was passed as env variable: diff --git a/tests/unit/stack_stores/test_stack_stores.py b/tests/unit/stack_stores/test_stack_stores.py index 0b7f9013152..a3fcf29d36a 100644 --- a/tests/unit/stack_stores/test_stack_stores.py +++ b/tests/unit/stack_stores/test_stack_stores.py @@ -11,20 +11,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. + import os import platform import random import shutil +import time +from multiprocessing import Process import pytest +import requests +import uvicorn +from requests.exceptions import ConnectionError from zenml.config.profile_config import ProfileConfiguration -from zenml.constants import REPOSITORY_DIRECTORY_NAME +from zenml.constants import ( + DEFAULT_SERVICE_START_STOP_TIMEOUT, + ENV_ZENML_PROFILE_CONFIGURATION, + REPOSITORY_DIRECTORY_NAME, + ZEN_SERVICE_ENTRYPOINT, + ZEN_SERVICE_IP, +) from zenml.enums import StackComponentType, StoreType from zenml.exceptions import StackComponentExistsError, StackExistsError from zenml.logger import get_logger from zenml.orchestrators import LocalOrchestrator -from zenml.services import ServiceState from zenml.stack import Stack from zenml.stack_stores import ( BaseStackStore, @@ -33,7 +44,6 @@ SqlStackStore, ) from zenml.stack_stores.models import StackComponentWrapper, StackWrapper -from zenml.zen_service.zen_service import ZenService, ZenServiceConfig logger = get_logger(__name__) @@ -55,7 +65,6 @@ def fresh_stack_store( elif store_type == StoreType.SQL: yield SqlStackStore().initialize(f"sqlite:///{tmp_path / 'store.db'}") elif store_type == StoreType.REST: - port = random.randint(8003, 9000) # create temporary stack store and profile configuration for unit tests backing_stack_store = LocalStackStore().initialize(str(tmp_path)) store_profile = ProfileConfiguration( @@ -63,19 +72,53 @@ def fresh_stack_store( store_url=backing_stack_store.url, store_type=backing_stack_store.type, ) - - zen_service = ZenService( - ZenServiceConfig( - port=port, - store_profile_configuration=store_profile, + # use environment file to pass profile into the zen service process + env_file = str(tmp_path / "environ.env") + with open(env_file, "w") as f: + f.write( + f"{ENV_ZENML_PROFILE_CONFIGURATION}='{store_profile.json()}'" ) + port = random.randint(8003, 9000) + proc = Process( + target=uvicorn.run, + args=(ZEN_SERVICE_ENTRYPOINT,), + kwargs=dict( + host=ZEN_SERVICE_IP, + port=port, + log_level="info", + env_file=env_file, + ), + daemon=True, ) - zen_service.start(timeout=10) - # rest stack store can't have trailing slash on url - url = zen_service.zen_service_uri.strip("/") + url = f"http://{ZEN_SERVICE_IP}:{port}" + proc.start() + + # wait 10 seconds for server to start + for t in range(DEFAULT_SERVICE_START_STOP_TIMEOUT): + try: + if requests.head(f"{url}/health").status_code == 200: + break + else: + time.sleep(1) + except ConnectionError: + time.sleep(1) + else: + proc.kill() + raise RuntimeError("Failed to start ZenService server.") + yield RestStackStore().initialize(url) - zen_service.stop(timeout=10) - assert zen_service.check_status()[0] == ServiceState.INACTIVE + + # make sure there's still a server and tear down + assert proc.is_alive() + proc.kill() + # wait 10 seconds for process to be killed: + for t in range(DEFAULT_SERVICE_START_STOP_TIMEOUT): + if proc.is_alive(): + time.sleep(1) + else: + break + else: + raise RuntimeError("Failed to shutdown ZenService server.") else: raise NotImplementedError(f"No StackStore for {store_type}")