Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
ENV_ZENML_ENABLE_RICH_TRACEBACK = "ZENML_ENABLE_RICH_TRACEBACK"
ENV_ZENML_DEFAULT_STORE_TYPE = "ZENML_DEFAULT_STORE_TYPE"
ENV_ZENML_PROFILE_NAME = "ZENML_PROFILE_NAME"
ENV_ZENML_PROFILE_CONFIGURATION = "ZENML_PROFILE_CONFIGURATION"

# Logging variables
IS_DEBUG_ENV: bool = handle_bool_env_var(ENV_ZENML_DEBUG, default=False)
Expand Down Expand Up @@ -126,6 +127,8 @@ def handle_int_env_var(var: str, default: int = 0) -> int:

# Services
DEFAULT_SERVICE_START_STOP_TIMEOUT = 10
ZEN_SERVICE_ENTRYPOINT = "zenml.zen_service.zen_service_api:app"
ZEN_SERVICE_IP = "127.0.0.1"

# API Endpoint paths:
IS_EMPTY = "/empty"
Expand Down
10 changes: 7 additions & 3 deletions src/zenml/zen_service/zen_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from pydantic import Field

from zenml.config.profile_config import ProfileConfiguration
from zenml.constants import ENV_ZENML_PROFILE_NAME
from zenml.constants import (
ENV_ZENML_PROFILE_NAME,
ZEN_SERVICE_ENTRYPOINT,
ZEN_SERVICE_IP,
)
from zenml.enums import StoreType
from zenml.logger import get_logger
from zenml.repository import Repository
Expand Down Expand Up @@ -140,8 +144,8 @@ def run(self) -> None:

try:
uvicorn.run(
"zenml.zen_service.zen_service_api:app",
host="127.0.0.1",
ZEN_SERVICE_ENTRYPOINT,
host=ZEN_SERVICE_IP,
port=self.endpoint.status.port,
log_level="info",
)
Expand Down
3 changes: 2 additions & 1 deletion src/zenml/zen_service/zen_service_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from zenml.config.global_config import GlobalConfiguration
from zenml.config.profile_config import ProfileConfiguration
from zenml.constants import (
ENV_ZENML_PROFILE_CONFIGURATION,
ENV_ZENML_PROFILE_NAME,
IS_EMPTY,
STACK_COMPONENTS,
Expand All @@ -32,7 +33,7 @@
from zenml.stack_stores import BaseStackStore
from zenml.stack_stores.models import StackComponentWrapper, StackWrapper

profile_configuration_json = os.environ.get("ZENML_PROFILE_CONFIGURATION")
profile_configuration_json = os.environ.get(ENV_ZENML_PROFILE_CONFIGURATION)
profile_name = os.environ.get(ENV_ZENML_PROFILE_NAME)

# Hopefully profile configuration was passed as env variable:
Expand Down
71 changes: 57 additions & 14 deletions tests/unit/stack_stores/test_stack_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.

import os
import platform
import random
import shutil
import time
from multiprocessing import Process

import pytest
import requests
import uvicorn
from requests.exceptions import ConnectionError

from zenml.config.profile_config import ProfileConfiguration
from zenml.constants import REPOSITORY_DIRECTORY_NAME
from zenml.constants import (
DEFAULT_SERVICE_START_STOP_TIMEOUT,
ENV_ZENML_PROFILE_CONFIGURATION,
REPOSITORY_DIRECTORY_NAME,
ZEN_SERVICE_ENTRYPOINT,
ZEN_SERVICE_IP,
)
from zenml.enums import StackComponentType, StoreType
from zenml.exceptions import StackComponentExistsError, StackExistsError
from zenml.logger import get_logger
from zenml.orchestrators import LocalOrchestrator
from zenml.services import ServiceState
from zenml.stack import Stack
from zenml.stack_stores import (
BaseStackStore,
Expand All @@ -33,7 +44,6 @@
SqlStackStore,
)
from zenml.stack_stores.models import StackComponentWrapper, StackWrapper
from zenml.zen_service.zen_service import ZenService, ZenServiceConfig

logger = get_logger(__name__)

Expand All @@ -55,27 +65,60 @@ def fresh_stack_store(
elif store_type == StoreType.SQL:
yield SqlStackStore().initialize(f"sqlite:///{tmp_path / 'store.db'}")
elif store_type == StoreType.REST:
port = random.randint(8003, 9000)
# create temporary stack store and profile configuration for unit tests
backing_stack_store = LocalStackStore().initialize(str(tmp_path))
store_profile = ProfileConfiguration(
name=f"test_profile_{hash(str(tmp_path))}",
store_url=backing_stack_store.url,
store_type=backing_stack_store.type,
)

zen_service = ZenService(
ZenServiceConfig(
port=port,
store_profile_configuration=store_profile,
# use environment file to pass profile into the zen service process
env_file = str(tmp_path / "environ.env")
with open(env_file, "w") as f:
f.write(
f"{ENV_ZENML_PROFILE_CONFIGURATION}='{store_profile.json()}'"
)
port = random.randint(8003, 9000)
proc = Process(
target=uvicorn.run,
args=(ZEN_SERVICE_ENTRYPOINT,),
kwargs=dict(
host=ZEN_SERVICE_IP,
port=port,
log_level="info",
env_file=env_file,
),
daemon=True,
)
zen_service.start(timeout=10)
# rest stack store can't have trailing slash on url
url = zen_service.zen_service_uri.strip("/")
url = f"http://{ZEN_SERVICE_IP}:{port}"
proc.start()

# wait 10 seconds for server to start
for t in range(DEFAULT_SERVICE_START_STOP_TIMEOUT):
try:
if requests.head(f"{url}/health").status_code == 200:
break
else:
time.sleep(1)
except ConnectionError:
time.sleep(1)
else:
proc.kill()
raise RuntimeError("Failed to start ZenService server.")

yield RestStackStore().initialize(url)
zen_service.stop(timeout=10)
assert zen_service.check_status()[0] == ServiceState.INACTIVE

# make sure there's still a server and tear down
assert proc.is_alive()
proc.kill()
# wait 10 seconds for process to be killed:
for t in range(DEFAULT_SERVICE_START_STOP_TIMEOUT):
if proc.is_alive():
time.sleep(1)
else:
break
else:
raise RuntimeError("Failed to shutdown ZenService server.")
else:
raise NotImplementedError(f"No StackStore for {store_type}")

Expand Down