diff --git a/docs/book/getting-started/core-concepts.md b/docs/book/getting-started/core-concepts.md index 569ef025f62..f0f240b3d41 100644 --- a/docs/book/getting-started/core-concepts.md +++ b/docs/book/getting-started/core-concepts.md @@ -60,7 +60,7 @@ As seen in the image, a step might use the outputs from a previous step and thus Pipelines and steps are defined in code using Python _decorators_ or _classes_. This is where the core business logic and value of your work live, and you will spend most of your time defining these two things. -Even though pipelines are simple Python functions, you are only allowed to call steps within this function. The inputs for steps called within a pipeline can either be the outputs of previous steps or alternatively, you can pass in values directly (as long as they're JSON-serializable). +Even though pipelines are simple Python functions, you are only allowed to call steps within this function. The inputs for steps called within a pipeline can either be the outputs of previous steps or alternatively, you can pass in values directly or map them onto pipeline parameters (as long as they're JSON-serializable). Similarly, you can return values from a pipeline that are step outputs as long as they are JSON-serializable. ```python from zenml import pipeline @@ -71,11 +71,11 @@ def my_pipeline(): step_2(input_one="hello", input_two=output_step_one) @pipeline -def agent_evaluation_pipeline(): +def agent_evaluation_pipeline(query: str = "What is machine learning?") -> str: """An AI agent evaluation pipeline.""" prompt = "You are a helpful assistant. Please answer: {query}" - test_query = "What is machine learning?" - evaluation_result = evaluate_agent_response(prompt, test_query) + evaluation_result = evaluate_agent_response(prompt, query) + return evaluation_result ``` Executing the Pipeline is as easy as calling the function that you decorated with the `@pipeline` decorator. @@ -83,7 +83,7 @@ Executing the Pipeline is as easy as calling the function that you decorated wit ```python if __name__ == "__main__": my_pipeline() - agent_evaluation_pipeline() + agent_evaluation_pipeline(query="What is an LLM?") ``` #### Artifacts @@ -118,9 +118,11 @@ Once you have implemented your workflow by using the concepts described above, y #### Stacks & Components -When you want to execute a pipeline run with ZenML, **Stacks** come into play. A **Stack** is a collection of **stack components**, where each component represents the respective configuration regarding a particular function in your MLOps pipeline, such as orchestration systems, artifact repositories, and model deployment platforms. +When you want to execute a pipeline run with ZenML, **Stacks** come into play. A **Stack** is a collection of **stack components**, where each component represents the respective configuration regarding a particular function in your MLOps pipeline, such as pipeline orchestration or deployment systems, artifact repositories and container registries. -For instance, if you take a close look at the default local stack of ZenML, you will see two components that are **required** in every stack in ZenML, namely an _orchestrator_ and an _artifact store_. +Pipelines can be executed in two ways: in **batch mode** (traditional execution through an orchestrator) or in **online mode** (long-running HTTP servers that can be invoked via REST API calls). Deploying pipelines for online mode execution allows you to serve your ML workflows as real-time endpoints, making them accessible for live inference and interactive use cases. + +For instance, if you take a close look at the default local stack of ZenML, you will see two components that are **required** in every stack in ZenML, namely an _orchestrator_ and an _artifact store_. Additional components like _deployers_ can be added to enable specific functionality such as deploying pipelines as HTTP endpoints. ![ZenML running code on the Local Stack.](../.gitbook/assets/02_pipeline_local_stack.png) @@ -130,16 +132,30 @@ Keep in mind that each one of these components is built on top of base abstracti #### Orchestrator -An **Orchestrator** is a workhorse that coordinates all the steps to run in a pipeline. Since pipelines can be set up with complex combinations of steps with various asynchronous dependencies between them, the orchestrator acts as the component that decides what steps to run and when to run them. +An **Orchestrator** is a workhorse that coordinates all the steps to run in a pipeline in batch mode. Since pipelines can be set up with complex combinations of steps with various asynchronous dependencies between them, the orchestrator acts as the component that decides what steps to run and when to run them. ZenML comes with a default _local orchestrator_ designed to run on your local machine. This is useful, especially during the exploration phase of your project. You don't have to rent a cloud instance just to try out basic things. +#### Deployer + +A **Deployer** is a stack component that manages the deployment of pipelines as long-running HTTP servers useful for online mode execution. Unlike orchestrators that execute pipelines in batch mode, deployers can create and manage persistent services that wrap your pipeline in a web application, usually containerized, allowing it to be invoked through HTTP requests. + +ZenML comes with a _Docker deployer_ that can run deployments on your local machine as Docker containers, making it easy to test and develop real-time pipeline endpoints before moving to production infrastructure. + +#### Pipeline Run + +A **Pipeline Run** is a record of a pipeline execution. When you run a pipeline using an orchestrator, a pipeline run is created tracking information about the execution such as the status, the artifacts and metadata produced by the pipeline and all its steps. When a pipeline is deployed for online mode execution, a pipeline run is similarly created for every HTTP request made to it. + #### Artifact Store An **Artifact Store** is a component that houses all data that passes through the pipeline as inputs and outputs. Each artifact that gets stored in the artifact store is tracked and versioned and this allows for extremely useful features like data caching, which speeds up your workflows. Similar to the orchestrator, ZenML comes with a default _local artifact store_ designed to run on your local machine. This is useful, especially during the exploration phase of your project. You don't have to set up a cloud storage system to try out basic things. +#### Deployment + +A **Deployment** is a running instance of a pipeline deployed as an HTTP endpoint. When you deploy a pipeline using a deployer, it becomes a long-running service that can be invoked through REST API calls. Each HTTP request to a deployment triggers a new pipeline run, creating the same artifacts and metadata tracking as traditional batch pipeline executions. This enables real-time inference, interactive ML workflows, and seamless integration with web applications and external services. + #### Flavor ZenML provides a dedicated base abstraction for each stack component type. These abstractions are used to develop solutions, called **Flavors**, tailored to specific use cases/tools. With ZenML installed, you get access to a variety of built-in and integrated Flavors for each component type, but users can also leverage the base abstractions to create their own custom flavors. diff --git a/docs/book/how-to/steps-pipelines/advanced_features.md b/docs/book/how-to/steps-pipelines/advanced_features.md index 7cabef7c259..34ace4c5b41 100644 --- a/docs/book/how-to/steps-pipelines/advanced_features.md +++ b/docs/book/how-to/steps-pipelines/advanced_features.md @@ -628,8 +628,8 @@ This is particularly useful for steps that interact with external services or re Hooks allow you to execute custom code at specific points in the pipeline or step lifecycle: ```python -def success_hook(step_name, step_output): - print(f"Step {step_name} completed successfully with output: {step_output}") +def success_hook(): + print(f"Step completed successfully") def failure_hook(exception: BaseException): print(f"Step failed with error: {str(exception)}") @@ -639,6 +639,11 @@ def my_step(): return 42 ``` +The following conventions apply to hooks: + +* the success hook takes no arguments +* the failure hook optionally takes a single `BaseException` typed argument + You can also define hooks at the pipeline level to apply to all steps: ```python diff --git a/docs/book/toc.md b/docs/book/toc.md index 6d0e4f9c703..58ce38bb981 100644 --- a/docs/book/toc.md +++ b/docs/book/toc.md @@ -55,6 +55,7 @@ * [Templates](how-to/templates/templates.md) * [Dashboard](how-to/dashboard/dashboard-features.md) + ## Reference * [Community & content](reference/community-and-content.md) diff --git a/pyproject.toml b/pyproject.toml index 5cd2a8034bd..9a1b42b2e55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "distro>=1.6.0,<2.0.0", "docker~=7.1.0", "gitpython>=3.1.18,<4.0.0", + "jsonref", "packaging>=24.1", "psutil>=5.0.0", "pydantic>=2.0,<=2.11.9", @@ -368,5 +369,6 @@ module = [ "numba.*", "uvloop.*", "litellm", + "jsonref", ] ignore_missing_imports = true diff --git a/scripts/install-zenml-dev.sh b/scripts/install-zenml-dev.sh index bee6fd5e6a9..5a6b5e818dc 100755 --- a/scripts/install-zenml-dev.sh +++ b/scripts/install-zenml-dev.sh @@ -2,6 +2,40 @@ INTEGRATIONS=no PIP_ARGS= +UPGRADE_ALL=no + +show_help() { + cat << EOF +Usage: $0 [OPTIONS] + +Install ZenML in development mode with optional integrations. + +OPTIONS: + -i, --integrations yes|no Install integrations (default: no) + -s, --system Install packages system-wide instead of in virtual environment + -u, --upgrade-all Uninstall existing ZenML, clear caches, and install latest versions + -h, --help Show this help message + +EXAMPLES: + # Basic installation + $0 + + # Install with integrations + $0 --integrations yes + + # Force reinstall with latest versions of all dependencies + $0 --upgrade-all --integrations yes + + # System-wide installation with latest versions + $0 --system --upgrade-all + +NOTES: + - The --upgrade-all flag will uninstall existing ZenML installation and clear all caches + - This ensures you get the latest compatible versions of all dependencies + - Use this when you want to refresh your environment with the newest packages + +EOF +} parse_args () { while [ $# -gt 0 ]; do @@ -15,8 +49,17 @@ parse_args () { PIP_ARGS="--system" shift # past argument ;; + -u|--upgrade-all) + UPGRADE_ALL="yes" + shift # past argument + ;; + -h|--help) + show_help + exit 0 + ;; -*|--*) echo "Unknown option $1" + show_help exit 1 ;; *) @@ -26,12 +69,39 @@ parse_args () { done } +clean_and_uninstall() { + echo "๐Ÿงน Cleaning existing ZenML installation and clearing caches..." + + # Uninstall ZenML (if installed) and clear pip cache + uv pip uninstall $PIP_ARGS zenml || true + + # Clear uv cache to ensure fresh downloads + uv cache clean || true + + # Clear pip cache as well (in case pip was used previously) + python -m pip cache purge 2>/dev/null || true + + echo "โœ… Cleanup completed" +} + install_zenml() { + echo "๐Ÿ“ฆ Installing ZenML in editable mode..." + + # Build upgrade arguments based on UPGRADE_ALL flag + upgrade_args="" + if [ "$UPGRADE_ALL" = "yes" ]; then + upgrade_args="--upgrade --force-reinstall" + echo "๐Ÿ”„ Using --upgrade --force-reinstall to get latest versions" + fi + # install ZenML in editable mode - uv pip install $PIP_ARGS -e ".[server,templates,terraform,secrets-aws,secrets-gcp,secrets-azure,secrets-hashicorp,s3fs,gcsfs,adlfs,dev,connectors-aws,connectors-gcp,connectors-azure,azureml,sagemaker,vertex]" + uv pip install $PIP_ARGS $upgrade_args -e ".[server,templates,terraform,secrets-aws,secrets-gcp,secrets-azure,secrets-hashicorp,s3fs,gcsfs,adlfs,dev,connectors-aws,connectors-gcp,connectors-azure,azureml,sagemaker,vertex]" + + echo "โœ… ZenML installation completed" } install_integrations() { + echo "๐Ÿ”Œ Installing ZenML integrations..." # figure out the python version python_version=$(python -c "import sys; print('.'.join(map(str, sys.version_info[:2])))") @@ -54,18 +124,37 @@ install_integrations() { --output-file integration-requirements.txt \ $ignore_integrations_args - # pin pyyaml>=6.0.1 - echo "" >> integration-requirements.txt - echo "pyyaml>=6.0.1" >> integration-requirements.txt - echo "pyopenssl" >> integration-requirements.txt - echo "typing-extensions" >> integration-requirements.txt + # Handle package pins based on upgrade mode + if [ "$UPGRADE_ALL" = "yes" ]; then + echo "๐Ÿ”„ Using latest versions for integration dependencies" + # When upgrading, use minimum versions to allow latest compatible + echo "" >> integration-requirements.txt + echo "pyyaml>=6.0.1" >> integration-requirements.txt + echo "pyopenssl" >> integration-requirements.txt + echo "typing-extensions" >> integration-requirements.txt + echo "maison<2" >> integration-requirements.txt + else + # Original behavior with specific pins + echo "" >> integration-requirements.txt + echo "pyyaml>=6.0.1" >> integration-requirements.txt + echo "pyopenssl" >> integration-requirements.txt + echo "typing-extensions" >> integration-requirements.txt + echo "maison<2" >> integration-requirements.txt + fi + echo "-e .[server,templates,terraform,secrets-aws,secrets-gcp,secrets-azure,secrets-hashicorp,s3fs,gcsfs,adlfs,dev,connectors-aws,connectors-gcp,connectors-azure,azureml,sagemaker,vertex]" >> integration-requirements.txt - # workaround to make yamlfix work - echo "maison<2" >> integration-requirements.txt + # Build upgrade arguments based on UPGRADE_ALL flag + upgrade_args="" + if [ "$UPGRADE_ALL" = "yes" ]; then + upgrade_args="--upgrade --force-reinstall" + echo "๐Ÿ”„ Using --upgrade --force-reinstall for integration dependencies" + fi - uv pip install $PIP_ARGS -r integration-requirements.txt + uv pip install $PIP_ARGS $upgrade_args -r integration-requirements.txt rm integration-requirements.txt + + echo "โœ… Integration installation completed" # https://github.com/Kludex/python-multipart/pull/166 # There is an install conflict between multipart and python_multipart @@ -83,7 +172,14 @@ export ZENML_ANALYTICS_OPT_IN=false parse_args "$@" -python -m pip install --upgrade wheel pip uv +# Clean and upgrade tooling packages if upgrading all +if [ "$UPGRADE_ALL" = "yes" ]; then + echo "๐Ÿš€ Upgrading all dependencies to latest versions..." + clean_and_uninstall + python -m pip install --upgrade --force-reinstall wheel pip uv +else + python -m pip install --upgrade wheel pip uv +fi install_zenml diff --git a/src/zenml/analytics/enums.py b/src/zenml/analytics/enums.py index e364c61f88c..c93bbb55929 100644 --- a/src/zenml/analytics/enums.py +++ b/src/zenml/analytics/enums.py @@ -91,3 +91,9 @@ class AnalyticsEvent(str, Enum): # Server Settings SERVER_SETTINGS_UPDATED = "Server Settings Updated" + + # Deployment + DEPLOY_PIPELINE = "Pipeline deployed" + CREATE_DEPLOYMENT = "Deployment created" + STOP_DEPLOYMENT = "Deployment stopped" + DELETE_DEPLOYMENT = "Deployment deleted" diff --git a/src/zenml/artifact_stores/base_artifact_store.py b/src/zenml/artifact_stores/base_artifact_store.py index 3ca322d99d4..0efd0e8493e 100644 --- a/src/zenml/artifact_stores/base_artifact_store.py +++ b/src/zenml/artifact_stores/base_artifact_store.py @@ -106,6 +106,10 @@ def _validate_path(self, path: str) -> None: IllegalOperationError: If the path is a local file and the server is not configured to allow local file access. """ + # Skip validation for memory:// URIs used in serving mode + if path.startswith("memory://"): + return + if not self.allow_local_file_access and not io_utils.is_remote(path): raise IllegalOperationError( "Files in a local artifact store cannot be accessed from the " @@ -139,6 +143,11 @@ def _sanitize_potential_path(self, potential_path: Any) -> Any: # Neither string nor bytes, this is not a path return potential_path + # Preserve special in-memory scheme used by serving mode as-is + # to avoid treating it as a local filesystem path. + if isinstance(path, str) and path.startswith("memory://"): + return path + if io_utils.is_remote(path): # If we have a remote path, replace windows path separators with # slashes diff --git a/src/zenml/artifacts/utils.py b/src/zenml/artifacts/utils.py index a96bb800d24..6d8d91bb9ad 100644 --- a/src/zenml/artifacts/utils.py +++ b/src/zenml/artifacts/utils.py @@ -152,7 +152,21 @@ def _store_artifact_data_and_prepare_request( Artifact version request for the artifact data that was stored. """ artifact_store = Client().active_stack.artifact_store - artifact_store.makedirs(uri) + + # Detect in-memory materializer to avoid touching the artifact store. + # Local import to minimize import-time dependencies. + from zenml.materializers.in_memory_materializer import ( + InMemoryMaterializer, + ) + + is_in_memory = issubclass(materializer_class, InMemoryMaterializer) + + if not is_in_memory: + artifact_store.makedirs(uri) + else: + # Ensure URI clearly indicates in-memory storage and not the artifact store + if not uri.startswith("memory://"): + uri = f"memory://custom_artifacts/{name}/{uuid4()}" materializer = materializer_class(uri=uri, artifact_store=artifact_store) materializer.uri = materializer.uri.replace("\\", "/") @@ -190,7 +204,7 @@ def _store_artifact_data_and_prepare_request( data_type=source_utils.resolve(data_type), content_hash=content_hash, project=Client().active_project.id, - artifact_store_id=artifact_store.id, + artifact_store_id=None if is_in_memory else artifact_store.id, visualizations=visualizations, has_custom_name=has_custom_name, save_type=save_type, diff --git a/src/zenml/cli/__init__.py b/src/zenml/cli/__init__.py index 03e0573d929..138f64c915f 100644 --- a/src/zenml/cli/__init__.py +++ b/src/zenml/cli/__init__.py @@ -2506,6 +2506,7 @@ def my_pipeline(...): from zenml.cli.base import * # noqa from zenml.cli.code_repository import * # noqa from zenml.cli.config import * # noqa +from zenml.cli.deployment import * # noqa from zenml.cli.downgrade import * # noqa from zenml.cli.feature import * # noqa from zenml.cli.integration import * # noqa diff --git a/src/zenml/cli/deployment.py b/src/zenml/cli/deployment.py new file mode 100644 index 00000000000..aaa54677e4b --- /dev/null +++ b/src/zenml/cli/deployment.py @@ -0,0 +1,701 @@ +# Copyright (c) ZenML GmbH 2022. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""CLI functionality to interact with deployments.""" + +import json +from typing import Any, List, Optional +from uuid import UUID + +import click + +from zenml.cli import utils as cli_utils +from zenml.cli.cli import TagGroup, cli +from zenml.cli.utils import fetch_snapshot, list_options +from zenml.client import Client +from zenml.console import console +from zenml.deployers.exceptions import DeploymentInvalidParametersError +from zenml.enums import CliCategories +from zenml.logger import get_logger +from zenml.models import ( + DeploymentFilter, +) +from zenml.pipelines.pipeline_definition import Pipeline +from zenml.utils import source_utils + +logger = get_logger(__name__) + + +def _import_pipeline(source: str) -> Pipeline: + """Import a pipeline. + + Args: + source: The pipeline source. + + Returns: + The pipeline. + """ + try: + pipeline_instance = source_utils.load(source) + except ModuleNotFoundError as e: + source_root = source_utils.get_source_root() + cli_utils.error( + f"Unable to import module `{e.name}`. Make sure the source path is " + f"relative to your source root `{source_root}`." + ) + except AttributeError as e: + cli_utils.error("Unable to load attribute from module: " + str(e)) + + if not isinstance(pipeline_instance, Pipeline): + cli_utils.error( + f"The given source path `{source}` does not resolve to a pipeline " + "object." + ) + + return pipeline_instance + + +@cli.group(cls=TagGroup, tag=CliCategories.MANAGEMENT_TOOLS) +def deployment() -> None: + """Interact with deployments.""" + + +@deployment.command("list", help="List all registered deployments.") +@list_options(DeploymentFilter) +def list_deployments(**kwargs: Any) -> None: + """List all registered deployments for the filter. + + Args: + **kwargs: Keyword arguments to filter deployments. + """ + client = Client() + try: + with console.status("Listing deployments...\n"): + deployments = client.list_deployments(**kwargs) + except KeyError as err: + cli_utils.error(str(err)) + else: + if not deployments.items: + cli_utils.declare("No deployments found for this filter.") + return + + cli_utils.print_deployment_table(deployments=deployments.items) + cli_utils.print_page_info(deployments) + + +@deployment.command("describe") +@click.argument("deployment_name_or_id", type=str, required=True) +@click.option( + "--show-secret", + "-s", + is_flag=True, + help="Show the secret key.", +) +@click.option( + "--show-metadata", + "-m", + is_flag=True, + help="Show the metadata.", +) +@click.option( + "--show-schema", + "-sc", + is_flag=True, + help="Show the schema.", +) +@click.option( + "--no-truncate", + "-nt", + is_flag=True, + help="Don't truncate the metadata.", +) +def describe_deployment( + deployment_name_or_id: str, + show_secret: bool = False, + show_metadata: bool = False, + no_truncate: bool = False, + show_schema: bool = False, +) -> None: + """Describe a deployment. + + Args: + deployment_name_or_id: The name or ID of the deployment to describe. + show_secret: If True, show the secret key. + show_metadata: If True, show the metadata. + show_schema: If True, show the schema. + no_truncate: If True, don't truncate the metadata. + """ + try: + deployment = Client().get_deployment( + name_id_or_prefix=deployment_name_or_id, + ) + except KeyError as e: + cli_utils.error(str(e)) + else: + cli_utils.pretty_print_deployment( + deployment, + show_secret=show_secret, + show_metadata=show_metadata, + show_schema=show_schema, + no_truncate=no_truncate, + ) + + +@deployment.command("provision") +@click.argument("deployment_name_or_id", type=str, required=True) +@click.option( + "--snapshot", + "-s", + "snapshot_name_or_id", + type=str, + required=False, + help="ID or name of the snapshot to use. If not provided, the current " + "snapshot configured for the deployment will be used.", +) +@click.option( + "--pipeline", + "-p", + "pipeline_name_or_id", + type=str, + required=False, + help="The name or ID of the pipeline to which the snapshot belongs.", +) +@click.option( + "--overtake", + "-o", + "overtake", + is_flag=True, + default=False, + required=False, + help="Provision the deployment with the given name even if it is " + "owned by a different user.", +) +@click.option( + "--timeout", + "-t", + "timeout", + type=int, + required=False, + default=None, + help="Maximum time in seconds to wait for the deployment to be " + "provisioned.", +) +def provision_deployment( + deployment_name_or_id: str, + snapshot_name_or_id: Optional[str] = None, + pipeline_name_or_id: Optional[str] = None, + overtake: bool = False, + timeout: Optional[int] = None, +) -> None: + """Provision a deployment. + + Args: + deployment_name_or_id: The name or ID of the deployment to deploy. + snapshot_name_or_id: The ID or name of the pipeline snapshot to use. + pipeline_name_or_id: The name or ID of the pipeline to which the + snapshot belongs. + overtake: If True, provision the deployment with the given name + even if it is owned by a different user. + timeout: The maximum time in seconds to wait for the deployment + to be provisioned. + """ + snapshot_id: Optional[UUID] = None + if snapshot_name_or_id: + snapshot = fetch_snapshot(snapshot_name_or_id, pipeline_name_or_id) + snapshot_id = snapshot.id + + client = Client() + try: + deployment = client.get_deployment(deployment_name_or_id) + except KeyError: + pass + else: + if ( + deployment.user + and deployment.user.id != client.active_user.id + and not overtake + ): + confirmation = cli_utils.confirmation( + f"Deployment with name '{deployment.name}' is owned by a " + f"different user '{deployment.user.name}'.\nDo you want to " + "continue and provision it " + "(hint: use the --overtake flag to skip this check)?" + ) + if not confirmation: + cli_utils.declare("Deployment provisioning canceled.") + return + + with console.status( + f"Provisioning deployment '{deployment_name_or_id}'...\n" + ): + try: + deployment = Client().provision_deployment( + name_id_or_prefix=deployment_name_or_id, + snapshot_id=snapshot_id, + timeout=timeout, + ) + except KeyError as e: + cli_utils.error(str(e)) + else: + cli_utils.declare( + f"Provisioned deployment '{deployment_name_or_id}'." + ) + cli_utils.pretty_print_deployment(deployment, show_secret=True) + + +@deployment.command("deprovision") +@click.argument("deployment_name_or_id", type=str, required=False) +@click.option( + "--all", + "-a", + is_flag=True, + default=False, + help="Deprovision all deployments.", +) +@click.option( + "--mine", + "-m", + is_flag=True, + default=False, + help="Deprovision only deployments owned by the current user.", +) +@click.option( + "--ignore-errors", + "-i", + is_flag=True, + default=False, + help="Ignore errors when deprovisioning multiple deployments.", +) +@click.option( + "--yes", + "-y", + is_flag=True, + default=False, + help="Don't ask for confirmation.", +) +@click.option( + "--max-count", + "-c", + "max_count", + type=int, + required=False, + default=10, + help="Maximum number of deployments to deprovision in one go.", +) +@click.option( + "--timeout", + "-t", + "timeout", + type=int, + required=False, + default=None, + help="Maximum time in seconds to wait for the deployment to be " + "deprovisioned.", +) +def deprovision_deployment( + deployment_name_or_id: Optional[str] = None, + all: bool = False, + mine: bool = False, + yes: bool = False, + ignore_errors: bool = False, + max_count: int = 10, + timeout: Optional[int] = None, +) -> None: + """Deprovision a deployment. + + Args: + deployment_name_or_id: The name or ID of the deployment to deprovision. + all: If set, deprovision all deployments. + mine: If set, deprovision only deployments owned by the current user. + yes: If set, don't ask for confirmation. + ignore_errors: If set, ignore errors when deprovisioning multiple + deployments. + max_count: The maximum number of deployments to deprovision in one go. + timeout: The maximum time in seconds to wait for the deployment + to be deprovisioned. + """ + client = Client() + + if all: + deployments = client.list_deployments(size=max_count).items + elif mine: + deployments = client.list_deployments( + user=client.active_user.id, + size=max_count, + ).items + elif deployment_name_or_id: + deployments = [ + client.get_deployment(name_id_or_prefix=deployment_name_or_id) + ] + else: + cli_utils.error( + "Either a deployment name or ID must be provided or --all or " + "--mine must be set." + ) + + if len(deployments) == 0: + cli_utils.error("No deployments found to deprovision.") + + if len(deployments) == 1: + cli_utils.declare( + f"The following deployment will be deprovisioned: " + f"{deployments[0].name}" + ) + else: + deployment_names = [deployment.name for deployment in deployments] + cli_utils.declare( + f"The following deployments will ALL be deprovisioned: " + f"{', '.join(deployment_names)}" + ) + + other_deployments = [ + deployment + for deployment in deployments + if deployment.user and deployment.user.id != client.active_user.id + ] + + if other_deployments: + deployment_names = [ + deployment.name for deployment in other_deployments + ] + cli_utils.warning( + f"The following deployments are not owned by the current " + f"user: {', '.join(deployment_names)} !" + ) + + if not yes: + confirmation = cli_utils.confirmation( + "Are you sure you want to continue?" + ) + if not confirmation: + cli_utils.declare("Deployment deprovision canceled.") + return + + for deployment in deployments: + with console.status( + f"Deprovisioning deployment '{deployment.name}'...\n" + ): + try: + client.deprovision_deployment( + name_id_or_prefix=deployment.id, + timeout=timeout, + ) + cli_utils.declare( + f"Deprovisioned deployment '{deployment.name}'." + ) + cli_utils.declare( + "Hint: to permanently delete the deployment, run `zenml " + f"deployments delete {deployment.name}`." + ) + except KeyError as e: + error_message = ( + f"Failed to deprovision deployment '{deployment.name}': " + f"{str(e)}" + ) + if all or mine and ignore_errors: + cli_utils.warning(error_message) + else: + cli_utils.error(error_message) + + +@deployment.command("delete") +@click.argument("deployment_name_or_id", type=str, required=False) +@click.option( + "--all", + "-a", + is_flag=True, + default=False, + help="Deprovision all deployments.", +) +@click.option( + "--mine", + "-m", + is_flag=True, + default=False, + help="Deprovision only deployments owned by the current user.", +) +@click.option( + "--ignore-errors", + "-i", + is_flag=True, + default=False, + help="Ignore errors when deprovisioning multiple deployments.", +) +@click.option( + "--yes", + "-y", + is_flag=True, + default=False, + help="Don't ask for confirmation.", +) +@click.option( + "--timeout", + "-t", + "timeout", + type=int, + required=False, + default=None, + help="Maximum time in seconds to wait for the deployment to be " + "deprovisioned.", +) +@click.option( + "--max-count", + "-c", + "max_count", + type=int, + required=False, + default=20, + help="Maximum number of deployments to delete in one go.", +) +@click.option( + "--force", + "-f", + is_flag=True, + default=False, + help="Force the deletion of the deployment if it cannot be deprovisioned.", +) +def delete_deployment( + deployment_name_or_id: Optional[str] = None, + all: bool = False, + mine: bool = False, + ignore_errors: bool = False, + yes: bool = False, + timeout: Optional[int] = None, + max_count: int = 20, + force: bool = False, +) -> None: + """Deprovision and delete a deployment. + + Args: + deployment_name_or_id: The name or ID of the deployment to delete. + all: If set, delete all deployments. + mine: If set, delete only deployments owned by the current user. + ignore_errors: If set, ignore errors when deleting multiple deployments. + yes: If set, don't ask for confirmation. + timeout: The maximum time in seconds to wait for the deployment + to be deprovisioned. + max_count: The maximum number of deployments to delete in one go. + force: If set, force the deletion of the deployment if it cannot + be deprovisioned. + """ + client = Client() + + if all: + deployments = client.list_deployments(size=max_count).items + elif mine: + deployments = client.list_deployments( + user=client.active_user.id, + size=max_count, + ).items + elif deployment_name_or_id: + deployments = [ + client.get_deployment(name_id_or_prefix=deployment_name_or_id) + ] + else: + cli_utils.error( + "Either a deployment name or ID must be provided or --all or " + "--mine must be set." + ) + + if len(deployments) == 0: + cli_utils.error("No deployments found to delete.") + + if len(deployments) == 1: + cli_utils.declare( + f"The following deployment will be deleted: {deployments[0].name}" + ) + else: + deployment_names = [deployment.name for deployment in deployments] + cli_utils.declare( + f"The following deployments will ALL be deleted: " + f"{', '.join(deployment_names)}" + ) + + other_deployments = [ + deployment + for deployment in deployments + if deployment.user and deployment.user.id != client.active_user.id + ] + + if other_deployments: + deployment_names = [ + deployment.name for deployment in other_deployments + ] + cli_utils.warning( + f"The following deployments are not owned by the current " + f"user: {', '.join(deployment_names)} !" + ) + + if not yes: + confirmation = cli_utils.confirmation( + "Are you sure you want to continue?" + ) + if not confirmation: + cli_utils.declare("Deployment deletion canceled.") + return + + for deployment in deployments: + with console.status(f"Deleting deployment '{deployment.name}'...\n"): + try: + Client().delete_deployment( + name_id_or_prefix=deployment.id, + force=force, + timeout=timeout, + ) + cli_utils.declare(f"Deleted deployment '{deployment.name}'.") + except KeyError as e: + error_message = ( + f"Failed to delete deployment '{deployment.name}': " + f"{str(e)}" + ) + if all or mine and ignore_errors: + cli_utils.warning(error_message) + else: + cli_utils.error(error_message) + + +@deployment.command("refresh") +@click.argument("deployment_name_or_id", type=str, required=True) +def refresh_deployment( + deployment_name_or_id: str, +) -> None: + """Refresh the status of a deployment. + + Args: + deployment_name_or_id: The name or ID of the deployment to refresh. + """ + try: + deployment = Client().refresh_deployment( + name_id_or_prefix=deployment_name_or_id + ) + + except KeyError as e: + cli_utils.error(str(e)) + else: + cli_utils.pretty_print_deployment(deployment, show_secret=True) + + +@deployment.command( + "invoke", context_settings={"ignore_unknown_options": True} +) +@click.argument("deployment_name_or_id", type=str, required=True) +@click.option( + "--timeout", + "-t", + "timeout", + type=int, + required=False, + default=None, + help="Maximum time in seconds to wait for the deployment to be invoked.", +) +@click.argument("args", nargs=-1, type=click.UNPROCESSED) +def invoke_deployment( + deployment_name_or_id: str, + args: List[str], + timeout: Optional[int] = None, +) -> None: + """Call a deployment with arguments. + + Args: + deployment_name_or_id: The name or ID of the deployment to call. + args: The arguments to pass to the deployment call. + timeout: The maximum time in seconds to wait for the deployment + to be invoked. + """ + from zenml.deployers.utils import invoke_deployment + + # Parse the given args + args = list(args) + args.append(deployment_name_or_id) + + name_or_id, parsed_args = cli_utils.parse_name_and_extra_arguments( + args, + expand_args=True, + name_mandatory=True, + ) + assert name_or_id is not None + + try: + response = invoke_deployment( + deployment_name_or_id=name_or_id, + timeout=timeout or 300, # 5 minute timeout + project=None, + **parsed_args, + ) + except DeploymentInvalidParametersError as e: + cli_utils.error( + f"Invalid parameters for deployment '{name_or_id}': \n" + f"{str(e)}\n\n" + f"Hint: run 'zenml deployment describe --schema {name_or_id}' " + "to inspect the deployment schema." + ) + except KeyError as e: + cli_utils.error( + str(e) + + "\n" + + f"Hint: run [green]`zenml deployment logs {name_or_id}`[/green] " + "to inspect the deployment logs." + ) + else: + cli_utils.declare(f"Invoked deployment '{name_or_id}' with response:") + print(json.dumps(response, indent=2)) + if isinstance(response, dict) and not response.get("success", True): + cli_utils.declare( + f"Hint: run [green]`zenml deployment logs {name_or_id}`[/green] " + "to inspect the deployment logs." + ) + + +@deployment.command("logs") +@click.argument("deployment_name_or_id", type=str, required=True) +@click.option( + "--follow", + "-f", + is_flag=True, + default=False, + help="Follow the logs.", +) +@click.option( + "--tail", + "-t", + type=int, + default=None, + help="The number of lines to show from the end of the logs.", +) +def log_deployment( + deployment_name_or_id: str, + follow: bool = False, + tail: Optional[int] = None, +) -> None: + """Get the logs of a deployment. + + Args: + deployment_name_or_id: The name or ID of the deployment to get the logs of. + follow: If True, follow the logs. + tail: The number of lines to show from the end of the logs. If None, + show all logs. + """ + try: + logs = Client().get_deployment_logs( + name_id_or_prefix=deployment_name_or_id, + follow=follow, + tail=tail, + ) + except KeyError as e: + cli_utils.error(str(e)) + else: + with console.status( + f"Streaming logs for deployment '{deployment_name_or_id}'...\n" + ): + for log in logs: + print(log) diff --git a/src/zenml/cli/pipeline.py b/src/zenml/cli/pipeline.py index 0d43c27c7e1..dc4253987ce 100644 --- a/src/zenml/cli/pipeline.py +++ b/src/zenml/cli/pipeline.py @@ -16,15 +16,15 @@ import json import os from typing import Any, Dict, List, Optional, Union -from uuid import UUID import click from zenml.cli import utils as cli_utils from zenml.cli.cli import TagGroup, cli -from zenml.cli.utils import list_options +from zenml.cli.utils import fetch_snapshot, list_options from zenml.client import Client from zenml.console import console +from zenml.deployers.base_deployer import BaseDeployer from zenml.enums import CliCategories from zenml.logger import get_logger from zenml.models import ( @@ -301,6 +301,201 @@ def run_pipeline( pipeline_instance() +@pipeline.command( + "deploy", + help="Deploy a pipeline. The SOURCE argument needs to be an " + "importable source path resolving to a ZenML pipeline instance, e.g. " + "`my_module.my_pipeline_instance`.", +) +@click.argument("source") +@click.option( + "--name", + "-n", + "deployment_name", + type=str, + required=False, + help="The name of the deployment resulted from deploying the pipeline. If " + "not provided, the name of the pipeline will be used. If an existing " + "deployment with the same name already exists, an error will be raised, " + "unless the --update or --overtake flag is used.", +) +@click.option( + "--config", + "-c", + "config_path", + type=click.Path(exists=True, dir_okay=False), + required=False, + help="Path to configuration file for the deployment.", +) +@click.option( + "--stack", + "-s", + "stack_name_or_id", + type=str, + required=False, + help="Name or ID of the stack to deploy on.", +) +@click.option( + "--build", + "-b", + "build_path_or_id", + type=str, + required=False, + help="ID or path of the build to use.", +) +@click.option( + "--prevent-build-reuse", + is_flag=True, + default=False, + required=False, + help="Prevent automatic build reusing.", +) +@click.option( + "--update", + "-u", + "update", + is_flag=True, + default=False, + required=False, + help="Update the deployment with the same name if it already exists.", +) +@click.option( + "--overtake", + "-o", + "overtake", + is_flag=True, + default=False, + required=False, + help="Update the deployment with the same name if it already " + "exists, even if it is owned by a different user.", +) +@click.option( + "--attach", + "-a", + "attach", + is_flag=True, + default=False, + required=False, + help="Attach to the deployment logs.", +) +@click.option( + "--timeout", + "-t", + "timeout", + type=int, + required=False, + default=None, + help="Maximum time in seconds to wait for the pipeline to be deployed.", +) +def deploy_pipeline( + source: str, + deployment_name: Optional[str] = None, + config_path: Optional[str] = None, + stack_name_or_id: Optional[str] = None, + build_path_or_id: Optional[str] = None, + prevent_build_reuse: bool = False, + update: bool = False, + overtake: bool = False, + attach: bool = False, + timeout: Optional[int] = None, +) -> None: + """Deploy a pipeline for online inference. + + Args: + source: Importable source resolving to a pipeline instance. + deployment_name: Name of the deployment used to deploy the pipeline on. + config_path: Path to pipeline configuration file. + stack_name_or_id: Name or ID of the stack on which the pipeline should + be deployed. + build_path_or_id: ID of file path of the build to use for the pipeline + deployment. + prevent_build_reuse: If True, prevents automatic reusing of previous + builds. + update: If True, update the deployment with the same name if it + already exists. + overtake: If True, update the deployment with the same name if + it already exists, even if it is owned by a different user. + attach: If True, attach to the deployment logs. + timeout: The maximum time in seconds to wait for the pipeline to be + deployed. + """ + if not Client().root: + cli_utils.warning( + "You're running the `zenml pipeline deploy` command without a " + "ZenML repository. Your current working directory will be used " + "as the source root relative to which the registered step classes " + "will be resolved. To silence this warning, run `zenml init` at " + "your source code root." + ) + + with cli_utils.temporary_active_stack(stack_name_or_id=stack_name_or_id): + pipeline_instance = _import_pipeline(source=source) + + build: Union[str, PipelineBuildBase, None] = None + if build_path_or_id: + if uuid_utils.is_valid_uuid(build_path_or_id): + build = build_path_or_id + elif os.path.exists(build_path_or_id): + build = PipelineBuildBase.from_yaml(build_path_or_id) + else: + cli_utils.error( + f"The specified build {build_path_or_id} is not a valid UUID " + "or file path." + ) + + pipeline_instance = pipeline_instance.with_options( + config_path=config_path, + build=build, + prevent_build_reuse=prevent_build_reuse, + ) + if not deployment_name: + deployment_name = pipeline_instance.name + client = Client() + try: + deployment = client.get_deployment(deployment_name) + except KeyError: + pass + else: + if ( + deployment.user + and deployment.user.id != client.active_user.id + and not overtake + ): + confirmation = cli_utils.confirmation( + f"Deployment with name '{deployment_name}' already exists " + f"and is owned by a different user '{deployment.user.name}'." + "\nDo you want to continue and update the existing deployment " + "(hint: use the --overtake flag to skip this check) ?" + ) + if not confirmation: + cli_utils.declare("Deployment canceled.") + return + elif not update and not overtake: + confirmation = cli_utils.confirmation( + f"Deployment with name '{deployment_name}' already exists.\n" + "Do you want to continue and update the existing " + "deployment " + "(hint: use the --update flag to skip this check) ?" + ) + if not confirmation: + cli_utils.declare("Deployment canceled.") + return + + deployment = pipeline_instance.deploy( + deployment_name=deployment_name, timeout=timeout + ) + + cli_utils.pretty_print_deployment(deployment, show_secret=False) + + if attach: + deployer = BaseDeployer.get_active_deployer() + for log in deployer.get_deployment_logs( + deployment_name_or_id=deployment.id, + follow=True, + ): + print(log) + + @pipeline.command( "create-run-template", help="Create a run template for a pipeline. The SOURCE argument needs to " @@ -877,50 +1072,10 @@ def run_snapshot( pipeline_name_or_id: The name or ID of the pipeline. config_path: Path to configuration file for the run. """ - if uuid_utils.is_valid_uuid(snapshot_name_or_id): - snapshot_id = UUID(snapshot_name_or_id) - elif pipeline_name_or_id: - try: - snapshot_id = ( - Client() - .get_snapshot( - snapshot_name_or_id, - pipeline_name_or_id=pipeline_name_or_id, - ) - .id - ) - except KeyError: - cli_utils.error( - f"There are no snapshots with name `{snapshot_name_or_id}` for " - f"pipeline `{pipeline_name_or_id}`." - ) - else: - snapshots = Client().list_snapshots( - name=snapshot_name_or_id, - ) - if snapshots.total == 0: - cli_utils.error( - f"There are no snapshots with name `{snapshot_name_or_id}`." - ) - elif snapshots.total == 1: - snapshot_id = snapshots.items[0].id - else: - snapshot_index = cli_utils.multi_choice_prompt( - object_type="snapshots", - choices=[ - [snapshot.pipeline.name, snapshot.name] - for snapshot in snapshots.items - ], - headers=["Pipeline", "Snapshot"], - prompt_text=f"There are multiple snapshots with name " - f"`{snapshot_name_or_id}`. Please select the snapshot to run", - ) - assert snapshot_index is not None - snapshot_id = snapshots.items[snapshot_index].id - + snapshot = fetch_snapshot(snapshot_name_or_id, pipeline_name_or_id) try: run = Client().trigger_pipeline( - snapshot_name_or_id=snapshot_id, + snapshot_name_or_id=snapshot.id, config_path=config_path, ) cli_utils.declare(f"Started snapshot run `{run.id}`.") @@ -928,6 +1083,142 @@ def run_snapshot( cli_utils.error(f"Failed to run snapshot: {e}") +@snapshot.command("deploy", help="Deploy a snapshot.") +@click.argument("snapshot_name_or_id") +@click.option( + "--pipeline", + "-p", + "pipeline_name_or_id", + type=str, + required=False, + help="The name or ID of the pipeline.", +) +@click.option( + "--deployment", + "-d", + "deployment_name_or_id", + type=str, + required=False, + help="The name or ID of the deployment to use for the pipeline. If " + "not provided, the name of the snapshot or pipeline will be used. If an " + "existing deployment with the same name already exists, an error will be " + "raised, unless the --update or --overtake flag is used.", +) +@click.option( + "--update", + "-u", + "update", + is_flag=True, + default=False, + required=False, + help="Update the deployment with the same name if it already exists.", +) +@click.option( + "--overtake", + "-o", + "overtake", + is_flag=True, + default=False, + required=False, + help="Update the deployment with the same name if it already " + "exists, even if it is owned by a different user.", +) +@click.option( + "--timeout", + "-t", + "timeout", + type=int, + required=False, + default=None, + help="Maximum time in seconds to wait for the snapshot to be deployed.", +) +def deploy_snapshot( + snapshot_name_or_id: str, + pipeline_name_or_id: Optional[str] = None, + deployment_name_or_id: Optional[str] = None, + update: bool = False, + overtake: bool = False, + timeout: Optional[int] = None, +) -> None: + """Deploy a pipeline for online inference. + + Args: + snapshot_name_or_id: The name or ID of the snapshot to deploy. + pipeline_name_or_id: The name or ID of the pipeline. + deployment_name_or_id: Name or ID of the deployment to use for the + pipeline. + update: If True, update the deployment with the same name if it + already exists. + overtake: If True, update the deployment with the same name if + it already exists, even if it is owned by a different user. + timeout: The maximum time in seconds to wait for the pipeline to be + deployed. + """ + snapshot = fetch_snapshot(snapshot_name_or_id, pipeline_name_or_id) + + if not deployment_name_or_id: + deployment_name_or_id = snapshot.name or snapshot.pipeline.name + + if not deployment_name_or_id: + cli_utils.error( + "No deployment name or ID provided. Please provide a deployment name or ID." + ) + + client = Client() + try: + deployment = client.get_deployment(deployment_name_or_id) + except KeyError: + pass + else: + if ( + deployment.user + and deployment.user.id != client.active_user.id + and not overtake + ): + confirmation = cli_utils.confirmation( + f"Deployment with name or ID '{deployment_name_or_id}' is " + f"owned by a different user '{deployment.user.name}'.\nDo you " + "want to continue and provision it " + "(hint: use the --overtake flag to skip this check)?" + ) + if not confirmation: + cli_utils.declare("Deployment provisioning canceled.") + return + + elif ( + not update + and not overtake + and not uuid_utils.is_valid_uuid(deployment_name_or_id) + ): + confirmation = cli_utils.confirmation( + f"Deployment with name or ID '{deployment_name_or_id}' already " + "exists.\n" + "Do you want to continue and update the existing " + "deployment " + "(hint: use the --update flag to skip this check) ?" + ) + if not confirmation: + cli_utils.declare("Deployment canceled.") + return + + with console.status( + f"Provisioning deployment '{deployment_name_or_id}'...\n" + ): + try: + deployment = Client().provision_deployment( + name_id_or_prefix=deployment_name_or_id, + snapshot_id=snapshot.id, + timeout=timeout, + ) + except KeyError as e: + cli_utils.error(str(e)) + else: + cli_utils.declare( + f"Provisioned deployment '{deployment_name_or_id}'." + ) + cli_utils.pretty_print_deployment(deployment, show_secret=True) + + @snapshot.command("list", help="List pipeline snapshots.") @list_options(PipelineSnapshotFilter) def list_pipeline_snapshots(**kwargs: Any) -> None: diff --git a/src/zenml/cli/stack.py b/src/zenml/cli/stack.py index 12d29b95fdf..87574c68a39 100644 --- a/src/zenml/cli/stack.py +++ b/src/zenml/cli/stack.py @@ -196,6 +196,14 @@ def stack() -> None: type=str, required=False, ) +@click.option( + "-D", + "--deployer", + "deployer", + help="Name of the deployer for this stack.", + type=str, + required=False, +) @click.option( "--set", "set_stack", @@ -248,6 +256,7 @@ def register_stack( annotator: Optional[str] = None, data_validator: Optional[str] = None, image_builder: Optional[str] = None, + deployer: Optional[str] = None, set_stack: bool = False, provider: Optional[str] = None, connector: Optional[str] = None, @@ -270,6 +279,7 @@ def register_stack( annotator: Name of the annotator for this stack. data_validator: Name of the data validator for this stack. image_builder: Name of the new image builder for this stack. + deployer: Name of the deployer for this stack. set_stack: Immediately set this stack as active. provider: Name of the cloud provider for this stack. connector: Name of the service connector for this stack. @@ -519,6 +529,7 @@ def register_stack( (StackComponentType.STEP_OPERATOR, step_operator), (StackComponentType.EXPERIMENT_TRACKER, experiment_tracker), (StackComponentType.CONTAINER_REGISTRY, container_registry), + (StackComponentType.DEPLOYER, deployer), ]: if component_name_ and component_type_ not in components: components[component_type_] = [ @@ -688,6 +699,14 @@ def register_stack( type=str, required=False, ) +@click.option( + "-D", + "--deployer", + "deployer", + help="Name of the deployer for this stack.", + type=str, + required=False, +) @click.option( "--secret", "secrets", @@ -728,6 +747,7 @@ def update_stack( data_validator: Optional[str] = None, image_builder: Optional[str] = None, model_registry: Optional[str] = None, + deployer: Optional[str] = None, secrets: List[str] = [], remove_secrets: List[str] = [], environment_variables: List[str] = [], @@ -749,6 +769,7 @@ def update_stack( data_validator: Name of the new data validator for this stack. image_builder: Name of the new image builder for this stack. model_registry: Name of the new model registry for this stack. + deployer: Name of the new deployer for this stack. secrets: Secrets to attach to the stack. remove_secrets: Secrets to remove from the stack. environment_variables: Environment variables to set when running on this @@ -793,6 +814,8 @@ def update_stack( updates[StackComponentType.ORCHESTRATOR] = [orchestrator] if step_operator: updates[StackComponentType.STEP_OPERATOR] = [step_operator] + if deployer: + updates[StackComponentType.DEPLOYER] = [deployer] try: updated_stack = client.update_stack( @@ -898,6 +921,14 @@ def update_stack( is_flag=True, required=False, ) +@click.option( + "-D", + "--deployer", + "deployer_flag", + help="Include this to remove the deployer from this stack.", + is_flag=True, + required=False, +) def remove_stack_component( stack_name_or_id: Optional[str] = None, container_registry_flag: Optional[bool] = False, @@ -910,6 +941,7 @@ def remove_stack_component( data_validator_flag: Optional[bool] = False, image_builder_flag: Optional[bool] = False, model_registry_flag: Optional[str] = None, + deployer_flag: Optional[bool] = False, ) -> None: """Remove stack components from a stack. @@ -927,6 +959,7 @@ def remove_stack_component( data_validator_flag: To remove the data validator from this stack. image_builder_flag: To remove the image builder from this stack. model_registry_flag: To remove the model registry from this stack. + deployer_flag: To remove the deployer from this stack. """ client = Client() @@ -963,6 +996,9 @@ def remove_stack_component( if image_builder_flag: stack_component_update[StackComponentType.IMAGE_BUILDER] = [] + if deployer_flag: + stack_component_update[StackComponentType.DEPLOYER] = [] + try: updated_stack = client.update_stack( name_id_or_prefix=stack_name_or_id, diff --git a/src/zenml/cli/utils.py b/src/zenml/cli/utils.py index 5ab3d6f2bcf..160e885f717 100644 --- a/src/zenml/cli/utils.py +++ b/src/zenml/cli/utils.py @@ -60,7 +60,17 @@ FILTERING_DATETIME_FORMAT, IS_DEBUG_ENV, ) -from zenml.enums import GenericFilterOps, ServiceState, StackComponentType +from zenml.deployers.utils import ( + get_deployment_input_schema, + get_deployment_invocation_example, + get_deployment_output_schema, +) +from zenml.enums import ( + DeploymentStatus, + GenericFilterOps, + ServiceState, + StackComponentType, +) from zenml.logger import get_logger from zenml.model_registries.base_model_registry import ( RegisteredModel, @@ -85,6 +95,7 @@ from zenml.utils.package_utils import requirement_installed from zenml.utils.time_utils import expires_in from zenml.utils.typing_utils import get_origin, is_union +from zenml.utils.uuid_utils import is_valid_uuid if TYPE_CHECKING: from uuid import UUID @@ -97,8 +108,10 @@ from zenml.models import ( AuthenticationMethodModel, ComponentResponse, + DeploymentResponse, FlavorResponse, PipelineRunResponse, + PipelineSnapshotResponse, ResourceTypeModel, ServiceConnectorRequest, ServiceConnectorResourcesModel, @@ -2305,6 +2318,332 @@ def print_pipeline_runs_table( print_table(runs_dicts) +def fetch_snapshot( + snapshot_name_or_id: str, + pipeline_name_or_id: Optional[str] = None, +) -> "PipelineSnapshotResponse": + """Fetch a snapshot by name or ID. + + Args: + snapshot_name_or_id: The name or ID of the snapshot. + pipeline_name_or_id: The name or ID of the pipeline. + + Returns: + The snapshot. + """ + if is_valid_uuid(snapshot_name_or_id): + return Client().get_snapshot(snapshot_name_or_id) + elif pipeline_name_or_id: + try: + return Client().get_snapshot( + snapshot_name_or_id, + pipeline_name_or_id=pipeline_name_or_id, + ) + except KeyError: + error( + f"There are no snapshots with name `{snapshot_name_or_id}` for " + f"pipeline `{pipeline_name_or_id}`." + ) + else: + snapshots = Client().list_snapshots( + name=snapshot_name_or_id, + ) + if snapshots.total == 0: + error(f"There are no snapshots with name `{snapshot_name_or_id}`.") + elif snapshots.total == 1: + return snapshots.items[0] + + snapshot_index = multi_choice_prompt( + object_type="snapshots", + choices=[ + [snapshot.pipeline.name, snapshot.name] + for snapshot in snapshots.items + ], + headers=["Pipeline", "Snapshot"], + prompt_text=f"There are multiple snapshots with name " + f"`{snapshot_name_or_id}`. Please select the snapshot to run", + ) + assert snapshot_index is not None + return snapshots.items[snapshot_index] + + +def get_deployment_status_emoji( + status: Optional[str], +) -> str: + """Returns an emoji representing the given deployment status. + + Args: + status: The deployment status to get the emoji for. + + Returns: + An emoji representing the given deployment status. + """ + if status == DeploymentStatus.PENDING: + return ":hourglass_flowing_sand:" + if status == DeploymentStatus.ERROR: + return ":x:" + if status == DeploymentStatus.RUNNING: + return ":gear:" + if status == DeploymentStatus.ABSENT: + return ":stop_sign:" + + return ":question:" + + +def format_deployment_status(status: Optional[str]) -> str: + """Format deployment status with color. + + Args: + status: The deployment status. + + Returns: + Formatted status string. + """ + if status == DeploymentStatus.RUNNING: + return "[green]RUNNING[/green]" + elif status == DeploymentStatus.PENDING: + return "[yellow]PENDING[/yellow]" + elif status == DeploymentStatus.ERROR: + return "[red]ERROR[/red]" + elif status == DeploymentStatus.ABSENT: + return "[dim]ABSENT[/dim]" + + return "[dim]UNKNOWN[/dim]" + + +def print_deployment_table( + deployments: Sequence["DeploymentResponse"], +) -> None: + """Print a prettified list of all deployments supplied to this method. + + Args: + deployments: List of deployments + """ + deployment_dicts = [] + for deployment in deployments: + if deployment.user: + user_name = deployment.user.name + else: + user_name = "-" + + if deployment.snapshot is None or deployment.snapshot.pipeline is None: + pipeline_name = "unlisted" + else: + pipeline_name = deployment.snapshot.pipeline.name + if deployment.snapshot is None or deployment.snapshot.stack is None: + stack_name = "[DELETED]" + else: + stack_name = deployment.snapshot.stack.name + status = deployment.status or DeploymentStatus.UNKNOWN.value + status_emoji = get_deployment_status_emoji(status) + run_dict = { + "NAME": deployment.name, + "PIPELINE": pipeline_name, + "SNAPSHOT": deployment.snapshot.name or "" + if deployment.snapshot + else "N/A", + "URL": deployment.url or "N/A", + "STATUS": f"{status_emoji} {status.upper()}", + "STACK": stack_name, + "OWNER": user_name, + } + deployment_dicts.append(run_dict) + print_table(deployment_dicts) + + +def pretty_print_deployment( + deployment: "DeploymentResponse", + show_secret: bool = False, + show_metadata: bool = False, + show_schema: bool = False, + no_truncate: bool = False, +) -> None: + """Print a prettified deployment with organized sections. + + Args: + deployment: The deployment to print. + show_secret: Whether to show the auth key or mask it. + show_metadata: Whether to show the metadata. + show_schema: Whether to show the schema. + no_truncate: Whether to truncate the metadata. + """ + # Header section + status = format_deployment_status(deployment.status) + status_emoji = get_deployment_status_emoji(deployment.status) + declare( + f"\n๐Ÿš€ Deployment: [bold cyan]{deployment.name}[/bold cyan] is: {status} {status_emoji}" + ) + if deployment.snapshot is None: + pipeline_name = "N/A" + snapshot_name = "N/A" + else: + pipeline_name = deployment.snapshot.pipeline.name + snapshot_name = deployment.snapshot.name or str(deployment.snapshot.id) + if deployment.snapshot is None or deployment.snapshot.stack is None: + stack_name = "[DELETED]" + else: + stack_name = deployment.snapshot.stack.name + declare(f"\n[bold]Pipeline:[/bold] [bold cyan]{pipeline_name}[/bold cyan]") + declare(f"[bold]Snapshot:[/bold] [bold cyan]{snapshot_name}[/bold cyan]") + declare(f"[bold]Stack:[/bold] [bold cyan]{stack_name}[/bold cyan]") + + # Connection section + if deployment.url: + declare("\n๐Ÿ“ก [bold]Connection Information:[/bold]") + + declare(f"\n[bold]Endpoint URL:[/bold] [link]{deployment.url}[/link]") + declare( + f"[bold]Swagger URL:[/bold] [link]{deployment.url.rstrip('/')}/docs[/link]" + ) + + # Auth key handling with proper security + auth_key = deployment.auth_key + if auth_key: + if show_secret: + declare(f"[bold]Auth Key:[/bold] [yellow]{auth_key}[/yellow]") + else: + masked_key = ( + f"{auth_key[:8]}***" if len(auth_key) > 8 else "***" + ) + declare( + f"[bold]Auth Key:[/bold] [yellow]{masked_key}[/yellow] " + f"[dim](run [green]`zenml deployment describe {deployment.name} " + "--show-secret`[/green] to reveal)[/dim]" + ) + + example = get_deployment_invocation_example(deployment) + + # CLI invoke command + cli_args = " ".join( + [ + f"--{k}=" + + ( + f"'{json.dumps(v)}'" + if isinstance(v, (dict, list)) + else json.dumps(v) + ) + for k, v in example.items() + ] + ) + cli_command = f"zenml deployment invoke {deployment.name} {cli_args}" + + declare("[bold]CLI Command Example:[/bold]") + console.print(f" [green]{cli_command}[/green]") + + # cURL example + declare("\n[bold]cURL Example:[/bold]") + curl_headers = [] + if auth_key: + if show_secret: + curl_headers.append(f'-H "Authorization: Bearer {auth_key}"') + else: + curl_headers.append( + '-H "Authorization: Bearer "' + ) + + curl_params = json.dumps(example, indent=2).replace("\n", "\n ") + + curl_headers.append('-H "Content-Type: application/json"') + headers_str = " \\\n ".join(curl_headers) + + curl_command = f"""curl -X POST {deployment.url}/invoke \\ + {headers_str} \\ + -d '{{ + "parameters": {curl_params} + }}'""" + + console.print(f" [green]{curl_command}[/green]") + + if show_schema: + input_schema = get_deployment_input_schema(deployment) + output_schema = get_deployment_output_schema(deployment) + declare("\n๐Ÿ“‹ [bold]Deployment JSON Schemas:[/bold]") + declare("\n[bold]Input Schema:[/bold]") + schema_json = json.dumps(input_schema, indent=2) + console.print(f"[green]{schema_json}[/green]") + declare("\n[bold]Output Schema:[/bold]") + schema_json = json.dumps(output_schema, indent=2) + console.print(f"[green]{schema_json}[/green]") + + if show_metadata: + declare("\n๐Ÿ“‹ [bold]Deployment Metadata[/bold]") + + # Get the metadata - it could be from deployment_metadata property or metadata + metadata = deployment.deployment_metadata + + if metadata: + # Recursively format nested dictionaries and lists + def format_value(value: Any, indent_level: int = 0) -> str: + if isinstance(value, dict): + if not value: + return "[dim]{}[/dim]" + formatted_items = [] + for k, v in value.items(): + formatted_v = format_value(v, indent_level + 1) + formatted_items.append( + f" {' ' * indent_level}[bold]{k}[/bold]: {formatted_v}" + ) + return "\n" + "\n".join(formatted_items) + elif isinstance(value, list): + if not value: + return "[dim][][/dim]" + formatted_items = [] + for i, item in enumerate(value): + formatted_item = format_value(item, indent_level + 1) + formatted_items.append( + f" {' ' * indent_level}[{i}]: {formatted_item}" + ) + return "\n" + "\n".join(formatted_items) + elif isinstance(value, str): + # Handle long strings by truncating if needed + if len(value) > 100 and not no_truncate: + return f"[green]{value[:97]}...[/green]" + return f"[green]{value}[/green]" + elif isinstance(value, bool): + return f"[yellow]{value}[/yellow]" + elif isinstance(value, (int, float)): + return f"[blue]{value}[/blue]" + elif value is None: + return "[dim]null[/dim]" + else: + return f"[white]{str(value)}[/white]" + + formatted_metadata = format_value(metadata) + console.print(formatted_metadata) + else: + declare(" [dim]No metadata available[/dim]") + + # Management section + declare("\nโš™๏ธ [bold]Management Commands[/bold]") + + mgmt_table = table.Table( + box=box.ROUNDED, + show_header=False, + border_style="dim", + padding=(0, 1), + ) + mgmt_table.add_column("Command", style="bold") + mgmt_table.add_column("Description") + + mgmt_table.add_row( + f"zenml deployment logs {deployment.name} -f", + "Follow deployment logs in real-time", + ) + mgmt_table.add_row( + f"zenml deployment describe {deployment.name}", + "Show detailed deployment information", + ) + mgmt_table.add_row( + f"zenml deployment deprovision {deployment.name}", + "Deprovision this deployment and keep a record of it", + ) + mgmt_table.add_row( + f"zenml deployment delete {deployment.name}", + "Deprovision and delete this deployment", + ) + console.print(mgmt_table) + + def check_zenml_pro_project_availability() -> None: """Check if the ZenML Pro project feature is available.""" client = Client() diff --git a/src/zenml/client.py b/src/zenml/client.py index 4004d7da2d3..63e2a28d19b 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -24,6 +24,7 @@ Any, Callable, Dict, + Generator, List, Mapping, Optional, @@ -60,6 +61,7 @@ from zenml.enums import ( ArtifactType, ColorVariants, + DeploymentStatus, LogicalOperators, ModelStages, OAuthDeviceStatus, @@ -106,6 +108,8 @@ ComponentRequest, ComponentResponse, ComponentUpdate, + DeploymentFilter, + DeploymentResponse, EventSourceFilter, EventSourceRequest, EventSourceResponse, @@ -3389,6 +3393,8 @@ def list_snapshots( schedule_id: Optional[Union[str, UUID]] = None, source_snapshot_id: Optional[Union[str, UUID]] = None, runnable: Optional[bool] = None, + deployable: Optional[bool] = None, + deployed: Optional[bool] = None, tag: Optional[str] = None, tags: Optional[List[str]] = None, hydrate: bool = False, @@ -3414,6 +3420,8 @@ def list_snapshots( schedule_id: The ID of the schedule to filter by. source_snapshot_id: The ID of the source snapshot to filter by. runnable: Whether the snapshot is runnable. + deployable: Whether the snapshot is deployable. + deployed: Whether the snapshot is deployed. tag: Filter by tag. tags: Filter by tags. hydrate: Flag deciding whether to hydrate the output model(s) @@ -3440,6 +3448,8 @@ def list_snapshots( schedule_id=schedule_id, source_snapshot_id=source_snapshot_id, runnable=runnable, + deployable=deployable, + deployed=deployed, tag=tag, tags=tags, ) @@ -3697,6 +3707,456 @@ def trigger_pipeline( return run + # ------------------------------ Deployments ----------------------------- + + def get_deployment( + self, + name_id_or_prefix: Union[str, UUID], + project: Optional[Union[str, UUID]] = None, + hydrate: bool = True, + ) -> DeploymentResponse: + """Get a deployment. + + Args: + name_id_or_prefix: Name/ID/ID prefix of the deployment to get. + project: The project name/ID to filter by. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + The deployment. + """ + return self._get_entity_by_id_or_name_or_prefix( + get_method=self.zen_store.get_deployment, + list_method=self.list_deployments, + name_id_or_prefix=name_id_or_prefix, + allow_name_prefix_match=False, + project=project, + hydrate=hydrate, + ) + + def list_deployments( + self, + sort_by: str = "created", + page: int = PAGINATION_STARTING_PAGE, + size: int = PAGE_SIZE_DEFAULT, + logical_operator: LogicalOperators = LogicalOperators.AND, + id: Optional[Union[UUID, str]] = None, + created: Optional[Union[datetime, str]] = None, + updated: Optional[Union[datetime, str]] = None, + name: Optional[str] = None, + snapshot_id: Optional[Union[str, UUID]] = None, + deployer_id: Optional[Union[str, UUID]] = None, + project: Optional[Union[str, UUID]] = None, + status: Optional[DeploymentStatus] = None, + url: Optional[str] = None, + user: Optional[Union[UUID, str]] = None, + pipeline: Optional[Union[UUID, str]] = None, + tag: Optional[str] = None, + tags: Optional[List[str]] = None, + hydrate: bool = False, + ) -> Page[DeploymentResponse]: + """List deployments. + + Args: + sort_by: The column to sort by. + page: The page of items. + size: The maximum size of all pages. + logical_operator: Which logical operator to use [and, or]. + id: Use the id of deployments to filter by. + created: Use to filter by time of creation. + updated: Use the last updated date for filtering. + name: The name of the deployment to filter by. + project: The project name/ID to filter by. + snapshot_id: The id of the snapshot to filter by. + deployer_id: The id of the deployer to filter by. + status: The status of the deployment to filter by. + url: The url of the deployment to filter by. + user: Filter by user name/ID. + pipeline: Filter by pipeline name/ID. + tag: Tag to filter by. + tags: Tags to filter by. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + A page of deployments. + """ + return self.zen_store.list_deployments( + deployment_filter_model=DeploymentFilter( + sort_by=sort_by, + page=page, + size=size, + logical_operator=logical_operator, + id=id, + created=created, + updated=updated, + project=project or self.active_project.id, + user=user, + name=name, + snapshot_id=snapshot_id, + deployer_id=deployer_id, + status=status, + url=url, + pipeline=pipeline, + tag=tag, + tags=tags, + ), + hydrate=hydrate, + ) + + def provision_deployment( + self, + name_id_or_prefix: Union[str, UUID], + project: Optional[Union[str, UUID]] = None, + snapshot_id: Optional[Union[str, UUID]] = None, + timeout: Optional[int] = None, + ) -> DeploymentResponse: + """Provision a deployment. + + Args: + name_id_or_prefix: Name/ID/ID prefix of the deployment to provision. + project: The project name/ID to filter by. + snapshot_id: The ID of the snapshot to use. If not provided, + the previous snapshot configured for the deployment will be + used. + timeout: The maximum time in seconds to wait for the pipeline + deployment to be provisioned. + + Returns: + The provisioned deployment. + + Raises: + NotImplementedError: If the deployer cannot be instantiated. + ValueError: If the existing deployment has no associated + snapshot. + KeyError: If the deployment is not found and no snapshot + ID was provided. + """ + from zenml.deployers.base_deployer import ( + BaseDeployer, + ) + from zenml.stack.stack import Stack + from zenml.stack.stack_component import StackComponent + + deployment: Optional[DeploymentResponse] = None + deployment_name_or_id = name_id_or_prefix + try: + deployment = self.get_deployment( + name_id_or_prefix=name_id_or_prefix, + project=project, + hydrate=True, + ) + deployment_name_or_id = deployment.id + except KeyError: + if isinstance(name_id_or_prefix, UUID): + raise + + stack = Client().active_stack + deployer: Optional[BaseDeployer] = None + + if snapshot_id: + snapshot = self.get_snapshot( + name_id_or_prefix=snapshot_id, + project=project, + hydrate=True, + ) + elif not deployment: + raise KeyError( + f"Deployment with name '{name_id_or_prefix}' was not " + "found and no snapshot ID was provided." + ) + else: + # Use the current snapshot + if not deployment.snapshot: + raise ValueError( + f"Deployment '{deployment.name}' has no associated " + "snapshot." + ) + snapshot = deployment.snapshot + + if deployment.deployer: + try: + deployer = cast( + BaseDeployer, + StackComponent.from_model(deployment.deployer), + ) + except ImportError: + raise NotImplementedError( + f"Deployer '{deployment.deployer.name}' could " + f"not be instantiated. This is likely because the " + f"deployer's dependencies are not installed." + ) + + if snapshot.stack and snapshot.stack.id != stack.id: + # We really need to use the original stack for which the deployment + # was created for to provision the deployment, otherwise the deployment + # might not have the correct dependencies installed. + stack = Stack.from_model(snapshot.stack) + + if not deployer: + if stack.deployer: + deployer = stack.deployer + else: + raise ValueError( + f"No deployer was found in the deployment's stack " + f"'{stack.name}' or in your active stack. Please add a " + "deployer to your stack to be able to provision a " + "deployment." + ) + + # Provision the endpoint through the deployer + deployment = deployer.provision_deployment( + snapshot=snapshot, + stack=stack, + deployment_name_or_id=deployment_name_or_id, + replace=True, + timeout=timeout, + ) + logger.info( + f"Provisioned deployment with name '{deployment.name}'.", + ) + + return deployment + + def deprovision_deployment( + self, + name_id_or_prefix: Union[str, UUID], + project: Optional[Union[str, UUID]] = None, + timeout: Optional[int] = None, + ) -> None: + """Deprovision a deployment. + + Args: + name_id_or_prefix: Name/ID/ID prefix of the deployment to deprovision. + project: The project name/ID to filter by. + timeout: The maximum time in seconds to wait for the deployment to + be deprovisioned. + + Raises: + NotImplementedError: If the deployer cannot be instantiated. + """ + from zenml.deployers.base_deployer import ( + BaseDeployer, + ) + from zenml.stack.stack_component import StackComponent + + deployment = self.get_deployment( + name_id_or_prefix=name_id_or_prefix, + project=project, + hydrate=False, + ) + if deployment.deployer: + # Instantiate and deprovision the deployment through the pipeline + # server + + try: + deployer = cast( + BaseDeployer, + StackComponent.from_model(deployment.deployer), + ) + except ImportError: + raise NotImplementedError( + f"Deployer '{deployment.deployer.name}' could " + f"not be instantiated. This is likely because the " + f"deployer's dependencies are not installed." + ) + deployer.deprovision_deployment( + deployment_name_or_id=deployment.id, + timeout=timeout, + ) + logger.info( + "Deprovisioned deployment with name '%s'.", + deployment.name, + ) + else: + logger.info( + f"Deployment with name '{deployment.name}' is no longer " + "managed by a deployer. This is likely because the deployer " + "was deleted. Please delete the deployment instead.", + ) + + def delete_deployment( + self, + name_id_or_prefix: Union[str, UUID], + project: Optional[Union[str, UUID]] = None, + force: bool = False, + timeout: Optional[int] = None, + ) -> None: + """Deprovision and delete a deployment. + + Args: + name_id_or_prefix: Name/ID/ID prefix of the deployment to delete. + project: The project name/ID to filter by. + force: If True, force the deletion even if the deployment cannot be + deprovisioned. + timeout: The maximum time in seconds to wait for the pipeline + deployment to be deprovisioned. + + Raises: + NotImplementedError: If the deployer cannot be instantiated. + """ + from zenml.deployers.base_deployer import ( + BaseDeployer, + ) + from zenml.stack.stack_component import StackComponent + + deployment = self.get_deployment( + name_id_or_prefix=name_id_or_prefix, + project=project, + hydrate=False, + ) + if deployment.deployer: + # Instantiate and deprovision the deployment through the pipeline + # server + + try: + deployer = cast( + BaseDeployer, + StackComponent.from_model(deployment.deployer), + ) + except ImportError as e: + msg = ( + f"Deployer '{deployment.deployer.name}' could " + f"not be instantiated. This is likely because the " + f"deployer's dependencies are not installed: {e}" + ) + if force: + logger.warning(msg + " Forcing deletion.") + self.zen_store.delete_deployment( + deployment_id=deployment.id + ) + else: + raise NotImplementedError(msg) + except Exception as e: + msg = ( + f"Failed to instantiate deployer '{deployment.deployer.name}'." + f"Error: {e}" + ) + if force: + logger.warning(msg + " Forcing deletion.") + self.zen_store.delete_deployment( + deployment_id=deployment.id + ) + else: + raise NotImplementedError(msg) + else: + deployer.delete_deployment( + deployment_name_or_id=deployment.id, + force=force, + timeout=timeout, + ) + else: + self.zen_store.delete_deployment(deployment_id=deployment.id) + logger.info("Deleted deployment with name '%s'.", deployment.name) + + def refresh_deployment( + self, + name_id_or_prefix: Union[str, UUID], + project: Optional[Union[str, UUID]] = None, + ) -> DeploymentResponse: + """Refresh the status of a deployment. + + Args: + name_id_or_prefix: Name/ID/ID prefix of the deployment to refresh. + project: The project name/ID to filter by. + + Returns: + The refreshed deployment. + + Raises: + NotImplementedError: If the deployer cannot be instantiated or if + the deployment is no longer managed by a deployer. + """ + from zenml.deployers.base_deployer import ( + BaseDeployer, + ) + from zenml.stack.stack_component import StackComponent + + deployment = self.get_deployment( + name_id_or_prefix=name_id_or_prefix, + project=project, + hydrate=False, + ) + if deployment.deployer: + try: + deployer = cast( + BaseDeployer, + StackComponent.from_model(deployment.deployer), + ) + except ImportError: + raise NotImplementedError( + f"Deployer '{deployment.deployer.name}' could " + f"not be instantiated. This is likely because the " + f"deployer's dependencies are not installed." + ) + return deployer.refresh_deployment( + deployment_name_or_id=deployment.id + ) + else: + raise NotImplementedError( + f"Deployment '{deployment.name}' is no longer managed by " + "a deployer. This is likely because the deployer " + "was deleted. Please delete the deployment instead." + ) + + def get_deployment_logs( + self, + name_id_or_prefix: Union[str, UUID], + project: Optional[Union[str, UUID]] = None, + follow: bool = False, + tail: Optional[int] = None, + ) -> Generator[str, bool, None]: + """Get the logs of a deployment. + + Args: + name_id_or_prefix: Name/ID/ID prefix of the deployment to get the logs + of. + project: The project name/ID to filter by. + follow: If True, follow the logs. + tail: The number of lines to show from the end of the logs. + + Yields: + The logs of the deployment. + + Raises: + NotImplementedError: If the deployer cannot be instantiated or if + the deployment is no longer managed by a deployer. + """ + from zenml.deployers.base_deployer import ( + BaseDeployer, + ) + from zenml.stack.stack_component import StackComponent + + deployment = self.get_deployment( + name_id_or_prefix=name_id_or_prefix, + project=project, + hydrate=False, + ) + if deployment.deployer: + try: + deployer = cast( + BaseDeployer, + StackComponent.from_model(deployment.deployer), + ) + except ImportError: + raise NotImplementedError( + f"Deployer '{deployment.deployer.name}' could " + f"not be instantiated. This is likely because the " + f"deployer's dependencies are not installed." + ) + yield from deployer.get_deployment_logs( + deployment_name_or_id=deployment.id, + follow=follow, + tail=tail, + ) + else: + raise NotImplementedError( + f"Deployment '{deployment.name}' is no longer managed by " + "a deployer. This is likely because the deployer " + "was deleted. Please delete the deployment instead." + ) + # ------------------------------ Run templates ----------------------------- def create_run_template( @@ -4188,6 +4648,7 @@ def list_pipeline_runs( hydrate: bool = False, include_full_metadata: bool = False, triggered_by_step_run_id: Optional[Union[UUID, str]] = None, + triggered_by_deployment_id: Optional[Union[UUID, str]] = None, ) -> Page[PipelineRunResponse]: """List all pipeline runs. @@ -4239,6 +4700,8 @@ def list_pipeline_runs( the response. triggered_by_step_run_id: The ID of the step run that triggered the pipeline run. + triggered_by_deployment_id: The ID of the deployment that triggered + the pipeline run. Returns: A page with Pipeline Runs fitting the filter description @@ -4281,6 +4744,7 @@ def list_pipeline_runs( in_progress=in_progress, templatable=templatable, triggered_by_step_run_id=triggered_by_step_run_id, + triggered_by_deployment_id=triggered_by_deployment_id, ) return self.zen_store.list_runs( runs_filter_model=runs_filter_model, diff --git a/src/zenml/config/compiler.py b/src/zenml/config/compiler.py index 9cb85ae6392..dc2f140ca68 100644 --- a/src/zenml/config/compiler.py +++ b/src/zenml/config/compiler.py @@ -650,11 +650,8 @@ def _compute_pipeline_spec( ) for output_artifact in pipeline._output_artifacts ] - try: - output_schema = pipeline._compute_output_schema() - except Exception as e: - logger.warning("Failed to compute pipeline output schema: %s", e) - output_schema = None + input_schema = pipeline._compute_input_schema() + output_schema = pipeline._compute_output_schema() return PipelineSpec( steps=step_specs, @@ -662,6 +659,7 @@ def _compute_pipeline_spec( output_schema=output_schema, source=pipeline.resolve(), parameters=pipeline._parameters, + input_schema=input_schema, ) diff --git a/src/zenml/config/pipeline_configurations.py b/src/zenml/config/pipeline_configurations.py index 8eb21534157..1c059a47d5e 100644 --- a/src/zenml/config/pipeline_configurations.py +++ b/src/zenml/config/pipeline_configurations.py @@ -20,7 +20,7 @@ from pydantic import SerializeAsAny, field_validator from zenml.config.cache_policy import CachePolicyWithValidator -from zenml.config.constants import DOCKER_SETTINGS_KEY +from zenml.config.constants import DOCKER_SETTINGS_KEY, RESOURCE_SETTINGS_KEY from zenml.config.frozen_base_model import FrozenBaseModel from zenml.config.retry_config import StepRetryConfig from zenml.config.source import SourceWithValidator @@ -30,7 +30,7 @@ from zenml.utils.time_utils import utc_now if TYPE_CHECKING: - from zenml.config import DockerSettings + from zenml.config import DockerSettings, ResourceSettings from zenml.config.base_settings import BaseSettings, SettingsOrDict @@ -53,6 +53,9 @@ class PipelineConfigurationUpdate(FrozenBaseModel): extra: Dict[str, Any] = {} failure_hook_source: Optional[SourceWithValidator] = None success_hook_source: Optional[SourceWithValidator] = None + init_hook_source: Optional[SourceWithValidator] = None + init_hook_kwargs: Optional[Dict[str, Any]] = None + cleanup_hook_source: Optional[SourceWithValidator] = None model: Optional[Model] = None parameters: Optional[Dict[str, Any]] = None retry: Optional[StepRetryConfig] = None @@ -125,3 +128,20 @@ def docker_settings(self) -> "DockerSettings": DOCKER_SETTINGS_KEY, {} ) return DockerSettings.model_validate(model_or_dict) + + @property + def resource_settings(self) -> "ResourceSettings": + """Resource settings of this step configuration. + + Returns: + The resource settings of this step configuration. + """ + from zenml.config import ResourceSettings + + model_or_dict: SettingsOrDict = self.settings.get( + RESOURCE_SETTINGS_KEY, {} + ) + + if isinstance(model_or_dict, BaseSettings): + model_or_dict = model_or_dict.model_dump() + return ResourceSettings.model_validate(model_or_dict) diff --git a/src/zenml/config/pipeline_run_configuration.py b/src/zenml/config/pipeline_run_configuration.py index 55febe14226..65167eadf26 100644 --- a/src/zenml/config/pipeline_run_configuration.py +++ b/src/zenml/config/pipeline_run_configuration.py @@ -106,6 +106,18 @@ class PipelineRunConfiguration( default=None, description="The failure hook source for all steps of the pipeline run.", ) + init_hook_source: Optional[SourceWithValidator] = Field( + default=None, + description="The init hook source for the pipeline run.", + ) + init_hook_kwargs: Optional[Dict[str, Any]] = Field( + default=None, + description="The init hook args for the pipeline run.", + ) + cleanup_hook_source: Optional[SourceWithValidator] = Field( + default=None, + description="The cleanup hook source for the pipeline run.", + ) success_hook_source: Optional[SourceWithValidator] = Field( default=None, description="The success hook source for all steps of the pipeline run.", diff --git a/src/zenml/config/pipeline_spec.py b/src/zenml/config/pipeline_spec.py index a1f9fa39cd0..4b6b4e29b72 100644 --- a/src/zenml/config/pipeline_spec.py +++ b/src/zenml/config/pipeline_spec.py @@ -42,10 +42,17 @@ class PipelineSpec(FrozenBaseModel): # inputs in the step specs refer to the pipeline parameter names # - 0.4: New Pipeline class, the upstream steps and # inputs in the step specs refer to the pipeline parameter names - # - 0.5: Adds outputs and output schema + # - 0.5: Adds input schema, outputs and output schema version: str = "0.5" source: Optional[SourceWithValidator] = None parameters: Dict[str, Any] = {} + input_schema: Optional[Dict[str, Any]] = Field( + default=None, + description="JSON schema of the pipeline inputs. This is only set " + "for pipeline specs with version >= 0.5. If the value is None, the " + "schema generation failed, which is most likely because some of the " + "pipeline inputs are not JSON serializable.", + ) steps: List[StepSpec] outputs: List[OutputSpec] = [] output_schema: Optional[Dict[str, Any]] = Field( diff --git a/src/zenml/config/resource_settings.py b/src/zenml/config/resource_settings.py index 2cd3c50f469..4482c3bf261 100644 --- a/src/zenml/config/resource_settings.py +++ b/src/zenml/config/resource_settings.py @@ -14,9 +14,15 @@ """Resource settings class used to specify resources for a step.""" from enum import Enum -from typing import Optional, Union +from typing import Literal, Optional, Union -from pydantic import ConfigDict, Field, NonNegativeInt, PositiveFloat +from pydantic import ( + ConfigDict, + Field, + NonNegativeInt, + PositiveFloat, + PositiveInt, +) from zenml.config.base_settings import BaseSettings @@ -62,16 +68,79 @@ def byte_value(self) -> int: class ResourceSettings(BaseSettings): """Hardware resource settings. + Deployers and deployed pipelines can also use the following settings: + + * min_replicas and max_replicas allow expressing both fixed scaling and + autoscaling range. For a fixed number of instances, set both to the same + value (e.g. 3). If min_replicas=0, it indicates the service can scale down + to zero instances when idle (if the platform supports it) most serverless + platforms do. If max_replicas is None or 0, it will be interpreted as + โ€œno specific limitโ€ (use platform default behavior, which might be unlimited + or a high default cap). Otherwise, max_replicas puts an upper bound on + scaling. + + * autoscaling_metric and autoscaling_target describe when to scale. For + example, autoscaling_metric="cpu", autoscaling_target=75.0 means keep CPU + around 75% - a Kubernetes integration would create an HPA with target CPU + 75%, whereas a Knative integration might ignore this if it's using + concurrency-based autoscaling. Similarly, "concurrency" with a target of, + say, 50 means the system should try to ensure each instance handles ~50 + concurrent requests before spawning a new instance. The integration code for + each platform will translate these generically: e.g. Cloud Run doesn't allow + changing the CPU threshold (fixed ~60%), so it might ignore a custom CPU + target; Knative supports concurrency and RPS metrics via annotations + so those would be applied if specified. + + * max_concurrency is a per-instance concurrency limit. This is particularly + useful for platforms that allow configuring a concurrency cap (Knative's + containerConcurrency, Cloud Run's concurrency setting, App Runner's max + concurrent requests, Modal's max_inputs in the concurrent decorator). If + set, this indicates โ€œdo not send more than this many simultaneous requests + to one instance.โ€ The autoscaler will then naturally create more instances + once this limit is hit on each existing instance. If + autoscaling_metric="concurrency" and no explicit target is given, one could + assume the target is equal to max_concurrency (or a certain utilization of + it). In some cases, platforms use a utilization factor - for simplicity, we + let the integration decide (some might use 80% of max concurrency as the + trigger to scale, for example, if that platform does so internally). If + max_concurrency is not set, it implies no fixed limit per instance (the + service instance will take as many requests as it can, scaling purely on + the chosen metric like CPU or an internal default concurrency). + Attributes: cpu_count: The amount of CPU cores that should be configured. gpu_count: The amount of GPUs that should be configured. memory: The amount of memory that should be configured. + min_replicas: Minimum number of container instances (replicas). + Use 0 to allow scale-to-zero on idle. Only relevant to + deployed pipelines. + max_replicas: Maximum number of container instances for autoscaling. + Set to 0 to imply "no explicit limit". Only relevant to deployed + pipelines. + autoscaling_metric: Metric to use for autoscaling triggers. + Options: "cpu", "memory", "concurrency", or "rps". Only relevant + to deployed pipelines. + autoscaling_target: Target value for the autoscaling metric (e.g. 70.0 + for 70% CPU or 20 for concurrency). Only relevant to deployed + pipelines. + max_concurrency: Maximum concurrent requests per instance (if supported + by the platform). Defines a concurrency limit for each container. + Only relevant to deployed pipelines. """ cpu_count: Optional[PositiveFloat] = None gpu_count: Optional[NonNegativeInt] = None memory: Optional[str] = Field(pattern=MEMORY_REGEX, default=None) + # Settings only applicable for deployers and deployed pipelines + min_replicas: Optional[NonNegativeInt] = None + max_replicas: Optional[NonNegativeInt] = None + autoscaling_metric: Optional[ + Literal["cpu", "memory", "concurrency", "rps"] + ] = None + autoscaling_target: Optional[PositiveFloat] = None + max_concurrency: Optional[PositiveInt] = None + @property def empty(self) -> bool: """Returns if this object is "empty" (=no values configured) or not. diff --git a/src/zenml/constants.py b/src/zenml/constants.py index af0007763e6..8fb67eadab1 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -216,6 +216,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int: ENV_ZENML_WORKLOAD_TOKEN_EXPIRATION_LEEWAY = ( "ZENML_WORKLOAD_TOKEN_EXPIRATION_LEEWAY" ) + # Logging variables IS_DEBUG_ENV: bool = handle_bool_env_var(ENV_ZENML_DEBUG, default=False) @@ -402,6 +403,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int: PIPELINE_BUILDS = "/pipeline_builds" PIPELINE_CONFIGURATION = "/pipeline-configuration" PIPELINE_DEPLOYMENTS = "/pipeline_deployments" +DEPLOYMENTS = "/deployments" PIPELINE_SNAPSHOTS = "/pipeline_snapshots" PIPELINES = "/pipelines" PIPELINE_SPEC = "/pipeline-spec" @@ -457,6 +459,9 @@ def handle_int_env_var(var: str, default: int = 0) -> int: # orchestrator constants ORCHESTRATOR_DOCKER_IMAGE_KEY = "orchestrator" +# deployer constants +DEPLOYER_DOCKER_IMAGE_KEY = "deployer" + # Secret constants SECRET_VALUES = "values" diff --git a/src/zenml/deployers/__init__.py b/src/zenml/deployers/__init__.py new file mode 100644 index 00000000000..a6c83fced0f --- /dev/null +++ b/src/zenml/deployers/__init__.py @@ -0,0 +1,53 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Deployers are stack components responsible for deploying pipelines as HTTP services. + +Deploying pipelines is the process of hosting and running machine-learning +pipelines as part of a managed web service and providing access to pipeline +execution through an API endpoint like HTTP or GRPC. Once deployed, you can send +execution requests to the pipeline through the web service's API and receive +responses containing the pipeline results or execution status. + +Add a deployer to your ZenML stack to be able to provision pipelines deployments +that transform your ML pipelines into long-running HTTP services +for real-time, on-demand execution instead of traditional batch processing. + +When present in a stack, the deployer also acts as a registry for pipeline +endpoints that are deployed with ZenML. You can use the deployer to list all +deployments that are currently provisioned for online execution or filtered +according to a particular snapshot or configuration, or to delete an external +deployment managed through ZenML. +""" + +from zenml.deployers.base_deployer import ( + BaseDeployer, + BaseDeployerFlavor, + BaseDeployerConfig, +) +from zenml.deployers.containerized_deployer import ( + ContainerizedDeployer, +) +from zenml.deployers.docker.docker_deployer import ( + DockerDeployer, + DockerDeployerFlavor, +) + +__all__ = [ + "BaseDeployer", + "BaseDeployerFlavor", + "BaseDeployerConfig", + "ContainerizedDeployer", + "DockerDeployer", + "DockerDeployerFlavor", +] diff --git a/src/zenml/deployers/base_deployer.py b/src/zenml/deployers/base_deployer.py new file mode 100644 index 00000000000..911bf9a1173 --- /dev/null +++ b/src/zenml/deployers/base_deployer.py @@ -0,0 +1,1072 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Base class for all ZenML deployers.""" + +import secrets +import string +import time +from abc import ABC, abstractmethod +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generator, + Optional, + Tuple, + Type, + Union, + cast, +) +from uuid import UUID + +from zenml.analytics.enums import AnalyticsEvent +from zenml.analytics.utils import track_handler +from zenml.client import Client +from zenml.config.base_settings import BaseSettings +from zenml.constants import ( + ENV_ZENML_ACTIVE_PROJECT_ID, + ENV_ZENML_ACTIVE_STACK_ID, +) +from zenml.deployers.exceptions import ( + DeployerError, + DeploymentAlreadyExistsError, + DeploymentDeployerMismatchError, + DeploymentDeprovisionError, + DeploymentNotFoundError, + DeploymentProvisionError, + DeploymentSnapshotMismatchError, + DeploymentTimeoutError, +) +from zenml.enums import DeploymentStatus, StackComponentType +from zenml.logger import get_logger +from zenml.models import ( + DeploymentOperationalState, + DeploymentRequest, + DeploymentResponse, + DeploymentUpdate, + PipelineSnapshotResponse, +) +from zenml.orchestrators.utils import get_config_environment_vars +from zenml.stack import StackComponent +from zenml.stack.flavor import Flavor +from zenml.stack.stack_component import StackComponentConfig +from zenml.utils.uuid_utils import is_valid_uuid + +if TYPE_CHECKING: + from zenml.stack import Stack + +logger = get_logger(__name__) + +DEFAULT_DEPLOYMENT_LCM_TIMEOUT = 600 + + +class BaseDeployerSettings(BaseSettings): + """Base settings for all deployers.""" + + auth_key: Optional[str] = None + generate_auth_key: bool = False + lcm_timeout: int = DEFAULT_DEPLOYMENT_LCM_TIMEOUT + + +class BaseDeployerConfig(StackComponentConfig): + """Base config for all deployers.""" + + +class BaseDeployer(StackComponent, ABC): + """Base class for all ZenML deployers. + + The deployer serves three major purposes: + + 1. It contains all the stack related configuration attributes required to + interact with the remote pipeline deployment tool, service or platform (e.g. + hostnames, URLs, references to credentials, other client related + configuration parameters). + + 2. It implements the life-cycle management for deployments, including + discovery, creation, deletion and updating. + + 3. It acts as a ZenML deployment registry, where every pipeline + deployment is stored as a database entity through the ZenML Client. This + allows the deployer to keep track of all externally running pipeline + deployments and to manage their lifecycle. + """ + + @property + def config(self) -> BaseDeployerConfig: + """Returns the `BaseDeployerConfig` config. + + Returns: + The configuration. + """ + return cast(BaseDeployerConfig, self._config) + + @classmethod + def get_active_deployer(cls) -> "BaseDeployer": + """Get the deployer registered in the active stack. + + Returns: + The deployer registered in the active stack. + + Raises: + TypeError: if a deployer is not part of the + active stack. + """ + client = Client() + deployer = client.active_stack.deployer + if not deployer or not isinstance(deployer, cls): + raise TypeError( + "The active stack needs to have a deployer " + "component registered to be able to deploy pipelines. " + "You can create a new stack with a deployer component " + "or update your active stack to add this component, e.g.:\n\n" + " `zenml deployer register ...`\n" + " `zenml stack register -D ...`\n" + " or:\n" + " `zenml stack update -D ...`\n\n" + ) + + return deployer + + def _update_deployment( + self, + deployment: DeploymentResponse, + operational_state: DeploymentOperationalState, + ) -> DeploymentResponse: + """Update an existing deployment instance with the operational state. + + Args: + deployment: The deployment to update. + operational_state: The operational state of the deployment. + + Returns: + The updated deployment. + """ + client = Client() + if operational_state.status == DeploymentStatus.ABSENT: + # Erase the URL and metadata for absent deployments + operational_state.url = "" + operational_state.metadata = {} + + return client.zen_store.update_deployment( + deployment.id, + DeploymentUpdate.from_operational_state(operational_state), + ) + + def _check_deployment_inputs_outputs( + self, + snapshot: PipelineSnapshotResponse, + ) -> None: + """Check if the deployment has compiled schemas for the pipeline inputs and outputs. + + Args: + snapshot: The pipeline snapshot to check. + + Raises: + DeploymentProvisionError: if the deployment has no compiled schemas + for the pipeline inputs and outputs. + """ + if ( + not snapshot.pipeline_spec + or not snapshot.pipeline_spec.input_schema + or not snapshot.pipeline_spec.output_schema + ): + raise DeploymentProvisionError( + f"The pipeline with name '{snapshot.pipeline.name}' referenced " + f"by the deployment with name or ID " + f"'{snapshot.name or snapshot.id}' " + "is missing the compiled schemas for the pipeline inputs or " + "outputs. This is most likely because some of the pipeline " + "inputs or outputs are not JSON serializable. Please check that " + "all the pipeline input arguments and return values have data " + "types that are JSON serializable." + ) + + def _check_deployment_deployer( + self, + deployment: DeploymentResponse, + ) -> None: + """Check if the deployment is managed by this deployer. + + Args: + deployment: The deployment to check. + + Raises: + DeploymentDeployerMismatchError: if the deployment is + not managed by this deployer. + """ + if deployment.deployer_id and deployment.deployer_id != self.id: + deployer = deployment.deployer + assert deployer, "Deployer not found" + raise DeploymentDeployerMismatchError( + f"The deployment with name '{deployment.name}' was provisioned " + f"with a deployer stack component ({deployer.name}) that is " + f"different from the active one: {self.name}. " + f"You can try one of the following:\n" + f"1. Use a different name for the deployment\n" + f"2. Delete the existing '{deployment.name}' deployment\n" + f"3. Use a stack that contains the '{self.name}' deployer stack " + "component\n" + ) + + def _check_deployment_snapshot( + self, snapshot: Optional[PipelineSnapshotResponse] = None + ) -> None: + """Check if the snapshot was created for this deployer. + + Args: + snapshot: The pipeline snapshot to check. + + Raises: + DeploymentSnapshotMismatchError: if the pipeline snapshot is + not built for this deployer. + """ + if not snapshot: + return + + if snapshot.stack and snapshot.stack.components.get( + StackComponentType.DEPLOYER + ): + deployer = snapshot.stack.components[StackComponentType.DEPLOYER][ + 0 + ] + if deployer.id != self.id: + raise DeploymentSnapshotMismatchError( + f"The pipeline snapshot with ID '{snapshot.id}' " + f"was not created for the deployer {self.name}. This will " + "lead to unexpected behavior and is not allowed." + ) + + def _check_snapshot_already_deployed( + self, + snapshot: PipelineSnapshotResponse, + new_deployment_id_or_name: Union[str, UUID], + ) -> None: + """Check if the snapshot is already deployed to another deployment. + + Args: + snapshot: The pipeline snapshot to check. + new_deployment_id_or_name: The ID or name of the deployment that is + being provisioned. + + Raises: + DeploymentAlreadyExistsError: if the snapshot is already deployed to + another deployment. + """ + if snapshot.deployment and ( + isinstance(snapshot.deployment.id, UUID) + and snapshot.deployment.id != new_deployment_id_or_name + or ( + isinstance(snapshot.deployment.id, str) + and snapshot.deployment.name != new_deployment_id_or_name + ) + ): + raise DeploymentAlreadyExistsError( + f"The pipeline snapshot with name or ID " + f"'{snapshot.name or snapshot.id}' " + f"already has an associated deployment: " + f"'{snapshot.deployment.name or snapshot.deployment.id}'. " + "You can try one of the following:\n" + "1. Delete the existing deployment before provisioning " + f"a new one: 'zenml deployment delete " + f"{snapshot.deployment.name or snapshot.deployment.id}'\n" + "2. Update the existing deployment with the snapshot: 'zenml " + f"pipeline snapshot deploy {snapshot.name or snapshot.id} " + f"--deployment {snapshot.deployment.name or snapshot.deployment.id}'\n" + "3. Create and deploy a different snapshot: 'zenml snapshot " + "create ...'\n" + ) + + def _generate_auth_key(self, key_length: int = 32) -> str: + """Generate an authentication key. + + Args: + key_length: The length of the authentication key. + + Returns: + The generated authentication key. + """ + # Generate a secure random string with letters, digits and special chars + alphabet = string.ascii_letters + string.digits + return "".join(secrets.choice(alphabet) for _ in range(key_length)) + + def _poll_deployment( + self, + deployment: DeploymentResponse, + desired_status: DeploymentStatus, + timeout: int, + ) -> Tuple[DeploymentResponse, DeploymentOperationalState]: + """Poll the deployment until it reaches the desired status, an error occurs or times out. + + Args: + deployment: The deployment to poll. + desired_status: The desired status of the deployment. + timeout: The maximum time in seconds to wait for the pipeline + deployment to reach the desired status. + + Returns: + The updated deployment and the operational state of the + deployment. + + Raises: + DeploymentTimeoutError: if the deployment + deployment times out while waiting to reach the desired status. + """ + logger.info( + f"Waiting for the deployment {deployment.name} to reach " + f"desired state '{desired_status}' for max {timeout} seconds..." + ) + + start_time = time.time() + sleep_time = 5 + while True: + deployment_state = DeploymentOperationalState( + status=DeploymentStatus.ERROR, + ) + try: + deployment_state = self.do_get_deployment_state(deployment) + except DeploymentNotFoundError: + deployment_state = DeploymentOperationalState( + status=DeploymentStatus.ABSENT + ) + except DeployerError as e: + logger.exception( + f"Failed to get deployment {deployment.name}: {e}" + ) + finally: + deployment = self._update_deployment( + deployment, deployment_state + ) + + if deployment.status in [ + desired_status, + DeploymentStatus.ERROR, + ]: + break + + elapsed_time = int(time.time() - start_time) + if elapsed_time > timeout: + raise DeploymentTimeoutError( + f"Timed out waiting for deployment {deployment.name} " + f"to reach desired state '{desired_status}' after {timeout} " + "seconds" + ) + logger.debug( + f"The deployment {deployment.name} state is still " + f"'{deployment.status}' after {elapsed_time} seconds. Waiting for " + f"max {timeout - elapsed_time} seconds..." + ) + time.sleep(sleep_time) + + return deployment, deployment_state + + def _get_deployment_analytics_metadata( + self, + deployment: "DeploymentResponse", + stack: Optional["Stack"] = None, + ) -> Dict[str, Any]: + """Returns the deployment metadata. + + Args: + deployment: The deployment to track. + stack: The stack on which the pipeline is deployed. + + Returns: + the metadata about the deployment + """ + snapshot = deployment.snapshot + stack_metadata = {} + if stack: + stack_metadata = { + component_type.value: component.flavor + for component_type, component in stack.components.items() + } + return { + "project_id": deployment.project_id, + "store_type": Client().zen_store.type.value, + **stack_metadata, + "deployment_id": str(deployment.id), + "snapshot_id": str(snapshot.id) if snapshot else None, + "deployer_id": str(self.id), + "deployer_flavor": self.flavor, + "deployment_status": deployment.status, + } + + def provision_deployment( + self, + snapshot: PipelineSnapshotResponse, + stack: "Stack", + deployment_name_or_id: Union[str, UUID], + replace: bool = True, + timeout: Optional[int] = None, + ) -> DeploymentResponse: + """Provision a deployment. + + The provision_deployment method is the main entry point for + provisioning deployments using the deployer. It is used to deploy + a pipeline snapshot as an HTTP deployment, or update an existing + deployment instance with the same name. The method returns a + DeploymentResponse object that is a representation of the + external deployment instance. + + Args: + snapshot: The pipeline snapshot to deploy as an HTTP deployment. + stack: The stack the pipeline will be deployed on. + deployment_name_or_id: Unique name or ID for the deployment. + This name must be unique at the project level. + replace: If True, it will update in-place any existing pipeline + deployment instance with the same name. If False, and the pipeline + deployment instance already exists, it will raise a + DeploymentAlreadyExistsError. + timeout: The maximum time in seconds to wait for the pipeline + deployment to be provisioned. If provided, will override the + deployer's default timeout. + + Raises: + DeploymentAlreadyExistsError: if the deployment already + exists and replace is False. + DeploymentProvisionError: if the deployment fails. + DeploymentSnapshotMismatchError: if the pipeline snapshot + was not created for this deployer. + DeploymentNotFoundError: if the deployment with the + given ID is not found. + DeployerError: if an unexpected error occurs. + + Returns: + The DeploymentResponse object representing the provisioned + deployment. + """ + if not replace and is_valid_uuid(deployment_name_or_id): + raise DeploymentAlreadyExistsError( + f"A deployment with ID '{deployment_name_or_id}' " + "already exists" + ) + + self._check_deployment_inputs_outputs(snapshot) + + client = Client() + + settings = cast( + BaseDeployerSettings, + self.get_settings(snapshot), + ) + + timeout = timeout or settings.lcm_timeout + auth_key = settings.auth_key + if not auth_key and settings.generate_auth_key: + auth_key = self._generate_auth_key() + + if snapshot.stack and snapshot.stack.id != stack.id: + # When a different stack is used then the one the snapshot was + # created for, the container image may not have the correct + # dependencies installed, which leads to unexpected errors during + # deployment. To avoid this, we raise an error here. + raise DeploymentSnapshotMismatchError( + f"The pipeline snapshot with ID '{snapshot.id}' " + f"was not created for the stack {stack.name} and might not " + "have the correct dependencies installed. This may " + "lead to unexpected behavior during deployment. Please switch " + f"to the correct active stack '{snapshot.stack.name}' or use " + "a different snapshot." + ) + + try: + # Get the existing deployment + deployment = client.get_deployment( + deployment_name_or_id, project=snapshot.project_id + ) + + self._check_snapshot_already_deployed(snapshot, deployment.id) + + logger.debug( + f"Existing deployment found with name '{deployment.name}'" + ) + except KeyError: + if isinstance(deployment_name_or_id, UUID): + raise DeploymentNotFoundError( + f"Deployment with ID '{deployment_name_or_id}' not found" + ) + + self._check_snapshot_already_deployed( + snapshot, deployment_name_or_id + ) + + logger.debug( + f"Creating new deployment {deployment_name_or_id} with " + f"snapshot ID: {snapshot.id}" + ) + + # Create the deployment request + deployment_request = DeploymentRequest( + name=deployment_name_or_id, + project=snapshot.project_id, + snapshot_id=snapshot.id, + deployer_id=self.id, # This deployer's ID + auth_key=auth_key, + ) + + deployment = client.zen_store.create_deployment(deployment_request) + logger.debug( + f"Created new deployment with name '{deployment.name}' " + f"and ID: {deployment.id}" + ) + else: + if not replace: + raise DeploymentAlreadyExistsError( + f"A deployment with name '{deployment.name}' " + "already exists" + ) + + self._check_deployment_deployer(deployment) + self._check_deployment_snapshot(snapshot) + + deployment_update = DeploymentUpdate( + snapshot_id=snapshot.id, + ) + if ( + deployment.auth_key + and not auth_key + or not deployment.auth_key + and auth_key + ): + # Key was either added or removed + deployment_update.auth_key = auth_key + elif deployment.auth_key != auth_key and ( + settings.auth_key or not settings.generate_auth_key + ): + # Key was changed and not because of re-generation + deployment_update.auth_key = auth_key + + # The deployment has been updated + deployment = client.zen_store.update_deployment( + deployment.id, + deployment_update, + ) + + logger.info( + f"Provisioning deployment {deployment.name} with " + f"snapshot ID: {snapshot.id}" + ) + + environment, secrets = get_config_environment_vars( + deployment_id=deployment.id, + ) + + # Make sure to use the correct active stack/project which correspond + # to the supplied stack and snapshot, which may be different from the + # active stack/project + environment[ENV_ZENML_ACTIVE_STACK_ID] = str(stack.id) + environment[ENV_ZENML_ACTIVE_PROJECT_ID] = str(snapshot.project_id) + + start_time = time.time() + deployment_state = DeploymentOperationalState( + status=DeploymentStatus.ERROR, + ) + with track_handler( + AnalyticsEvent.DEPLOY_PIPELINE + ) as analytics_handler: + try: + deployment_state = self.do_provision_deployment( + deployment, + stack=stack, + environment=environment, + secrets=secrets, + timeout=timeout, + ) + except DeploymentProvisionError as e: + raise DeploymentProvisionError( + f"Failed to provision deployment {deployment.name}: {e}" + ) from e + except DeployerError as e: + raise DeployerError( + f"Failed to provision deployment {deployment.name}: {e}" + ) from e + except Exception as e: + raise DeployerError( + f"Unexpected error while provisioning deployment for " + f"{deployment.name}: {e}" + ) from e + finally: + deployment = self._update_deployment( + deployment, deployment_state + ) + + logger.info( + f"Provisioned deployment {deployment.name} with " + f"snapshot ID: {snapshot.id}. Operational state is: " + f"{deployment_state.status}" + ) + + try: + if deployment_state.status == DeploymentStatus.RUNNING: + return deployment + + # Subtract the time spent deploying the deployment from the + # timeout + timeout = timeout - int(time.time() - start_time) + deployment, _ = self._poll_deployment( + deployment, DeploymentStatus.RUNNING, timeout + ) + + if deployment.status != DeploymentStatus.RUNNING: + raise DeploymentProvisionError( + f"Failed to provision deployment {deployment.name}: " + f"The deployment's operational state is " + f"{deployment.status}. Please check the status or logs " + "of the deployment for more information." + ) + + finally: + analytics_handler.metadata = ( + self._get_deployment_analytics_metadata( + deployment=deployment, + stack=stack, + ) + ) + + return deployment + + def refresh_deployment( + self, + deployment_name_or_id: Union[str, UUID], + project: Optional[UUID] = None, + ) -> DeploymentResponse: + """Refresh the status of a deployment by name or ID. + + Call this to refresh the operational state of a deployment. + + Args: + deployment_name_or_id: The name or ID of the deployment to get. + project: The project ID of the deployment to get. Required + if a name is provided. + + Returns: + The deployment. + + Raises: + DeploymentNotFoundError: if the deployment is not found. + DeployerError: if an unexpected error occurs. + """ + client = Client() + try: + deployment = client.get_deployment( + deployment_name_or_id, project=project + ) + except KeyError: + raise DeploymentNotFoundError( + f"Deployment with name or ID '{deployment_name_or_id}' " + f"not found" + ) + + self._check_deployment_deployer(deployment) + + deployment_state = DeploymentOperationalState( + status=DeploymentStatus.ERROR, + ) + try: + deployment_state = self.do_get_deployment_state(deployment) + except DeploymentNotFoundError: + deployment_state.status = DeploymentStatus.ABSENT + except DeployerError as e: + raise DeployerError( + f"Failed to refresh deployment {deployment_name_or_id}: {e}" + ) from e + except Exception as e: + raise DeployerError( + f"Unexpected error while refreshing deployment for " + f"{deployment_name_or_id}: {e}" + ) from e + finally: + deployment = self._update_deployment(deployment, deployment_state) + + return deployment + + def deprovision_deployment( + self, + deployment_name_or_id: Union[str, UUID], + project: Optional[UUID] = None, + timeout: Optional[int] = None, + ) -> DeploymentResponse: + """Deprovision a deployment. + + Args: + deployment_name_or_id: The name or ID of the deployment to + deprovision. + project: The project ID of the deployment to deprovision. + Required if a name is provided. + timeout: The maximum time in seconds to wait for the pipeline + deployment to deprovision. If provided, will override the + deployer's default timeout. + + Returns: + The deployment. + + Raises: + DeploymentNotFoundError: if the deployment is not found + or is not managed by this deployer. + DeploymentDeprovisionError: if the deployment + deprovision fails. + DeployerError: if an unexpected error occurs. + """ + client = Client() + try: + deployment = client.get_deployment( + deployment_name_or_id, project=project + ) + except KeyError: + raise DeploymentNotFoundError( + f"Deployment with name or ID '{deployment_name_or_id}' " + f"not found" + ) + + self._check_deployment_deployer(deployment) + + if not timeout and deployment.snapshot: + settings = cast( + BaseDeployerSettings, + self.get_settings(deployment.snapshot), + ) + + timeout = settings.lcm_timeout + + timeout = timeout or DEFAULT_DEPLOYMENT_LCM_TIMEOUT + + start_time = time.time() + deployment_state = DeploymentOperationalState( + status=DeploymentStatus.ERROR, + ) + with track_handler( + AnalyticsEvent.STOP_DEPLOYMENT + ) as analytics_handler: + try: + deleted_deployment_state = self.do_deprovision_deployment( + deployment, timeout + ) + if not deleted_deployment_state: + # When do_delete_deployment returns a None value, this + # is to signal that the deployment is already fully deprovisioned. + deployment_state.status = DeploymentStatus.ABSENT + except DeploymentNotFoundError: + deployment_state.status = DeploymentStatus.ABSENT + except DeployerError as e: + raise DeployerError( + f"Failed to delete deployment {deployment_name_or_id}: {e}" + ) from e + except Exception as e: + raise DeployerError( + f"Unexpected error while deleting deployment for " + f"{deployment_name_or_id}: {e}" + ) from e + finally: + deployment = self._update_deployment( + deployment, deployment_state + ) + + try: + if deployment_state.status == DeploymentStatus.ABSENT: + return deployment + + # Subtract the time spent deprovisioning the deployment from the timeout + timeout = timeout - int(time.time() - start_time) + deployment, _ = self._poll_deployment( + deployment, DeploymentStatus.ABSENT, timeout + ) + + if deployment.status != DeploymentStatus.ABSENT: + raise DeploymentDeprovisionError( + f"Failed to deprovision deployment {deployment_name_or_id}: " + f"Operational state: {deployment.status}" + ) + + finally: + analytics_handler.metadata = ( + self._get_deployment_analytics_metadata( + deployment=deployment, + stack=None, + ) + ) + + return deployment + + def delete_deployment( + self, + deployment_name_or_id: Union[str, UUID], + project: Optional[UUID] = None, + force: bool = False, + timeout: Optional[int] = None, + ) -> None: + """Deprovision and delete a deployment. + + Args: + deployment_name_or_id: The name or ID of the deployment to + delete. + project: The project ID of the deployment to deprovision. + Required if a name is provided. + force: if True, force the deployment to delete even if it + cannot be deprovisioned. + timeout: The maximum time in seconds to wait for the pipeline + deployment to be deprovisioned. If provided, will override the + deployer's default timeout. + + Raises: + DeployerError: if an unexpected error occurs. + """ + client = Client() + try: + deployment = self.deprovision_deployment( + deployment_name_or_id, project, timeout + ) + except DeploymentNotFoundError: + # The deployment was already deleted + return + except DeployerError as e: + if force: + logger.warning( + f"Failed to deprovision deployment " + f"{deployment_name_or_id}: {e}. Forcing deletion." + ) + deployment = client.get_deployment( + deployment_name_or_id, project=project + ) + client.zen_store.delete_deployment(deployment_id=deployment.id) + else: + raise + else: + client.zen_store.delete_deployment(deployment_id=deployment.id) + + def get_deployment_logs( + self, + deployment_name_or_id: Union[str, UUID], + project: Optional[UUID] = None, + follow: bool = False, + tail: Optional[int] = None, + ) -> Generator[str, bool, None]: + """Get the logs of a deployment. + + Args: + deployment_name_or_id: The name or ID of the deployment to get + the logs of. + project: The project ID of the deployment to get the logs of. + Required if a name is provided. + follow: if True, the logs will be streamed as they are written. + tail: only retrieve the last NUM lines of log output. + + Returns: + A generator that yields the logs of the deployment. + + Raises: + DeploymentNotFoundError: if the deployment is not found. + DeployerError: if an unexpected error occurs. + """ + client = Client() + try: + deployment = client.get_deployment( + deployment_name_or_id, project=project + ) + except KeyError: + raise DeploymentNotFoundError( + f"Deployment with name or ID '{deployment_name_or_id}' " + f"not found" + ) + + self._check_deployment_deployer(deployment) + + try: + return self.do_get_deployment_state_logs(deployment, follow, tail) + except DeployerError as e: + raise DeployerError( + f"Failed to get logs for deployment {deployment_name_or_id}: {e}" + ) from e + except Exception as e: + raise DeployerError( + f"Unexpected error while getting logs for deployment for " + f"{deployment_name_or_id}: {e}" + ) from e + + # ------------------ Abstract Methods ------------------ + + @abstractmethod + def do_provision_deployment( + self, + deployment: DeploymentResponse, + stack: "Stack", + environment: Dict[str, str], + secrets: Dict[str, str], + timeout: int, + ) -> DeploymentOperationalState: + """Abstract method to deploy a pipeline as an HTTP deployment. + + Concrete deployer subclasses must implement the following + functionality in this method: + + - Create the actual deployment infrastructure (e.g., + FastAPI server, Kubernetes deployment, cloud function, etc.) based on + the information in the deployment response, particularly the + pipeline snapshot. When determining how to name the external + resources, do not rely on the deployment name as being immutable + or unique. + + - If the deployment infrastructure is already provisioned, update + it to match the information in the deployment response. + + - Return a DeploymentOperationalState representing the operational + state of the provisioned deployment. + + Note that the deployment infrastructure is not required to be + deployed immediately. The deployer can return a + DeploymentOperationalState with a status of + DeploymentStatus.PENDING, and the base deployer will poll + the deployment infrastructure by calling the + `do_get_deployment_state` method until it is ready or it times out. + + Args: + deployment: The deployment to deploy as an HTTP deployment. + stack: The stack the pipeline will be deployed on. + environment: A dictionary of environment variables to set on the + deployment. + secrets: A dictionary of secret environment variables to set + on the deployment. These secret environment variables + should not be exposed as regular environment variables on the + deployer. + timeout: The maximum time in seconds to wait for the pipeline + deployment to be provisioned. + + Returns: + The DeploymentOperationalState object representing the + operational state of the provisioned deployment. + + Raises: + DeploymentProvisionError: if provisioning the deployment + fails. + DeployerError: if an unexpected error occurs. + """ + + @abstractmethod + def do_get_deployment_state( + self, + deployment: DeploymentResponse, + ) -> DeploymentOperationalState: + """Abstract method to get information about a deployment. + + Args: + deployment: The deployment to get information about. + + Returns: + The DeploymentOperationalState object representing the + updated operational state of the deployment. + + Raises: + DeploymentNotFoundError: if no deployment is found + corresponding to the provided DeploymentResponse. + DeployerError: if the deployment information cannot + be retrieved for any other reason or if an unexpected error + occurs. + """ + + @abstractmethod + def do_get_deployment_state_logs( + self, + deployment: DeploymentResponse, + follow: bool = False, + tail: Optional[int] = None, + ) -> Generator[str, bool, None]: + """Abstract method to get the logs of a deployment. + + Args: + deployment: The deployment to get the logs of. + follow: if True, the logs will be streamed as they are written + tail: only retrieve the last NUM lines of log output. + + Yields: + The logs of the deployment. + + Raises: + DeploymentNotFoundError: if no deployment is found + corresponding to the provided DeploymentResponse. + DeploymentLogsNotFoundError: if the deployment logs are not + found. + DeployerError: if the deployment logs cannot + be retrieved for any other reason or if an unexpected error + occurs. + """ + + @abstractmethod + def do_deprovision_deployment( + self, + deployment: DeploymentResponse, + timeout: int, + ) -> Optional[DeploymentOperationalState]: + """Abstract method to deprovision a deployment. + + Concrete deployer subclasses must implement the following + functionality in this method: + + - Deprovision the actual deployment infrastructure (e.g., + FastAPI server, Kubernetes deployment, cloud function, etc.) based on + the information in the deployment response. + + - Return a DeploymentOperationalState representing the operational + state of the deleted deployment, or None if the deletion is + completed before the call returns. + + Note that the deployment infrastructure is not required to be + deleted immediately. The deployer can return a + DeploymentOperationalState with a status of + DeploymentStatus.PENDING, and the base deployer will poll + the deployment infrastructure by calling the + `do_get_deployment_state` method until it is deleted or it times out. + + Args: + deployment: The deployment to delete. + timeout: The maximum time in seconds to wait for the pipeline + deployment to be deprovisioned. + + Returns: + The DeploymentOperationalState object representing the + operational state of the deprovisioned deployment, or None + if the deprovision is completed before the call returns. + + Raises: + DeploymentNotFoundError: if no deployment is found + corresponding to the provided DeploymentResponse. + DeploymentDeprovisionError: if the deployment + deprovision fails. + DeployerError: if an unexpected error occurs. + """ + + +class BaseDeployerFlavor(Flavor): + """Base class for deployer flavors.""" + + @property + def type(self) -> StackComponentType: + """Returns the flavor type. + + Returns: + The flavor type. + """ + return StackComponentType.DEPLOYER + + @property + def config_class(self) -> Type[BaseDeployerConfig]: + """Returns `BaseDeployerConfig` config class. + + Returns: + The config class. + """ + return BaseDeployerConfig + + @property + @abstractmethod + def implementation_class(self) -> Type[BaseDeployer]: + """The class that implements the deployer.""" diff --git a/src/zenml/deployers/containerized_deployer.py b/src/zenml/deployers/containerized_deployer.py new file mode 100644 index 00000000000..675fa0a3e7f --- /dev/null +++ b/src/zenml/deployers/containerized_deployer.py @@ -0,0 +1,98 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Base class for all ZenML deployers.""" + +from abc import ABC +from typing import ( + List, + Set, +) + +import zenml +from zenml.config.build_configuration import BuildConfiguration +from zenml.config.global_config import GlobalConfiguration +from zenml.constants import ( + DEPLOYER_DOCKER_IMAGE_KEY, +) +from zenml.deployers.base_deployer import BaseDeployer +from zenml.logger import get_logger +from zenml.models import ( + PipelineSnapshotBase, + PipelineSnapshotResponse, +) + +logger = get_logger(__name__) + + +class ContainerizedDeployer(BaseDeployer, ABC): + """Base class for all containerized deployers.""" + + CONTAINER_REQUIREMENTS: List[str] = [] + + @staticmethod + def get_image(snapshot: PipelineSnapshotResponse) -> str: + """Get the docker image used to deploy a pipeline snapshot. + + Args: + snapshot: The pipeline snapshot to get the image for. + + Returns: + The docker image used to deploy the pipeline snapshot. + + Raises: + RuntimeError: if the pipeline snapshot does not have a build or + if the deployer image is not in the build. + """ + if snapshot.build is None: + raise RuntimeError("Pipeline snapshot does not have a build. ") + if DEPLOYER_DOCKER_IMAGE_KEY not in snapshot.build.images: + raise RuntimeError( + "Pipeline snapshot build does not have a deployer image. " + ) + return snapshot.build.images[DEPLOYER_DOCKER_IMAGE_KEY].image + + @property + def requirements(self) -> Set[str]: + """Set of PyPI requirements for the deployer. + + Returns: + A set of PyPI requirements for the deployer. + """ + requirements = super().requirements + requirements.update(self.CONTAINER_REQUIREMENTS) + + if self.config.is_local and GlobalConfiguration().uses_sql_store: + # If we're directly connected to a DB, we need to install the + # `local` extra in the Docker image to include the DB dependencies. + requirements.add(f"'zenml[local]=={zenml.__version__}'") + + return requirements + + def get_docker_builds( + self, snapshot: "PipelineSnapshotBase" + ) -> List["BuildConfiguration"]: + """Gets the Docker builds required for the component. + + Args: + snapshot: The pipeline snapshot for which to get the builds. + + Returns: + The required Docker builds. + """ + return [ + BuildConfiguration( + key=DEPLOYER_DOCKER_IMAGE_KEY, + settings=snapshot.pipeline_configuration.docker_settings, + ) + ] diff --git a/src/zenml/deployers/docker/__init__.py b/src/zenml/deployers/docker/__init__.py new file mode 100644 index 00000000000..c822a6b0510 --- /dev/null +++ b/src/zenml/deployers/docker/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Implementation for the local Docker deployer.""" diff --git a/src/zenml/deployers/docker/docker_deployer.py b/src/zenml/deployers/docker/docker_deployer.py new file mode 100644 index 00000000000..59f39ad07f2 --- /dev/null +++ b/src/zenml/deployers/docker/docker_deployer.py @@ -0,0 +1,680 @@ +# Copyright (c) ZenML GmbH 2022. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Implementation of the ZenML Docker deployer.""" + +import copy +import os +import sys +from typing import ( + Any, + Dict, + Generator, + List, + Optional, + Tuple, + Type, + cast, +) + +import docker.errors as docker_errors +from docker.client import DockerClient +from docker.models.containers import Container +from pydantic import BaseModel + +from zenml.config.base_settings import BaseSettings +from zenml.config.global_config import GlobalConfiguration +from zenml.constants import ( + ENV_ZENML_LOCAL_STORES_PATH, +) +from zenml.deployers.base_deployer import ( + BaseDeployerConfig, + BaseDeployerFlavor, + BaseDeployerSettings, +) +from zenml.deployers.containerized_deployer import ( + ContainerizedDeployer, +) +from zenml.deployers.exceptions import ( + DeployerError, + DeploymentDeprovisionError, + DeploymentLogsNotFoundError, + DeploymentNotFoundError, + DeploymentProvisionError, +) +from zenml.deployers.server.entrypoint_configuration import ( + AUTH_KEY_OPTION, + DEPLOYMENT_ID_OPTION, + PORT_OPTION, + DeploymentEntrypointConfiguration, +) +from zenml.enums import DeploymentStatus, StackComponentType +from zenml.logger import get_logger +from zenml.models import ( + DeploymentOperationalState, + DeploymentResponse, +) +from zenml.stack import Stack, StackValidator +from zenml.utils import docker_utils +from zenml.utils.networking_utils import ( + lookup_preferred_or_free_port, +) + +logger = get_logger(__name__) + + +class DockerDeploymentMetadata(BaseModel): + """Metadata for a Docker deployment.""" + + port: Optional[int] = None + container_id: Optional[str] = None + container_name: Optional[str] = None + container_image_id: Optional[str] = None + container_image_uri: Optional[str] = None + container_status: Optional[str] = None + + @classmethod + def from_container( + cls, container: Container + ) -> "DockerDeploymentMetadata": + """Create a DockerDeploymentMetadata from a docker container. + + Args: + container: The docker container to get the metadata for. + + Returns: + The metadata for the docker container. + """ + image = container.image + if image is not None: + image_url = image.attrs["RepoTags"][0] + image_id = image.attrs["Id"] + else: + image_url = None + image_id = None + if container.ports: + ports = list(container.ports.values()) + if len(ports) > 0: + port = int(ports[0][0]["HostPort"]) + else: + port = None + else: + port = None + return cls( + port=port, + container_id=container.id, + container_name=container.name, + container_image_uri=image_url, + container_image_id=image_id, + container_status=container.status, + ) + + @classmethod + def from_deployment( + cls, deployment: DeploymentResponse + ) -> "DockerDeploymentMetadata": + """Create a DockerDeploymentMetadata from a deployment. + + Args: + deployment: The deployment to get the metadata for. + + Returns: + The metadata for the deployment. + """ + return cls.model_validate(deployment.deployment_metadata) + + +class DockerDeployer(ContainerizedDeployer): + """Deployer responsible for deploying pipelines locally using Docker.""" + + CONTAINER_REQUIREMENTS: List[str] = ["uvicorn", "fastapi"] + _docker_client: Optional[DockerClient] = None + + @property + def settings_class(self) -> Optional[Type["BaseSettings"]]: + """Settings class for the Docker deployer. + + Returns: + The settings class. + """ + return DockerDeployerSettings + + @property + def config(self) -> "DockerDeployerConfig": + """Returns the `DockerDeployerConfig` config. + + Returns: + The configuration. + """ + return cast(DockerDeployerConfig, self._config) + + @property + def validator(self) -> Optional[StackValidator]: + """Ensures there is an image builder in the stack. + + Returns: + A `StackValidator` instance. + """ + return StackValidator( + required_components={StackComponentType.IMAGE_BUILDER} + ) + + @property + def docker_client(self) -> DockerClient: + """Initialize and/or return the docker client. + + Returns: + The docker client. + """ + if self._docker_client is None: + self._docker_client = ( + docker_utils._try_get_docker_client_from_env() + ) + return self._docker_client + + def _get_container_id(self, deployment: DeploymentResponse) -> str: + """Get the docker container id associated with a deployment. + + Args: + deployment: The deployment to get the container id for. + + Returns: + The docker container id for the deployment. + """ + return f"zenml-deployment-{deployment.id}" + + def _get_container( + self, deployment: DeploymentResponse + ) -> Optional[Container]: + """Get the docker container associated with a deployment. + + Args: + deployment: The deployment to get the container for. + + Returns: + The docker container for the service, or None if the container + does not exist. + """ + try: + return self.docker_client.containers.get( + self._get_container_id(deployment) + ) + except docker_errors.NotFound: + # container doesn't exist yet or was removed + return None + + def _get_container_operational_state( + self, container: Container + ) -> DeploymentOperationalState: + """Get the operational state of a docker container running a deployment. + + Args: + container: The docker container to get the operational state of. + + Returns: + The operational state of the docker container running the pipeline + deployment. + """ + metadata = DockerDeploymentMetadata.from_container(container) + state = DeploymentOperationalState( + status=DeploymentStatus.UNKNOWN, + metadata=metadata.model_dump(exclude_none=True), + ) + if metadata.container_status == "running": + state.status = DeploymentStatus.RUNNING + elif metadata.container_status == "exited": + state.status = DeploymentStatus.ERROR + elif metadata.container_status in ["created", "restarting", "paused"]: + state.status = DeploymentStatus.PENDING + elif metadata.container_status == "dead": + state.status = DeploymentStatus.ERROR + elif metadata.container_status == "removing": + state.status = DeploymentStatus.PENDING + elif metadata.container_status == "exited": + state.status = DeploymentStatus.ABSENT + elif metadata.container_status == "dead": + state.status = DeploymentStatus.ERROR + + if state.status == DeploymentStatus.RUNNING: + state.url = "http://localhost" + if metadata.port: + state.url += f":{metadata.port}" + + return state + + def do_provision_deployment( + self, + deployment: DeploymentResponse, + stack: "Stack", + environment: Dict[str, str], + secrets: Dict[str, str], + timeout: int, + ) -> DeploymentOperationalState: + """Deploy a pipeline as a Docker container. + + Args: + deployment: The deployment to run as a Docker container. + stack: The stack the pipeline will be deployed on. + environment: A dictionary of environment variables to set on the + deployment. + secrets: A dictionary of secret environment variables to set + on the deployment. These secret environment variables + should not be exposed as regular environment variables on the + deployer. + timeout: The maximum time in seconds to wait for the pipeline + deployment to be provisioned. + + Returns: + The DeploymentOperationalState object representing the + operational state of the provisioned deployment. + + Raises: + DeploymentProvisionError: if provisioning the deployment + fails. + """ + assert deployment.snapshot, "Pipeline snapshot not found" + snapshot = deployment.snapshot + + # Currently, there is no safe way to pass secrets to a docker + # container, so we simply merge them into the environment variables. + environment.update(secrets) + + settings = cast( + DockerDeployerSettings, + self.get_settings(snapshot), + ) + + existing_metadata = DockerDeploymentMetadata.from_deployment( + deployment + ) + + entrypoint = DeploymentEntrypointConfiguration.get_entrypoint_command() + + entrypoint_kwargs = { + DEPLOYMENT_ID_OPTION: deployment.id, + PORT_OPTION: 8000, + } + if deployment.auth_key: + entrypoint_kwargs[AUTH_KEY_OPTION] = deployment.auth_key + + arguments = DeploymentEntrypointConfiguration.get_entrypoint_arguments( + **entrypoint_kwargs + ) + + # Add the local stores path as a volume mount + stack.check_local_paths() + local_stores_path = GlobalConfiguration().local_stores_path + volumes = { + local_stores_path: { + "bind": local_stores_path, + "mode": "rw", + } + } + environment[ENV_ZENML_LOCAL_STORES_PATH] = local_stores_path + + # check if a container already exists for the deployment + container = self._get_container(deployment) + + if container: + # the container exists, check if it is running + if container.status == "running": + logger.debug( + f"Container for deployment '{deployment.name}' is " + "already running", + ) + container.stop(timeout=timeout) + + # the container is stopped or in an error state, remove it + logger.debug( + f"Removing previous container for deployment " + f"'{deployment.name}'", + ) + container.remove(force=True) + + logger.debug( + f"Starting container for deployment '{deployment.name}'..." + ) + + image = self.get_image(deployment.snapshot) + + try: + self.docker_client.images.get(image) + except docker_errors.ImageNotFound: + logger.debug( + f"Pulling container image '{image}' for deployment " + f"'{deployment.name}'...", + ) + self.docker_client.images.pull(image) + + preferred_ports: List[int] = [] + if settings.port: + preferred_ports.append(settings.port) + if existing_metadata.port: + preferred_ports.append(existing_metadata.port) + port = lookup_preferred_or_free_port( + preferred_ports=preferred_ports, + allocate_port_if_busy=settings.allocate_port_if_busy, + range=settings.port_range, + ) + ports: Dict[str, Optional[int]] = {"8000/tcp": port} + + uid_args: Dict[str, Any] = {} + if sys.platform == "win32": + # File permissions are not checked on Windows. This if clause + # prevents mypy from complaining about unused 'type: ignore' + # statements + pass + else: + # Run the container in the context of the local UID/GID + # to ensure that the local database can be shared + # with the container. + logger.debug( + "Setting UID and GID to local user/group in container." + ) + uid_args = dict( + user=os.getuid(), + group_add=[os.getgid()], + ) + + run_args = copy.deepcopy(settings.run_args) + docker_environment = run_args.pop("environment", {}) + docker_environment.update(environment) + + docker_volumes = run_args.pop("volumes", {}) + docker_volumes.update(volumes) + + extra_hosts = run_args.pop("extra_hosts", {}) + extra_hosts["host.docker.internal"] = "host-gateway" + + run_args.update(uid_args) + + try: + container = self.docker_client.containers.run( + image=image, + name=self._get_container_id(deployment), + entrypoint=entrypoint, + command=arguments, + detach=True, + volumes=docker_volumes, + environment=docker_environment, + remove=False, + auto_remove=False, + ports=ports, + labels={ + "zenml-deployment-id": str(deployment.id), + "zenml-deployment-name": deployment.name, + "zenml-deployer-name": str(self.name), + "zenml-deployer-id": str(self.id), + "managed-by": "zenml", + }, + extra_hosts=extra_hosts, + **run_args, + ) + + logger.debug( + f"Docker container for deployment '{deployment.name}' " + f"started with ID {self._get_container_id(deployment)}", + ) + + except docker_errors.DockerException as e: + raise DeploymentProvisionError( + f"Docker container for deployment '{deployment.name}' " + f"failed to start: {e}" + ) + + return self._get_container_operational_state(container) + + def do_get_deployment_state( + self, + deployment: DeploymentResponse, + ) -> DeploymentOperationalState: + """Get information about a docker deployment. + + Args: + deployment: The deployment to get information about. + + Returns: + The DeploymentOperationalState object representing the + updated operational state of the deployment. + + Raises: + DeploymentNotFoundError: if no deployment is found + corresponding to the provided DeploymentResponse. + """ + container = self._get_container(deployment) + if container is None: + raise DeploymentNotFoundError( + f"Docker container for deployment '{deployment.name}' " + "not found" + ) + + return self._get_container_operational_state(container) + + def do_get_deployment_state_logs( + self, + deployment: DeploymentResponse, + follow: bool = False, + tail: Optional[int] = None, + ) -> Generator[str, bool, None]: + """Get the logs of a Docker deployment. + + Args: + deployment: The deployment to get the logs of. + follow: if True, the logs will be streamed as they are written + tail: only retrieve the last NUM lines of log output. + + Yields: + The logs of the deployment. + + Raises: + DeploymentNotFoundError: if no deployment is found + corresponding to the provided DeploymentResponse. + DeploymentLogsNotFoundError: if the deployment logs are not + found. + DeployerError: if the deployment logs cannot + be retrieved for any other reason or if an unexpected error + occurs. + """ + container = self._get_container(deployment) + if container is None: + raise DeploymentNotFoundError( + f"Docker container for deployment '{deployment.name}' " + "not found" + ) + + try: + log_kwargs: Dict[str, Any] = { + "stdout": True, + "stderr": True, + "stream": follow, + "follow": follow, + "timestamps": True, + } + + if tail is not None and tail > 0: + log_kwargs["tail"] = tail + + log_stream = container.logs(**log_kwargs) + + if follow: + for log_line in log_stream: + if isinstance(log_line, bytes): + yield log_line.decode( + "utf-8", errors="replace" + ).rstrip() + else: + yield str(log_line).rstrip() + else: + if isinstance(log_stream, bytes): + log_text = log_stream.decode("utf-8", errors="replace") + for line in log_text.splitlines(): + yield line + else: + for log_line in log_stream: + if isinstance(log_line, bytes): + yield log_line.decode( + "utf-8", errors="replace" + ).rstrip() + else: + yield str(log_line).rstrip() + + except docker_errors.NotFound as e: + raise DeploymentLogsNotFoundError( + f"Logs for deployment '{deployment.name}' not found: {e}" + ) + except docker_errors.APIError as e: + raise DeployerError( + f"Docker API error while retrieving logs for deployment " + f"'{deployment.name}': {e}" + ) + except docker_errors.DockerException as e: + raise DeployerError( + f"Docker error while retrieving logs for deployment " + f"'{deployment.name}': {e}" + ) + except Exception as e: + raise DeployerError( + f"Unexpected error while retrieving logs for deployment " + f"'{deployment.name}': {e}" + ) + + def do_deprovision_deployment( + self, + deployment: DeploymentResponse, + timeout: int, + ) -> Optional[DeploymentOperationalState]: + """Deprovision a docker deployment. + + Args: + deployment: The deployment to deprovision. + timeout: The maximum time in seconds to wait for the pipeline + deployment to be deprovisioned. + + Returns: + The DeploymentOperationalState object representing the + operational state of the deleted deployment, or None if the + deletion is completed before the call returns. + + Raises: + DeploymentNotFoundError: if no deployment is found + corresponding to the provided DeploymentResponse. + DeploymentDeprovisionError: if the deployment + deprovision fails. + """ + container = self._get_container(deployment) + if container is None: + raise DeploymentNotFoundError( + f"Docker container for deployment '{deployment.name}' " + "not found" + ) + + try: + container.stop(timeout=timeout) + container.remove() + except docker_errors.DockerException as e: + raise DeploymentDeprovisionError( + f"Docker container for deployment '{deployment.name}' " + f"failed to delete: {e}" + ) + + return None + + +class DockerDeployerSettings(BaseDeployerSettings): + """Docker deployer settings. + + Attributes: + port: The port to expose the deployment on. + allocate_port_if_busy: If True, allocate a free port if the configured + port is busy. + port_range: The range of ports to search for a free port. + run_args: Arguments to pass to the `docker run` call. (See + https://docker-py.readthedocs.io/en/stable/containers.html for a list + of what can be passed.) + """ + + port: Optional[int] = None + allocate_port_if_busy: bool = True + port_range: Tuple[int, int] = (8000, 65535) + run_args: Dict[str, Any] = {} + + +class DockerDeployerConfig(BaseDeployerConfig, DockerDeployerSettings): + """Docker deployer config.""" + + @property + def is_local(self) -> bool: + """Checks if this stack component is running locally. + + Returns: + True if this config is for a local component, False otherwise. + """ + return True + + +class DockerDeployerFlavor(BaseDeployerFlavor): + """Flavor for the Docker deployer.""" + + @property + def name(self) -> str: + """Name of the deployer flavor. + + Returns: + Name of the deployer flavor. + """ + return "docker" + + @property + def docs_url(self) -> Optional[str]: + """A url to point at docs explaining this flavor. + + Returns: + A flavor docs url. + """ + return self.generate_default_docs_url() + + @property + def sdk_docs_url(self) -> Optional[str]: + """A url to point at SDK docs explaining this flavor. + + Returns: + A flavor SDK docs url. + """ + return self.generate_default_sdk_docs_url() + + @property + def logo_url(self) -> str: + """A url to represent the flavor in the dashboard. + + Returns: + The flavor logo. + """ + return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/deployer/docker.png" + + @property + def config_class(self) -> Type[BaseDeployerConfig]: + """Config class for the base deployer flavor. + + Returns: + The config class. + """ + return DockerDeployerConfig + + @property + def implementation_class(self) -> Type["DockerDeployer"]: + """Implementation class for this flavor. + + Returns: + Implementation class for this flavor. + """ + return DockerDeployer diff --git a/src/zenml/deployers/exceptions.py b/src/zenml/deployers/exceptions.py new file mode 100644 index 00000000000..3fe71607a10 --- /dev/null +++ b/src/zenml/deployers/exceptions.py @@ -0,0 +1,63 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Base class for all ZenML deployers.""" + +from zenml.exceptions import EntityExistsError +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +class DeployerError(Exception): + """Base class for deployer errors.""" + + +class DeploymentAlreadyExistsError(EntityExistsError, DeployerError): + """Error raised when a deployment already exists.""" + + +class DeploymentNotFoundError(KeyError, DeployerError): + """Error raised when a deployment is not found.""" + + +class DeploymentProvisionError(DeployerError): + """Error raised when a deployment provisioning fails.""" + + +class DeploymentTimeoutError(DeployerError): + """Error raised when a deployment provisioning or deprovisioning times out.""" + + +class DeploymentDeprovisionError(DeployerError): + """Error raised when a deployment deprovisioning fails.""" + + +class DeploymentLogsNotFoundError(KeyError, DeployerError): + """Error raised when pipeline logs are not found.""" + + +class DeploymentDeployerMismatchError(DeployerError): + """Error raised when a deployment is not managed by this deployer.""" + + +class DeploymentSnapshotMismatchError(DeployerError): + """Error raised when a deployment snapshot does not match the current deployer.""" + + +class DeploymentHTTPError(DeployerError): + """Error raised when an HTTP request to a deployment fails.""" + + +class DeploymentInvalidParametersError(DeployerError): + """Error raised when the parameters for a deployment are invalid.""" diff --git a/src/zenml/deployers/server/app.py b/src/zenml/deployers/server/app.py new file mode 100644 index 00000000000..4717516d1e3 --- /dev/null +++ b/src/zenml/deployers/server/app.py @@ -0,0 +1,381 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""FastAPI application for running ZenML pipeline deployments.""" + +import os +from contextlib import asynccontextmanager +from typing import AsyncGenerator, Literal, Optional + +from fastapi import ( + APIRouter, + Depends, + FastAPI, + HTTPException, + Request, +) +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import HTMLResponse, JSONResponse +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +from zenml.deployers.server.models import ( + ExecutionMetrics, + ServiceInfo, + get_pipeline_invoke_models, +) +from zenml.deployers.server.service import PipelineDeploymentService +from zenml.logger import get_logger + +logger = get_logger(__name__) + +_service: Optional[PipelineDeploymentService] = None + + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + """Manage application lifespan. + + Args: + app: The FastAPI application instance being deployed. + + Yields: + None: Control is handed back to FastAPI once initialization completes. + + Raises: + ValueError: If no deployment identifier is configured. + Exception: If initialization or cleanup fails. + """ + # Check for test mode + if os.getenv("ZENML_DEPLOYMENT_TEST_MODE", "false").lower() == "true": + logger.info("๐Ÿงช Running in test mode - skipping initialization") + yield + return + + # Startup + logger.info("๐Ÿš€ Starting ZenML Pipeline Serving service...") + + deployment_id = os.getenv("ZENML_DEPLOYMENT_ID") + if not deployment_id: + raise ValueError( + "ZENML_DEPLOYMENT_ID environment variable is required" + ) + + try: + global _service + _service = PipelineDeploymentService(deployment_id) + _service.initialize() + app.include_router(_build_invoke_router(_service)) + logger.info("โœ… Pipeline deployment service initialized successfully") + except Exception as e: + logger.error(f"โŒ Failed to initialize: {e}") + raise + + yield + + # Shutdown + logger.info("๐Ÿ›‘ Shutting down ZenML Pipeline Deployment service...") + try: + if _service: + _service.cleanup() + logger.info( + "โœ… Pipeline deployment service cleaned up successfully" + ) + except Exception as e: + logger.error(f"โŒ Error during service cleanup: {e}") + finally: + # Ensure globals are reset to avoid stale references across lifecycles + _service = None + + +# Create FastAPI application with OpenAPI security scheme +app = FastAPI( + title=f"ZenML Pipeline Deployment {os.getenv('ZENML_DEPLOYMENT_ID')}", + description="deploy ZenML pipelines as FastAPI endpoints", + version="0.2.0", + lifespan=lifespan, + docs_url="/docs", + redoc_url="/redoc", +) + +# Define security scheme for OpenAPI documentation +security = HTTPBearer( + scheme_name="Bearer Token", + description="Enter your API key as a Bearer token", + auto_error=False, # We handle errors in our dependency +) + + +def _build_invoke_router(service: PipelineDeploymentService) -> APIRouter: + """Create an idiomatic APIRouter that exposes /invoke. + + Args: + service: The deployment service used to execute pipeline runs. + + Returns: + A router exposing the `/invoke` endpoint wired to the service. + """ + router = APIRouter() + + PipelineInvokeRequest, PipelineInvokeResponse = get_pipeline_invoke_models( + service + ) + + @router.post( + "/invoke", + name="invoke_pipeline", + summary="Invoke the pipeline with validated parameters", + response_model=PipelineInvokeResponse, + ) + def _( + request: PipelineInvokeRequest, # type: ignore[valid-type] + _: None = Depends(verify_token), + ) -> PipelineInvokeResponse: # type: ignore[valid-type] + return service.execute_pipeline(request) + + return router + + +def get_pipeline_service() -> PipelineDeploymentService: + """Get the pipeline deployment service. + + Returns: + The initialized pipeline deployment service instance. + """ + assert _service is not None + return _service + + +def verify_token( + credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), +) -> None: + """Verify the provided Bearer token for authentication. + + This dependency function integrates with FastAPI's security system + to provide proper OpenAPI documentation and authentication UI. + + Args: + credentials: HTTP Bearer credentials from the request + + Raises: + HTTPException: If authentication is required but token is invalid + """ + auth_key = os.getenv("ZENML_DEPLOYMENT_AUTH_KEY", "").strip() + auth_enabled = auth_key and auth_key != "" + + # If authentication is not enabled, allow all requests + if not auth_enabled: + return + + # If authentication is enabled, validate the token + if not credentials: + raise HTTPException( + status_code=401, + detail="Authorization header required", + headers={"WWW-Authenticate": "Bearer"}, + ) + + if credentials.credentials != auth_key: + raise HTTPException( + status_code=401, + detail="Invalid authentication token", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Token is valid, authentication successful + return + + +# Add CORS middleware to allow frontend access +# TODO: In production, restrict allow_origins to specific domains for security +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Allow all origins - restrict in production + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["*"], +) + + +@app.get("/", response_class=HTMLResponse) +async def root( + service: PipelineDeploymentService = Depends(get_pipeline_service), +) -> str: + """Root endpoint with service information. + + Args: + service: The pipeline serving service dependency. + + Returns: + An HTML page describing the serving deployment. + """ + info = service.get_service_info() + + html_content = f""" + + + + ZenML Pipeline Deployment + + + +

๐Ÿš€ ZenML Pipeline Deployment

+
+

Service Status

+

Status: Running

+

Pipeline: {info.pipeline.name}

+
+ + + + """ + return html_content + + +@app.get("/health") +async def health_check( + service: PipelineDeploymentService = Depends(get_pipeline_service), +) -> Literal["OK"]: + """Service health check endpoint. + + Args: + service: The pipeline serving service dependency. + + Returns: + "OK" if the service is healthy, otherwise raises an HTTPException. + + Raises: + HTTPException: If the service is not healthy. + """ + if not service.is_healthy(): + raise HTTPException(503, "Service is unhealthy") + + return "OK" + + +@app.get("/info") +async def info( + service: PipelineDeploymentService = Depends(get_pipeline_service), +) -> ServiceInfo: + """Get detailed information about the service, including pipeline metadata and schema. + + Args: + service: The pipeline serving service dependency. + + Returns: + Service info. + """ + return service.get_service_info() + + +@app.get("/metrics") +async def execution_metrics( + service: PipelineDeploymentService = Depends(get_pipeline_service), +) -> ExecutionMetrics: + """Get pipeline execution metrics and statistics. + + Args: + service: The pipeline serving service dependency. + + Returns: + Aggregated execution metrics. + """ + return service.get_execution_metrics() + + +# Custom exception handlers +@app.exception_handler(ValueError) +def value_error_handler(request: Request, exc: ValueError) -> JSONResponse: + """Handle ValueError exceptions (synchronous for unit tests). + + Args: + request: The request. + exc: The exception. + + Returns: + The error response. + """ + logger.error("ValueError in request: %s", exc) + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + +@app.exception_handler(RuntimeError) +def runtime_error_handler(request: Request, exc: RuntimeError) -> JSONResponse: + """Handle RuntimeError exceptions (synchronous for unit tests). + + Args: + request: The request. + exc: The exception. + + Returns: + The error response. + """ + logger.error("RuntimeError in request: %s", exc) + return JSONResponse(status_code=500, content={"detail": str(exc)}) + + +if __name__ == "__main__": + import argparse + + import uvicorn + + parser = argparse.ArgumentParser() + parser.add_argument( + "--deployment_id", + default=os.getenv("ZENML_DEPLOYMENT_ID"), + help="Pipeline snapshot ID", + ) + parser.add_argument( + "--host", + default=os.getenv("ZENML_SERVICE_HOST", "0.0.0.0"), # nosec + ) + parser.add_argument( + "--port", + type=int, + default=int(os.getenv("ZENML_SERVICE_PORT", "8001")), + ) + parser.add_argument( + "--workers", + type=int, + default=int(os.getenv("ZENML_SERVICE_WORKERS", "1")), + ) + parser.add_argument( + "--log_level", default=os.getenv("ZENML_LOG_LEVEL", "info").lower() + ) + parser.add_argument( + "--auth_key", default=os.getenv("ZENML_DEPLOYMENT_AUTH_KEY", "") + ) + args = parser.parse_args() + + if args.deployment_id: + os.environ["ZENML_DEPLOYMENT_ID"] = args.deployment_id + if args.auth_key: + os.environ["ZENML_DEPLOYMENT_AUTH_KEY"] = args.auth_key + + logger.info(f"Starting FastAPI server on {args.host}:{args.port}") + + uvicorn.run( + "zenml.deployers.server.app:app", + host=args.host, + port=args.port, + workers=args.workers, + log_level=args.log_level, + reload=False, + ) diff --git a/src/zenml/deployers/server/entrypoint_configuration.py b/src/zenml/deployers/server/entrypoint_configuration.py new file mode 100644 index 00000000000..792ca535bc9 --- /dev/null +++ b/src/zenml/deployers/server/entrypoint_configuration.py @@ -0,0 +1,192 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""ZenML Pipeline Deployment Entrypoint Configuration.""" + +import os +from typing import Any, List, Set +from uuid import UUID + +from zenml.client import Client +from zenml.entrypoints.base_entrypoint_configuration import ( + BaseEntrypointConfiguration, +) +from zenml.integrations.registry import integration_registry +from zenml.logger import get_logger +from zenml.models.v2.core.pipeline_snapshot import PipelineSnapshotResponse +from zenml.utils import uuid_utils + +logger = get_logger(__name__) + +# Deployment-specific entrypoint options +DEPLOYMENT_ID_OPTION = "deployment_id" +HOST_OPTION = "host" +PORT_OPTION = "port" +WORKERS_OPTION = "workers" +LOG_LEVEL_OPTION = "log_level" +CREATE_RUNS_OPTION = "create_runs" +AUTH_KEY_OPTION = "auth_key" + + +class DeploymentEntrypointConfiguration(BaseEntrypointConfiguration): + """Entrypoint configuration for ZenML Pipeline Deployment. + + This entrypoint configuration handles the startup and configuration + of the ZenML pipeline deployment FastAPI application. + """ + + @classmethod + def get_entrypoint_options(cls) -> Set[str]: + """Gets all options required for the deployment entrypoint. + + Returns: + Set of required option names + """ + return { + DEPLOYMENT_ID_OPTION, + HOST_OPTION, + PORT_OPTION, + WORKERS_OPTION, + LOG_LEVEL_OPTION, + CREATE_RUNS_OPTION, + AUTH_KEY_OPTION, + } + + @classmethod + def get_entrypoint_arguments(cls, **kwargs: Any) -> List[str]: + """Gets arguments for the deployment entrypoint command. + + Args: + **kwargs: Keyword arguments containing deployment configuration + + Returns: + List of command-line arguments + + Raises: + ValueError: If the deployment ID is not a valid UUID. + """ + # Get base arguments (snapshot_id, etc.) + base_args = super().get_entrypoint_arguments(**kwargs) + + deployment_id = kwargs.get(DEPLOYMENT_ID_OPTION) + if not uuid_utils.is_valid_uuid(deployment_id): + raise ValueError( + f"Missing or invalid deployment ID as argument for entrypoint " + f"configuration. Please make sure to pass a valid UUID to " + f"`{cls.__name__}.{cls.get_entrypoint_arguments.__name__}" + f"({DEPLOYMENT_ID_OPTION}=)`." + ) + + # Add deployment-specific arguments with defaults + deployment_args = [ + f"--{DEPLOYMENT_ID_OPTION}", + str(kwargs.get(DEPLOYMENT_ID_OPTION, "")), + f"--{HOST_OPTION}", + str(kwargs.get(HOST_OPTION, "0.0.0.0")), # nosec + f"--{PORT_OPTION}", + str(kwargs.get(PORT_OPTION, 8001)), + f"--{WORKERS_OPTION}", + str(kwargs.get(WORKERS_OPTION, 1)), + f"--{LOG_LEVEL_OPTION}", + str(kwargs.get(LOG_LEVEL_OPTION, "info")), + f"--{CREATE_RUNS_OPTION}", + str(kwargs.get(CREATE_RUNS_OPTION, "false")), + f"--{AUTH_KEY_OPTION}", + str(kwargs.get(AUTH_KEY_OPTION, "")), + ] + + return base_args + deployment_args + + def load_snapshot(self) -> "PipelineSnapshotResponse": + """Loads the deployment snapshot. + + Returns: + The deployment snapshot. + + Raises: + RuntimeError: If the deployment has no snapshot. + """ + deployment_id = UUID(self.entrypoint_args[DEPLOYMENT_ID_OPTION]) + deployment = Client().zen_store.get_deployment( + deployment_id=deployment_id + ) + if deployment.snapshot is None: + raise RuntimeError("Deployment has no snapshot") + return deployment.snapshot + + def run(self) -> None: + """Run the ZenML pipeline deployment application. + + This method starts the FastAPI server with the configured parameters + and the specified pipeline deployment. + + Raises: + Exception: If the server fails to start. + """ + import uvicorn + + # Activate integrations to ensure all components are available + integration_registry.activate_integrations() + + # Extract configuration from entrypoint args + deployment_id = self.entrypoint_args[DEPLOYMENT_ID_OPTION] + host = self.entrypoint_args.get(HOST_OPTION, "0.0.0.0") # nosec + port = int(self.entrypoint_args.get(PORT_OPTION, 8001)) + workers = int(self.entrypoint_args.get(WORKERS_OPTION, 1)) + log_level = self.entrypoint_args.get(LOG_LEVEL_OPTION, "info") + create_runs = ( + self.entrypoint_args.get(CREATE_RUNS_OPTION, "false").lower() + == "true" + ) + auth_key = self.entrypoint_args.get(AUTH_KEY_OPTION, None) + + snapshot = self.load_snapshot() + + # Download code if necessary (for remote execution environments) + self.download_code_if_necessary(snapshot=snapshot) + + # Set environment variables for the deployment application + os.environ["ZENML_DEPLOYMENT_ID"] = deployment_id + if create_runs: + os.environ["ZENML_DEPLOYMENT_CREATE_RUNS"] = "true" + if auth_key: + os.environ["ZENML_DEPLOYMENT_AUTH_KEY"] = auth_key + + logger.info("๐Ÿš€ Starting ZenML Pipeline Deployment...") + logger.info(f" Deployment ID: {deployment_id}") + logger.info(f" Snapshot ID: {snapshot.id}") + logger.info(f" Host: {host}") + logger.info(f" Port: {port}") + logger.info(f" Workers: {workers}") + logger.info(f" Log Level: {log_level}") + logger.info(f" Create Runs: {create_runs}") + logger.info("") + logger.info(f"๐Ÿ“– API Documentation: http://{host}:{port}/docs") + logger.info(f"๐Ÿ” Health Check: http://{host}:{port}/health") + logger.info("") + + try: + # Start the FastAPI server + uvicorn.run( + "zenml.deployers.server.app:app", + host=host, + port=port, + workers=workers, + log_level=log_level.lower(), + access_log=True, + ) + except KeyboardInterrupt: + logger.info("\n๐Ÿ›‘ Deployment stopped by user") + except Exception as e: + logger.error(f"โŒ Failed to start deployment: {str(e)}") + raise diff --git a/src/zenml/deployers/server/models.py b/src/zenml/deployers/server/models.py new file mode 100644 index 00000000000..6b341d0de7e --- /dev/null +++ b/src/zenml/deployers/server/models.py @@ -0,0 +1,187 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""FastAPI application models.""" + +from datetime import datetime +from typing import TYPE_CHECKING, Annotated, Any, Dict, Optional, Tuple, Type +from uuid import UUID + +from pydantic import BaseModel, Field, WithJsonSchema + +from zenml.logger import get_logger + +logger = get_logger(__name__) + +if TYPE_CHECKING: + from zenml.deployers.server.service import PipelineDeploymentService + + +class DeploymentInvocationResponseMetadata(BaseModel): + """Pipeline invoke response metadata model.""" + + deployment_id: UUID = Field(title="The ID of the deployment.") + deployment_name: str = Field(title="The name of the deployment.") + snapshot_id: UUID = Field(title="The ID of the snapshot.") + snapshot_name: Optional[str] = Field( + default=None, title="The name of the snapshot." + ) + pipeline_name: str = Field(title="The name of the pipeline.") + run_id: Optional[UUID] = Field( + default=None, title="The ID of the pipeline run." + ) + run_name: Optional[str] = Field( + default=None, title="The name of the pipeline run." + ) + parameters_used: Dict[str, Any] = Field( + title="The parameters used for the pipeline execution." + ) + + +class BaseDeploymentInvocationRequest(BaseModel): + """Base pipeline invoke request model.""" + + parameters: BaseModel = Field( + title="The parameters for the pipeline execution." + ) + run_name: Optional[str] = Field( + default=None, title="Custom name for the pipeline run." + ) + timeout: int = Field( + default=300, title="The timeout for the pipeline execution." + ) + skip_artifact_materialization: bool = Field( + default=False, + title="Whether to keep outputs in memory for fast access instead of " + "storing them as artifacts.", + ) + + +class BaseDeploymentInvocationResponse(BaseModel): + """Base pipeline invoke response model.""" + + success: bool = Field( + title="Whether the pipeline execution was successful." + ) + outputs: Optional[Dict[str, Any]] = Field( + default=None, + title="The outputs of the pipeline execution, if the pipeline execution " + "was successful.", + ) + execution_time: float = Field( + title="The time taken to execute the pipeline." + ) + metadata: DeploymentInvocationResponseMetadata = Field( + title="The metadata of the pipeline execution." + ) + error: Optional[str] = Field( + default=None, + title="The error that occurred, if the pipeline invocation failed.", + ) + + +class PipelineInfo(BaseModel): + """Pipeline info model.""" + + name: str = Field(title="The name of the pipeline.") + parameters: Optional[Dict[str, Any]] = Field( + default=None, title="The parameters of the pipeline." + ) + input_schema: Optional[Dict[str, Any]] = Field( + default=None, title="The input schema of the pipeline." + ) + output_schema: Optional[Dict[str, Any]] = Field( + default=None, title="The output schema of the pipeline." + ) + + +class DeploymentInfo(BaseModel): + """Deployment info model.""" + + id: UUID = Field(title="The ID of the deployment.") + name: str = Field(title="The name of the deployment.") + + +class SnapshotInfo(BaseModel): + """Snapshot info model.""" + + id: UUID = Field(title="The ID of the snapshot.") + name: Optional[str] = Field( + default=None, title="The name of the snapshot." + ) + + +class ServiceInfo(BaseModel): + """Service info model.""" + + deployment: DeploymentInfo = Field( + title="The deployment of the pipeline service." + ) + snapshot: SnapshotInfo = Field( + title="The snapshot of the pipeline service." + ) + pipeline: PipelineInfo = Field( + title="The pipeline of the pipeline service." + ) + total_executions: int = Field( + title="The total number of pipeline executions." + ) + last_execution_time: Optional[datetime] = Field( + default=None, title="The time of the last pipeline execution." + ) + status: str = Field(title="The status of the pipeline service.") + uptime: float = Field(title="The uptime of the pipeline service.") + + +class ExecutionMetrics(BaseModel): + """Execution metrics model.""" + + total_executions: int = Field( + title="The total number of pipeline executions." + ) + last_execution_time: Optional[datetime] = Field( + default=None, title="The time of the last pipeline execution." + ) + + +def get_pipeline_invoke_models( + service: "PipelineDeploymentService", +) -> Tuple[Type[BaseModel], Type[BaseModel]]: + """Generate the request and response models for the pipeline invoke endpoint. + + Args: + service: The pipeline deployment service. + + Returns: + A tuple containing the request and response models. + """ + if TYPE_CHECKING: + # mypy has a difficult time with dynamic models, so we return something + # static for mypy to use + return BaseModel, BaseModel + + else: + + class PipelineInvokeRequest(BaseDeploymentInvocationRequest): + parameters: Annotated[ + service.input_model, + WithJsonSchema(service.input_schema, mode="validation"), + ] + + class PipelineInvokeResponse(BaseDeploymentInvocationResponse): + outputs: Annotated[ + Optional[Dict[str, Any]], + WithJsonSchema(service.output_schema, mode="serialization"), + ] + + return PipelineInvokeRequest, PipelineInvokeResponse diff --git a/src/zenml/deployers/server/runtime.py b/src/zenml/deployers/server/runtime.py new file mode 100644 index 00000000000..cd0bfd1e1ef --- /dev/null +++ b/src/zenml/deployers/server/runtime.py @@ -0,0 +1,170 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Thread-safe runtime context for deployments. + +This module provides request-scoped state for deployment invocations using +contextvars to ensure thread safety and proper request isolation. Each +deployment request gets its own isolated context that doesn't interfere +with concurrent requests. + +It also provides parameter override functionality for the orchestrator +to access deployment parameters without tight coupling. +""" + +import contextvars +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field + +from zenml.models import PipelineSnapshotResponse + + +class _DeploymentState(BaseModel): + model_config = {"extra": "forbid"} + + active: bool = False + skip_artifact_materialization: bool = False + request_id: Optional[str] = None + snapshot_id: Optional[str] = None + pipeline_parameters: Dict[str, Any] = Field(default_factory=dict) + outputs: Dict[str, Dict[str, Any]] = Field(default_factory=dict) + + # In-memory data storage for artifacts + in_memory_data: Dict[str, Any] = Field(default_factory=dict) + + def reset(self) -> None: + """Reset the deployment state.""" + self.active = False + self.request_id = None + self.snapshot_id = None + self.pipeline_parameters.clear() + self.outputs.clear() + self.skip_artifact_materialization = False + self.in_memory_data.clear() + + +_deployment_context: contextvars.ContextVar[_DeploymentState] = ( + contextvars.ContextVar("deployment_context", default=_DeploymentState()) +) + + +def _get_context() -> _DeploymentState: + """Get the current deployment context state. + + Returns: + The current deployment context state. + """ + return _deployment_context.get() + + +def start( + request_id: str, + snapshot: PipelineSnapshotResponse, + parameters: Dict[str, Any], + skip_artifact_materialization: bool = False, +) -> None: + """Initialize deployment state for the current request context. + + Args: + request_id: The ID of the request. + snapshot: The snapshot to deploy. + parameters: The parameters to deploy. + skip_artifact_materialization: Whether to skip artifact materialization. + """ + state = _DeploymentState() + state.active = True + state.request_id = request_id + state.snapshot_id = str(snapshot.id) + state.pipeline_parameters = parameters + state.outputs = {} + state.skip_artifact_materialization = skip_artifact_materialization + _deployment_context.set(state) + + +def stop() -> None: + """Clear the deployment state for the current request context.""" + state = _get_context() + state.reset() + + +def is_active() -> bool: + """Return whether deployment state is active in the current context. + + Returns: + True if the deployment state is active in the current context, False otherwise. + """ + return _get_context().active + + +def record_step_outputs(step_name: str, outputs: Dict[str, Any]) -> None: + """Record raw outputs for a step by invocation id. + + Args: + step_name: The name of the step to record the outputs for. + outputs: A dictionary of outputs to record. + """ + state = _get_context() + if not state.active: + return + if not outputs: + return + state.outputs.setdefault(step_name, {}).update(outputs) + + +def get_outputs() -> Dict[str, Dict[str, Any]]: + """Return the outputs for all steps in the current context. + + Returns: + A dictionary of outputs for all steps. + """ + return dict(_get_context().outputs) + + +def should_skip_artifact_materialization() -> bool: + """Check if the current request should skip artifact materialization. + + Returns: + True if artifact materialization is skipped for this request. + """ + if is_active(): + state = _get_context() + return state.skip_artifact_materialization + return False + + +def put_in_memory_data(uri: str, data: Any) -> None: + """Store data in memory for the given URI. + + Args: + uri: The artifact URI to store data for. + data: The data to store in memory. + """ + if is_active(): + state = _get_context() + state.in_memory_data[uri] = data + + +def get_in_memory_data(uri: str) -> Any: + """Get data from memory for the given URI. + + Args: + uri: The artifact URI to retrieve data for. + + Returns: + The stored data, or None if not found. + """ + if is_active(): + state = _get_context() + return state.in_memory_data.get(uri) + return None diff --git a/src/zenml/deployers/server/service.py b/src/zenml/deployers/server/service.py new file mode 100644 index 00000000000..5e79b5fc7b1 --- /dev/null +++ b/src/zenml/deployers/server/service.py @@ -0,0 +1,610 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Pipeline deployment service.""" + +import contextvars +import time +import traceback +from datetime import datetime, timezone +from typing import Any, Dict, Optional, Tuple, Type, Union +from uuid import UUID, uuid4 + +from pydantic import BaseModel + +import zenml.pipelines.run_utils as run_utils +from zenml.client import Client +from zenml.deployers.server import runtime +from zenml.deployers.server.models import ( + BaseDeploymentInvocationRequest, + BaseDeploymentInvocationResponse, + DeploymentInfo, + DeploymentInvocationResponseMetadata, + ExecutionMetrics, + PipelineInfo, + ServiceInfo, + SnapshotInfo, +) +from zenml.enums import StackComponentType +from zenml.hooks.hook_validators import load_and_run_hook +from zenml.logger import get_logger +from zenml.models import ( + PipelineRunResponse, + PipelineRunTriggerInfo, + PipelineSnapshotResponse, +) +from zenml.orchestrators.base_orchestrator import BaseOrchestrator +from zenml.orchestrators.local.local_orchestrator import ( + LocalOrchestrator, + LocalOrchestratorConfig, +) +from zenml.pipelines.pipeline_definition import Pipeline +from zenml.stack import Stack +from zenml.steps.utils import get_unique_step_output_names +from zenml.utils import env_utils, source_utils + +logger = get_logger(__name__) + + +class SharedLocalOrchestrator(LocalOrchestrator): + """Local orchestrator tweaked for deployments. + + This is a slight modification of the LocalOrchestrator: + - uses request-scoped orchestrator run ids by storing them in contextvars + - bypasses the init/cleanup hook execution because they are run globally by + the deployment service + """ + + # Use contextvars for thread-safe, request-scoped state + _shared_orchestrator_run_id: contextvars.ContextVar[Optional[str]] = ( + contextvars.ContextVar("orchestrator_run_id", default=None) + ) + + def get_orchestrator_run_id(self) -> str: + """Get the orchestrator run id. + + Returns: + The orchestrator run id. + """ + run_id = self._shared_orchestrator_run_id.get() + if run_id is None: + run_id = str(uuid4()) + self._shared_orchestrator_run_id.set(run_id) + return run_id + + @classmethod + def run_init_hook(cls, snapshot: "PipelineSnapshotResponse") -> None: + """Runs the init hook. + + Args: + snapshot: The snapshot to run the init hook for. + """ + # Bypass the init hook execution because it is run globally by + # the deployment service + pass + + @classmethod + def run_cleanup_hook(cls, snapshot: "PipelineSnapshotResponse") -> None: + """Runs the cleanup hook. + + Args: + snapshot: The snapshot to run the cleanup hook for. + """ + # Bypass the cleanup hook execution because it is run globally by + # the deployment service + pass + + +class PipelineDeploymentService: + """Pipeline deployment service.""" + + def __init__(self, deployment_id: Union[str, UUID]) -> None: + """Initialize service with minimal state. + + Args: + deployment_id: The ID of the running deployment. + + Raises: + RuntimeError: If the deployment or snapshot cannot be loaded. + """ + # Accept both str and UUID for flexibility + if isinstance(deployment_id, str): + deployment_id = UUID(deployment_id) + + self._client = Client() + + # Execution tracking + self.service_start_time = time.time() + self.last_execution_time: Optional[datetime] = None + self.total_executions = 0 + + # Cache a local orchestrator instance to avoid per-request construction + self._orchestrator = SharedLocalOrchestrator( + name="deployment-local", + id=uuid4(), + config=LocalOrchestratorConfig(), + flavor="local", + type=StackComponentType.ORCHESTRATOR, + user=uuid4(), + created=datetime.now(), + updated=datetime.now(), + ) + + logger.info("Loading pipeline snapshot configuration...") + + try: + self.deployment = self._client.zen_store.get_deployment( + deployment_id=deployment_id + ) + except Exception as e: + raise RuntimeError(f"Failed to load deployment: {e}") from e + + if self.deployment.snapshot is None: + raise RuntimeError("Deployment has no snapshot") + self.snapshot = self.deployment.snapshot + + def initialize(self) -> None: + """Initialize service with proper error handling. + + Raises: + Exception: If the service cannot be initialized. + """ + try: + # Execute init hook + BaseOrchestrator.run_init_hook(self.snapshot) + + # Log success + self._log_initialization_success() + + except Exception as e: + logger.error(f"โŒ Failed to initialize service: {e}") + logger.error(f" Traceback: {traceback.format_exc()}") + raise + + def cleanup(self) -> None: + """Execute cleanup hook if present.""" + BaseOrchestrator.run_cleanup_hook(self.snapshot) + + def execute_pipeline( + self, + request: BaseDeploymentInvocationRequest, + ) -> BaseDeploymentInvocationResponse: + """Execute the deployment with the given parameters. + + Args: + request: Runtime parameters supplied by the caller. + + Returns: + A BaseDeploymentInvocationResponse describing the execution result. + """ + # Unused parameters for future implementation + _ = request.run_name, request.timeout + parameters = request.parameters.model_dump() + start_time = time.time() + logger.info("Starting pipeline execution") + + placeholder_run: Optional[PipelineRunResponse] = None + try: + # Create a placeholder run separately from the actual execution, + # so that we have a run ID to include in the response even if the + # pipeline execution fails. + placeholder_run, deployment_snapshot = ( + self._prepare_execute_with_orchestrator( + resolved_params=parameters, + ) + ) + + captured_outputs = self._execute_with_orchestrator( + placeholder_run=placeholder_run, + deployment_snapshot=deployment_snapshot, + resolved_params=parameters, + skip_artifact_materialization=request.skip_artifact_materialization, + ) + + # Map outputs using fast (in-memory) or slow (artifact) path + mapped_outputs = self._map_outputs(captured_outputs) + + return self._build_response( + placeholder_run=placeholder_run, + mapped_outputs=mapped_outputs, + start_time=start_time, + resolved_params=parameters, + ) + + except Exception as e: + logger.error(f"โŒ Pipeline execution failed: {e}") + return self._build_response( + placeholder_run=placeholder_run, + mapped_outputs=None, + start_time=start_time, + resolved_params=parameters, + error=e, + ) + + def get_service_info(self) -> ServiceInfo: + """Get service information. + + Returns: + A dictionary containing service information. + """ + uptime = time.time() - self.service_start_time + return ServiceInfo( + deployment=DeploymentInfo( + id=self.deployment.id, + name=self.deployment.name, + ), + snapshot=SnapshotInfo( + id=self.snapshot.id, + name=self.snapshot.name, + ), + pipeline=PipelineInfo( + name=self.snapshot.pipeline_configuration.name, + parameters=self.snapshot.pipeline_spec.parameters + if self.snapshot.pipeline_spec + else None, + input_schema=self.input_schema, + output_schema=self.output_schema, + ), + total_executions=self.total_executions, + last_execution_time=self.last_execution_time, + status="healthy", + uptime=uptime, + ) + + def get_execution_metrics(self) -> ExecutionMetrics: + """Return lightweight execution metrics for observability. + + Returns: + Aggregated execution metrics. + """ + return ExecutionMetrics( + total_executions=self.total_executions, + last_execution_time=self.last_execution_time, + ) + + def is_healthy(self) -> bool: + """Check service health. + + Returns: + True if the service is healthy, otherwise False. + """ + return True + + def _map_outputs( + self, + runtime_outputs: Optional[Dict[str, Dict[str, Any]]] = None, + ) -> Dict[str, Any]: + """Map pipeline outputs using centralized runtime processing. + + Args: + runtime_outputs: Optional in-memory outputs captured from runtime. + + Returns: + A dictionary containing outputs and any warnings from filtering. + """ + filtered_outputs = {} + if runtime_outputs and self.snapshot.pipeline_spec: + # Filter outputs based on pipeline schema (raises RuntimeError if missing) + output_mappings = self.snapshot.pipeline_spec.outputs + + unique_step_output_mapping = get_unique_step_output_names( + {(o.step_name, o.output_name): o for o in output_mappings} + ) + + for output_mapping in output_mappings: + unique_step_output_name = unique_step_output_mapping[ + ( + output_mapping.step_name, + output_mapping.output_name, + ) + ][1] + if output_mapping.step_name in runtime_outputs.keys(): + filtered_outputs[unique_step_output_name] = ( + runtime_outputs[output_mapping.step_name].get( + output_mapping.output_name, None + ) + ) + else: + logger.warning( + f"Output {output_mapping.output_name} not found in " + f"runtime outputs for step {output_mapping.step_name}" + ) + filtered_outputs[unique_step_output_name] = None + else: + logger.debug("No output mappings found, returning empty outputs") + + return filtered_outputs + + def _prepare_execute_with_orchestrator( + self, + resolved_params: Dict[str, Any], + ) -> Tuple[PipelineRunResponse, PipelineSnapshotResponse]: + """Prepare the execution with the orchestrator. + + Args: + resolved_params: The resolved parameters. + + Returns: + A tuple of (placeholder_run, deployment_snapshot). + """ + # Create a new snapshot with deployment-specific parameters and settings + from zenml.orchestrators.utils import ( + deployment_snapshot_request_from_source_snapshot, + ) + + deployment_snapshot_request = ( + deployment_snapshot_request_from_source_snapshot( + source_snapshot=self.snapshot, + deployment_parameters=resolved_params, + ) + ) + + # Create the new snapshot in the store + deployment_snapshot = self._client.zen_store.create_snapshot( + deployment_snapshot_request + ) + + # Create a placeholder run using the new deployment snapshot + placeholder_run = run_utils.create_placeholder_run( + snapshot=deployment_snapshot, + logs=None, + trigger_info=PipelineRunTriggerInfo( + deployment_id=self.deployment.id, + ), + ) + + return placeholder_run, deployment_snapshot + + def _execute_with_orchestrator( + self, + placeholder_run: PipelineRunResponse, + deployment_snapshot: PipelineSnapshotResponse, + resolved_params: Dict[str, Any], + skip_artifact_materialization: bool, + ) -> Optional[Dict[str, Dict[str, Any]]]: + """Run the snapshot via the orchestrator and return the concrete run. + + Args: + placeholder_run: The placeholder run to execute the pipeline on. + deployment_snapshot: The deployment snapshot to execute the pipeline + on. + resolved_params: Normalized pipeline parameters. + skip_artifact_materialization: Whether runtime should skip artifact + materialization. + + Returns: + The in-memory outputs of the execution. + + Raises: + RuntimeError: If the orchestrator has not been initialized. + RuntimeError: If the pipeline cannot be executed. + + """ + active_stack: Stack = self._client.active_stack + + if self._orchestrator is None: + raise RuntimeError("Orchestrator not initialized") + + # Start deployment runtime context with parameters (still needed for + # in-memory materializer) + runtime.start( + request_id=str(uuid4()), + snapshot=deployment_snapshot, + parameters=resolved_params, + skip_artifact_materialization=skip_artifact_materialization, + ) + + captured_outputs: Optional[Dict[str, Dict[str, Any]]] = None + try: + # Use the new deployment snapshot with pre-configured settings + self._orchestrator.run( + snapshot=deployment_snapshot, + stack=active_stack, + placeholder_run=placeholder_run, + ) + + # Capture in-memory outputs before stopping the runtime context + if runtime.is_active(): + captured_outputs = runtime.get_outputs() + except Exception as e: + logger.exception(f"Failed to execute pipeline: {e}") + raise RuntimeError(f"Failed to execute pipeline: {e}") + finally: + # Always stop deployment runtime context + runtime.stop() + + return captured_outputs + + def _execute_init_hook(self) -> None: + """Execute init hook if present. + + Raises: + Exception: If executing the hook fails. + """ + init_hook_source = ( + self.snapshot.pipeline_configuration.init_hook_source + ) + init_hook_kwargs = ( + self.snapshot.pipeline_configuration.init_hook_kwargs + ) + + if not init_hook_source: + return + + logger.info("Executing pipeline's init hook...") + try: + with env_utils.temporary_environment( + self.snapshot.pipeline_configuration.environment + ): + self.pipeline_state = load_and_run_hook( + init_hook_source, init_hook_kwargs + ) + except Exception as e: + logger.exception(f"Failed to execute init hook: {e}") + raise + + def _log_initialization_success(self) -> None: + """Log successful initialization.""" + pipeline_name = self.snapshot.pipeline_configuration.name + step_count = len(self.snapshot.step_configurations) + stack_name = ( + self.snapshot.stack.name if self.snapshot.stack else "unknown" + ) + + logger.info("โœ… Service initialized successfully:") + logger.info(f" Deployment: {self.deployment.name}") + logger.info(f" Pipeline: {pipeline_name}") + logger.info(f" Steps: {step_count}") + logger.info(f" Stack: {stack_name}") + + def _build_response( + self, + resolved_params: Dict[str, Any], + start_time: float, + mapped_outputs: Optional[Dict[str, Any]] = None, + placeholder_run: Optional[PipelineRunResponse] = None, + error: Optional[Exception] = None, + ) -> BaseDeploymentInvocationResponse: + """Build success response with execution tracking. + + Args: + resolved_params: The resolved parameters. + start_time: The start time of the execution. + mapped_outputs: The mapped outputs. + placeholder_run: The placeholder run that was executed. + error: The error that occurred. + + Returns: + A BaseDeploymentInvocationResponse describing the execution. + """ + execution_time = time.time() - start_time + self.total_executions += 1 + self.last_execution_time = datetime.now(timezone.utc) + + run: Optional[PipelineRunResponse] = placeholder_run + if placeholder_run: + try: + # Fetch the concrete run via its id + run = self._client.get_pipeline_run( + name_id_or_prefix=placeholder_run.id, + hydrate=True, + include_full_metadata=True, + ) + except Exception: + logger.exception( + f"Failed to fetch concrete run: {placeholder_run.id}" + ) + run = placeholder_run + + return BaseDeploymentInvocationResponse( + success=(error is None), + outputs=mapped_outputs, + error=str(error) if error else None, + execution_time=execution_time, + metadata=DeploymentInvocationResponseMetadata( + deployment_id=self.deployment.id, + deployment_name=self.deployment.name, + pipeline_name=self.snapshot.pipeline_configuration.name, + run_id=run.id if run else None, + run_name=run.name if run else None, + parameters_used=resolved_params, + snapshot_id=self.snapshot.id, + snapshot_name=self.snapshot.name, + ), + ) + + # ---------- + # Schemas and models for OpenAPI enrichment + # ---------- + + @property + def input_model( + self, + ) -> Type[BaseModel]: + """Construct a Pydantic model representing pipeline input parameters. + + Load the pipeline class from `pipeline_spec.source` and derive the + entrypoint signature types to create a dynamic Pydantic model + (extra='forbid') to use for parameter validation. + + Returns: + A Pydantic `BaseModel` subclass that validates the pipeline input + parameters. + + Raises: + RuntimeError: If the pipeline class cannot be loaded or if no + parameters model can be constructed for the pipeline. + """ + if ( + not self.snapshot.pipeline_spec + or not self.snapshot.pipeline_spec.source + ): + raise RuntimeError( + f"Snapshot `{self.snapshot.id}` is missing a " + "pipeline_spec.source; cannot build input model." + ) + + try: + pipeline_class: Pipeline = source_utils.load( + self.snapshot.pipeline_spec.source + ) + except Exception as e: + raise RuntimeError( + "Failed to load pipeline class from snapshot" + ) from e + + model = pipeline_class._compute_input_model() + if not model: + raise RuntimeError( + f"Failed to construct input model from pipeline " + f"`{self.snapshot.pipeline_configuration.name}`." + ) + return model + + @property + def input_schema(self) -> Dict[str, Any]: + """Return the JSON schema for pipeline input parameters. + + Returns: + The JSON schema for pipeline parameters. + + Raises: + RuntimeError: If the pipeline input schema is not available. + """ + if ( + self.snapshot.pipeline_spec + and self.snapshot.pipeline_spec.input_schema + ): + return self.snapshot.pipeline_spec.input_schema + # This should never happen, given that we check for this in the + # base deployer. + raise RuntimeError("The pipeline input schema is not available.") + + @property + def output_schema(self) -> Dict[str, Any]: + """Return the JSON schema for the pipeline outputs. + + Returns: + The JSON schema for the pipeline outputs. + + Raises: + RuntimeError: If the pipeline output schema is not available. + """ + if ( + self.snapshot.pipeline_spec + and self.snapshot.pipeline_spec.output_schema + ): + return self.snapshot.pipeline_spec.output_schema + # This should never happen, given that we check for this in the + # base deployer. + raise RuntimeError("The pipeline output schema is not available.") diff --git a/src/zenml/deployers/utils.py b/src/zenml/deployers/utils.py new file mode 100644 index 00000000000..c7055b55926 --- /dev/null +++ b/src/zenml/deployers/utils.py @@ -0,0 +1,276 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""ZenML deployers utilities.""" + +import json +from typing import Any, Dict, Optional, Union +from uuid import UUID + +import jsonref +import requests + +from zenml.client import Client +from zenml.deployers.exceptions import ( + DeploymentHTTPError, + DeploymentNotFoundError, + DeploymentProvisionError, +) +from zenml.enums import DeploymentStatus +from zenml.models import DeploymentResponse +from zenml.steps.step_context import get_step_context +from zenml.utils.json_utils import pydantic_encoder + + +def get_deployment_input_schema( + deployment: DeploymentResponse, +) -> Dict[str, Any]: + """Get the schema for a deployment's input parameters. + + Args: + deployment: The deployment for which to get the schema. + + Returns: + The schema for the deployment's input parameters. + + Raises: + RuntimeError: If the deployment has no associated input schema. + """ + if ( + deployment.snapshot + and deployment.snapshot.pipeline_spec + and deployment.snapshot.pipeline_spec.input_schema + ): + return deployment.snapshot.pipeline_spec.input_schema + + raise RuntimeError( + f"Deployment {deployment.name} has no associated input schema." + ) + + +def get_deployment_output_schema( + deployment: DeploymentResponse, +) -> Dict[str, Any]: + """Get the schema for a deployment's output parameters. + + Args: + deployment: The deployment for which to get the schema. + + Returns: + The schema for the deployment's output parameters. + + Raises: + RuntimeError: If the deployment has no associated output schema. + """ + if ( + deployment.snapshot + and deployment.snapshot.pipeline_spec + and deployment.snapshot.pipeline_spec.output_schema + ): + return deployment.snapshot.pipeline_spec.output_schema + + raise RuntimeError( + f"Deployment {deployment.name} has no associated output schema." + ) + + +def get_deployment_invocation_example( + deployment: DeploymentResponse, +) -> Dict[str, Any]: + """Generate an example invocation command for a deployment. + + Args: + deployment: The deployment for which to generate an example invocation. + + Returns: + A dictionary containing the example invocation parameters. + """ + parameters_schema = get_deployment_input_schema(deployment) + + properties = parameters_schema.get("properties", {}) + + if not properties: + return {} + + parameters = {} + + for attr_name, attr_schema in properties.items(): + parameters[attr_name] = "" + if not isinstance(attr_schema, dict): + continue + + default_value = None + + if "default" in attr_schema: + default_value = attr_schema["default"] + elif "const" in attr_schema: + default_value = attr_schema["const"] + + parameters[attr_name] = default_value or "" + + return parameters + + +def invoke_deployment( + deployment_name_or_id: Union[str, UUID], + project: Optional[UUID] = None, + timeout: int = 300, # 5 minute timeout + **kwargs: Any, +) -> Any: + """Call a deployment and return the result. + + Args: + deployment_name_or_id: The name or ID of the deployment to call. + project: The project ID of the deployment to call. + timeout: The timeout for the HTTP request to the deployment. + **kwargs: Keyword arguments to pass to the deployment. + + Returns: + The response from the deployment, parsed as JSON if possible, + otherwise returned as text. + + Raises: + DeploymentNotFoundError: If the deployment is not found. + DeploymentProvisionError: If the deployment is not running + or has no URL. + DeploymentHTTPError: If the HTTP request to the endpoint fails. + """ + client = Client() + try: + deployment = client.get_deployment( + deployment_name_or_id, project=project + ) + except KeyError: + raise DeploymentNotFoundError( + f"Deployment with name or ID '{deployment_name_or_id}' not found" + ) + + if deployment.status != DeploymentStatus.RUNNING: + raise DeploymentProvisionError( + f"Deployment {deployment_name_or_id} is not running. Please " + "refresh or re-deploy the deployment or check its logs for " + "more details." + ) + + if not deployment.url: + raise DeploymentProvisionError( + f"Deployment {deployment_name_or_id} has no URL. Please " + "refresh the deployment or check its logs for more " + "details." + ) + + input_schema = None + if deployment.snapshot and deployment.snapshot.pipeline_spec: + input_schema = deployment.snapshot.pipeline_spec.input_schema + + if input_schema: + # Resolve the references in the schema first, otherwise we won't be able + # to access the data types for object-typed parameters. + input_schema = jsonref.replace_refs(input_schema) + assert isinstance(input_schema, dict) + + properties = input_schema.get("properties", {}) + + # Some kwargs having one of the collection data types (list, dict) in + # the schema may be supplied as a JSON string. We need to unpack + # them before we construct the final JSON payload. + # + # We ignore all errors here because they will be better handled by the + # deployment itself server side. + for key in kwargs.keys(): + if key not in properties: + continue + value = kwargs[key] + if not isinstance(value, str): + continue + attr_schema = properties[key] + try: + if attr_schema.get("type") == "object": + value = json.loads(value) + if isinstance(value, dict): + kwargs[key] = value + elif attr_schema.get("type") == "array": + value = json.loads(value) + if isinstance(value, list): + kwargs[key] = value + except (json.JSONDecodeError, ValueError): + pass + + # Serialize kwargs to JSON + params = dict(parameters=kwargs) + try: + payload = json.dumps(params, default=pydantic_encoder) + except (TypeError, ValueError) as e: + raise DeploymentHTTPError( + f"Failed to serialize request data to JSON: {e}" + ) + + # Construct the invoke endpoint URL + invoke_url = deployment.url.rstrip("/") + "/invoke" + + # Prepare headers + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + } + + # Add authorization header if auth_key is present + if deployment.auth_key: + headers["Authorization"] = f"Bearer {deployment.auth_key}" + + try: + step_context = get_step_context() + except RuntimeError: + step_context = None + + if step_context: + # Include these so that the deployment can identify the step + # and pipeline run that called it, if called from a step. + headers["ZenML-Step-Name"] = step_context.step_name + headers["ZenML-Pipeline-Name"] = step_context.pipeline.name + headers["ZenML-Pipeline-Run-ID"] = str(step_context.pipeline_run.id) + headers["ZenML-Pipeline-Run-Name"] = step_context.pipeline_run.name + + # Make the HTTP request + try: + response = requests.post( + invoke_url, + data=payload, + headers=headers, + timeout=timeout, + ) + response.raise_for_status() + + # Try to parse JSON response, fallback to text if not JSON + try: + return response.json() + except ValueError: + return response.text + + except requests.exceptions.HTTPError as e: + raise DeploymentHTTPError( + f"HTTP {e.response.status_code} error calling deployment " + f"{deployment_name_or_id}: {e.response.text}" + ) + except requests.exceptions.ConnectionError as e: + raise DeploymentHTTPError( + f"Failed to connect to deployment {deployment_name_or_id}: {e}" + ) + except requests.exceptions.Timeout as e: + raise DeploymentHTTPError( + f"Timeout calling deployment {deployment_name_or_id}: {e}" + ) + except requests.exceptions.RequestException as e: + raise DeploymentHTTPError( + f"Request failed for deployment {deployment_name_or_id}: {e}" + ) diff --git a/src/zenml/entrypoints/base_entrypoint_configuration.py b/src/zenml/entrypoints/base_entrypoint_configuration.py index ef132e53d4a..c8eb1a2d396 100644 --- a/src/zenml/entrypoints/base_entrypoint_configuration.py +++ b/src/zenml/entrypoints/base_entrypoint_configuration.py @@ -119,22 +119,28 @@ def get_entrypoint_arguments( Raises: ValueError: If no valid snapshot ID is passed. """ - snapshot_id = kwargs.get(SNAPSHOT_ID_OPTION) - if not uuid_utils.is_valid_uuid(snapshot_id): - raise ValueError( - f"Missing or invalid snapshot ID as argument for entrypoint " - f"configuration. Please make sure to pass a valid UUID to " - f"`{cls.__name__}.{cls.get_entrypoint_arguments.__name__}" - f"({SNAPSHOT_ID_OPTION}=)`." - ) - arguments = [ f"--{ENTRYPOINT_CONFIG_SOURCE_OPTION}", source_utils.resolve(cls).import_path, - f"--{SNAPSHOT_ID_OPTION}", - str(snapshot_id), ] + if SNAPSHOT_ID_OPTION in cls.get_entrypoint_options(): + snapshot_id = kwargs.get(SNAPSHOT_ID_OPTION) + if not uuid_utils.is_valid_uuid(snapshot_id): + raise ValueError( + f"Missing or invalid snapshot ID as argument for entrypoint " + f"configuration. Please make sure to pass a valid UUID to " + f"`{cls.__name__}.{cls.get_entrypoint_arguments.__name__}" + f"({SNAPSHOT_ID_OPTION}=)`." + ) + + arguments.extend( + [ + f"--{SNAPSHOT_ID_OPTION}", + str(snapshot_id), + ] + ) + return arguments @classmethod diff --git a/src/zenml/enums.py b/src/zenml/enums.py index eed6f8adcb6..93edbba278f 100644 --- a/src/zenml/enums.py +++ b/src/zenml/enums.py @@ -161,6 +161,7 @@ class StackComponentType(StrEnum): ORCHESTRATOR = "orchestrator" STEP_OPERATOR = "step_operator" MODEL_REGISTRY = "model_registry" + DEPLOYER = "deployer" @property def plural(self) -> str: @@ -396,6 +397,7 @@ class TaggableResourceTypes(StrEnum): PIPELINE_RUN = "pipeline_run" RUN_TEMPLATE = "run_template" PIPELINE_SNAPSHOT = "pipeline_snapshot" + DEPLOYMENT = "deployment" class ResponseUpdateStrategy(StrEnum): @@ -493,7 +495,18 @@ class ServiceState(StrEnum): SCALED_TO_ZERO = "scaled_to_zero" +class DeploymentStatus(StrEnum): + """Status of a deployment.""" + + UNKNOWN = "unknown" + PENDING = "pending" + RUNNING = "running" + ABSENT = "absent" + ERROR = "error" + + class PipelineRunTriggeredByType(StrEnum): """All possible types that can trigger a pipeline run.""" STEP_RUN = "step_run" + DEPLOYMENT = "deployment" diff --git a/src/zenml/exceptions.py b/src/zenml/exceptions.py index 1c077d2ddf8..9988d0cef7a 100644 --- a/src/zenml/exceptions.py +++ b/src/zenml/exceptions.py @@ -244,3 +244,11 @@ def __init__( monitoring the pipeline run. """ self.original_exception = original_exception + + +class HookValidationException(ZenMLBaseException): + """Exception raised when hook validation fails.""" + + +class HookExecutionException(ZenMLBaseException): + """Exception raised when hook execution fails.""" diff --git a/src/zenml/hooks/hook_validators.py b/src/zenml/hooks/hook_validators.py index 50f091e74cb..cdf7d5f55b6 100644 --- a/src/zenml/hooks/hook_validators.py +++ b/src/zenml/hooks/hook_validators.py @@ -13,28 +13,110 @@ # permissions and limitations under the License. """Validation functions for hooks.""" -import inspect -from typing import TYPE_CHECKING +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Optional, + Tuple, + Union, +) + +from pydantic import ConfigDict, ValidationError from zenml.config.source import Source +from zenml.exceptions import HookValidationException +from zenml.logger import get_logger from zenml.utils import source_utils +from zenml.utils.pydantic_utils import validate_function_args + +logger = get_logger(__name__) + if TYPE_CHECKING: - from zenml.types import HookSpecification + from zenml.types import HookSpecification, InitHookSpecification -def resolve_and_validate_hook(hook: "HookSpecification") -> Source: - """Resolves and validates a hook callback. +def _validate_hook_arguments( + _func: Callable[..., Any], + hook_kwargs: Dict[str, Any], + exception_arg: Union[BaseException, bool] = False, +) -> Dict[str, Any]: + """Validates hook arguments. + + Args: + _func: The hook function to validate. + hook_kwargs: The hook keyword arguments to validate. + exception_arg: The exception argument to validate. + + Returns: + The validated hook arguments. + + Raises: + HookValidationException: If the hook arguments are not valid. + """ + # Validate hook arguments + try: + hook_args: Tuple[Any, ...] = () + if isinstance(exception_arg, BaseException): + hook_args = (exception_arg,) + elif exception_arg is True: + hook_args = (Exception(),) + config = ConfigDict(arbitrary_types_allowed=len(hook_args) > 0) + validated_kwargs = validate_function_args( + _func, config, *hook_args, **hook_kwargs + ) + except (ValidationError, TypeError) as e: + exc_msg = ( + "Failed to validate hook arguments for {func}: {e}\n" + "Please observe the following guidelines:\n" + "- the success hook takes no arguments\n" + "- the failure hook optionally takes a single `BaseException` " + "typed argument\n" + "- the init hook takes any number of JSON-safe arguments\n" + "- the cleanup hook takes no arguments\n" + ) + + if not hook_args: + raise HookValidationException(exc_msg.format(func=_func, e=e)) + + # If we have an exception argument, we try again without it. This is + # to account for the case where the hook function does not expect an + # exception argument. + hook_args = () + config = ConfigDict(arbitrary_types_allowed=False) + try: + validated_kwargs = validate_function_args( + _func, config, *hook_args, **hook_kwargs + ) + except (ValidationError, TypeError) as e: + raise HookValidationException(exc_msg.format(func=_func, e=e)) + + return validated_kwargs + + +def resolve_and_validate_hook( + hook: Union["HookSpecification", "InitHookSpecification"], + hook_kwargs: Optional[Dict[str, Any]] = None, + allow_exception_arg: bool = False, +) -> Tuple[Source, Optional[Dict[str, Any]]]: + """Resolves and validates a hook callback and its arguments. Args: hook: Hook function or source. + hook_kwargs: The arguments to pass to the hook. + allow_exception_arg: Whether to allow an implicit exception argument + to be passed to the hook. Returns: - Hook source. + Tuple of hook source and validated hook arguments converted to JSON-safe + values. Raises: ValueError: If `hook_func` is not a valid callable. """ + # Resolve the hook function if isinstance(hook, (str, Source)): func = source_utils.load(hook) else: @@ -43,34 +125,65 @@ def resolve_and_validate_hook(hook: "HookSpecification") -> Source: if not callable(func): raise ValueError(f"{func} is not a valid function.") - sig = inspect.getfullargspec(inspect.unwrap(func)) - sig_annotations = sig.annotations - if "return" in sig_annotations: - sig_annotations.pop("return") + # Validate hook arguments + validated_kwargs = _validate_hook_arguments( + func, hook_kwargs or {}, allow_exception_arg + ) + + return source_utils.resolve(func), validated_kwargs - if sig.args and len(sig.args) != len(sig_annotations): - raise ValueError( - "You can only pass arguments to a hook that are annotated with a " - "`BaseException` type." + +def load_and_run_hook( + hook_source: "Source", + hook_parameters: Optional[Dict[str, Any]] = None, + step_exception: Optional[BaseException] = None, + raise_on_error: bool = False, +) -> Any: + """Loads hook source and runs the hook. + + Args: + hook_source: The source of the hook function. + hook_parameters: The parameters of the hook function. + step_exception: The exception of the original step. + raise_on_error: Whether to raise an error if the hook fails. + + Returns: + The return value of the hook function. + + Raises: + HookValidationException: If hook validation fails. + RuntimeError: If the hook fails and raise_on_error is True. + """ + try: + hook = source_utils.load(hook_source) + except Exception as e: + msg = f"Failed to load hook source '{hook_source}' with exception: {e}" + if raise_on_error: + raise RuntimeError(msg) from e + else: + logger.error(msg) + return None + try: + validated_kwargs = _validate_hook_arguments( + hook, hook_parameters or {}, step_exception or False ) + except HookValidationException as e: + if raise_on_error: + raise + else: + logger.error(e) + return None - if sig_annotations: - annotations = sig_annotations.values() - seen_annotations = set() - for annotation in annotations: - if annotation: - if annotation not in (BaseException,): - raise ValueError( - "Hook arguments must be of type `BaseException`, not " - f"`{annotation}`." - ) - - if annotation in seen_annotations: - raise ValueError( - "You can only pass one `BaseException` type to a hook." - "Currently your function has the following" - f"annotations: {sig_annotations}" - ) - seen_annotations.add(annotation) - - return source_utils.resolve(func) + try: + logger.debug(f"Running hook {hook} with params: {validated_kwargs}") + return hook(**validated_kwargs) + except Exception as e: + msg = ( + f"Failed to run hook '{hook_source}' with params: " + f"{validated_kwargs} with exception: '{e}'" + ) + if raise_on_error: + raise RuntimeError(msg) from e + else: + logger.error(msg) + return None diff --git a/src/zenml/integrations/aws/__init__.py b/src/zenml/integrations/aws/__init__.py index d6cf5866fa0..28cd8a82a85 100644 --- a/src/zenml/integrations/aws/__init__.py +++ b/src/zenml/integrations/aws/__init__.py @@ -28,6 +28,7 @@ AWS_CONTAINER_REGISTRY_FLAVOR = "aws" AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR = "sagemaker" AWS_SAGEMAKER_ORCHESTRATOR_FLAVOR = "sagemaker" +AWS_DEPLOYER_FLAVOR = "aws" # Service connector constants AWS_CONNECTOR_TYPE = "aws" @@ -61,6 +62,7 @@ def flavors(cls) -> List[Type[Flavor]]: """ from zenml.integrations.aws.flavors import ( AWSContainerRegistryFlavor, + AWSDeployerFlavor, AWSImageBuilderFlavor, SagemakerOrchestratorFlavor, SagemakerStepOperatorFlavor, @@ -68,6 +70,7 @@ def flavors(cls) -> List[Type[Flavor]]: return [ AWSContainerRegistryFlavor, + AWSDeployerFlavor, AWSImageBuilderFlavor, SagemakerStepOperatorFlavor, SagemakerOrchestratorFlavor, diff --git a/src/zenml/integrations/aws/deployers/__init__.py b/src/zenml/integrations/aws/deployers/__init__.py new file mode 100644 index 00000000000..8feb9ad45c9 --- /dev/null +++ b/src/zenml/integrations/aws/deployers/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Initialization for the AWS deployer.""" + +from zenml.integrations.aws.deployers.aws_deployer import ( # noqa + AWSDeployer, +) + +__all__ = ["AWSDeployer"] diff --git a/src/zenml/integrations/aws/deployers/aws_deployer.py b/src/zenml/integrations/aws/deployers/aws_deployer.py new file mode 100644 index 00000000000..84365ef2fa3 --- /dev/null +++ b/src/zenml/integrations/aws/deployers/aws_deployer.py @@ -0,0 +1,1777 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Implementation of the AWS App Runner deployer.""" + +import datetime +import json +import re +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generator, + List, + Optional, + Tuple, + Type, + cast, +) +from uuid import UUID + +import boto3 +from botocore.exceptions import BotoCoreError, ClientError +from pydantic import BaseModel + +from zenml.config.base_settings import BaseSettings +from zenml.config.resource_settings import ResourceSettings +from zenml.deployers.containerized_deployer import ContainerizedDeployer +from zenml.deployers.exceptions import ( + DeployerError, + DeploymentDeprovisionError, + DeploymentLogsNotFoundError, + DeploymentNotFoundError, + DeploymentProvisionError, +) +from zenml.deployers.server.entrypoint_configuration import ( + AUTH_KEY_OPTION, + DEPLOYMENT_ID_OPTION, + PORT_OPTION, + DeploymentEntrypointConfiguration, +) +from zenml.enums import DeploymentStatus, StackComponentType +from zenml.integrations.aws.flavors.aws_deployer_flavor import ( + AWSDeployerConfig, + AWSDeployerSettings, +) +from zenml.logger import get_logger +from zenml.models import ( + DeploymentOperationalState, + DeploymentResponse, +) +from zenml.stack import StackValidator + +if TYPE_CHECKING: + from zenml.stack import Stack + +logger = get_logger(__name__) + +# Default resource and scaling configuration constants +DEFAULT_CPU = 0.25 # vCPU +DEFAULT_MEMORY = 0.5 # GB +DEFAULT_MIN_REPLICAS = 1 +DEFAULT_MAX_REPLICAS = 25 +DEFAULT_MAX_CONCURRENCY = 100 + +# AWS App Runner built-in limits +AWS_APP_RUNNER_MAX_SIZE = 1000 +AWS_APP_RUNNER_MAX_CONCURRENCY = 1000 + + +class AppRunnerDeploymentMetadata(BaseModel): + """Metadata for an App Runner deployment.""" + + service_name: Optional[str] = None + service_arn: Optional[str] = None + service_url: Optional[str] = None + region: Optional[str] = None + service_id: Optional[str] = None + status: Optional[str] = None + source_configuration: Optional[Dict[str, Any]] = None + instance_configuration: Optional[Dict[str, Any]] = None + auto_scaling_configuration_summary: Optional[Dict[str, Any]] = None + auto_scaling_configuration_arn: Optional[str] = None + health_check_configuration: Optional[Dict[str, Any]] = None + network_configuration: Optional[Dict[str, Any]] = None + observability_configuration: Optional[Dict[str, Any]] = None + encryption_configuration: Optional[Dict[str, Any]] = None + cpu: Optional[str] = None + memory: Optional[str] = None + port: Optional[int] = None + auto_scaling_max_concurrency: Optional[int] = None + auto_scaling_max_size: Optional[int] = None + auto_scaling_min_size: Optional[int] = None + is_publicly_accessible: Optional[bool] = None + health_check_grace_period_seconds: Optional[int] = None + health_check_interval_seconds: Optional[int] = None + health_check_path: Optional[str] = None + health_check_protocol: Optional[str] = None + health_check_timeout_seconds: Optional[int] = None + health_check_healthy_threshold: Optional[int] = None + health_check_unhealthy_threshold: Optional[int] = None + tags: Optional[Dict[str, str]] = None + traffic_allocation: Optional[Dict[str, int]] = None + created_at: Optional[str] = None + updated_at: Optional[str] = None + deleted_at: Optional[str] = None + secret_arn: Optional[str] = None + + @classmethod + def from_app_runner_service( + cls, + service: Dict[str, Any], + region: str, + secret_arn: Optional[str] = None, + ) -> "AppRunnerDeploymentMetadata": + """Create metadata from an App Runner service. + + Args: + service: The App Runner service dictionary from describe_service. + region: The AWS region. + secret_arn: The AWS Secrets Manager secret ARN for the deployment. + + Returns: + The metadata for the App Runner service. + """ + instance_config = service.get("InstanceConfiguration", {}) + cpu = instance_config.get("Cpu") + memory = instance_config.get("Memory") + + auto_scaling_config = service.get( + "AutoScalingConfigurationSummary", {} + ) + auto_scaling_configuration_arn = auto_scaling_config.get( + "AutoScalingConfigurationArn" + ) + auto_scaling_max_concurrency = auto_scaling_config.get( + "MaxConcurrency" + ) + auto_scaling_max_size = auto_scaling_config.get("MaxSize") + auto_scaling_min_size = auto_scaling_config.get("MinSize") + + health_check_config = service.get("HealthCheckConfiguration", {}) + health_check_grace_period = health_check_config.get( + "HealthCheckGracePeriodSeconds" + ) + health_check_interval = health_check_config.get("Interval") + health_check_path = health_check_config.get("Path") + health_check_protocol = health_check_config.get("Protocol") + health_check_timeout = health_check_config.get("Timeout") + health_check_healthy_threshold = health_check_config.get( + "HealthyThreshold" + ) + health_check_unhealthy_threshold = health_check_config.get( + "UnhealthyThreshold" + ) + + network_config = service.get("NetworkConfiguration", {}) + is_publicly_accessible = network_config.get( + "IngressConfiguration", {} + ).get("IsPubliclyAccessible") + + source_config = service.get("SourceConfiguration", {}) + image_repo = source_config.get("ImageRepository", {}) + image_config = image_repo.get("ImageConfiguration", {}) + + port = None + if image_config: + port = image_config.get("Port") + + traffic_allocation = {} + traffic_config = service.get("TrafficConfiguration", []) + for traffic in traffic_config: + if traffic.get("Type") == "LATEST": + traffic_allocation["LATEST"] = traffic.get("Percent", 0) + elif traffic.get("Revision"): + traffic_allocation[traffic["Revision"]] = traffic.get( + "Percent", 0 + ) + elif traffic.get("Tag"): + traffic_allocation[f"tag:{traffic['Tag']}"] = traffic.get( + "Percent", 0 + ) + + # Extract timestamps + created_at = service.get("CreatedAt") + updated_at = service.get("UpdatedAt") + deleted_at = service.get("DeletedAt") + + return cls( + service_name=service.get("ServiceName"), + service_arn=service.get("ServiceArn"), + service_url=service.get("ServiceUrl"), + region=region, + service_id=service.get("ServiceId"), + status=service.get("Status"), + source_configuration=source_config, + instance_configuration=instance_config, + auto_scaling_configuration_summary=auto_scaling_config, + auto_scaling_configuration_arn=auto_scaling_configuration_arn, + health_check_configuration=health_check_config, + network_configuration=network_config, + observability_configuration=service.get( + "ObservabilityConfiguration" + ), + encryption_configuration=service.get("EncryptionConfiguration"), + cpu=cpu, + memory=memory, + port=port, + auto_scaling_max_concurrency=auto_scaling_max_concurrency, + auto_scaling_max_size=auto_scaling_max_size, + auto_scaling_min_size=auto_scaling_min_size, + is_publicly_accessible=is_publicly_accessible, + health_check_grace_period_seconds=health_check_grace_period, + health_check_interval_seconds=health_check_interval, + health_check_path=health_check_path, + health_check_protocol=health_check_protocol, + health_check_timeout_seconds=health_check_timeout, + health_check_healthy_threshold=health_check_healthy_threshold, + health_check_unhealthy_threshold=health_check_unhealthy_threshold, + tags=dict(service.get("Tags", {})), + traffic_allocation=traffic_allocation + if traffic_allocation + else None, + created_at=created_at.isoformat() if created_at else None, + updated_at=updated_at.isoformat() if updated_at else None, + deleted_at=deleted_at.isoformat() if deleted_at else None, + secret_arn=secret_arn, + ) + + @classmethod + def from_deployment( + cls, deployment: DeploymentResponse + ) -> "AppRunnerDeploymentMetadata": + """Create metadata from a deployment. + + Args: + deployment: The deployment to get the metadata for. + + Returns: + The metadata for the deployment. + """ + return cls.model_validate(deployment.deployment_metadata) + + +class AWSDeployer(ContainerizedDeployer): + """Deployer responsible for deploying pipelines on AWS App Runner.""" + + CONTAINER_REQUIREMENTS: List[str] = ["uvicorn", "fastapi"] + + _boto_session: Optional[boto3.Session] = None + _region: Optional[str] = None + _app_runner_client: Optional[Any] = None + _secrets_manager_client: Optional[Any] = None + _logs_client: Optional[Any] = None + + @property + def config(self) -> AWSDeployerConfig: + """Returns the `AWSDeployerConfig` config. + + Returns: + The configuration. + """ + return cast(AWSDeployerConfig, self._config) + + @property + def settings_class(self) -> Optional[Type["BaseSettings"]]: + """Settings class for the AWS deployer. + + Returns: + The settings class. + """ + return AWSDeployerSettings + + @property + def validator(self) -> Optional[StackValidator]: + """Ensures there is an image builder in the stack. + + Returns: + A `StackValidator` instance. + """ + return StackValidator( + required_components={ + StackComponentType.IMAGE_BUILDER, + StackComponentType.CONTAINER_REGISTRY, + } + ) + + def _get_boto_session_and_region(self) -> Tuple[boto3.Session, str]: + """Get an authenticated boto3 session and determine the region. + + Returns: + A tuple containing the boto3 session and the AWS region. + + Raises: + RuntimeError: If the service connector returns an unexpected type. + """ + if ( + self._boto_session is not None + and self._region is not None + and not self.connector_has_expired() + ): + return self._boto_session, self._region + + # Option 1: Service connector + if connector := self.get_connector(): + boto_session = connector.connect() + if not isinstance(boto_session, boto3.Session): + raise RuntimeError( + f"Expected to receive a `boto3.Session` object from the " + f"linked connector, but got type `{type(boto_session)}`." + ) + + region = boto_session.region_name + if not region: + # Fallback to config region or default + region = self.config.region or "us-east-1" + logger.warning( + f"No region found in boto3 session, using {region}" + ) + + # Option 2: Implicit configuration + else: + boto_session = boto3.Session(region_name=self.config.region) + + self._boto_session = boto_session + self._region = region + return boto_session, region + + @property + def app_runner_client(self) -> Any: + """Get the App Runner client. + + Returns: + The App Runner client. + """ + if self._app_runner_client is None or self.connector_has_expired(): + session, region = self._get_boto_session_and_region() + self._app_runner_client = session.client( + "apprunner", region_name=region + ) + return self._app_runner_client + + @property + def secrets_manager_client(self) -> Any: + """Get the Secrets Manager client. + + Returns: + The Secrets Manager client. + """ + if ( + self._secrets_manager_client is None + or self.connector_has_expired() + ): + session, region = self._get_boto_session_and_region() + self._secrets_manager_client = session.client( + "secretsmanager", region_name=region + ) + return self._secrets_manager_client + + @property + def logs_client(self) -> Any: + """Get the CloudWatch Logs client. + + Returns: + The CloudWatch Logs client. + """ + if self._logs_client is None or self.connector_has_expired(): + session, region = self._get_boto_session_and_region() + self._logs_client = session.client("logs", region_name=region) + return self._logs_client + + @property + def region(self) -> str: + """Get the AWS region. + + Returns: + The AWS region. + """ + _, region = self._get_boto_session_and_region() + return region + + def get_tags( + self, + deployment: DeploymentResponse, + settings: AWSDeployerSettings, + ) -> List[Dict[str, str]]: + """Get the tags for a deployment to be used for AWS resources. + + Args: + deployment: The deployment. + settings: The deployer settings. + + Returns: + The tags for the deployment. + """ + tags = { + **settings.tags, + "zenml-deployment-id": str(deployment.id), + "zenml-deployment-name": deployment.name, + "zenml-deployer-name": str(self.name), + "zenml-deployer-id": str(self.id), + "managed-by": "zenml", + } + + return [{"Key": k, "Value": v} for k, v in tags.items()] + + def _sanitize_name( + self, + name: str, + random_suffix: str, + max_length: int = 32, + extra_allowed_characters: str = "-_", + ) -> str: + """Sanitize a name to comply with AWS naming requirements. + + Common AWS naming requirements: + - Length: 4-max_length characters + - Characters: letters (a-z, A-Z), numbers (0-9), configured extra + allowed characters (e.g. dashes and underscores) + - Must start and end with a letter or number + - Cannot contain consecutive extra_allowed_characters + + Args: + name: The raw name to sanitize. + random_suffix: A random suffix to add to the name to ensure + uniqueness. + max_length: The maximum length of the name. + extra_allowed_characters: Extra allowed characters in the name. + + Returns: + A sanitized name that complies with AWS requirements. + + Raises: + RuntimeError: If the random suffix is invalid. + ValueError: If the service name is invalid. + """ + if ( + not re.match(r"^[a-zA-Z0-9]+$", random_suffix) + or len(random_suffix) < 4 + ): + raise RuntimeError( + f"Invalid random suffix: {random_suffix}. Must contain only " + "letters and numbers and be at least 4 characters long." + ) + + # Use the first extra allowed character as the separator + separator = extra_allowed_characters[0] + + # Replace all disallowed characters with the separator + sanitized = re.sub( + rf"[^a-zA-Z0-9{extra_allowed_characters}]", + separator, + name, + ) + + # Remove consecutive extra allowed characters + for char in extra_allowed_characters: + sanitized = re.sub( + rf"[{char}]+", + char, + sanitized, + ) + + # Remove leading and trailing extra allowed characters before truncating + sanitized = re.sub( + r"^[^a-zA-Z0-9]+|[^a-zA-Z0-9]+$", + "", + sanitized, + ) + + # Truncate to fit within max_length character limit including suffix + max_base_length = ( + max_length - len(random_suffix) - 1 + ) # -1 for the separator + if len(sanitized) > max_base_length: + sanitized = sanitized[:max_base_length] + + # Ensure it starts and ends with alphanumeric characters + sanitized = re.sub( + r"^[^a-zA-Z0-9]+|[^a-zA-Z0-9]+$", + "", + sanitized, + ) + + # Ensure we have at least one character after cleanup + if not sanitized: + raise ValueError( + f"Invalid name: {name}. Must contain at least one " + "alphanumeric character." + ) + + return f"{sanitized}{separator}{random_suffix}" + + def _get_service_name( + self, deployment_name: str, deployment_id: UUID, prefix: str + ) -> str: + """Get the App Runner service name for a deployment. + + Args: + deployment_name: The deployment name. + deployment_id: The deployment ID. + prefix: The prefix to use for the service name. + + Returns: + The App Runner service name that complies with all naming + requirements. + """ + # We use the first 8 characters of the deployment UUID as a random + # suffix to ensure uniqueness. + deployment_id_short = str(deployment_id)[:8] + raw_name = f"{prefix}{deployment_name}" + + return self._sanitize_name( + raw_name, + random_suffix=deployment_id_short, + max_length=40, + extra_allowed_characters="-_", + ) + + def _get_secret_name( + self, + deployment_name: str, + deployment_id: UUID, + prefix: str, + ) -> str: + """Get the Secrets Manager secret name for a deployment. + + Args: + deployment_name: The deployment name. + deployment_id: The deployment ID. + prefix: The prefix to use for the secret name. + + Returns: + The Secrets Manager secret name. + """ + # We use the first 8 characters of the deployment UUID as a random + # suffix to ensure uniqueness. + deployment_id_short = str(deployment_id)[:8] + raw_name = f"{prefix}{deployment_name}" + + return self._sanitize_name( + raw_name, + random_suffix=deployment_id_short, + max_length=512, + extra_allowed_characters="-_./", + ) + + def _create_or_update_secret( + self, + secret_name: str, + secret_value: str, + deployment: DeploymentResponse, + settings: AWSDeployerSettings, + ) -> str: + """Create or update a secret in Secrets Manager. + + Args: + secret_name: The name of the secret. + secret_value: The value to store. + deployment: The deployment. + settings: The deployer settings. + + Returns: + The secret ARN. + + Raises: + ClientError: If the secret cannot be updated. + DeployerError: If secret creation/update fails. + """ + try: + try: + response = self.secrets_manager_client.update_secret( + SecretId=secret_name, + SecretString=secret_value, + Tags=self.get_tags(deployment, settings), + ) + logger.debug(f"Updated existing secret {secret_name}") + return response["ARN"] # type: ignore[no-any-return] + except ClientError as e: + if e.response["Error"]["Code"] == "ResourceNotFoundException": + logger.debug(f"Creating new secret {secret_name}") + response = self.secrets_manager_client.create_secret( + Name=secret_name, + SecretString=secret_value, + Description=f"ZenML deployment secret for {deployment.name}", + Tags=self.get_tags(deployment, settings), + ) + logger.debug(f"Created new secret {secret_name}") + return response["ARN"] # type: ignore[no-any-return] + else: + raise + + except (ClientError, BotoCoreError) as e: + raise DeployerError( + f"Failed to create/update secret {secret_name}: {e}" + ) + + def _get_secret_arn(self, deployment: DeploymentResponse) -> Optional[str]: + """Get the existing AWS Secrets Manager secret ARN for a deployment. + + Args: + deployment: The deployment. + + Returns: + The existing AWS Secrets Manager secret ARN for the deployment, + or None if no secret exists. + """ + metadata = AppRunnerDeploymentMetadata.from_deployment(deployment) + + if not metadata.secret_arn: + return None + + try: + self.secrets_manager_client.describe_secret( + SecretId=metadata.secret_arn + ) + return metadata.secret_arn + except ClientError as e: + if e.response["Error"]["Code"] == "ResourceNotFoundException": + return None + logger.exception(f"Failed to verify secret {metadata.secret_arn}") + return None + + def _delete_secret(self, secret_arn: str) -> None: + """Delete a secret from Secrets Manager. + + Args: + secret_arn: The ARN of the secret to delete. + """ + try: + self.secrets_manager_client.delete_secret( + SecretId=secret_arn, + ForceDeleteWithoutRecovery=True, + ) + logger.debug(f"Deleted secret {secret_arn}") + except ClientError as e: + if e.response["Error"]["Code"] == "ResourceNotFoundException": + logger.debug( + f"Secret {secret_arn} not found, skipping deletion" + ) + else: + logger.exception(f"Failed to delete secret {secret_arn}") + + def _cleanup_deployment_secrets( + self, + deployment: DeploymentResponse, + ) -> None: + """Clean up the secret associated with a deployment. + + Args: + deployment: The deployment. + """ + secret_arn = self._get_secret_arn(deployment) + + if secret_arn: + self._delete_secret(secret_arn) + + def _get_auto_scaling_config_name( + self, deployment_name: str, deployment_id: UUID + ) -> str: + """Get the auto-scaling configuration name for a deployment. + + Args: + deployment_name: The deployment name. + deployment_id: The deployment ID. + + Returns: + The auto-scaling configuration name. + """ + # We use the first 8 characters of the deployment UUID as a random + # suffix to ensure uniqueness. + deployment_id_short = str(deployment_id)[:8] + raw_name = f"zenml-{deployment_name}-{deployment_id_short}" + + return self._sanitize_name( + raw_name, + random_suffix=deployment_id_short, + max_length=32, + extra_allowed_characters="-_", + ) + + def _create_or_update_auto_scaling_config( + self, + config_name: str, + min_size: int, + max_size: int, + max_concurrency: int, + deployment: DeploymentResponse, + settings: AWSDeployerSettings, + ) -> str: + """Create or update an auto-scaling configuration for App Runner. + + Args: + config_name: The name for the auto-scaling configuration. + min_size: Minimum number of instances. + max_size: Maximum number of instances. + max_concurrency: Maximum concurrent requests per instance. + deployment: The deployment. + settings: The deployer settings. + + Returns: + The ARN of the created/updated auto-scaling configuration. + + Raises: + ClientError: If the auto-scaling configuration cannot be described. + DeployerError: If auto-scaling configuration creation/update fails. + """ + try: + metadata = AppRunnerDeploymentMetadata.from_deployment(deployment) + existing_arn = metadata.auto_scaling_configuration_arn + + if existing_arn: + try: + response = self.app_runner_client.describe_auto_scaling_configuration( + AutoScalingConfigurationArn=existing_arn + ) + existing_config = response["AutoScalingConfiguration"] + + if ( + existing_config["MaxConcurrency"] == max_concurrency + and existing_config["MaxSize"] == max_size + and existing_config["MinSize"] == min_size + ): + logger.debug( + f"Auto-scaling configuration {existing_arn} is up " + "to date" + ) + return existing_arn + + logger.debug( + f"Auto-scaling configuration {existing_arn} is out of " + "date, updating it" + ) + + except ClientError as e: + if ( + e.response["Error"]["Code"] + != "InvalidRequestException" + ): + raise + logger.debug( + f"Existing auto-scaling configuration {existing_arn} " + "not found, creating new one" + ) + else: + logger.debug( + f"Creating auto-scaling configuration {config_name}" + ) + + # The create_auto_scaling_configuration call is used to both create + # a new auto-scaling configuration and update an existing one. + # It is possible to create multiple revisions of the same + # configuration by calling create_auto_scaling_configuration + # multiple times using the same AutoScalingConfigurationName. + response = ( + self.app_runner_client.create_auto_scaling_configuration( + AutoScalingConfigurationName=config_name, + MaxConcurrency=max_concurrency, + MaxSize=max_size, + MinSize=min_size, + Tags=self.get_tags(deployment, settings), + ) + ) + + return response["AutoScalingConfiguration"][ # type: ignore[no-any-return] + "AutoScalingConfigurationArn" + ] + + except (ClientError, BotoCoreError) as e: + raise DeployerError( + f"Failed to create/update auto-scaling configuration " + f"{config_name}: {e}" + ) + + def _cleanup_deployment_auto_scaling_config( + self, deployment: DeploymentResponse + ) -> None: + """Clean up the auto-scaling configuration associated with a deployment. + + Args: + deployment: The deployment. + """ + metadata = AppRunnerDeploymentMetadata.from_deployment(deployment) + config_arn = metadata.auto_scaling_configuration_arn + if not config_arn: + return + + try: + logger.debug(f"Deleting auto-scaling configuration {config_arn}") + self.app_runner_client.delete_auto_scaling_configuration( + AutoScalingConfigurationArn=config_arn + ) + except ClientError as e: + if e.response["Error"]["Code"] == "ResourceNotFoundException": + logger.debug( + f"Auto-scaling configuration {config_arn} not found, " + "skipping deletion" + ) + else: + logger.warning( + f"Failed to delete auto-scaling configuration {config_arn}: " + f"{e}" + ) + except Exception as e: + logger.warning( + f"Failed to delete auto-scaling configuration {config_arn}: " + f"{e}" + ) + + def _prepare_environment_variables( + self, + deployment: DeploymentResponse, + environment: Dict[str, str], + secrets: Dict[str, str], + settings: AWSDeployerSettings, + ) -> Tuple[Dict[str, str], Dict[str, str], Optional[str]]: + """Prepare environment variables for App Runner, handling secrets appropriately. + + Args: + deployment: The deployment. + environment: Regular environment variables. + secrets: Sensitive environment variables. + settings: The deployer settings. + + Returns: + Tuple containing: + - Dictionary of regular environment variables. + - Dictionary of secret environment variables (key -> secret ARN). + - Optional secret ARN (None if no secrets or fallback to env vars). + """ + secret_refs = {} + active_secret_arn: Optional[str] = None + + env_vars = {**settings.environment_variables, **environment} + + if secrets: + if settings.use_secrets_manager: + # Always store secrets as single JSON secret and reference their + # keys in the App Runner service configuration environment + # variables. + + secret_name = self._get_secret_name( + deployment.name, deployment.id, settings.secret_name_prefix + ) + + try: + secret_value = json.dumps(secrets) + secret_arn = self._create_or_update_secret( + secret_name, secret_value, deployment, settings + ) + active_secret_arn = secret_arn + + for key in secrets.keys(): + secret_refs[key] = f"{secret_arn}:{key}::" + + logger.debug( + f"Secret {secret_name} stored with ARN {secret_arn} " + f"containing {len(secrets)} secret(s)" + ) + + except Exception as e: + logger.warning( + f"Failed to create secret, falling back " + f"to direct env vars: {e}" + ) + env_vars.update(secrets) + + existing_secret_arn = self._get_secret_arn(deployment) + if ( + existing_secret_arn + and existing_secret_arn != active_secret_arn + ): + # Sometimes the previous secret resource is different from + # the new secret resource, e.g. if the secret name changed. + # In this case, we need to delete the old secret resource. + self._delete_secret(existing_secret_arn) + else: + logger.warning( + "Storing secrets directly in environment variables. " + "Consider enabling use_secrets_manager for better security." + ) + env_vars.update(secrets) + + return env_vars, secret_refs, active_secret_arn + + def _get_app_runner_service( + self, deployment: DeploymentResponse + ) -> Optional[Dict[str, Any]]: + """Get an existing App Runner service for a deployment. + + Args: + deployment: The deployment. + + Returns: + The App Runner service dictionary, or None if it doesn't exist. + + Raises: + ClientError: If the App Runner service cannot be described. + """ + existing_metadata = AppRunnerDeploymentMetadata.from_deployment( + deployment + ) + + if not existing_metadata.service_arn: + return None + + try: + response = self.app_runner_client.describe_service( + ServiceArn=existing_metadata.service_arn + ) + return response["Service"] # type: ignore[no-any-return] + except ClientError as e: + if e.response["Error"]["Code"] == "ResourceNotFoundException": + return None + raise + + def _get_service_operational_state( + self, + service: Dict[str, Any], + region: str, + secret_arn: Optional[str] = None, + ) -> DeploymentOperationalState: + """Get the operational state of an App Runner service. + + Args: + service: The App Runner service dictionary. + region: The AWS region. + secret_arn: The active Secrets Manager secret ARN. + + Returns: + The operational state of the App Runner service. + """ + metadata = AppRunnerDeploymentMetadata.from_app_runner_service( + service, region, secret_arn + ) + + state = DeploymentOperationalState( + status=DeploymentStatus.UNKNOWN, + metadata=metadata.model_dump(exclude_none=True), + ) + + # Map App Runner service status to ZenML status. Valid values are: + # - CREATE_FAILED + # - DELETE_FAILED + # - RUNNING + # - DELETED + # - PAUSED + # - OPERATION_IN_PROGRESS + service_status = service.get("Status", "").upper() + + if service_status in [ + "CREATE_FAILED", + "DELETE_FAILED", + ]: + state.status = DeploymentStatus.ERROR + elif service_status == "OPERATION_IN_PROGRESS": + state.status = DeploymentStatus.PENDING + elif service_status == "RUNNING": + state.status = DeploymentStatus.RUNNING + state.url = service.get("ServiceUrl") + if state.url and not state.url.startswith("https://"): + state.url = f"https://{state.url}" + elif service_status == "DELETED": + state.status = DeploymentStatus.ABSENT + elif service_status == "PAUSED": + state.status = ( + DeploymentStatus.PENDING + ) # Treat paused as pending for now + else: + state.status = DeploymentStatus.UNKNOWN + + return state + + def _requires_service_replacement( + self, + existing_service: Dict[str, Any], + settings: AWSDeployerSettings, + ) -> bool: + """Check if the service configuration requires replacement. + + App Runner only requires service replacement for fundamental service-level + changes that cannot be handled through revisions: + + - Network access configuration + - VPC configuration + - Encryption configuration + - Observability configuration + + All other configuration changes (image, resources, environment, scaling) + can be handled as updates. + + Args: + existing_service: The existing App Runner service. + settings: The new deployer settings. + + Returns: + True if the service needs to be replaced, False if it can be updated. + """ + network_config = existing_service.get("NetworkConfiguration", {}) + ingress_config = network_config.get("IngressConfiguration", {}) + current_public_access = ingress_config.get("IsPubliclyAccessible") + if current_public_access != settings.is_publicly_accessible: + return True + + current_vpc_config = network_config.get("EgressConfiguration", {}) + has_current_vpc = bool(current_vpc_config.get("VpcConnectorArn")) + will_have_vpc = bool(settings.ingress_vpc_configuration) + if has_current_vpc != will_have_vpc: + return True + + current_encryption = existing_service.get( + "EncryptionConfiguration", {} + ) + current_kms_key = current_encryption.get("KmsKey") + if current_kms_key != settings.encryption_kms_key: + return True + + return False + + def _convert_resource_settings_to_aws_format( + self, + resource_settings: ResourceSettings, + resource_combinations: List[Tuple[float, float]], + strict_resource_matching: bool = False, + ) -> Tuple[str, str]: + """Convert ResourceSettings to AWS App Runner resource format. + + AWS App Runner only supports specific CPU-memory combinations. + This method selects the best combination that meets the requirements. + + Args: + resource_settings: The resource settings from pipeline configuration. + resource_combinations: List of supported CPU (vCPU) and memory (GB) + combinations. + strict_resource_matching: Whether to enforce strict matching of + resource requirements to AWS App Runner supported CPU and + memory combinations or approximate the closest matching + supported combination. + + Returns: + Tuple of (cpu, memory) in AWS App Runner format. + """ + requested_cpu = resource_settings.cpu_count + requested_memory_gb = None + if resource_settings.memory is not None: + requested_memory_gb = resource_settings.get_memory(unit="GB") + + cpu, memory = self._select_aws_cpu_memory_combination( + requested_cpu, + requested_memory_gb, + resource_combinations, + strict_resource_matching, + ) + + return cpu, memory + + def _select_aws_cpu_memory_combination( + self, + requested_cpu: Optional[float], + requested_memory_gb: Optional[float], + resource_combinations: List[Tuple[float, float]], + strict_resource_matching: bool = False, + ) -> Tuple[str, str]: + """Select the best AWS App Runner CPU-memory combination. + + AWS App Runner only supports specific CPU and memory combinations, e.g.: + - 0.25 vCPU: 0.5 GB, 1 GB + - 0.5 vCPU: 1 GB + - 1 vCPU: 2 GB, 3 GB, 4 GB + - 2 vCPU: 4 GB, 6 GB + - 4 vCPU: 8 GB, 10 GB, 12 GB + + This method selects the best combination that meets the requirements. + + Args: + requested_cpu: Requested CPU count (can be None) + requested_memory_gb: Requested memory in GB (can be None) + resource_combinations: List of supported CPU (vCPU) and memory (GB) + combinations. + strict_resource_matching: Whether to enforce strict matching of + resource requirements to AWS App Runner supported CPU and + memory combinations or approximate the closest matching + supported combination. + + Returns: + Tuple of (cpu, memory) that best matches requirements, in AWS App + Runner format. + + Raises: + ValueError: If the requested resource requirements cannot be matched + to any of the supported combinations for the AWS App Runner + service and strict_resource_matching is True. + """ + if requested_cpu is None and requested_memory_gb is None: + return f"{DEFAULT_CPU:g} vCPU", f"{DEFAULT_MEMORY:g} GB" + + sorted_combinations = sorted(resource_combinations) + + best_combination = None + exact_match = False + best_score = float("inf") # Lower is better + + for cpu_val, mem_val in sorted_combinations: + cpu_ok = requested_cpu is None or cpu_val >= requested_cpu + mem_ok = ( + requested_memory_gb is None or mem_val >= requested_memory_gb + ) + exact_match = ( + cpu_val == requested_cpu and mem_val == requested_memory_gb + ) + if exact_match: + best_combination = (cpu_val, mem_val) + break + + if cpu_ok and mem_ok: + # Calculate "waste" score (how much over-provisioning) + cpu_waste = ( + 0 if requested_cpu is None else (cpu_val - requested_cpu) + ) + mem_waste = ( + 0 + if requested_memory_gb is None + else (mem_val - requested_memory_gb) + ) + + # Prioritize CPU requirements, then memory + score = cpu_waste * 10 + mem_waste + + if score < best_score: + best_score = score + best_combination = (cpu_val, mem_val) + + # If no combination satisfies requirements, use the highest available + if best_combination is None: + best_combination = sorted_combinations[-1] + + result = ( + f"{best_combination[0]:g} vCPU", + f"{best_combination[1]:g} GB", + ) + + if strict_resource_matching and not exact_match: + raise ValueError( + f"Requested resource requirements ({requested_cpu} vCPU, " + f"{requested_memory_gb} GB) cannot be matched to any of the " + f"supported combinations for the AWS App Runner service. " + f"The closest matching combination is {result[0]} and " + f"{result[1]}." + ) + + return result + + def _convert_scaling_settings_to_aws_format( + self, + resource_settings: ResourceSettings, + ) -> Tuple[int, int, int]: + """Convert ResourceSettings scaling to AWS App Runner format. + + Args: + resource_settings: The resource settings from pipeline configuration. + + Returns: + Tuple of (min_replicas, max_replicas, max_concurrency) for AWS App + Runner. + """ + min_replicas = DEFAULT_MIN_REPLICAS + if resource_settings.min_replicas is not None: + min_replicas = max( + 1, resource_settings.min_replicas + ) # AWS App Runner min is 1 + + max_replicas = DEFAULT_MAX_REPLICAS + if resource_settings.max_replicas is not None: + # ResourceSettings uses 0 to mean "no limit" + # AWS App Runner needs a specific value, so we use the platform maximum + if resource_settings.max_replicas == 0: + max_replicas = AWS_APP_RUNNER_MAX_SIZE + else: + max_replicas = min( + resource_settings.max_replicas, AWS_APP_RUNNER_MAX_SIZE + ) + + max_concurrency = DEFAULT_MAX_CONCURRENCY + if resource_settings.max_concurrency is not None: + max_concurrency = min( + resource_settings.max_concurrency, + AWS_APP_RUNNER_MAX_CONCURRENCY, + ) + + return min_replicas, max_replicas, max_concurrency + + def do_provision_deployment( + self, + deployment: DeploymentResponse, + stack: "Stack", + environment: Dict[str, str], + secrets: Dict[str, str], + timeout: int, + ) -> DeploymentOperationalState: + """Serve a pipeline as an App Runner service. + + Args: + deployment: The deployment to serve. + stack: The stack the pipeline will be deployed on. + environment: Environment variables to set. + secrets: Secret environment variables to set. + timeout: The maximum time in seconds to wait for the pipeline + deployment to be provisioned. + + Returns: + The operational state of the provisioned deployment. + + Raises: + DeploymentProvisionError: If the deployment fails. + DeploymentDeprovisionError: If the previous deployment fails to + deprovision. + DeployerError: If an unexpected error occurs. + """ + snapshot = deployment.snapshot + assert snapshot, "Pipeline snapshot not found" + + settings = cast( + AWSDeployerSettings, + self.get_settings(snapshot), + ) + + resource_settings = snapshot.pipeline_configuration.resource_settings + + cpu, memory = self._convert_resource_settings_to_aws_format( + resource_settings, + self.config.resource_combinations, + settings.strict_resource_matching, + ) + min_size, max_size, max_concurrency = ( + self._convert_scaling_settings_to_aws_format( + resource_settings, + ) + ) + + client = self.app_runner_client + + service_name = self._get_service_name( + deployment.name, deployment.id, settings.service_name_prefix + ) + + existing_service = self._get_app_runner_service(deployment) + image = self.get_image(snapshot) + region = self.region + + if existing_service and self._requires_service_replacement( + existing_service, settings + ): + try: + self.do_deprovision_deployment(deployment, timeout) + except DeploymentNotFoundError: + logger.warning( + f"Deployment '{deployment.name}' not found, " + f"skipping deprovision of existing App Runner service" + ) + except DeployerError as e: + raise DeploymentDeprovisionError( + f"Failed to deprovision existing App Runner service for " + f"deployment '{deployment.name}': {e}\n" + "Bailing out to avoid leaving orphaned resources." + "You might need to manually delete the existing App Runner " + "service instance to continue or forcefully delete the " + "deployment." + ) + existing_service = None + + entrypoint = DeploymentEntrypointConfiguration.get_entrypoint_command() + arguments = DeploymentEntrypointConfiguration.get_entrypoint_arguments( + **{ + DEPLOYMENT_ID_OPTION: deployment.id, + PORT_OPTION: settings.port, + AUTH_KEY_OPTION: deployment.auth_key, + } + ) + + env_vars, secret_refs, active_secret_arn = ( + self._prepare_environment_variables( + deployment, environment, secrets, settings + ) + ) + + # AWS App Runner only supports ECR repositories. + if "public.ecr.aws" in image: + image_repo_type = "ECR_PUBLIC" + elif "amazonaws.com" in image: + image_repo_type = "ECR" + else: + raise DeploymentProvisionError( + f"AWS App Runner only supports Amazon ECR and ECR Public " + f"repositories. The container image '{image}' does not appear " + f"to be hosted on either platform. Supported image repositories:\n" + f"- ECR Public: public.ecr.aws/...\n" + f"- ECR Private: *.amazonaws.com/...\n" + f"Please push your image to one of these registries before " + f"deploying to App Runner." + ) + + image_config: Dict[str, Any] = { + "Port": str(settings.port), + "StartCommand": " ".join(entrypoint + arguments), + } + + if env_vars: + image_config["RuntimeEnvironmentVariables"] = env_vars + + if secret_refs: + image_config["RuntimeEnvironmentSecrets"] = secret_refs + + image_repository_config = { + "ImageIdentifier": image, + "ImageConfiguration": image_config, + "ImageRepositoryType": image_repo_type, + } + + source_configuration = { + "ImageRepository": image_repository_config, + # We don't want to automatically deploy new revisions when new + # container images are pushed to the repository. + "AutoDeploymentsEnabled": False, + } + + if settings.access_role_arn: + source_configuration["AuthenticationConfiguration"] = { + "AccessRoleArn": settings.access_role_arn + } + elif image_repo_type == "ECR": + logger.warning( + "Using private ECR repository without explicit access_role_arn. " + "Ensure the default App Runner service role has permissions to " + f"pull the '{image}' image from the repository, or specify " + "access_role_arn in deployer settings." + ) + + instance_configuration = { + "Cpu": cpu, + "Memory": memory, + } + if settings.instance_role_arn: + instance_configuration["InstanceRoleArn"] = ( + settings.instance_role_arn + ) + elif secret_refs: + logger.warning( + "Storing secrets in AWS Secrets Manager is enabled but no " + "explicit instance role is provided. Ensure the default " + "App Runner service role has secretsmanager:GetSecretValue " + "permissions, provide an explicit instance role or disable " + "'use_secrets_manager' in deployer settings." + ) + + auto_scaling_config_name = self._get_auto_scaling_config_name( + deployment.name, deployment.id + ) + auto_scaling_config_arn = self._create_or_update_auto_scaling_config( + auto_scaling_config_name, + min_size, + max_size, + max_concurrency, + deployment, + settings, + ) + + health_check_configuration = { + "Protocol": settings.health_check_protocol, + "Interval": settings.health_check_interval_seconds, + "Timeout": settings.health_check_timeout_seconds, + "HealthyThreshold": settings.health_check_healthy_threshold, + "UnhealthyThreshold": settings.health_check_unhealthy_threshold, + } + + if settings.health_check_protocol.upper() == "HTTP": + health_check_configuration["Path"] = settings.health_check_path + + network_configuration = { + "IngressConfiguration": { + "IsPubliclyAccessible": settings.is_publicly_accessible, + } + } + + traffic_configurations = [] + for revision, percent in settings.traffic_allocation.items(): + if revision == "LATEST": + traffic_configurations.append( + { + "Type": "LATEST", + "Percent": percent, + } + ) + else: + if revision.startswith("tag:"): + traffic_configurations.append( + { + "Tag": revision[4:], # Remove "tag:" prefix + "Percent": percent, + } + ) + else: + traffic_configurations.append( + { + "Revision": revision, + "Percent": percent, + } + ) + + if settings.ingress_vpc_configuration: + vpc_config = json.loads(settings.ingress_vpc_configuration) + network_configuration["IngressConfiguration"][ + "VpcIngressConnectionConfiguration" + ] = vpc_config + + encryption_configuration = None + if settings.encryption_kms_key: + encryption_configuration = { + "KmsKey": settings.encryption_kms_key, + } + + observability_configuration = None + if settings.observability_configuration_arn: + observability_configuration = { + "ObservabilityEnabled": True, + "ObservabilityConfigurationArn": settings.observability_configuration_arn, + } + + service_tags = self.get_tags(deployment, settings) + + try: + if existing_service: + logger.debug( + f"Updating existing App Runner service for pipeline " + f"deployment '{deployment.name}'" + ) + + update_request = { + "ServiceArn": existing_service["ServiceArn"], + "SourceConfiguration": source_configuration, + "InstanceConfiguration": instance_configuration, + "AutoScalingConfigurationArn": auto_scaling_config_arn, + "HealthCheckConfiguration": health_check_configuration, + "NetworkConfiguration": network_configuration, + } + + if not ( + len(traffic_configurations) == 1 + and traffic_configurations[0].get("Type") == "LATEST" + and traffic_configurations[0].get("Percent") == 100 + ): + update_request["TrafficConfiguration"] = ( + traffic_configurations + ) + + if encryption_configuration: + update_request["EncryptionConfiguration"] = ( + encryption_configuration + ) + + if observability_configuration: + update_request["ObservabilityConfiguration"] = ( + observability_configuration + ) + + response = client.update_service(**update_request) + service_arn = response["Service"]["ServiceArn"] + + # Update tags separately + client.tag_resource( + ResourceArn=service_arn, + Tags=service_tags, + ) + + updated_service = response["Service"] + else: + logger.debug( + f"Creating new App Runner service for deployment " + f"'{deployment.name}' in region {region}" + ) + + create_request = { + "ServiceName": service_name, + "SourceConfiguration": source_configuration, + "InstanceConfiguration": instance_configuration, + "AutoScalingConfigurationArn": auto_scaling_config_arn, + "Tags": service_tags, + "HealthCheckConfiguration": health_check_configuration, + "NetworkConfiguration": network_configuration, + } + + if encryption_configuration: + create_request["EncryptionConfiguration"] = ( + encryption_configuration + ) + + if observability_configuration: + create_request["ObservabilityConfiguration"] = ( + observability_configuration + ) + + # Only add traffic configuration if it's not the default + # (100% LATEST) + if not ( + len(traffic_configurations) == 1 + and traffic_configurations[0].get("Type") == "LATEST" + and traffic_configurations[0].get("Percent") == 100 + ): + create_request["TrafficConfiguration"] = ( + traffic_configurations + ) + + response = client.create_service(**create_request) + updated_service = response["Service"] + + return self._get_service_operational_state( + updated_service, region, active_secret_arn + ) + + except (ClientError, BotoCoreError) as e: + raise DeploymentProvisionError( + f"Failed to deploy App Runner service for deployment " + f"'{deployment.name}': {e}" + ) + except Exception as e: + raise DeployerError( + f"Unexpected error while provisioning deployment " + f"'{deployment.name}': {e}" + ) + + def do_get_deployment_state( + self, + deployment: DeploymentResponse, + ) -> DeploymentOperationalState: + """Get information about an App Runner deployment. + + Args: + deployment: The deployment to get information about. + + Returns: + The operational state of the deployment. + + Raises: + DeploymentNotFoundError: If the deployment is not found. + RuntimeError: If the service ARN is not found in the deployment metadata. + """ + service = self._get_app_runner_service(deployment) + + if service is None: + raise DeploymentNotFoundError( + f"App Runner service for deployment '{deployment.name}' " + "not found" + ) + + existing_metadata = AppRunnerDeploymentMetadata.from_deployment( + deployment + ) + + if not existing_metadata.region: + raise RuntimeError( + f"Region not found in deployment metadata for " + f"deployment '{deployment.name}'" + ) + + existing_secret_arn = self._get_secret_arn(deployment) + + return self._get_service_operational_state( + service, + existing_metadata.region, + existing_secret_arn, + ) + + def do_get_deployment_state_logs( + self, + deployment: DeploymentResponse, + follow: bool = False, + tail: Optional[int] = None, + ) -> Generator[str, bool, None]: + """Get the logs of an App Runner deployment. + + Args: + deployment: The deployment to get the logs of. + follow: If True, stream logs as they are written. + tail: Only retrieve the last NUM lines of log output. + + Yields: + The logs of the deployment. + + Raises: + NotImplementedError: If log following is requested. + DeploymentNotFoundError: If the deployment is not found. + DeploymentLogsNotFoundError: If the logs are not found. + DeployerError: If an unexpected error occurs. + RuntimeError: If the service name is not found in the deployment metadata. + """ + if follow: + raise NotImplementedError( + "Log following is not yet implemented for App Runner deployer" + ) + + service = self._get_app_runner_service(deployment) + if service is None: + raise DeploymentNotFoundError( + f"App Runner service for deployment '{deployment.name}' not " + "found" + ) + + try: + existing_metadata = AppRunnerDeploymentMetadata.from_deployment( + deployment + ) + service_name = existing_metadata.service_name + if not service_name: + raise RuntimeError( + f"Service name not found in deployment metadata for " + f"deployment '{deployment.name}'" + ) + + # App Runner automatically creates CloudWatch log groups + log_group_name = f"/aws/apprunner/{service_name}/service" + + try: + streams_response = self.logs_client.describe_log_streams( + logGroupName=log_group_name, + orderBy="LastEventTime", + descending=True, + ) + + log_lines = [] + for stream in streams_response.get("logStreams", []): + stream_name = stream["logStreamName"] + + events_response = self.logs_client.get_log_events( + logGroupName=log_group_name, + logStreamName=stream_name, + startFromHead=False, # Get most recent first + ) + + for event in events_response.get("events", []): + timestamp = event.get("timestamp", 0) + message = event.get("message", "") + + dt = datetime.datetime.fromtimestamp( + timestamp / 1000.0 + ) + formatted_time = dt.isoformat() + + log_line = f"[{formatted_time}] {message}" + log_lines.append(log_line) + + # Sort by timestamp (most recent last for tail to work correctly) + log_lines.sort() + + if tail is not None and tail > 0: + log_lines = log_lines[-tail:] + + for log_line in log_lines: + yield log_line + + except ClientError as e: + if e.response["Error"]["Code"] == "ResourceNotFoundException": + raise DeploymentLogsNotFoundError( + f"Log group not found for App Runner service " + f"'{service_name}'" + ) + raise + + except (ClientError, BotoCoreError) as e: + raise DeploymentLogsNotFoundError( + f"Failed to retrieve logs for deployment '{deployment.name}': " + f"{e}" + ) + except Exception as e: + raise DeployerError( + f"Unexpected error while retrieving logs for deployment " + f"'{deployment.name}': {e}" + ) + + def do_deprovision_deployment( + self, + deployment: DeploymentResponse, + timeout: int, + ) -> Optional[DeploymentOperationalState]: + """Deprovision an App Runner deployment. + + Args: + deployment: The deployment to deprovision. + timeout: The maximum time in seconds to wait for the pipeline + deployment to be deprovisioned. + + Returns: + The operational state of the deprovisioned deployment, or None if + deletion is completed immediately. + + Raises: + DeploymentNotFoundError: If the deployment is not found. + DeploymentDeprovisionError: If the deprovision fails. + DeployerError: If an unexpected error occurs. + RuntimeError: If the service ARN is not found in the deployment metadata. + """ + service = self._get_app_runner_service(deployment) + if service is None: + raise DeploymentNotFoundError( + f"App Runner service for deployment '{deployment.name}' not " + "found" + ) + + try: + existing_metadata = AppRunnerDeploymentMetadata.from_deployment( + deployment + ) + if not existing_metadata.service_arn: + raise RuntimeError( + f"Service ARN not found in deployment metadata for " + f"deployment '{deployment.name}'" + ) + + logger.debug( + f"Deleting App Runner service for deployment " + f"'{deployment.name}'" + ) + + # Delete the service + self.app_runner_client.delete_service( + ServiceArn=existing_metadata.service_arn + ) + + except ClientError as e: + if e.response["Error"]["Code"] == "ResourceNotFoundException": + raise DeploymentNotFoundError( + f"App Runner service for deployment '{deployment.name}' " + "not found" + ) + raise DeploymentDeprovisionError( + f"Failed to delete App Runner service for deployment " + f"'{deployment.name}': {e}" + ) + except Exception as e: + raise DeployerError( + f"Unexpected error while deleting deployment " + f"'{deployment.name}': {e}" + ) + + deployment_before_deletion = deployment + + # App Runner deletion is asynchronous and the auto-scaling configuration + # and secrets need to be cleaned up after the service is deleted. So we + # poll the service here instead of doing it in the base deployer class. + deployment, deployment_state = self._poll_deployment( + deployment, DeploymentStatus.ABSENT, timeout + ) + + if deployment_state.status != DeploymentStatus.ABSENT: + return deployment_state + + try: + self._cleanup_deployment_secrets(deployment_before_deletion) + + self._cleanup_deployment_auto_scaling_config( + deployment_before_deletion + ) + except Exception as e: + raise DeployerError( + f"Unexpected error while cleaning up resources for pipeline " + f"deployment '{deployment.name}': {e}" + ) + + return None diff --git a/src/zenml/integrations/aws/flavors/__init__.py b/src/zenml/integrations/aws/flavors/__init__.py index a2cdc428add..823c08bbcdf 100644 --- a/src/zenml/integrations/aws/flavors/__init__.py +++ b/src/zenml/integrations/aws/flavors/__init__.py @@ -17,6 +17,10 @@ AWSContainerRegistryConfig, AWSContainerRegistryFlavor, ) +from zenml.integrations.aws.flavors.aws_deployer_flavor import ( + AWSDeployerConfig, + AWSDeployerFlavor, +) from zenml.integrations.aws.flavors.aws_image_builder_flavor import ( AWSImageBuilderConfig, AWSImageBuilderFlavor, @@ -33,6 +37,8 @@ __all__ = [ "AWSContainerRegistryFlavor", "AWSContainerRegistryConfig", + "AWSDeployerConfig", + "AWSDeployerFlavor", "AWSImageBuilderConfig", "AWSImageBuilderFlavor", "SagemakerStepOperatorFlavor", diff --git a/src/zenml/integrations/aws/flavors/aws_deployer_flavor.py b/src/zenml/integrations/aws/flavors/aws_deployer_flavor.py new file mode 100644 index 00000000000..b21e5841b3d --- /dev/null +++ b/src/zenml/integrations/aws/flavors/aws_deployer_flavor.py @@ -0,0 +1,344 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""AWS App Runner deployer flavor.""" + +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type + +from pydantic import Field + +from zenml.deployers.base_deployer import ( + BaseDeployerConfig, + BaseDeployerFlavor, + BaseDeployerSettings, +) +from zenml.integrations.aws import ( + AWS_CONNECTOR_TYPE, + AWS_DEPLOYER_FLAVOR, + AWS_RESOURCE_TYPE, +) +from zenml.models import ServiceConnectorRequirements + +if TYPE_CHECKING: + from zenml.integrations.aws.deployers import AWSDeployer + + +class AWSDeployerSettings(BaseDeployerSettings): + """Settings for the AWS App Runner deployer.""" + + region: Optional[str] = Field( + default=None, + description="AWS region where the App Runner service will be deployed. " + "If not specified, the region will be determined from the authenticated " + "session (service connector or implicit credentials). " + "App Runner is available in specific regions: " + "https://docs.aws.amazon.com/apprunner/latest/dg/regions.html", + ) + service_name_prefix: str = Field( + default="zenml-", + description="Prefix for service names in App Runner to avoid naming " + "conflicts.", + ) + + # Health check configuration + health_check_grace_period_seconds: int = Field( + default=20, + ge=0, + le=20, + description="Grace period for health checks in seconds. Range: 0-20.", + ) + + health_check_interval_seconds: int = Field( + default=10, + ge=1, + le=20, + description="Interval between health checks in seconds. Range: 1-20.", + ) + + health_check_path: str = Field( + default="/health", + description="Health check path for the App Runner service.", + ) + + health_check_protocol: str = Field( + default="TCP", + description="Health check protocol. Options: 'TCP', 'HTTP'.", + pattern="^TCP|HTTP$", + ) + + health_check_timeout_seconds: int = Field( + default=2, + ge=1, + le=20, + description="Timeout for health checks in seconds. Range: 1-20.", + ) + + health_check_healthy_threshold: int = Field( + default=1, + ge=1, + le=20, + description="Number of consecutive successful health checks required.", + ) + + health_check_unhealthy_threshold: int = Field( + default=5, + ge=1, + le=20, + description="Number of consecutive failed health checks before unhealthy.", + ) + + # Networking and security + is_publicly_accessible: bool = Field( + default=True, + description="Whether the App Runner service is publicly accessible.", + ) + + ingress_vpc_configuration: Optional[str] = Field( + default=None, + description="VPC configuration for private App Runner services. " + "JSON string with VpcId, VpcEndpointId, and VpcIngressConnectionName.", + ) + + # Environment and configuration + environment_variables: Dict[str, str] = Field( + default_factory=dict, + description="Environment variables to set in the App Runner service.", + ) + + # Tags + tags: Dict[str, str] = Field( + default_factory=dict, + description="Tags to apply to the App Runner service.", + ) + + # App Runner specific settings + port: int = Field( + default=8080, + ge=1, + le=65535, + description="Port on which the container listens for requests.", + ) + + # Secret management configuration + use_secrets_manager: bool = Field( + default=True, + description="Whether to store sensitive environment variables in AWS " + "Secrets Manager instead of directly in the App Runner service " + "configuration.", + ) + + secret_name_prefix: str = Field( + default="zenml-", + description="Prefix for secret names in Secrets Manager to avoid naming " + "conflicts.", + ) + + # Observability + observability_configuration_arn: Optional[str] = Field( + default=None, + description="ARN of the observability configuration to associate with " + "the App Runner service.", + ) + + # Encryption + encryption_kms_key: Optional[str] = Field( + default=None, + description="KMS key ARN for encrypting App Runner service data.", + ) + + # IAM Roles + instance_role_arn: Optional[str] = Field( + default=None, + description="ARN of the IAM role to assign to the App Runner service instances.", + ) + + access_role_arn: Optional[str] = Field( + default=None, + description="ARN of the IAM role that App Runner uses to access the " + "image repository (ECR). Required for private ECR repositories. If not " + "specified, App Runner will attempt to use the default service role, " + "which may not have ECR access permissions.", + ) + + # Traffic allocation for A/B testing and gradual rollouts + traffic_allocation: Dict[str, int] = Field( + default_factory=lambda: {"LATEST": 100}, + description="Traffic allocation between revisions for A/B testing and " + "gradual rollouts. Keys can be revision names, tags, or 'LATEST' for " + "the most recent revision. Values are percentages that must sum to 100. " + "Example: {'LATEST': 80, 'my-stable-revision': 20}", + ) + + # Resource matching + strict_resource_matching: bool = Field( + default=False, + description="Whether to enforce strict matching of resource requirements " + "to AWS App Runner supported CPU (vCPU) and memory (GB) combinations. " + "When True, raises an error if no exact match is found. When False, " + "automatically selects the closest matching supported combination. " + "See https://docs.aws.amazon.com/apprunner/latest/dg/architecture.html#architecture.vcpu-memory " + "for more details.", + ) + + +# AWS App Runner supported CPU (vCPU) and memory (GB) combinations +DEFAULT_RESOURCE_COMBINATIONS = [ + ( + 0.25, + 0.5, + ), + ( + 0.25, + 1.0, + ), + ( + 0.5, + 1.0, + ), + ( + 1.0, + 2.0, + ), + ( + 1.0, + 3.0, + ), + ( + 1.0, + 4.0, + ), + ( + 2.0, + 4.0, + ), + ( + 2.0, + 6.0, + ), + ( + 4.0, + 8.0, + ), + ( + 4.0, + 10.0, + ), + ( + 4.0, + 12.0, + ), +] + + +class AWSDeployerConfig( + BaseDeployerConfig, + AWSDeployerSettings, +): + """Configuration for the AWS App Runner deployer.""" + + resource_combinations: List[Tuple[float, float]] = Field( + default=DEFAULT_RESOURCE_COMBINATIONS, + description="AWS App Runner supported CPU (vCPU), memory (GB) " + "combinations.", + ) + + @property + def is_remote(self) -> bool: + """Checks if this stack component is running remotely. + + This designation is used to determine if the stack component can be + used with a local ZenML database or if it requires a remote ZenML + server. + + Returns: + True if this config is for a remote component, False otherwise. + """ + return True + + +class AWSDeployerFlavor(BaseDeployerFlavor): + """AWS App Runner deployer flavor.""" + + @property + def name(self) -> str: + """Name of the deployer flavor. + + Returns: + Name of the deployer flavor. + """ + return AWS_DEPLOYER_FLAVOR + + @property + def service_connector_requirements( + self, + ) -> Optional[ServiceConnectorRequirements]: + """Service connector resource requirements for service connectors. + + Specifies resource requirements that are used to filter the available + service connector types that are compatible with this flavor. + + Returns: + Requirements for compatible service connectors, if a service + connector is required for this flavor. + """ + return ServiceConnectorRequirements( + connector_type=AWS_CONNECTOR_TYPE, + resource_type=AWS_RESOURCE_TYPE, + ) + + @property + def docs_url(self) -> Optional[str]: + """A url to point at docs explaining this flavor. + + Returns: + A flavor docs url. + """ + return self.generate_default_docs_url() + + @property + def sdk_docs_url(self) -> Optional[str]: + """A url to point at SDK docs explaining this flavor. + + Returns: + A flavor SDK docs url. + """ + return self.generate_default_sdk_docs_url() + + @property + def logo_url(self) -> str: + """A url to represent the flavor in the dashboard. + + Returns: + The flavor logo. + """ + return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/deployer/aws-app-runner.png" + + @property + def config_class(self) -> Type[AWSDeployerConfig]: + """Returns the AWSDeployerConfig config class. + + Returns: + The config class. + """ + return AWSDeployerConfig + + @property + def implementation_class(self) -> Type["AWSDeployer"]: + """Implementation class for this flavor. + + Returns: + Implementation class for this flavor. + """ + from zenml.integrations.aws.deployers import AWSDeployer + + return AWSDeployer diff --git a/src/zenml/integrations/databricks/orchestrators/databricks_orchestrator.py b/src/zenml/integrations/databricks/orchestrators/databricks_orchestrator.py index 782099a62f3..9acf3cac29c 100644 --- a/src/zenml/integrations/databricks/orchestrators/databricks_orchestrator.py +++ b/src/zenml/integrations/databricks/orchestrators/databricks_orchestrator.py @@ -314,8 +314,8 @@ def _construct_databricks_pipeline( databricks_client = self._get_databricks_client() # Create an empty folder in a volume. - deployment_name = snapshot.pipeline.name - databricks_directory = f"{DATABRICKS_WHEELS_DIRECTORY_PREFIX}/{deployment_name}/{orchestrator_run_name}" + snapshot_name = snapshot.pipeline.name + databricks_directory = f"{DATABRICKS_WHEELS_DIRECTORY_PREFIX}/{snapshot_name}/{orchestrator_run_name}" databricks_wheel_path = ( f"{databricks_directory}/{wheel_path.rsplit('/', 1)[-1]}" ) diff --git a/src/zenml/integrations/gcp/__init__.py b/src/zenml/integrations/gcp/__init__.py index 526013124b4..bfd37ea82ce 100644 --- a/src/zenml/integrations/gcp/__init__.py +++ b/src/zenml/integrations/gcp/__init__.py @@ -33,6 +33,7 @@ GCP_VERTEX_EXPERIMENT_TRACKER_FLAVOR = "vertex" GCP_VERTEX_ORCHESTRATOR_FLAVOR = "vertex" GCP_VERTEX_STEP_OPERATOR_FLAVOR = "vertex" +GCP_DEPLOYER_FLAVOR = "gcp" # Service connector constants GCP_CONNECTOR_TYPE = "gcp" @@ -56,6 +57,8 @@ class GcpIntegration(Integration): "google-cloud-aiplatform>=1.34.0", # includes shapely pin fix "google-cloud-build>=3.11.0", "google-cloud-pipeline-components>=2.19.0", + "google-cloud-run>=0.10.0", + "google-cloud-logging>=3.8.0", "kubernetes", ] REQUIREMENTS_IGNORED_ON_UNINSTALL = ["kubernetes","kfp"] @@ -74,6 +77,7 @@ def flavors(cls) -> List[Type[Flavor]]: """ from zenml.integrations.gcp.flavors import ( GCPArtifactStoreFlavor, + GCPDeployerFlavor, GCPImageBuilderFlavor, VertexExperimentTrackerFlavor, VertexOrchestratorFlavor, @@ -82,6 +86,7 @@ def flavors(cls) -> List[Type[Flavor]]: return [ GCPArtifactStoreFlavor, + GCPDeployerFlavor, GCPImageBuilderFlavor, VertexExperimentTrackerFlavor, VertexOrchestratorFlavor, diff --git a/src/zenml/integrations/gcp/deployers/__init__.py b/src/zenml/integrations/gcp/deployers/__init__.py new file mode 100644 index 00000000000..a7b7039e9a8 --- /dev/null +++ b/src/zenml/integrations/gcp/deployers/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Initialization for the GCP deployer.""" + +from zenml.integrations.gcp.deployers.gcp_deployer import ( # noqa + GCPDeployer, +) + +__all__ = ["GCPDeployer"] diff --git a/src/zenml/integrations/gcp/deployers/gcp_deployer.py b/src/zenml/integrations/gcp/deployers/gcp_deployer.py new file mode 100644 index 00000000000..ce711f26bbb --- /dev/null +++ b/src/zenml/integrations/gcp/deployers/gcp_deployer.py @@ -0,0 +1,1381 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Implementation of the GCP Cloud Run deployer.""" + +import math +import re +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generator, + List, + Optional, + Tuple, + Type, + cast, +) +from uuid import UUID + +from google.api_core import exceptions as google_exceptions +from google.cloud import run_v2, secretmanager +from google.cloud.logging_v2 import Client as LoggingClient +from google.protobuf.json_format import MessageToDict +from pydantic import BaseModel + +from zenml.config.base_settings import BaseSettings +from zenml.config.resource_settings import ResourceSettings +from zenml.deployers.containerized_deployer import ContainerizedDeployer +from zenml.deployers.exceptions import ( + DeployerError, + DeploymentDeprovisionError, + DeploymentLogsNotFoundError, + DeploymentNotFoundError, + DeploymentProvisionError, +) +from zenml.deployers.server.entrypoint_configuration import ( + AUTH_KEY_OPTION, + DEPLOYMENT_ID_OPTION, + PORT_OPTION, + DeploymentEntrypointConfiguration, +) +from zenml.enums import DeploymentStatus, StackComponentType +from zenml.integrations.gcp.flavors.gcp_deployer_flavor import ( + GCPDeployerConfig, + GCPDeployerSettings, +) +from zenml.integrations.gcp.google_credentials_mixin import ( + GoogleCredentialsMixin, +) +from zenml.logger import get_logger +from zenml.models import ( + DeploymentOperationalState, + DeploymentResponse, +) +from zenml.stack import StackValidator + +if TYPE_CHECKING: + from zenml.stack import Stack + +logger = get_logger(__name__) + +# Default resource and scaling configuration constants +DEFAULT_CPU = "1" +DEFAULT_MEMORY = "2Gi" +DEFAULT_MIN_INSTANCES = 1 +DEFAULT_MAX_INSTANCES = 100 +DEFAULT_CONCURRENCY = 80 + +# GCP Cloud Run built-in limits +GCP_CLOUD_RUN_MAX_INSTANCES = 1000 + + +class CloudRunDeploymentMetadata(BaseModel): + """Metadata for a Cloud Run deployment.""" + + service_name: Optional[str] = None + service_url: Optional[str] = None + project_id: Optional[str] = None + location: Optional[str] = None + revision_name: Optional[str] = None + reconciling: Optional[bool] = None + service_status: Optional[Dict[str, Any]] = None + cpu: Optional[str] = None + memory: Optional[str] = None + min_instances: Optional[int] = None + max_instances: Optional[int] = None + concurrency: Optional[int] = None + timeout_seconds: Optional[int] = None + ingress: Optional[str] = None + vpc_connector: Optional[str] = None + service_account: Optional[str] = None + execution_environment: Optional[str] = None + port: Optional[int] = None + allow_unauthenticated: Optional[bool] = None + labels: Optional[Dict[str, str]] = None + annotations: Optional[Dict[str, str]] = None + traffic_allocation: Optional[Dict[str, int]] = None + created_time: Optional[str] = None + updated_time: Optional[str] = None + secrets: List[str] = [] + + @classmethod + def from_cloud_run_service( + cls, + service: run_v2.Service, + project_id: str, + location: str, + secrets: List[secretmanager.Secret], + ) -> "CloudRunDeploymentMetadata": + """Create metadata from a Cloud Run service. + + Args: + service: The Cloud Run service object. + project_id: The GCP project ID. + location: The GCP location. + secrets: The list of existing GCP Secret Manager secrets for the + deployment. + + Returns: + The metadata for the Cloud Run service. + """ + container = None + if service.template and service.template.containers: + container = service.template.containers[0] + + cpu = None + memory = None + if container and container.resources and container.resources.limits: + cpu = container.resources.limits.get("cpu") + memory = container.resources.limits.get("memory") + + min_instances = None + max_instances = None + if service.template and service.template.scaling: + scaling = service.template.scaling + min_instances = scaling.min_instance_count + max_instances = scaling.max_instance_count + + concurrency = None + if service.template: + concurrency = service.template.max_instance_request_concurrency + + timeout_seconds = None + if service.template and service.template.timeout: + timeout_seconds = service.template.timeout.seconds + + ingress = None + if service.ingress: + ingress = str(service.ingress) + + vpc_connector = None + if service.template and service.template.vpc_access: + vpc_connector = service.template.vpc_access.connector + + service_account = None + if service.template: + service_account = service.template.service_account + + execution_environment = None + if service.template and service.template.execution_environment: + execution_environment = str(service.template.execution_environment) + + port = None + if container and container.ports: + port = container.ports[0].container_port + + traffic_allocation = {} + if service.traffic: + for traffic in service.traffic: + if traffic.revision: + traffic_allocation[traffic.revision] = traffic.percent + elif traffic.tag: + traffic_allocation[traffic.tag] = traffic.percent + else: + traffic_allocation["LATEST"] = traffic.percent + + return cls( + service_name=service.name.split("/")[-1] if service.name else None, + service_url=service.uri if hasattr(service, "uri") else None, + project_id=project_id, + location=location, + revision_name=( + service.template.revision + if service.template and service.template.revision + else None + ), + reconciling=service.reconciling, + service_status=MessageToDict( + service.terminal_condition._pb, + ) + if service.terminal_condition + else None, + cpu=cpu, + memory=memory, + min_instances=min_instances, + max_instances=max_instances, + concurrency=concurrency, + timeout_seconds=timeout_seconds, + ingress=ingress, + vpc_connector=vpc_connector, + service_account=service_account, + execution_environment=execution_environment, + port=port, + allow_unauthenticated=True, + labels=dict(service.labels) if service.labels else {}, + annotations=dict(service.annotations) + if service.annotations + else {}, + traffic_allocation=traffic_allocation, + created_time=( + service.create_time.isoformat() + if service.create_time + else None + ), + updated_time=( + service.update_time.isoformat() + if service.update_time + else None + ), + secrets=[secret.name for secret in secrets], + ) + + @classmethod + def from_deployment( + cls, deployment: DeploymentResponse + ) -> "CloudRunDeploymentMetadata": + """Create metadata from a deployment. + + Args: + deployment: The deployment to get the metadata for. + + Returns: + The metadata for the deployment. + """ + return cls.model_validate(deployment.deployment_metadata) + + +class GCPDeployer(ContainerizedDeployer, GoogleCredentialsMixin): + """Deployer responsible for deploying pipelines on GCP Cloud Run.""" + + CONTAINER_REQUIREMENTS: List[str] = ["uvicorn", "fastapi"] + + _credentials: Optional[Any] = None + _project_id: Optional[str] = None + _cloud_run_client: Optional[run_v2.ServicesClient] = None + _logging_client: Optional[LoggingClient] = None + _secret_manager_client: Optional[ + secretmanager.SecretManagerServiceClient + ] = None + + @property + def config(self) -> GCPDeployerConfig: + """Returns the `GCPDeployerConfig` config. + + Returns: + The configuration. + """ + return cast(GCPDeployerConfig, self._config) + + @property + def settings_class(self) -> Optional[Type["BaseSettings"]]: + """Settings class for the GCP deployer. + + Returns: + The settings class. + """ + return GCPDeployerSettings + + @property + def validator(self) -> Optional[StackValidator]: + """Ensures there is an image builder in the stack. + + Returns: + A `StackValidator` instance. + """ + return StackValidator( + required_components={ + StackComponentType.IMAGE_BUILDER, + StackComponentType.CONTAINER_REGISTRY, + } + ) + + def _get_credentials_and_project_id(self) -> Tuple[Any, str]: + """Get GCP credentials and project ID. + + Returns: + A tuple containing the credentials and project ID. + """ + if ( + self._credentials is not None + and self._project_id is not None + and not self.connector_has_expired() + ): + return self._credentials, self._project_id + + credentials, project_id = self._get_authentication() + + self._credentials = credentials + self._project_id = project_id + return credentials, project_id + + @property + def project_id(self) -> str: + """Get the GCP project ID. + + Returns: + The GCP project ID. + """ + _, project_id = self._get_credentials_and_project_id() + return project_id + + @property + def cloud_run_client(self) -> run_v2.ServicesClient: + """Get the Cloud Run client. + + Returns: + The Cloud Run client. + """ + if self._cloud_run_client is None or self.connector_has_expired(): + credentials, _ = self._get_credentials_and_project_id() + self._cloud_run_client = run_v2.ServicesClient( + credentials=credentials + ) + return self._cloud_run_client + + @property + def logging_client(self) -> LoggingClient: + """Get the Cloud Logging client. + + Returns: + The Cloud Logging client. + """ + if self._logging_client is None or self.connector_has_expired(): + credentials, project_id = self._get_credentials_and_project_id() + self._logging_client = LoggingClient( + project=project_id, credentials=credentials + ) + return self._logging_client + + @property + def secret_manager_client( + self, + ) -> secretmanager.SecretManagerServiceClient: + """Get the Secret Manager client. + + Returns: + The Secret Manager client. + """ + if self._secret_manager_client is None or self.connector_has_expired(): + credentials, _ = self._get_credentials_and_project_id() + self._secret_manager_client = ( + secretmanager.SecretManagerServiceClient( + credentials=credentials + ) + ) + return self._secret_manager_client + + def get_labels( + self, deployment: DeploymentResponse, settings: GCPDeployerSettings + ) -> Dict[str, str]: + """Get the labels for a deployment. + + Args: + deployment: The deployment. + settings: The deployer settings. + + Returns: + The labels for the deployment. + """ + return { + **settings.labels, + "zenml-deployment-id": str(deployment.id), + "zenml-deployment-name": deployment.name, + "zenml-deployer-name": str(self.name), + "zenml-deployer-id": str(self.id), + "managed-by": "zenml", + } + + def _sanitize_name( + self, + name: str, + random_suffix: str, + max_length: int = 63, + ) -> str: + """Sanitize a name to comply with GCP naming requirements. + + Common GCP naming requirements: + - Length: 1-max_length characters + - Characters: lowercase letters (a-z), numbers (0-9), hyphens (-) + - Must start and end with a letter or number + + Args: + name: The raw name to sanitize. + random_suffix: A random suffix to add to the name to ensure + uniqueness. + max_length: The maximum length of the name. + + Returns: + A sanitized name that complies with GCP requirements. + + Raises: + RuntimeError: If the random suffix is invalid. + ValueError: If the service name is invalid. + """ + if ( + not re.match(r"^[a-z0-9]+$", random_suffix) + or len(random_suffix) < 1 + ): + raise RuntimeError( + f"Invalid random suffix: {random_suffix}. Must contain only " + "lowercase letters and numbers and be at least 1 character " + "long." + ) + + # Convert to lowercase and replace all disallowed characters with + # hyphens + sanitized = re.sub(r"[^a-z0-9-]", "-", name.lower()) + + # Remove consecutive hyphens + sanitized = re.sub(r"-+", "-", sanitized) + + # Remove leading and trailing hyphens before truncating + sanitized = re.sub( + r"^[^a-zA-Z0-9]+|[^a-zA-Z0-9]+$", + "", + sanitized, + ) + + # Truncate to fit within max_length character limit including suffix + max_base_length = ( + max_length - len(random_suffix) - 1 # -1 for the hyphen + ) + if len(sanitized) > max_base_length: + sanitized = sanitized[:max_base_length] + + # Ensure it starts and ends with alphanumeric characters + sanitized = re.sub( + r"^[^a-zA-Z0-9]+|[^a-zA-Z0-9]+$", + "", + sanitized, + ) + + # Ensure we have at least one character after cleanup + if not sanitized: + raise ValueError( + f"Invalid name: {name}. Must contain at least one " + "alphanumeric character." + ) + + return f"{sanitized}-{random_suffix}" + + def _get_service_name( + self, deployment_name: str, deployment_id: UUID, prefix: str + ) -> str: + """Get the Cloud Run service name for a deployment. + + Args: + deployment_id: The deployment ID. + deployment_name: The deployment name. + prefix: The prefix to use for the service name. + + Returns: + The Cloud Run service name that complies with all naming requirements. + """ + deployment_id_short = str(deployment_id)[:8] + raw_name = f"{prefix}{deployment_name}" + + return self._sanitize_name( + raw_name, deployment_id_short, max_length=63 + ) + + def _get_secret_name( + self, + deployment_id: UUID, + env_var_name: str, + prefix: str, + ) -> str: + """Get the Secret Manager secret name for an environment variable. + + Args: + deployment_id: The deployment ID. + env_var_name: The environment variable name. + prefix: The prefix to use for the secret name. + + Returns: + The Secret Manager secret name. + """ + deployment_id_short = str(deployment_id)[:8] + raw_name = f"{prefix}{env_var_name}" + + return self._sanitize_name( + raw_name, deployment_id_short, max_length=255 + ) + + def _create_or_update_secret( + self, + secret_name: str, + secret_value: str, + project_id: str, + deployment: DeploymentResponse, + settings: GCPDeployerSettings, + ) -> secretmanager.Secret: + """Create or update a secret in Secret Manager. + + Args: + secret_name: The name of the secret. + secret_value: The value to store. + project_id: The GCP project ID. + deployment: The deployment. + settings: The deployer settings. + + Returns: + The full secret. + + Raises: + DeployerError: If secret creation/update fails. + """ + parent = f"projects/{project_id}" + secret_id = secret_name + secret_path = f"{parent}/secrets/{secret_id}" + + try: + try: + secret = self.secret_manager_client.get_secret( + name=secret_path + ) + logger.debug( + f"Secret {secret_name} already exists, adding new version" + ) + except google_exceptions.NotFound: + logger.debug(f"Creating new secret {secret_name}") + secret = secretmanager.Secret( + replication=secretmanager.Replication( + automatic=secretmanager.Replication.Automatic() + ), + labels=self.get_labels(deployment, settings), + ) + secret = self.secret_manager_client.create_secret( + parent=parent, secret_id=secret_id, secret=secret + ) + + payload = secretmanager.SecretPayload( + data=secret_value.encode("utf-8") + ) + version_response = self.secret_manager_client.add_secret_version( + parent=secret_path, payload=payload + ) + + logger.debug(f"Created secret version: {version_response.name}") + return secret + + except google_exceptions.GoogleAPICallError as e: + raise DeployerError( + f"Failed to create/update secret {secret_name}: {e}" + ) + + def _get_secrets( + self, deployment: DeploymentResponse + ) -> List[secretmanager.Secret]: + """Get the existing GCP Secret Manager secrets for a deployment. + + Args: + deployment: The deployment. + + Returns: + The list of existing GCP Secret Manager secrets for the + deployment. + """ + metadata = CloudRunDeploymentMetadata.from_deployment(deployment) + secrets: List[secretmanager.Secret] = [] + for secret_name in metadata.secrets: + try: + secret = self.secret_manager_client.get_secret( + name=secret_name + ) + secrets.append(secret) + except google_exceptions.NotFound: + continue + except google_exceptions.GoogleAPICallError: + logger.exception(f"Failed to get secret {secret_name}") + continue + return secrets + + def _delete_secret(self, secret_name: str, project_id: str) -> None: + """Delete a secret from Secret Manager. + + Args: + secret_name: The name of the secret to delete. + project_id: The GCP project ID. + """ + secret_path = f"projects/{project_id}/secrets/{secret_name}" + try: + self.secret_manager_client.delete_secret(name=secret_path) + logger.debug(f"Deleted secret {secret_path}") + except google_exceptions.NotFound: + logger.debug(f"Secret {secret_path} not found, skipping deletion") + except google_exceptions.GoogleAPICallError: + logger.exception(f"Failed to delete secret {secret_path}") + + def _cleanup_deployment_secrets( + self, + deployment: DeploymentResponse, + ) -> None: + """Clean up all secrets associated with a deployment. + + Args: + deployment: The deployment. + """ + secrets = self._get_secrets(deployment) + + for secret in secrets: + _, project_id, _, secret_name = secret.name.split("/") + self._delete_secret(secret_name, project_id) + + def _prepare_environment_variables( + self, + deployment: DeploymentResponse, + environment: Dict[str, str], + secrets: Dict[str, str], + settings: GCPDeployerSettings, + project_id: str, + ) -> Tuple[List[run_v2.EnvVar], List[secretmanager.Secret]]: + """Prepare environment variables for Cloud Run, handling secrets appropriately. + + Args: + deployment: The deployment. + environment: Regular environment variables. + secrets: Sensitive environment variables. + settings: The deployer settings. + project_id: The GCP project ID. + + Returns: + Tuple containing: + - List of Cloud Run environment variable configurations. + - List of active Secret Manager secrets. + """ + env_vars = [] + + merged_env = {**settings.environment_variables, **environment} + for key, value in merged_env.items(): + env_vars.append(run_v2.EnvVar(name=key, value=value)) + + active_secrets: List[secretmanager.Secret] = [] + if secrets: + if settings.use_secret_manager: + for key, value in secrets.items(): + secret_name = self._get_secret_name( + deployment.id, key.lower(), settings.secret_name_prefix + ) + + try: + active_secret = self._create_or_update_secret( + secret_name, + value, + project_id, + deployment, + settings, + ) + + # Create environment variable that references the secret + env_var = run_v2.EnvVar( + name=key, + value_source=run_v2.EnvVarSource( + secret_key_ref=run_v2.SecretKeySelector( + secret=secret_name, version="latest" + ) + ), + ) + env_vars.append(env_var) + active_secrets.append(active_secret) + + except Exception as e: + logger.warning( + f"Failed to create secret for {key}, falling back " + f"to direct env var: {e}" + ) + env_vars.append(run_v2.EnvVar(name=key, value=value)) + + metadata = CloudRunDeploymentMetadata.from_deployment( + deployment + ) + # Delete GCP secrets that are no longer needed + active_secret_names = [ + secret.name for secret in active_secrets + ] + for existing_secret_name in metadata.secrets: + if existing_secret_name not in active_secret_names: + _, project_id, _, secret_name = ( + existing_secret_name.split("/") + ) + self._delete_secret(secret_name, project_id) + else: + logger.warning( + "Storing secrets directly in environment variables. " + "Consider enabling use_secret_manager for better security." + ) + for key, value in secrets.items(): + env_vars.append(run_v2.EnvVar(name=key, value=value)) + + return env_vars, active_secrets + + def _get_service_path( + self, + service_name: str, + project_id: str, + location: str, + ) -> str: + """Get the full Cloud Run service path. + + Args: + service_name: The name of the Cloud Run service. + project_id: The GCP project ID. + location: The GCP location. + + Returns: + The full Cloud Run service path. + """ + return f"projects/{project_id}/locations/{location}/services/{service_name}" + + def _get_cloud_run_service( + self, deployment: DeploymentResponse + ) -> Optional[run_v2.Service]: + """Get an existing Cloud Run service for a deployment. + + Args: + deployment: The deployment. + + Returns: + The Cloud Run service, or None if it doesn't exist. + """ + existing_metadata = CloudRunDeploymentMetadata.from_deployment( + deployment + ) + + if ( + not existing_metadata.service_name + or not existing_metadata.location + or not existing_metadata.project_id + ): + return None + + service_path = self._get_service_path( + existing_metadata.service_name, + existing_metadata.project_id, + existing_metadata.location, + ) + + try: + return self.cloud_run_client.get_service(name=service_path) + except google_exceptions.NotFound: + return None + + def _get_service_operational_state( + self, + service: run_v2.Service, + project_id: str, + location: str, + secrets: List[secretmanager.Secret], + ) -> DeploymentOperationalState: + """Get the operational state of a Cloud Run service. + + Args: + service: The Cloud Run service. + project_id: The GCP project ID. + location: The GCP location. + secrets: The list of active Secret Manager secrets. + + Returns: + The operational state of the Cloud Run service. + """ + metadata = CloudRunDeploymentMetadata.from_cloud_run_service( + service, project_id, location, secrets + ) + + state = DeploymentOperationalState( + status=DeploymentStatus.UNKNOWN, + metadata=metadata.model_dump(exclude_none=True), + ) + + # This flag is set while the service is being reconciled + if service.reconciling: + state.status = DeploymentStatus.PENDING + else: + if ( + service.terminal_condition.state + == run_v2.Condition.State.CONDITION_SUCCEEDED + ): + state.status = DeploymentStatus.RUNNING + state.url = service.uri + elif ( + service.terminal_condition.state + == run_v2.Condition.State.CONDITION_FAILED + ): + state.status = DeploymentStatus.ERROR + elif service.terminal_condition.state in [ + run_v2.Condition.State.CONDITION_PENDING, + run_v2.Condition.State.CONDITION_RECONCILING, + ]: + state.status = DeploymentStatus.PENDING + else: + state.status = DeploymentStatus.UNKNOWN + + return state + + def _convert_resource_settings_to_gcp_format( + self, + resource_settings: ResourceSettings, + ) -> Tuple[str, str]: + """Convert ResourceSettings to GCP Cloud Run resource format. + + GCP Cloud Run CPU constraints: + - Fractional CPUs: 0.08 to < 1.0 (in increments of 0.01) + - Integer CPUs: 1, 2, 4, 6, or 8 (no fractional values allowed >= 1.0) + + Args: + resource_settings: The resource settings from pipeline configuration. + + Returns: + Tuple of (cpu, memory) in GCP Cloud Run format. + """ + cpu = DEFAULT_CPU + if resource_settings.cpu_count is not None: + cpu_count = resource_settings.cpu_count + + if cpu_count < 1.0: + # For values < 1.0, allow fractional CPUs + # Ensure minimum is 0.08 and round to 2 decimal places + cpu_count = max(0.08, round(cpu_count, 2)) + cpu = str(cpu_count) + else: + # For values >= 1.0, round up to the nearest valid integer + valid_cpu_values = [1, 2, 4, 6, 8] + rounded_cpu = math.ceil(cpu_count) + + # Find the smallest valid CPU value that satisfies the requirement + for valid_cpu in valid_cpu_values: + if valid_cpu >= rounded_cpu: + cpu = str(valid_cpu) + break + else: + # If requested CPU exceeds maximum, use maximum + cpu = str(valid_cpu_values[-1]) + + memory = DEFAULT_MEMORY + memory_value_gib = None + + if resource_settings.memory is not None: + memory_value_gib = resource_settings.get_memory(unit="GiB") + + final_memory_gib = self._validate_memory_for_cpu(cpu, memory_value_gib) + + if final_memory_gib is not None: + if final_memory_gib == int(final_memory_gib): + memory = f"{int(final_memory_gib)}Gi" + else: + memory = f"{final_memory_gib:.1f}Gi" + + return str(cpu), memory + + def _validate_memory_for_cpu( + self, cpu: str, memory_gib: Optional[float] + ) -> Optional[float]: + """Validate and adjust memory allocation based on CPU requirements. + + GCP Cloud Run has minimum memory requirements per CPU configuration: + - 1 CPU: 128 MiB minimum (0.125 GiB) + - 2 CPU: 128 MiB minimum (0.125 GiB) + - 4 CPU: 2 GiB minimum + - 6 CPU: 4 GiB minimum + - 8 CPU: 4 GiB minimum + + Args: + cpu: CPU allocation as string (e.g., "1", "2", "4") + memory_gib: Memory allocation in GiB (e.g., 2.0, 0.5, None) + + Returns: + Adjusted memory allocation in GiB that meets minimum requirements, or None if no memory specified + """ + if memory_gib is None: + return None + + min_memory_per_cpu_gib = { + 1: 0.125, # 128 MiB = 0.125 GiB + 2: 0.125, # 128 MiB = 0.125 GiB + 4: 2.0, # 2 GiB + 6: 4.0, # 4 GiB + 8: 4.0, # 4 GiB + } + + # Handle fractional CPUs (< 1.0) - use minimum for 1 CPU + cpu_float = float(cpu) + if cpu_float < 1.0: + cpu_int = 1 + else: + cpu_int = int(cpu_float) + + required_memory_gib = min_memory_per_cpu_gib.get(cpu_int, 0.125) + + return max(memory_gib, required_memory_gib) + + def _convert_scaling_settings_to_gcp_format( + self, + resource_settings: ResourceSettings, + ) -> Tuple[int, int]: + """Convert ResourceSettings scaling to GCP Cloud Run format. + + Args: + resource_settings: The resource settings from pipeline configuration. + + Returns: + Tuple of (min_instances, max_instances) for GCP Cloud Run. + """ + min_instances = DEFAULT_MIN_INSTANCES + if resource_settings.min_replicas is not None: + min_instances = resource_settings.min_replicas + + max_instances = DEFAULT_MAX_INSTANCES + if resource_settings.max_replicas is not None: + # ResourceSettings uses 0 to mean "no limit" + # GCP Cloud Run needs a specific value, so we use the platform maximum + if resource_settings.max_replicas == 0: + max_instances = GCP_CLOUD_RUN_MAX_INSTANCES + else: + max_instances = resource_settings.max_replicas + + return min_instances, max_instances + + def _convert_concurrency_settings_to_gcp_format( + self, + resource_settings: ResourceSettings, + ) -> int: + """Convert ResourceSettings concurrency to GCP Cloud Run format. + + Args: + resource_settings: The resource settings from pipeline configuration. + + Returns: + The concurrency setting for GCP Cloud Run. + """ + concurrency = DEFAULT_CONCURRENCY + if resource_settings.max_concurrency is not None: + concurrency = resource_settings.max_concurrency + + return concurrency + + def do_provision_deployment( + self, + deployment: DeploymentResponse, + stack: "Stack", + environment: Dict[str, str], + secrets: Dict[str, str], + timeout: int, + ) -> DeploymentOperationalState: + """Serve a pipeline as a Cloud Run service. + + Args: + deployment: The deployment to serve. + stack: The stack the pipeline will be deployed on. + environment: Environment variables to set. + secrets: Secret environment variables to set. + timeout: The maximum time in seconds to wait for the pipeline + deployment to be provisioned. + + Returns: + The operational state of the provisioned deployment. + + Raises: + DeploymentProvisionError: If the deployment fails. + DeployerError: If an unexpected error occurs. + """ + snapshot = deployment.snapshot + assert snapshot, "Pipeline snapshot not found" + + settings = cast( + GCPDeployerSettings, + self.get_settings(snapshot), + ) + + resource_settings = snapshot.pipeline_configuration.resource_settings + + cpu, memory = self._convert_resource_settings_to_gcp_format( + resource_settings, + ) + min_instances, max_instances = ( + self._convert_scaling_settings_to_gcp_format( + resource_settings, + ) + ) + concurrency = self._convert_concurrency_settings_to_gcp_format( + resource_settings, + ) + + project_id = self.project_id + + service_name = self._get_service_name( + deployment.name, deployment.id, settings.service_name_prefix + ) + + service_path = self._get_service_path( + service_name, project_id, settings.location + ) + + # If a previous deployment of the same deployment exists but with + # a different service name, location, or project, we need to clean up + # the old service. + existing_metadata = CloudRunDeploymentMetadata.from_deployment( + deployment + ) + + if ( + existing_metadata.service_name + and existing_metadata.location + and existing_metadata.project_id + ): + existing_service_path = self._get_service_path( + existing_metadata.service_name, + existing_metadata.project_id, + existing_metadata.location, + ) + if existing_service_path != service_path: + try: + self.do_deprovision_deployment(deployment, timeout) + except DeploymentNotFoundError: + logger.warning( + f"Deployment '{deployment.name}' not found, " + f"skipping deprovision of existing Cloud Run service" + ) + except DeployerError as e: + logger.warning( + f"Failed to deprovision existing Cloud Run service for " + f"deployment '{deployment.name}': {e}" + ) + + image = self.get_image(snapshot) + + entrypoint = DeploymentEntrypointConfiguration.get_entrypoint_command() + arguments = DeploymentEntrypointConfiguration.get_entrypoint_arguments( + **{ + DEPLOYMENT_ID_OPTION: deployment.id, + PORT_OPTION: settings.port, + AUTH_KEY_OPTION: deployment.auth_key, + } + ) + + env_vars, active_secrets = self._prepare_environment_variables( + deployment, environment, secrets, settings, project_id + ) + + resources = run_v2.ResourceRequirements( + limits={ + "cpu": cpu, + "memory": memory, + } + ) + + scaling = run_v2.RevisionScaling( + min_instance_count=min_instances, + max_instance_count=max_instances, + ) + + vpc_access = None + if settings.vpc_connector: + vpc_access = run_v2.VpcAccess(connector=settings.vpc_connector) + + container = run_v2.Container( + image=image, + command=entrypoint, + args=arguments, + env=env_vars, + resources=resources, + ports=[run_v2.ContainerPort(container_port=settings.port)], + ) + + template = run_v2.RevisionTemplate( + labels=settings.labels, + annotations=settings.annotations, + scaling=scaling, + vpc_access=vpc_access, + max_instance_request_concurrency=concurrency, + timeout=f"{settings.timeout_seconds}s", + service_account=settings.service_account, + containers=[container], + execution_environment=( + run_v2.ExecutionEnvironment.EXECUTION_ENVIRONMENT_GEN2 + if settings.execution_environment == "gen2" + else run_v2.ExecutionEnvironment.EXECUTION_ENVIRONMENT_GEN1 + ), + ) + + traffic = [] + for revision, percent in settings.traffic_allocation.items(): + if revision == "LATEST": + traffic.append( + run_v2.TrafficTarget( + type_=run_v2.TrafficTargetAllocationType.TRAFFIC_TARGET_ALLOCATION_TYPE_LATEST, + percent=percent, + ) + ) + else: + traffic.append( + run_v2.TrafficTarget( + revision=revision, + percent=percent, + ) + ) + + ingress_mapping = { + "all": run_v2.IngressTraffic.INGRESS_TRAFFIC_ALL, + "internal": run_v2.IngressTraffic.INGRESS_TRAFFIC_INTERNAL_ONLY, + "internal-and-cloud-load-balancing": run_v2.IngressTraffic.INGRESS_TRAFFIC_INTERNAL_LOAD_BALANCER, + } + ingress = ingress_mapping.get( + settings.ingress, run_v2.IngressTraffic.INGRESS_TRAFFIC_ALL + ) + + # Create the service (name should NOT be set for CreateServiceRequest) + service = run_v2.Service( + labels=self.get_labels(deployment, settings), + annotations=settings.annotations, + template=template, + traffic=traffic, + ingress=ingress, + invoker_iam_disabled=settings.allow_unauthenticated, + ) + + try: + existing_service = None + try: + existing_service = self.cloud_run_client.get_service( + name=service_path + ) + except google_exceptions.NotFound: + pass + + if existing_service: + # Update existing service - need to set the name in the + # CreateServiceRequest for updates + service.name = service_path + logger.debug( + f"Updating existing Cloud Run service for pipeline " + f"deployment '{deployment.name}'" + ) + self.cloud_run_client.update_service(service=service) + else: + logger.debug( + f"Creating new Cloud Run service for deployment " + f"'{deployment.name}'" + ) + parent = f"projects/{project_id}/locations/{settings.location}" + # Create new service - name must not be set in the + # CreateServiceRequest, using service_id instead + self.cloud_run_client.create_service( + parent=parent, service=service, service_id=service_name + ) + # Adding the name here for the operational state retrieval + service.name = service_path + + return self._get_service_operational_state( + service, project_id, settings.location, active_secrets + ) + + except google_exceptions.GoogleAPICallError as e: + raise DeploymentProvisionError( + f"Failed to deploy Cloud Run service for deployment " + f"'{deployment.name}': {e}" + ) + except Exception as e: + raise DeployerError( + f"Unexpected error while provisioning deployment " + f"'{deployment.name}': {e}" + ) + + def do_get_deployment_state( + self, + deployment: DeploymentResponse, + ) -> DeploymentOperationalState: + """Get information about a Cloud Run deployment. + + Args: + deployment: The deployment to get information about. + + Returns: + The operational state of the deployment. + + Raises: + DeploymentNotFoundError: If the deployment is not found. + RuntimeError: If the project ID or location is not found in the + deployment metadata. + """ + service = self._get_cloud_run_service(deployment) + + if service is None: + raise DeploymentNotFoundError( + f"Cloud Run service for deployment '{deployment.name}' " + "not found" + ) + + existing_metadata = CloudRunDeploymentMetadata.from_deployment( + deployment + ) + + if not existing_metadata.project_id or not existing_metadata.location: + raise RuntimeError( + f"Project ID or location not found in deployment metadata for " + f"deployment '{deployment.name}'" + ) + + existing_secrets = self._get_secrets(deployment) + + return self._get_service_operational_state( + service, + existing_metadata.project_id, + existing_metadata.location, + existing_secrets, + ) + + def do_get_deployment_state_logs( + self, + deployment: DeploymentResponse, + follow: bool = False, + tail: Optional[int] = None, + ) -> Generator[str, bool, None]: + """Get the logs of a Cloud Run deployment. + + Args: + deployment: The deployment to get the logs of. + follow: If True, stream logs as they are written. + tail: Only retrieve the last NUM lines of log output. + + Yields: + The logs of the deployment. + + Raises: + NotImplementedError: If log following is requested. + DeploymentLogsNotFoundError: If the logs are not found. + DeployerError: If an unexpected error occurs. + """ + if follow: + raise NotImplementedError( + "Log following is not yet implemented for Cloud Run deployer" + ) + + try: + existing_metadata = CloudRunDeploymentMetadata.from_deployment( + deployment + ) + service_name = existing_metadata.service_name + if not service_name: + assert deployment.snapshot, ( + "Pipeline snapshot not set for deployment" + ) + settings = cast( + GCPDeployerSettings, + self.get_settings(deployment.snapshot), + ) + # We rely on the running service name, if a service is currently + # active. If not, we fall back to the service name generated + # from the current configuration. + service_name = self._get_service_name( + deployment.name, + deployment.id, + settings.service_name_prefix, + ) + + filter_str = ( + 'resource.type="cloud_run_revision" AND ' + f'resource.labels.service_name="{service_name}"' + ) + + entries = self.logging_client.list_entries(filter_=filter_str) + + log_lines = [] + for entry in entries: + if hasattr(entry, "payload") and entry.payload: + timestamp = ( + entry.timestamp.isoformat() if entry.timestamp else "" + ) + log_line = f"[{timestamp}] {entry.payload}" + log_lines.append(log_line) + + if tail is not None and tail > 0: + log_lines = log_lines[-tail:] + + for log_line in log_lines: + yield log_line + + except google_exceptions.GoogleAPICallError as e: + raise DeploymentLogsNotFoundError( + f"Failed to retrieve logs for deployment " + f"'{deployment.name}': {e}" + ) + except Exception as e: + raise DeployerError( + f"Unexpected error while retrieving logs for deployment " + f"'{deployment.name}': {e}" + ) + + def do_deprovision_deployment( + self, + deployment: DeploymentResponse, + timeout: int, + ) -> Optional[DeploymentOperationalState]: + """Deprovision a Cloud Run deployment. + + Args: + deployment: The deployment to deprovision. + timeout: The maximum time in seconds to wait for the pipeline + deployment to be deprovisioned. + + Returns: + The operational state of the deprovisioned deployment, or None if + deletion is completed immediately. + + Raises: + DeploymentNotFoundError: If the deployment is not found. + DeploymentDeprovisionError: If the deprovision fails. + DeployerError: If an unexpected error occurs. + RuntimeError: If the service name, project ID or location is not + found in the deployment metadata. + """ + service = self._get_cloud_run_service(deployment) + if service is None: + raise DeploymentNotFoundError( + f"Cloud Run service for deployment '{deployment.name}' not found" + ) + + try: + existing_metadata = CloudRunDeploymentMetadata.from_deployment( + deployment + ) + if ( + not existing_metadata.service_name + or not existing_metadata.project_id + or not existing_metadata.location + ): + raise RuntimeError( + f"Service name, project ID or location not found in " + f"deployment metadata for deployment '{deployment.name}'" + ) + + service_path = self._get_service_path( + existing_metadata.service_name, + existing_metadata.project_id, + existing_metadata.location, + ) + + logger.debug( + f"Deleting Cloud Run service for deployment '{deployment.name}'" + ) + + self.cloud_run_client.delete_service(name=service_path) + + self._cleanup_deployment_secrets(deployment) + + except google_exceptions.NotFound: + raise DeploymentNotFoundError( + f"Cloud Run service for deployment '{deployment.name}' not found" + ) + except google_exceptions.GoogleAPICallError as e: + raise DeploymentDeprovisionError( + f"Failed to delete Cloud Run service for deployment " + f"'{deployment.name}': {e}" + ) + except Exception as e: + raise DeployerError( + f"Unexpected error while deleting deployment " + f"'{deployment.name}': {e}" + ) + + return self.do_get_deployment_state(deployment) diff --git a/src/zenml/integrations/gcp/flavors/__init__.py b/src/zenml/integrations/gcp/flavors/__init__.py index e70f4937594..a920aa14242 100644 --- a/src/zenml/integrations/gcp/flavors/__init__.py +++ b/src/zenml/integrations/gcp/flavors/__init__.py @@ -17,6 +17,10 @@ GCPArtifactStoreConfig, GCPArtifactStoreFlavor, ) +from zenml.integrations.gcp.flavors.gcp_deployer_flavor import ( + GCPDeployerConfig, + GCPDeployerFlavor, +) from zenml.integrations.gcp.flavors.gcp_image_builder_flavor import ( GCPImageBuilderConfig, GCPImageBuilderFlavor, @@ -37,6 +41,8 @@ __all__ = [ "GCPArtifactStoreFlavor", "GCPArtifactStoreConfig", + "GCPDeployerFlavor", + "GCPDeployerConfig", "GCPImageBuilderFlavor", "GCPImageBuilderConfig", "VertexExperimentTrackerFlavor", diff --git a/src/zenml/integrations/gcp/flavors/gcp_deployer_flavor.py b/src/zenml/integrations/gcp/flavors/gcp_deployer_flavor.py new file mode 100644 index 00000000000..d8c4c660b37 --- /dev/null +++ b/src/zenml/integrations/gcp/flavors/gcp_deployer_flavor.py @@ -0,0 +1,233 @@ +# Copyright (c) ZenML GmbH 2022. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""GCP Cloud Run deployer flavor.""" + +from typing import TYPE_CHECKING, Dict, Optional, Type + +from pydantic import Field + +from zenml.deployers.base_deployer import ( + BaseDeployerConfig, + BaseDeployerFlavor, + BaseDeployerSettings, +) +from zenml.integrations.gcp import ( + GCP_DEPLOYER_FLAVOR, + GCP_RESOURCE_TYPE, +) +from zenml.integrations.gcp.google_credentials_mixin import ( + GoogleCredentialsConfigMixin, +) +from zenml.models import ServiceConnectorRequirements + +if TYPE_CHECKING: + from zenml.integrations.gcp.deployers import GCPDeployer + + +class GCPDeployerSettings(BaseDeployerSettings): + """Settings for the GCP Cloud Run deployer.""" + + location: str = Field( + default="europe-west3", + description="Name of GCP region where the pipeline will be deployed. " + "Cloud Run is available in specific regions: " + "https://cloud.google.com/run/docs/locations", + ) + service_name_prefix: str = Field( + default="zenml-", + description="Prefix for service names in Cloud Run to avoid naming " + "conflicts.", + ) + + # Timeout and execution configuration + timeout_seconds: int = Field( + default=300, + ge=1, + le=3600, + description="Request timeout in seconds. Maximum is 3600 (1 hour).", + ) + + # Networking and security + ingress: str = Field( + default="all", + description="Ingress settings for the service. " + "Options: 'all', 'internal', 'internal-and-cloud-load-balancing'.", + ) + + vpc_connector: Optional[str] = Field( + default=None, + description="VPC connector for private networking. " + "Format: projects/PROJECT_ID/locations/LOCATION/connectors/CONNECTOR_NAME", + ) + + # Service account and IAM + service_account: Optional[str] = Field( + default=None, + description="Service account email to run the Cloud Run service. " + "If not specified, uses the default Compute Engine service account.", + ) + + # Environment and configuration + environment_variables: Dict[str, str] = Field( + default_factory=dict, + description="Environment variables to set in the Cloud Run service.", + ) + + # Labels and annotations + labels: Dict[str, str] = Field( + default_factory=dict, + description="Labels to apply to the Cloud Run service.", + ) + + annotations: Dict[str, str] = Field( + default_factory=dict, + description="Annotations to apply to the Cloud Run service.", + ) + + # Cloud Run specific settings + execution_environment: str = Field( + default="gen2", + description="Execution environment generation. Options: 'gen1', 'gen2'.", + ) + + port: int = Field( + default=8080, + ge=1, + le=65535, + description="Port on which the container listens for requests.", + ) + + # Deployment configuration + traffic_allocation: Dict[str, int] = Field( + default_factory=lambda: {"LATEST": 100}, + description="Traffic allocation between revisions. " + "Keys are revision names or 'LATEST', values are percentages.", + ) + + allow_unauthenticated: bool = Field( + default=True, + description="Whether to allow unauthenticated requests to the service.", + ) + + # Secret management configuration + use_secret_manager: bool = Field( + default=True, + description="Whether to store sensitive environment variables in GCP " + "Secret Manager instead of directly in the Cloud Run service " + "configuration.", + ) + + secret_name_prefix: str = Field( + default="zenml-", + description="Prefix for secret names in Secret Manager to avoid naming " + "conflicts.", + ) + + +class GCPDeployerConfig( + BaseDeployerConfig, + GoogleCredentialsConfigMixin, + GCPDeployerSettings, +): + """Configuration for the GCP Cloud Run deployer.""" + + @property + def is_remote(self) -> bool: + """Checks if this stack component is running remotely. + + This designation is used to determine if the stack component can be + used with a local ZenML database or if it requires a remote ZenML + server. + + Returns: + True if this config is for a remote component, False otherwise. + """ + return True + + +class GCPDeployerFlavor(BaseDeployerFlavor): + """GCP Cloud Run deployer flavor.""" + + @property + def name(self) -> str: + """Name of the orchestrator flavor. + + Returns: + Name of the orchestrator flavor. + """ + return GCP_DEPLOYER_FLAVOR + + @property + def service_connector_requirements( + self, + ) -> Optional[ServiceConnectorRequirements]: + """Service connector resource requirements for service connectors. + + Specifies resource requirements that are used to filter the available + service connector types that are compatible with this flavor. + + Returns: + Requirements for compatible service connectors, if a service + connector is required for this flavor. + """ + return ServiceConnectorRequirements( + resource_type=GCP_RESOURCE_TYPE, + ) + + @property + def docs_url(self) -> Optional[str]: + """A url to point at docs explaining this flavor. + + Returns: + A flavor docs url. + """ + return self.generate_default_docs_url() + + @property + def sdk_docs_url(self) -> Optional[str]: + """A url to point at SDK docs explaining this flavor. + + Returns: + A flavor SDK docs url. + """ + return self.generate_default_sdk_docs_url() + + @property + def logo_url(self) -> str: + """A url to represent the flavor in the dashboard. + + Returns: + The flavor logo. + """ + return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/deployer/google-cloud-run.svg" + + @property + def config_class(self) -> Type[GCPDeployerConfig]: + """Returns the GCPDeployerConfig config class. + + Returns: + The config class. + """ + return GCPDeployerConfig + + @property + def implementation_class(self) -> Type["GCPDeployer"]: + """Implementation class for this flavor. + + Returns: + Implementation class for this flavor. + """ + from zenml.integrations.gcp.deployers import GCPDeployer + + return GCPDeployer diff --git a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py index 22b3b796421..f3c05d976bc 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +++ b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py @@ -300,7 +300,7 @@ def main() -> None: step_command = StepEntrypointConfiguration.get_entrypoint_command() mount_local_stores = active_stack.orchestrator.config.is_local - shared_env = get_config_environment_vars() + shared_env, secrets = get_config_environment_vars() shared_env[ENV_ZENML_KUBERNETES_RUN_ID] = orchestrator_run_id try: @@ -426,6 +426,8 @@ def start_step_job(node: Node) -> NodeStatus: }, } ) + else: + step_env.update(secrets) pod_manifest = build_pod_manifest( pod_name=None, diff --git a/src/zenml/integrations/skypilot/orchestrators/skypilot_orchestrator_entrypoint.py b/src/zenml/integrations/skypilot/orchestrators/skypilot_orchestrator_entrypoint.py index 25bbd596e22..54f7e2c7d20 100644 --- a/src/zenml/integrations/skypilot/orchestrators/skypilot_orchestrator_entrypoint.py +++ b/src/zenml/integrations/skypilot/orchestrators/skypilot_orchestrator_entrypoint.py @@ -171,7 +171,7 @@ def main() -> None: logger.info("Fetching pipeline run: %s", run.id) - shared_env = get_config_environment_vars() + shared_env, secrets = get_config_environment_vars() shared_env[ENV_ZENML_SKYPILOT_ORCHESTRATOR_RUN_ID] = ( orchestrator_run_id ) @@ -211,6 +211,8 @@ def run_step_on_skypilot_vm(step_name: str) -> None: step_config=step.config, stack=active_stack ) ) + # For now, we don't support separating secrets from environment + step_env.update(secrets) # Create the Docker run command run_command = create_docker_run_command( diff --git a/src/zenml/materializers/__init__.py b/src/zenml/materializers/__init__.py index 317903e262a..3fad68fd0af 100644 --- a/src/zenml/materializers/__init__.py +++ b/src/zenml/materializers/__init__.py @@ -26,6 +26,7 @@ from zenml.materializers.cloudpickle_materializer import ( CloudpickleMaterializer, ) +from zenml.materializers.in_memory_materializer import InMemoryMaterializer from zenml.materializers.path_materializer import PathMaterializer from zenml.materializers.pydantic_materializer import PydanticMaterializer from zenml.materializers.service_materializer import ServiceMaterializer @@ -39,6 +40,7 @@ "BuiltInMaterializer", "BytesMaterializer", "CloudpickleMaterializer", + "InMemoryMaterializer", "StructuredStringMaterializer", "PydanticMaterializer", "ServiceMaterializer", diff --git a/src/zenml/materializers/in_memory_materializer.py b/src/zenml/materializers/in_memory_materializer.py new file mode 100644 index 00000000000..5d9f2ceb470 --- /dev/null +++ b/src/zenml/materializers/in_memory_materializer.py @@ -0,0 +1,104 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""InMemory materializer that stores all artifacts in serving runtime memory.""" + +from typing import ( + Any, + ClassVar, + Dict, + Optional, + Tuple, + Type, +) + +from zenml.enums import ArtifactType +from zenml.materializers.base_materializer import BaseMaterializer +from zenml.metadata.metadata_types import MetadataType + + +class InMemoryMaterializer(BaseMaterializer): + """Simple materializer for serving mode - stores everything in memory. + + This materializer works for any data type and stores/loads data + to/from the serving runtime's in-memory storage using the original URI as the key. + No metadata extraction or visualizations in serving mode. + """ + + ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (object,) + ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA + SKIP_REGISTRATION: ClassVar[bool] = True + + def save(self, data: Any) -> None: + """Store data in serving runtime memory using original URI as key. + + Args: + data: The data to save. + """ + from zenml.deployers.server import runtime + + self.validate_save_type_compatibility(type(data)) + runtime.put_in_memory_data(self.uri, data) + + def load(self, data_type: Type[Any]) -> Any: + """Load data from serving runtime memory using original URI as key. + + Args: + data_type: The type of the data to load. + + Returns: + The loaded data. + + Raises: + RuntimeError: If no data is available in memory. + """ + from zenml.deployers.server import runtime + + self.validate_load_type_compatibility(data_type) + value = runtime.get_in_memory_data(self.uri) + if value is None: + raise RuntimeError(f"No data available for URI `{self.uri}`") + return value + + def extract_full_metadata(self, data: Any) -> Dict[str, MetadataType]: + """No metadata extraction in serving mode. + + Args: + data: The data to extract metadata from. + + Returns: + Empty metadata dictionary. + """ + return {} + + def save_visualizations(self, data: Any) -> Dict[str, Any]: + """No visualizations in serving mode. + + Args: + data: The data to save visualizations for. + + Returns: + Empty visualizations dictionary. + """ + return {} + + def compute_content_hash(self, data: Any) -> Optional[str]: + """No content hash computation in serving mode. + + Args: + data: The data to compute the content hash of. + + Returns: + None. + """ + return None diff --git a/src/zenml/models/__init__.py b/src/zenml/models/__init__.py index 6e742b5a489..aea28ae24a0 100644 --- a/src/zenml/models/__init__.py +++ b/src/zenml/models/__init__.py @@ -154,6 +154,16 @@ EventSourceFlavorResponseMetadata, EventSourceFlavorResponseResources, ) +from zenml.models.v2.core.deployment import ( + DeploymentRequest, + DeploymentUpdate, + DeploymentFilter, + DeploymentOperationalState, + DeploymentResponse, + DeploymentResponseBody, + DeploymentResponseMetadata, + DeploymentResponseResources, +) from zenml.models.v2.core.device import ( OAuthDeviceUpdate, OAuthDeviceFilter, @@ -464,6 +474,10 @@ ComponentResponseBody.model_rebuild() ComponentResponseMetadata.model_rebuild() ComponentResponseResources.model_rebuild() +DeploymentRequest.model_rebuild() +DeploymentResponseBody.model_rebuild() +DeploymentResponseMetadata.model_rebuild() +DeploymentResponseResources.model_rebuild() EventSourceResponseBody.model_rebuild() EventSourceResponseMetadata.model_rebuild() EventSourceResponseResources.model_rebuild() @@ -631,6 +645,14 @@ "ComponentResponseResources", "DefaultComponentRequest", "DefaultStackRequest", + "DeploymentRequest", + "DeploymentUpdate", + "DeploymentFilter", + "DeploymentOperationalState", + "DeploymentResponse", + "DeploymentResponseBody", + "DeploymentResponseMetadata", + "DeploymentResponseResources", "EventSourceFlavorResponse", "EventSourceFlavorResponseBody", "EventSourceFlavorResponseMetadata", diff --git a/src/zenml/models/v2/base/scoped.py b/src/zenml/models/v2/base/scoped.py index 2eea88ae56b..8f15f7049a8 100644 --- a/src/zenml/models/v2/base/scoped.py +++ b/src/zenml/models/v2/base/scoped.py @@ -555,6 +555,7 @@ def apply_sorting( from zenml.zen_stores.schemas import ( ArtifactSchema, ArtifactVersionSchema, + DeploymentSchema, ModelSchema, ModelVersionSchema, PipelineRunSchema, @@ -574,6 +575,7 @@ def apply_sorting( PipelineRunSchema: TaggableResourceTypes.PIPELINE_RUN, RunTemplateSchema: TaggableResourceTypes.RUN_TEMPLATE, PipelineSnapshotSchema: TaggableResourceTypes.PIPELINE_SNAPSHOT, + DeploymentSchema: TaggableResourceTypes.DEPLOYMENT, } sorted_tags = ( diff --git a/src/zenml/models/v2/core/deployment.py b/src/zenml/models/v2/core/deployment.py new file mode 100644 index 00000000000..872f0a9b431 --- /dev/null +++ b/src/zenml/models/v2/core/deployment.py @@ -0,0 +1,411 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Models representing Pipeline Deployments.""" + +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Optional, + Type, + TypeVar, + Union, +) +from uuid import UUID + +from pydantic import BaseModel, Field + +from zenml.constants import STR_FIELD_MAX_LENGTH +from zenml.enums import DeploymentStatus +from zenml.models.v2.base.base import BaseUpdate +from zenml.models.v2.base.scoped import ( + ProjectScopedFilter, + ProjectScopedRequest, + ProjectScopedResponse, + ProjectScopedResponseBody, + ProjectScopedResponseMetadata, + ProjectScopedResponseResources, + TaggableFilter, +) +from zenml.utils.tag_utils import Tag + +if TYPE_CHECKING: + from sqlalchemy.sql.elements import ColumnElement + + from zenml.models.v2.core.component import ComponentResponse + from zenml.models.v2.core.pipeline import PipelineResponse + from zenml.models.v2.core.pipeline_snapshot import ( + PipelineSnapshotResponse, + ) + from zenml.models.v2.core.tag import TagResponse + from zenml.zen_stores.schemas.base_schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) + + +class DeploymentOperationalState(BaseModel): + """Operational state of a deployment.""" + + status: DeploymentStatus = Field(default=DeploymentStatus.UNKNOWN) + url: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + + +# ------------------ Request Model ------------------ + + +class DeploymentRequest(ProjectScopedRequest): + """Request model for deployments.""" + + name: str = Field( + title="The name of the deployment.", + description="A unique name for the deployment within the project.", + max_length=STR_FIELD_MAX_LENGTH, + ) + snapshot_id: UUID = Field( + title="The pipeline snapshot ID.", + description="The ID of the pipeline snapshot associated with the " + "deployment.", + ) + deployer_id: UUID = Field( + title="The deployer ID.", + description="The ID of the deployer component managing this deployment.", + ) + auth_key: Optional[str] = Field( + default=None, + title="The auth key of the deployment.", + description="The auth key of the deployment.", + ) + tags: Optional[List[Union[str, Tag]]] = Field( + default=None, + title="Tags of the deployment.", + ) + + +# ------------------ Update Model ------------------ + + +class DeploymentUpdate(BaseUpdate): + """Update model for deployments.""" + + name: Optional[str] = Field( + default=None, + title="The new name of the deployment.", + max_length=STR_FIELD_MAX_LENGTH, + ) + snapshot_id: Optional[UUID] = Field( + default=None, + title="New pipeline snapshot ID.", + ) + url: Optional[str] = Field( + default=None, + title="The new URL of the deployment.", + ) + status: Optional[DeploymentStatus] = Field( + default=None, + title="The new status of the deployment.", + ) + deployment_metadata: Optional[Dict[str, Any]] = Field( + default=None, + title="The new metadata of the deployment.", + ) + auth_key: Optional[str] = Field( + default=None, + title="The new auth key of the deployment.", + ) + add_tags: Optional[List[str]] = Field( + default=None, title="New tags to add to the deployment." + ) + remove_tags: Optional[List[str]] = Field( + default=None, title="Tags to remove from the deployment." + ) + + @classmethod + def from_operational_state( + cls, operational_state: DeploymentOperationalState + ) -> "DeploymentUpdate": + """Create an update from an operational state. + + Args: + operational_state: The operational state to create an update from. + + Returns: + The update. + """ + return cls( + status=operational_state.status, + url=operational_state.url, + deployment_metadata=operational_state.metadata, + ) + + +# ------------------ Response Model ------------------ + + +class DeploymentResponseBody(ProjectScopedResponseBody): + """Response body for deployments.""" + + url: Optional[str] = Field( + default=None, + title="The URL of the deployment.", + description="The HTTP URL where the deployment can be accessed.", + ) + status: Optional[DeploymentStatus] = Field( + default=None, + title="The status of the deployment.", + description="Current operational status of the deployment.", + ) + + +class DeploymentResponseMetadata(ProjectScopedResponseMetadata): + """Response metadata for deployments.""" + + deployment_metadata: Dict[str, Any] = Field( + title="The metadata of the deployment.", + ) + auth_key: Optional[str] = Field( + default=None, + title="The auth key of the deployment.", + description="The auth key of the deployment.", + ) + + +class DeploymentResponseResources(ProjectScopedResponseResources): + """Response resources for deployments.""" + + snapshot: Optional["PipelineSnapshotResponse"] = Field( + default=None, + title="The pipeline snapshot.", + description="The pipeline snapshot being deployed.", + ) + deployer: Optional["ComponentResponse"] = Field( + default=None, + title="The deployer.", + description="The deployer component managing this deployment.", + ) + pipeline: Optional["PipelineResponse"] = Field( + default=None, + title="The pipeline.", + description="The pipeline being deployed.", + ) + tags: List["TagResponse"] = Field( + title="Tags associated with the deployment.", + ) + + +class DeploymentResponse( + ProjectScopedResponse[ + DeploymentResponseBody, + DeploymentResponseMetadata, + DeploymentResponseResources, + ] +): + """Response model for deployments.""" + + name: str = Field( + title="The name of the deployment.", + max_length=STR_FIELD_MAX_LENGTH, + ) + + def get_hydrated_version(self) -> "DeploymentResponse": + """Get the hydrated version of this deployment. + + Returns: + an instance of the same entity with the metadata and resources fields + attached. + """ + from zenml.client import Client + + client = Client() + return client.get_deployment(self.id) + + # Helper properties + @property + def url(self) -> Optional[str]: + """The URL of the deployment. + + Returns: + The URL of the deployment. + """ + return self.get_body().url + + @property + def status(self) -> Optional[DeploymentStatus]: + """The status of the deployment. + + Returns: + The status of the deployment. + """ + return self.get_body().status + + @property + def deployment_metadata(self) -> Dict[str, Any]: + """The metadata of the deployment. + + Returns: + The metadata of the deployment. + """ + return self.get_metadata().deployment_metadata + + @property + def auth_key(self) -> Optional[str]: + """The auth key of the deployment. + + Returns: + The auth key of the deployment. + """ + return self.get_metadata().auth_key + + @property + def snapshot(self) -> Optional["PipelineSnapshotResponse"]: + """The pipeline snapshot. + + Returns: + The pipeline snapshot. + """ + return self.get_resources().snapshot + + @property + def deployer(self) -> Optional["ComponentResponse"]: + """The deployer. + + Returns: + The deployer. + """ + return self.get_resources().deployer + + @property + def pipeline(self) -> Optional["PipelineResponse"]: + """The pipeline. + + Returns: + The pipeline. + """ + return self.get_resources().pipeline + + def tags(self) -> List["TagResponse"]: + """The tags of the deployment. + + Returns: + The tags of the deployment. + """ + return self.get_resources().tags + + @property + def snapshot_id(self) -> Optional[UUID]: + """The pipeline snapshot ID. + + Returns: + The pipeline snapshot ID. + """ + snapshot = self.get_resources().snapshot + if snapshot: + return snapshot.id + return None + + @property + def deployer_id(self) -> Optional[UUID]: + """The deployer ID. + + Returns: + The deployer ID. + """ + deployer = self.get_resources().deployer + if deployer: + return deployer.id + return None + + +# ------------------ Filter Model ------------------ + + +class DeploymentFilter(ProjectScopedFilter, TaggableFilter): + """Model to enable advanced filtering of deployments.""" + + CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + *ProjectScopedFilter.CUSTOM_SORTING_OPTIONS, + *TaggableFilter.CUSTOM_SORTING_OPTIONS, + ] + FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + *ProjectScopedFilter.FILTER_EXCLUDE_FIELDS, + *TaggableFilter.FILTER_EXCLUDE_FIELDS, + "pipeline", + ] + CLI_EXCLUDE_FIELDS = [ + *ProjectScopedFilter.CLI_EXCLUDE_FIELDS, + *TaggableFilter.CLI_EXCLUDE_FIELDS, + ] + + name: Optional[str] = Field( + default=None, + description="Name of the deployment.", + ) + url: Optional[str] = Field( + default=None, + description="URL of the deployment.", + ) + status: Optional[str] = Field( + default=None, + description="Status of the deployment.", + ) + pipeline: Optional[Union[UUID, str]] = Field( + default=None, + description="Pipeline associated with the deployment.", + union_mode="left_to_right", + ) + snapshot_id: Optional[Union[UUID, str]] = Field( + default=None, + description="Pipeline snapshot ID associated with the deployment.", + union_mode="left_to_right", + ) + deployer_id: Optional[Union[UUID, str]] = Field( + default=None, + description="Deployer ID managing the deployment.", + union_mode="left_to_right", + ) + + def get_custom_filters( + self, table: Type["AnySchema"] + ) -> List["ColumnElement[bool]"]: + """Get custom filters. + + Args: + table: The query table. + + Returns: + A list of custom filters. + """ + from sqlmodel import and_ + + from zenml.zen_stores.schemas import ( + DeploymentSchema, + PipelineSchema, + PipelineSnapshotSchema, + ) + + custom_filters = super().get_custom_filters(table) + + if self.pipeline: + pipeline_filter = and_( + DeploymentSchema.snapshot_id == PipelineSnapshotSchema.id, + PipelineSnapshotSchema.pipeline_id == PipelineSchema.id, + self.generate_name_or_id_query_conditions( + value=self.pipeline, table=PipelineSchema + ), + ) + custom_filters.append(pipeline_filter) + + return custom_filters diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index b0695c2b84d..d064e60b08b 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -82,6 +82,10 @@ class PipelineRunTriggerInfo(BaseZenModel): default=None, title="The ID of the step run that triggered the pipeline run.", ) + deployment_id: Optional[UUID] = Field( + default=None, + title="The ID of the deployment that triggered the pipeline run.", + ) class PipelineRunRequest(ProjectScopedRequest): @@ -636,6 +640,7 @@ class PipelineRunFilter( "pipeline_name", "templatable", "triggered_by_step_run_id", + "triggered_by_deployment_id", "linked_to_model_version_id", ] CLI_EXCLUDE_FIELDS = [ @@ -764,6 +769,11 @@ class PipelineRunFilter( description="The ID of the step run that triggered this pipeline run.", union_mode="left_to_right", ) + triggered_by_deployment_id: Optional[Union[UUID, str]] = Field( + default=None, + description="The ID of the deployment that triggered this pipeline run.", + union_mode="left_to_right", + ) model_config = ConfigDict(protected_namespaces=()) def get_custom_filters( @@ -785,6 +795,7 @@ def get_custom_filters( from zenml.zen_stores.schemas import ( CodeReferenceSchema, CodeRepositorySchema, + DeploymentSchema, ModelSchema, ModelVersionPipelineRunSchema, ModelVersionSchema, @@ -970,6 +981,19 @@ def get_custom_filters( ) custom_filters.append(trigger_filter) + if self.triggered_by_deployment_id: + trigger_filter = and_( + PipelineRunSchema.triggered_by == DeploymentSchema.id, + PipelineRunSchema.triggered_by_type + == PipelineRunTriggeredByType.DEPLOYMENT.value, + self.generate_custom_query_conditions_for_column( + value=self.triggered_by_deployment_id, + table=DeploymentSchema, + column="id", + ), + ) + custom_filters.append(trigger_filter) + if self.linked_to_model_version_id: linked_to_model_version_filter = and_( PipelineRunSchema.id diff --git a/src/zenml/models/v2/core/pipeline_snapshot.py b/src/zenml/models/v2/core/pipeline_snapshot.py index c290cda6a23..09b0e9162dd 100644 --- a/src/zenml/models/v2/core/pipeline_snapshot.py +++ b/src/zenml/models/v2/core/pipeline_snapshot.py @@ -33,7 +33,7 @@ from zenml.config.pipeline_spec import PipelineSpec from zenml.config.step_configurations import Step from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH -from zenml.enums import ExecutionStatus +from zenml.enums import ExecutionStatus, StackComponentType from zenml.models.v2.base.base import BaseUpdate, BaseZenModel from zenml.models.v2.base.scoped import ( ProjectScopedFilter, @@ -48,6 +48,7 @@ CodeReferenceRequest, CodeReferenceResponse, ) +from zenml.models.v2.core.deployment import DeploymentResponse from zenml.models.v2.core.pipeline import PipelineResponse from zenml.models.v2.core.pipeline_build import ( PipelineBuildResponse, @@ -153,7 +154,6 @@ class PipelineSnapshotRequest(PipelineSnapshotBase, ProjectScopedRequest): template: Optional[UUID] = Field( default=None, description="DEPRECATED: Template used for the snapshot.", - deprecated=True, ) source_snapshot: Optional[UUID] = Field( default=None, @@ -230,6 +230,9 @@ class PipelineSnapshotResponseBody(ProjectScopedResponseBody): runnable: bool = Field( title="If a run can be started from the snapshot.", ) + deployable: bool = Field( + title="If the snapshot can be deployed.", + ) class PipelineSnapshotResponseMetadata(ProjectScopedResponseMetadata): @@ -311,6 +314,10 @@ class PipelineSnapshotResponseResources(ProjectScopedResponseResources): default=None, title="The code reference associated with the snapshot.", ) + deployment: Optional[DeploymentResponse] = Field( + default=None, + title="The deployment associated with the snapshot.", + ) tags: List[TagResponse] = Field( default=[], title="Tags associated with the snapshot.", @@ -365,6 +372,15 @@ def runnable(self) -> bool: """ return self.get_body().runnable + @property + def deployable(self) -> bool: + """The `deployable` property. + + Returns: + the value of the property. + """ + return self.get_body().deployable + @property def description(self) -> Optional[str]: """The `description` property. @@ -536,6 +552,15 @@ def code_reference(self) -> Optional[CodeReferenceResponse]: """ return self.get_resources().code_reference + @property + def deployment(self) -> Optional[DeploymentResponse]: + """The `deployment` property. + + Returns: + the value of the property. + """ + return self.get_resources().deployment + @property def tags(self) -> List[TagResponse]: """The `tags` property. @@ -586,6 +611,8 @@ class PipelineSnapshotFilter(ProjectScopedFilter, TaggableFilter): "pipeline", "stack", "runnable", + "deployable", + "deployed", ] CUSTOM_SORTING_OPTIONS = [ *ProjectScopedFilter.CUSTOM_SORTING_OPTIONS, @@ -633,6 +660,14 @@ class PipelineSnapshotFilter(ProjectScopedFilter, TaggableFilter): default=None, description="Whether the snapshot is runnable.", ) + deployable: Optional[bool] = Field( + default=None, + description="Whether the snapshot is deployable.", + ) + deployed: Optional[bool] = Field( + default=None, + description="Whether the snapshot is deployed.", + ) def get_custom_filters( self, table: Type["AnySchema"] @@ -645,12 +680,15 @@ def get_custom_filters( Returns: A list of custom filters. """ - from sqlmodel import and_, col + from sqlmodel import and_, col, not_, select from zenml.zen_stores.schemas import ( + DeploymentSchema, PipelineBuildSchema, PipelineSchema, PipelineSnapshotSchema, + StackComponentSchema, + StackCompositionSchema, StackSchema, ) @@ -693,6 +731,44 @@ def get_custom_filters( custom_filters.append(runnable_filter) + if self.deployable is True: + deployer_exists = ( + select(StackComponentSchema.id) + .where( + StackComponentSchema.type + == StackComponentType.DEPLOYER.value + ) + .where( + StackCompositionSchema.component_id + == StackComponentSchema.id + ) + .where( + StackCompositionSchema.stack_id + == PipelineSnapshotSchema.stack_id + ) + .exists() + ) + deployable_filter = and_( + col(PipelineSnapshotSchema.build_id).is_not(None), + deployer_exists, + ) + + custom_filters.append(deployable_filter) + + if self.deployed is not None: + deployment_exists = ( + select(DeploymentSchema.id) + .where( + DeploymentSchema.snapshot_id == PipelineSnapshotSchema.id + ) + .exists() + ) + if self.deployed is True: + deployed_filter = and_(deployment_exists) + else: + deployed_filter = and_(not_(deployment_exists)) + custom_filters.append(deployed_filter) + return custom_filters diff --git a/src/zenml/orchestrators/base_orchestrator.py b/src/zenml/orchestrators/base_orchestrator.py index 65090e94478..5e7422ff6ce 100644 --- a/src/zenml/orchestrators/base_orchestrator.py +++ b/src/zenml/orchestrators/base_orchestrator.py @@ -37,10 +37,12 @@ ) from zenml.enums import ExecutionMode, ExecutionStatus, StackComponentType from zenml.exceptions import ( + HookExecutionException, IllegalOperationError, RunMonitoringError, RunStoppedException, ) +from zenml.hooks.hook_validators import load_and_run_hook from zenml.logger import get_logger from zenml.metadata.metadata_types import MetadataType from zenml.orchestrators.publish_utils import ( @@ -51,6 +53,8 @@ from zenml.orchestrators.step_launcher import StepLauncher from zenml.orchestrators.utils import get_config_environment_vars from zenml.stack import Flavor, Stack, StackComponent, StackComponentConfig +from zenml.steps.step_context import RunContext, get_or_create_run_context +from zenml.utils.env_utils import temporary_environment from zenml.utils.pydantic_utils import before_validator_handler if TYPE_CHECKING: @@ -214,7 +218,7 @@ def prepare_or_run_pipeline( """DEPRECATED: Prepare or run a pipeline. Args: - deployment: The pipeline deployment to prepare or run. + deployment: The deployment to prepare or run. stack: The stack the pipeline will run on. environment: Environment variables to set in the orchestration environment. These don't need to be set if running locally. @@ -232,8 +236,8 @@ def run( Args: snapshot: The pipeline snapshot. stack: The stack on which to run the pipeline. - placeholder_run: An optional placeholder run for the deployment. - This will be deleted in case the pipeline deployment failed. + placeholder_run: An optional placeholder run for the snapshot. + This will be deleted in case the pipeline run failed. Raises: KeyboardInterrupt: If the orchestrator is synchronous and the @@ -250,11 +254,15 @@ def run( if placeholder_run: pipeline_run_id = placeholder_run.id - base_environment = get_config_environment_vars( + base_environment, secrets = get_config_environment_vars( schedule_id=schedule_id, pipeline_run_id=pipeline_run_id, ) + # TODO: for now, we don't support separate secrets from environment + # in the orchestrator environment + base_environment.update(secrets) + prevent_client_side_caching = handle_bool_env_var( ENV_ZENML_PREVENT_CLIENT_SIDE_CACHING, default=False ) @@ -385,7 +393,10 @@ def run( finally: self._cleanup_run() - def run_step(self, step: "Step") -> None: + def run_step( + self, + step: "Step", + ) -> None: """Runs the given step. Args: @@ -487,6 +498,96 @@ def supported_execution_modes(self) -> List[ExecutionMode]: """ return [ExecutionMode.CONTINUE_ON_FAILURE] + @property + def run_init_cleanup_at_step_level(self) -> bool: + """Whether the orchestrator runs the init and cleanup hooks at step level. + + For orchestrators that run their steps in isolated step environments, + the run context cannot be shared between steps. In this case, the init + and cleanup hooks need to be run at step level for each individual step. + + For orchestrators that run their steps in a shared environment with a + shared memory (e.g. the local orchestrator), the init and cleanup hooks + can be run at run level and this property should be overridden to return + True. + + Returns: + Whether the orchestrator runs the init and cleanup hooks at step + level. + """ + return True + + @classmethod + def run_init_hook(cls, snapshot: "PipelineSnapshotResponse") -> None: + """Runs the init hook. + + Args: + snapshot: The snapshot to run the init hook for. + + Raises: + HookExecutionException: If the init hook fails. + """ + # The lifetime of the run context starts when the init hook is executed + # and ends when the cleanup hook is executed + run_context = get_or_create_run_context() + init_hook_source = snapshot.pipeline_configuration.init_hook_source + init_hook_kwargs = snapshot.pipeline_configuration.init_hook_kwargs + + # We only run the init hook once, if the (thread-local) run context + # associated with the current run has not been initialized yet. This + # allows us to run the init hook only once per run per execution + # environment (process, container, etc.). + if not run_context.initialized: + if not init_hook_source: + run_context.initialize(None) + return + + logger.info("Executing the pipeline's init hook...") + try: + with temporary_environment( + snapshot.pipeline_configuration.environment + ): + run_state = load_and_run_hook( + init_hook_source, + hook_parameters=init_hook_kwargs, + raise_on_error=True, + ) + except Exception as e: + raise HookExecutionException( + f"Failed to execute init hook for pipeline " + f"{snapshot.pipeline_configuration.name}" + ) from e + + run_context.initialize(run_state) + + @classmethod + def run_cleanup_hook(cls, snapshot: "PipelineSnapshotResponse") -> None: + """Runs the cleanup hook. + + Args: + snapshot: The snapshot to run the cleanup hook for. + """ + # The lifetime of the run context starts when the init hook is executed + # and ends when the cleanup hook is executed + if not RunContext._exists(): + return + + if ( + cleanup_hook_source + := snapshot.pipeline_configuration.cleanup_hook_source + ): + logger.info("Executing the pipeline's cleanup hook...") + with temporary_environment( + snapshot.pipeline_configuration.environment + ): + load_and_run_hook( + cleanup_hook_source, + raise_on_error=False, + ) + + # Destroy the run context, so it's created anew for the next run + RunContext._clear() + def _validate_execution_mode( self, snapshot: "PipelineSnapshotResponse" ) -> None: diff --git a/src/zenml/orchestrators/local/local_orchestrator.py b/src/zenml/orchestrators/local/local_orchestrator.py index f43d82a63b9..e45c4c1918d 100644 --- a/src/zenml/orchestrators/local/local_orchestrator.py +++ b/src/zenml/orchestrators/local/local_orchestrator.py @@ -44,6 +44,25 @@ class LocalOrchestrator(BaseOrchestrator): _orchestrator_run_id: Optional[str] = None + @property + def run_init_cleanup_at_step_level(self) -> bool: + """Whether the orchestrator runs the init and cleanup hooks at step level. + + For orchestrators that run their steps in isolated step environments, + the run context cannot be shared between steps. In this case, the init + and cleanup hooks need to be run at step level for each individual step. + + For orchestrators that run their steps in a shared environment with a + shared memory (e.g. the local orchestrator), the init and cleanup hooks + can be run at run level and this property should be overridden to return + True. + + Returns: + Whether the orchestrator runs the init and cleanup hooks at step + level. + """ + return False + def submit_pipeline( self, snapshot: "PipelineSnapshotResponse", @@ -73,7 +92,8 @@ def submit_pipeline( Optional submission result. Raises: - Exception: If the pipeline run fails. + step_exception: The exception that occurred while running a failed + step. RuntimeError: If the pipeline run fails. """ if snapshot.schedule: @@ -89,8 +109,11 @@ def submit_pipeline( execution_mode = snapshot.pipeline_configuration.execution_mode failed_steps: List[str] = [] + step_exception: Optional[Exception] = None skipped_steps: List[str] = [] + self.run_init_hook(snapshot=snapshot) + # Run each step for step_name, step in snapshot.step_configurations.items(): if ( @@ -142,12 +165,19 @@ def submit_pipeline( try: with temporary_environment(step_environment): self.run_step(step=step) - except Exception: + except Exception as e: failed_steps.append(step_name) logger.exception("Step %s failed.", step_name) if execution_mode == ExecutionMode.FAIL_FAST: - raise + step_exception = e + break + + self.run_cleanup_hook(snapshot=snapshot) + + if execution_mode == ExecutionMode.FAIL_FAST and failed_steps: + assert step_exception is not None + raise step_exception if failed_steps: raise RuntimeError( diff --git a/src/zenml/orchestrators/output_utils.py b/src/zenml/orchestrators/output_utils.py index 0d2cb400683..e0ff7f8544b 100644 --- a/src/zenml/orchestrators/output_utils.py +++ b/src/zenml/orchestrators/output_utils.py @@ -58,7 +58,11 @@ def generate_artifact_uri( def prepare_output_artifact_uris( - step_run: "StepRunResponse", stack: "Stack", step: "Step" + step_run: "StepRunResponse", + stack: "Stack", + step: "Step", + *, + create_dirs: bool = True, ) -> Dict[str, str]: """Prepares the output artifact URIs to run the current step. @@ -66,6 +70,7 @@ def prepare_output_artifact_uris( step_run: The step run for which to prepare the artifact URIs. stack: The stack on which the pipeline is running. step: The step configuration. + create_dirs: Whether to pre-create directories in the artifact store. Raises: RuntimeError: If an artifact URI already exists. @@ -75,18 +80,43 @@ def prepare_output_artifact_uris( """ artifact_store = stack.artifact_store output_artifact_uris: Dict[str, str] = {} + for output_name in step.config.outputs.keys(): substituted_output_name = string_utils.format_name_template( output_name, substitutions=step_run.config.substitutions ) - artifact_uri = generate_artifact_uri( - artifact_store=stack.artifact_store, - step_run=step_run, - output_name=substituted_output_name, - ) - if artifact_store.exists(artifact_uri): - raise RuntimeError("Artifact already exists") - artifact_store.makedirs(artifact_uri) + if create_dirs: + artifact_uri = generate_artifact_uri( + artifact_store=artifact_store, + step_run=step_run, + output_name=substituted_output_name, + ) + else: + # Produce a clear in-memory URI that doesn't point to the store. + sanitized_output = substituted_output_name + for banned_character in [ + "<", + ">", + ":", + '"', + "/", + "\\", + "|", + "?", + "*", + ]: + sanitized_output = sanitized_output.replace( + banned_character, "_" + ) + artifact_uri = ( + f"memory://{step_run.name}/{sanitized_output}/" + f"{step_run.id}/{str(uuid4())[:8]}" + ) + + if create_dirs: + if artifact_store.exists(artifact_uri): + raise RuntimeError("Artifact already exists") + artifact_store.makedirs(artifact_uri) output_artifact_uris[output_name] = artifact_uri return output_artifact_uris diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index d284cfa5230..29d818fde65 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -26,6 +26,7 @@ ENV_ZENML_STEP_OPERATOR, handle_bool_env_var, ) +from zenml.deployers.server import runtime from zenml.enums import ExecutionMode, ExecutionStatus from zenml.environment import get_run_environment_dict from zenml.exceptions import RunInterruptedException, RunStoppedException @@ -137,9 +138,16 @@ def __init__( def _setup_signal_handlers(self) -> None: """Set up signal handlers for graceful shutdown, chaining previous handlers.""" - # Save previous handlers - self._prev_sigterm_handler = signal.getsignal(signal.SIGTERM) - self._prev_sigint_handler = signal.getsignal(signal.SIGINT) + try: + # Save previous handlers + self._prev_sigterm_handler = signal.getsignal(signal.SIGTERM) + self._prev_sigint_handler = signal.getsignal(signal.SIGINT) + except ValueError as e: + # This happens when not in the main thread + logger.debug(f"Cannot set up signal handlers: {e}") + self._prev_sigterm_handler = None + self._prev_sigint_handler = None + return def signal_handler(signum: int, frame: Any) -> None: """Handle shutdown signals gracefully. @@ -231,8 +239,13 @@ def signal_handler(signum: int, frame: Any) -> None: self._prev_sigint_handler(signum, frame) # Register handlers for common termination signals - signal.signal(signal.SIGTERM, signal_handler) - signal.signal(signal.SIGINT, signal_handler) + try: + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + except ValueError as e: + # This happens when not in the main thread + logger.debug(f"Cannot register signal handlers: {e}") + # Continue without signal handling - the step will still run def launch(self) -> None: """Launches the step. @@ -411,7 +424,6 @@ def _run_step( step_run: The model of the current step run. force_write_logs: The context for the step logs. """ - # Prepare step run information. step_run_info = StepRunInfo( config=self._step.config, pipeline=self._snapshot.pipeline_configuration, @@ -423,10 +435,12 @@ def _run_step( ) output_artifact_uris = output_utils.prepare_output_artifact_uris( - step_run=step_run, stack=self._stack, step=self._step + step_run=step_run, + stack=self._stack, + step=self._step, + create_dirs=not runtime.should_skip_artifact_materialization(), ) - # Run the step. start_time = time.time() try: if self._step.config.step_operator: @@ -482,9 +496,13 @@ def _run_step_with_step_operator( step_run_id=str(step_run_info.step_run_id), ) ) - environment = orchestrator_utils.get_config_environment_vars( + environment, secrets = orchestrator_utils.get_config_environment_vars( pipeline_run_id=step_run_info.run_id, ) + # TODO: for now, we don't support separate secrets from environment + # in the step operator environment + environment.update(secrets) + environment.update( env_utils.get_step_environment( step_config=step_run_info.config, diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index eb926fc5607..54930520118 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -23,7 +23,6 @@ Any, Dict, List, - Optional, Tuple, Type, ) @@ -38,11 +37,14 @@ ENV_ZENML_STEP_OPERATOR, handle_bool_env_var, ) +from zenml.deployers.server import runtime from zenml.enums import ArtifactSaveType from zenml.exceptions import StepInterfaceError +from zenml.hooks.hook_validators import load_and_run_hook from zenml.logger import get_logger from zenml.logging.step_logging import PipelineLogsStorageContext, redirected from zenml.materializers.base_materializer import BaseMaterializer +from zenml.materializers.in_memory_materializer import InMemoryMaterializer from zenml.models.v2.core.step_run import ( StepRunInputResponse, StepRunUpdate, @@ -55,7 +57,10 @@ from zenml.orchestrators.utils import ( is_setting_enabled, ) -from zenml.steps.step_context import StepContext, get_step_context +from zenml.steps.step_context import ( + StepContext, + get_step_context, +) from zenml.steps.utils import ( OutputSignature, parse_return_type_annotations, @@ -73,7 +78,6 @@ if TYPE_CHECKING: from zenml.artifact_stores import BaseArtifactStore - from zenml.config.source import Source from zenml.config.step_configurations import Step from zenml.models import ( ArtifactVersionResponse, @@ -90,7 +94,11 @@ class StepRunner: """Class to run steps.""" - def __init__(self, step: "Step", stack: "Stack"): + def __init__( + self, + step: "Step", + stack: "Stack", + ): """Initializes the step runner. Args: @@ -209,6 +217,14 @@ def run( step_failed = False try: + if ( + pipeline_run.snapshot + and self._stack.orchestrator.run_init_cleanup_at_step_level + ): + self._stack.orchestrator.run_init_hook( + snapshot=pipeline_run.snapshot + ) + with env_utils.temporary_environment(step_environment): return_values = step_instance.call_entrypoint( **function_params @@ -239,10 +255,11 @@ def run( := self.configuration.failure_hook_source ): logger.info("Detected failure hook. Running...") - self.load_and_run_hook( - failure_hook_source, - step_exception=step_exception, - ) + with env_utils.temporary_environment(step_environment): + load_and_run_hook( + failure_hook_source, + step_exception=step_exception, + ) raise finally: try: @@ -262,15 +279,23 @@ def run( := self.configuration.success_hook_source ): logger.info("Detected success hook. Running...") - self.load_and_run_hook( - success_hook_source, - step_exception=None, - ) + with env_utils.temporary_environment( + step_environment + ): + load_and_run_hook( + success_hook_source, + step_exception=None, + ) # Store and publish the output artifacts of the step function. output_data = self._validate_outputs( return_values, output_annotations ) + # Record outputs in serving context for fast access + if runtime.is_active(): + runtime.record_step_outputs( + step_run.name, output_data + ) artifact_metadata_enabled = is_setting_enabled( is_enabled_on_step=step_run_info.config.enable_artifact_metadata, is_enabled_on_pipeline=step_run_info.pipeline.enable_artifact_metadata, @@ -300,6 +325,17 @@ def run( }, model_version=model_version, ) + + # We run the cleanup hook at step level if we're not in an + # environment that supports a shared run context + if ( + pipeline_run.snapshot + and self._stack.orchestrator.run_init_cleanup_at_step_level + ): + self._stack.orchestrator.run_cleanup_hook( + snapshot=pipeline_run.snapshot + ) + finally: step_context._cleanup_registry.execute_callbacks( raise_on_exception=False @@ -422,45 +458,6 @@ def _parse_inputs( return function_params - def _parse_hook_inputs( - self, - args: List[str], - annotations: Dict[str, Any], - step_exception: Optional[BaseException], - ) -> Dict[str, Any]: - """Parses the inputs for a hook function. - - Args: - args: The arguments of the hook function. - annotations: The annotations of the hook function. - step_exception: The exception of the original step. - - Returns: - The parsed inputs for the hook function. - - Raises: - TypeError: If hook function is passed a wrong parameter type. - """ - function_params: Dict[str, Any] = {} - - if args and args[0] == "self": - args.pop(0) - - for arg in args: - arg_type = annotations.get(arg, None) - arg_type = resolve_type_annotation(arg_type) - - if issubclass(arg_type, BaseException): - function_params[arg] = step_exception - else: - # It should not be of any other type - raise TypeError( - "Hook functions can only take arguments of type " - f"`BaseException`, not {arg_type}" - ) - - return function_params - def _load_input_artifact( self, artifact: "ArtifactVersionResponse", data_type: Type[Any] ) -> Any: @@ -652,6 +649,9 @@ def _store_output_artifacts( materializer_class = materializer_registry[data_type] + # Choose materializer class upfront based on serving mode + if runtime.should_skip_artifact_materialization(): + materializer_class = InMemoryMaterializer uri = output_artifact_uris[output_name] artifact_config = output_annotations[output_name].artifact_config @@ -704,31 +704,3 @@ def _store_output_artifacts( artifact_requests ) return dict(zip(output_data.keys(), responses)) - - def load_and_run_hook( - self, - hook_source: "Source", - step_exception: Optional[BaseException], - ) -> None: - """Loads hook source and runs the hook. - - Args: - hook_source: The source of the hook function. - step_exception: The exception of the original step. - """ - try: - hook = source_utils.load(hook_source) - hook_spec = inspect.getfullargspec(inspect.unwrap(hook)) - - function_params = self._parse_hook_inputs( - args=hook_spec.args, - annotations=hook_spec.annotations, - step_exception=step_exception, - ) - logger.debug(f"Running hook {hook} with params: {function_params}") - hook(**function_params) - except Exception as e: - logger.error( - f"Failed to load hook source with exception: '{hook_source}': " - f"{e}" - ) diff --git a/src/zenml/orchestrators/utils.py b/src/zenml/orchestrators/utils.py index 916b52f1b4c..3e884d7b428 100644 --- a/src/zenml/orchestrators/utils.py +++ b/src/zenml/orchestrators/utils.py @@ -15,13 +15,14 @@ import os import random -from typing import TYPE_CHECKING, Any, Dict, Optional, cast +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, cast from uuid import UUID from zenml.client import Client from zenml.config.global_config import ( GlobalConfiguration, ) +from zenml.config.step_configurations import Step from zenml.constants import ( ENV_ZENML_ACTIVE_PROJECT_ID, ENV_ZENML_ACTIVE_STACK_ID, @@ -33,7 +34,13 @@ ) from zenml.enums import APITokenType, AuthScheme, StackComponentType, StoreType from zenml.logger import get_logger +from zenml.models import ( + CodeReferenceRequest, + PipelineSnapshotRequest, + PipelineSnapshotResponse, +) from zenml.stack import StackComponent +from zenml.utils import pydantic_utils logger = get_logger(__name__) @@ -105,7 +112,8 @@ def is_setting_enabled( def get_config_environment_vars( schedule_id: Optional[UUID] = None, pipeline_run_id: Optional[UUID] = None, -) -> Dict[str, str]: + deployment_id: Optional[UUID] = None, +) -> Tuple[Dict[str, str], Dict[str, str]]: """Gets environment variables to set for mirroring the active config. If a schedule ID, pipeline run ID or step run ID is given, and the current @@ -118,27 +126,30 @@ def get_config_environment_vars( schedule_id: Optional schedule ID to use to generate a new API token. pipeline_run_id: Optional pipeline run ID to use to generate a new API token. + deployment_id: Optional deployment ID to use to generate a new API + token. Returns: - Environment variable dict. + Environment variable dict and secrets dict. """ from zenml.login.credentials_store import get_credentials_store from zenml.zen_stores.rest_zen_store import RestZenStore global_config = GlobalConfiguration() environment_vars = global_config.get_config_environment_vars() + secrets: Dict[str, str] = {} if ( global_config.store_configuration.type == StoreType.REST and global_config.zen_store.get_store_info().auth_scheme != AuthScheme.NO_AUTH ): + assert isinstance(global_config.zen_store, RestZenStore) + credentials_store = get_credentials_store() url = global_config.store_configuration.url api_token = credentials_store.get_token(url, allow_expired=False) if schedule_id or pipeline_run_id: - assert isinstance(global_config.zen_store, RestZenStore) - # The user has the option to manually set an expiration for the API # token generated for a pipeline run. In this case, we generate a new # generic API token that will be valid for the indicated duration. @@ -194,13 +205,17 @@ def get_config_environment_vars( pipeline_run_id=pipeline_run_id, ) - environment_vars[ENV_ZENML_STORE_PREFIX + "API_TOKEN"] = ( - new_api_token + secrets[ENV_ZENML_STORE_PREFIX + "API_TOKEN"] = new_api_token + elif deployment_id: + new_api_token = global_config.zen_store.get_api_token( + token_type=APITokenType.WORKLOAD, + deployment_id=deployment_id, ) + secrets[ENV_ZENML_STORE_PREFIX + "API_TOKEN"] = new_api_token elif api_token: # For all other cases, the pipeline run environment is configured # with the current access token. - environment_vars[ENV_ZENML_STORE_PREFIX + "API_TOKEN"] = ( + secrets[ENV_ZENML_STORE_PREFIX + "API_TOKEN"] = ( api_token.access_token ) @@ -217,7 +232,7 @@ def get_config_environment_vars( Client().active_project.id ) - return environment_vars + return environment_vars, secrets class register_artifact_store_filesystem: @@ -316,3 +331,147 @@ def __exit__( # that belongs to the active artifact store as it may have been # overwritten. Client().active_stack.artifact_store._register() + + +def deployment_snapshot_request_from_source_snapshot( + source_snapshot: PipelineSnapshotResponse, + deployment_parameters: Dict[str, Any], + template_id: Optional[UUID] = None, +) -> PipelineSnapshotRequest: + """Generate a snapshot request for deployment execution. + + This is a simplified version optimized for deployment execution that: + - Applies hardcoded deployment-specific overrides (disabled caching, no step operators) + - Only handles parameter overrides (no secrets handling) + - Removes zen_store dependency for secrets + + Args: + source_snapshot: The source snapshot from which to create the + snapshot request. + deployment_parameters: Parameters to override for deployment execution. + template_id: The ID of the template from which to create the snapshot + request. + + Raises: + ValueError: If there are missing/extra step parameters. + + Returns: + The generated snapshot request. + """ + # Deployment-specific pipeline configuration (disabled caching) + pipeline_configuration = pydantic_utils.update_model( + source_snapshot.pipeline_configuration, {"enable_cache": False} + ) + + steps = {} + for invocation_id, step in source_snapshot.step_configurations.items(): + # Apply deployment parameters to steps that need them + step_parameters = {} + for param_name in step.config.parameters: + if param_name in deployment_parameters: + step_parameters[param_name] = deployment_parameters[param_name] + elif param_name in step.config.parameters: + step_parameters[param_name] = step.config.parameters[ + param_name + ] + else: + raise ValueError( + f"Missing required parameter '{param_name}' for step " + f"'{invocation_id}' in deployment execution" + ) + + # Deployment-specific step overrides + step_update = { + "enable_cache": False, # Disable caching for all steps + "step_operator": None, # Remove step operators for deployments + "retry": None, # Remove retry configuration + "parameters": step_parameters, + } + + step_config = pydantic_utils.update_model( + step.step_config_overrides, step_update + ) + merged_step_config = step_config.apply_pipeline_configuration( + pipeline_configuration + ) + + # Validate parameters + required_parameters = set(step.config.parameters) + configured_parameters = set(step_parameters) + + unknown_parameters = configured_parameters - required_parameters + if unknown_parameters: + raise ValueError( + "Deployment parameters contain the following unknown " + f"parameters for step {invocation_id}: {unknown_parameters}." + ) + + missing_parameters = required_parameters - configured_parameters + if missing_parameters: + raise ValueError( + "Deployment execution is missing the following required " + f"parameters for step {invocation_id}: {missing_parameters}." + ) + + steps[invocation_id] = Step( + spec=step.spec, + config=merged_step_config, + step_config_overrides=step_config, + ) + + code_reference_request = None + if source_snapshot.code_reference: + code_reference_request = CodeReferenceRequest( + commit=source_snapshot.code_reference.commit, + subdirectory=source_snapshot.code_reference.subdirectory, + code_repository=source_snapshot.code_reference.code_repository.id, + ) + + zenml_version = Client().zen_store.get_store_info().version + + # Compute the source snapshot ID + source_snapshot_id = source_snapshot.source_snapshot_id + + if source_snapshot.stack is None: + raise ValueError("Source snapshot stack is None") + + # Update the pipeline spec parameters by overriding only known keys + updated_pipeline_spec = source_snapshot.pipeline_spec + try: + if ( + source_snapshot.pipeline_spec + and source_snapshot.pipeline_spec.parameters is not None + ): + original_params: Dict[str, Any] = dict( + source_snapshot.pipeline_spec.parameters + ) + merged_params: Dict[str, Any] = original_params.copy() + for k, v in deployment_parameters.items(): + if k in original_params: + merged_params[k] = v + updated_pipeline_spec = pydantic_utils.update_model( + source_snapshot.pipeline_spec, {"parameters": merged_params} + ) + except Exception: + # In case of any unforeseen errors, fall back to the original spec + updated_pipeline_spec = source_snapshot.pipeline_spec + + return PipelineSnapshotRequest( + project=source_snapshot.project_id, + run_name_template=source_snapshot.run_name_template, + pipeline_configuration=pipeline_configuration, + step_configurations=steps, + client_environment={}, + client_version=zenml_version, + server_version=zenml_version, + stack=source_snapshot.stack.id, + pipeline=source_snapshot.pipeline.id, + schedule=None, + code_reference=code_reference_request, + code_path=source_snapshot.code_path, + build=source_snapshot.build.id if source_snapshot.build else None, + template=template_id, + source_snapshot=source_snapshot_id, + pipeline_version_hash=source_snapshot.pipeline_version_hash, + pipeline_spec=updated_pipeline_spec, + ) diff --git a/src/zenml/pipelines/pipeline_decorator.py b/src/zenml/pipelines/pipeline_decorator.py index f7dded64adc..6db807bf0db 100644 --- a/src/zenml/pipelines/pipeline_decorator.py +++ b/src/zenml/pipelines/pipeline_decorator.py @@ -35,7 +35,7 @@ from zenml.config.retry_config import StepRetryConfig from zenml.model.model import Model from zenml.pipelines.pipeline_definition import Pipeline - from zenml.types import HookSpecification + from zenml.types import HookSpecification, InitHookSpecification from zenml.utils.tag_utils import Tag F = TypeVar("F", bound=Callable[..., None]) @@ -62,6 +62,9 @@ def pipeline( extra: Optional[Dict[str, Any]] = None, on_failure: Optional["HookSpecification"] = None, on_success: Optional["HookSpecification"] = None, + on_init: Optional["InitHookSpecification"] = None, + on_init_kwargs: Optional[Dict[str, Any]] = None, + on_cleanup: Optional["HookSpecification"] = None, model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, substitutions: Optional[Dict[str, str]] = None, @@ -85,6 +88,9 @@ def pipeline( extra: Optional[Dict[str, Any]] = None, on_failure: Optional["HookSpecification"] = None, on_success: Optional["HookSpecification"] = None, + on_init: Optional["InitHookSpecification"] = None, + on_init_kwargs: Optional[Dict[str, Any]] = None, + on_cleanup: Optional["HookSpecification"] = None, model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, substitutions: Optional[Dict[str, str]] = None, @@ -113,6 +119,14 @@ def pipeline( on_success: Callback function in event of success of the step. Can be a function with no arguments, or a source path to such a function (e.g. `module.my_function`). + on_init: Callback function to run on initialization of the pipeline. Can + be a function with no arguments, or a source path to such a function + (e.g. `module.my_function`) if the function returns a value, it will + be stored as the pipeline state. + on_init_kwargs: Arguments for the init hook. + on_cleanup: Callback function to run on cleanup of the pipeline. Can be a + function with no arguments, or a source path to such a function + (e.g. `module.my_function`). model: configuration of the model in the Model Control Plane. retry: Retry configuration for the pipeline steps. substitutions: Extra placeholders to use in the name templates. @@ -140,6 +154,9 @@ def inner_decorator(func: "F") -> "Pipeline": extra=extra, on_failure=on_failure, on_success=on_success, + on_init=on_init, + on_init_kwargs=on_init_kwargs, + on_cleanup=on_cleanup, model=model, retry=retry, substitutions=substitutions, diff --git a/src/zenml/pipelines/pipeline_definition.py b/src/zenml/pipelines/pipeline_definition.py index 1ab52fd8058..f388dc5f581 100644 --- a/src/zenml/pipelines/pipeline_definition.py +++ b/src/zenml/pipelines/pipeline_definition.py @@ -64,6 +64,7 @@ ) from zenml.models import ( CodeReferenceRequest, + DeploymentResponse, LogsRequest, PipelineBuildBase, PipelineBuildResponse, @@ -86,6 +87,7 @@ from zenml.steps import BaseStep from zenml.steps.entrypoint_function_utils import StepArtifact from zenml.steps.step_invocation import StepInvocation +from zenml.steps.utils import get_unique_step_output_names from zenml.utils import ( code_repository_utils, code_utils, @@ -111,7 +113,7 @@ from zenml.model.lazy_load import ModelVersionDataLazyLoader from zenml.model.model import Model from zenml.models import ArtifactVersionResponse - from zenml.types import HookSpecification + from zenml.types import HookSpecification, InitHookSpecification StepConfigurationUpdateOrDict = Union[ Dict[str, Any], StepConfigurationUpdate @@ -146,6 +148,9 @@ def __init__( extra: Optional[Dict[str, Any]] = None, on_failure: Optional["HookSpecification"] = None, on_success: Optional["HookSpecification"] = None, + on_init: Optional["InitHookSpecification"] = None, + on_init_kwargs: Optional[Dict[str, Any]] = None, + on_cleanup: Optional["HookSpecification"] = None, model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, substitutions: Optional[Dict[str, str]] = None, @@ -177,6 +182,14 @@ def __init__( on_success: Callback function in event of success of the step. Can be a function with no arguments, or a source path to such a function (e.g. `module.my_function`). + on_init: Callback function to run on initialization of the pipeline. + Can be a function with no arguments, or a source path to such a + function (e.g. `module.my_function`) if the function returns a + value, it will be stored as the pipeline state. + on_init_kwargs: Arguments for the init hook. + on_cleanup: Callback function to run on cleanup of the pipeline. Can + be a function with no arguments, or a source path to such a + function with no arguments (e.g. `module.my_function`). model: configuration of the model in the Model Control Plane. retry: Retry configuration for the pipeline steps. substitutions: Extra placeholders to use in the name templates. @@ -204,6 +217,9 @@ def __init__( extra=extra, on_failure=on_failure, on_success=on_success, + on_init=on_init, + on_init_kwargs=on_init_kwargs, + on_cleanup=on_cleanup, model=model, retry=retry, substitutions=substitutions, @@ -330,6 +346,9 @@ def configure( extra: Optional[Dict[str, Any]] = None, on_failure: Optional["HookSpecification"] = None, on_success: Optional["HookSpecification"] = None, + on_init: Optional["InitHookSpecification"] = None, + on_init_kwargs: Optional[Dict[str, Any]] = None, + on_cleanup: Optional["HookSpecification"] = None, model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, parameters: Optional[Dict[str, Any]] = None, @@ -372,6 +391,14 @@ def configure( on_success: Callback function in event of success of the step. Can be a function with no arguments, or a source path to such a function (e.g. `module.my_function`). + on_init: Callback function to run on initialization of the pipeline. + Can be a function with no arguments, or a source path to such a + function (e.g. `module.my_function`) if the function returns a + value, it will be stored as the pipeline state. + on_init_kwargs: Arguments for the init hook. + on_cleanup: Callback function to run on cleanup of the pipeline. Can + be a function with no arguments, or a source path to such a + function with no arguments (e.g. `module.my_function`). model: configuration of the model version in the Model Control Plane. retry: Retry configuration for the pipeline steps. parameters: input parameters for the pipeline. @@ -386,16 +413,49 @@ def configure( Returns: The pipeline instance that this method was called on. + + Raises: + ValueError: If on_init_kwargs is provided but on_init is not and + the init hook source is found in the current pipeline + configuration. """ failure_hook_source = None if on_failure: # string of on_failure hook function to be used for this pipeline - failure_hook_source = resolve_and_validate_hook(on_failure) + failure_hook_source, _ = resolve_and_validate_hook( + on_failure, allow_exception_arg=True + ) success_hook_source = None if on_success: # string of on_success hook function to be used for this pipeline - success_hook_source = resolve_and_validate_hook(on_success) + success_hook_source, _ = resolve_and_validate_hook(on_success) + + init_hook_kwargs = None + init_hook_source = None + if on_init or on_init_kwargs: + if not on_init and self.configuration.init_hook_source: + # load the init hook source from the existing configuration if + # not provided; this is needed for partial updates + on_init = source_utils.load( + self.configuration.init_hook_source + ) + if not on_init: + raise ValueError( + "on_init is not provided and no init hook source is found " + "in the existing configuration" + ) + + # string of on_init hook function and JSON-able arguments to be used + # for this pipeline + init_hook_source, init_hook_kwargs = resolve_and_validate_hook( + on_init, on_init_kwargs + ) + + cleanup_hook_source = None + if on_cleanup: + # string of on_cleanup hook function to be used for this pipeline + cleanup_hook_source, _ = resolve_and_validate_hook(on_cleanup) if merge and tags and self._configuration.tags: # Merge tags explicitly here as the recursive update later only @@ -419,6 +479,9 @@ def configure( "extra": extra, "failure_hook_source": failure_hook_source, "success_hook_source": success_hook_source, + "init_hook_source": init_hook_source, + "init_hook_kwargs": init_hook_kwargs, + "cleanup_hook_source": cleanup_hook_source, "model": model, "retry": retry, "parameters": parameters, @@ -632,6 +695,37 @@ def build( code_repository=code_repository, ) + def deploy( + self, + deployment_name: str, + timeout: Optional[int] = None, + *args: Any, + **kwargs: Any, + ) -> DeploymentResponse: + """Deploy the pipeline for online inference. + + Args: + deployment_name: The name to use for the deployment. + timeout: The maximum time in seconds to wait for the pipeline to be + deployed. + *args: Pipeline entrypoint input arguments. + **kwargs: Pipeline entrypoint input keyword arguments. + + Returns: + The deployment response. + """ + self.prepare(*args, **kwargs) + snapshot = self._create_snapshot(**self._run_args) + + stack = Client().active_stack + + stack.prepare_pipeline_submission(snapshot=snapshot) + return stack.deploy_pipeline( + snapshot=snapshot, + deployment_name=deployment_name, + timeout=timeout, + ) + def _create_snapshot( self, *, @@ -1656,30 +1750,106 @@ def _reconfigure_from_file_with_overrides( with self.__suppress_configure_warnings__(): self.configure(**_from_config_file) - def _compute_output_schema(self) -> Dict[str, Any]: + def _compute_output_schema(self) -> Optional[Dict[str, Any]]: """Computes the output schema for the pipeline. Returns: The output schema for the pipeline. """ - - def _get_schema_output_name(output_artifact: "StepArtifact") -> str: - return ( - output_artifact.invocation_id.replace("-", "_") - + "-" - + output_artifact.output_name.replace("-", "_") + try: + # Generate unique step output names + unique_step_output_mapping = get_unique_step_output_names( + { + (o.invocation_id, o.output_name): o + for o in self._output_artifacts + } ) - fields: Dict[str, Any] = { - _get_schema_output_name(output_artifact): ( - output_artifact.annotation.resolved_annotation, - ..., + fields: Dict[str, Any] = { + entry[1]: ( + entry[0].annotation.resolved_annotation, + ..., + ) + for _, entry in unique_step_output_mapping.items() + } + output_model_class: Type[BaseModel] = create_model( + "PipelineOutput", + __config__=ConfigDict(arbitrary_types_allowed=True), + **fields, ) - for output_artifact in self._output_artifacts - } - output_model_class: Type[BaseModel] = create_model( - "PipelineOutput", - __config__=ConfigDict(arbitrary_types_allowed=True), - **fields, + return output_model_class.model_json_schema(mode="serialization") + except Exception as e: + logger.debug( + f"Failed to generate the output schema for " + f"pipeline `{self.name}: {e}. This is most likely " + "because some of the pipeline outputs are not JSON " + "serializable. This means that the pipeline cannot be " + "deployed.", + ) + + return None + + def _compute_input_model(self) -> Optional[Type[BaseModel]]: + """Create a Pydantic model that represents the pipeline input parameters. + + Returns: + A Pydantic model that represents the pipeline input + parameters. + """ + from zenml.steps.entrypoint_function_utils import ( + validate_entrypoint_function, ) - return output_model_class.model_json_schema(mode="serialization") + + try: + entrypoint_definition = validate_entrypoint_function( + self.entrypoint + ) + + defaults: Dict[str, Any] = self._parameters + model_args: Dict[str, Any] = {} + for name, param in entrypoint_definition.inputs.items(): + if name in defaults: + default_value = defaults[name] + elif param.default is not inspect.Parameter.empty: + default_value = param.default + else: + default_value = ... + + model_args[name] = (param.annotation, default_value) + + model_args["__config__"] = ConfigDict(extra="forbid") + params_model: Type[BaseModel] = create_model( + "PipelineInput", + **model_args, + ) + return params_model + except Exception as e: + logger.debug( + f"Failed to generate the input parameters model for pipeline " + f"`{self.name}: {e}. This means that the pipeline cannot be " + "deployed.", + ) + return None + + def _compute_input_schema(self) -> Optional[Dict[str, Any]]: + """Create a JSON schema that represents the pipeline input parameters. + + Returns: + A JSON schema that represents the pipeline input parameters. + """ + input_model = self._compute_input_model() + if not input_model: + return None + + try: + return input_model.model_json_schema() + except Exception as e: + logger.debug( + f"Failed to generate the input parameters schema for " + f"pipeline `{self.name}: {e}. This is most likely " + "because some of the pipeline inputs are not JSON " + "serializable. This means that the pipeline cannot be " + "deployed.", + ) + + return None diff --git a/src/zenml/stack/flavor_registry.py b/src/zenml/stack/flavor_registry.py index 18f937dc4eb..4a1237c8171 100644 --- a/src/zenml/stack/flavor_registry.py +++ b/src/zenml/stack/flavor_registry.py @@ -67,6 +67,7 @@ def builtin_flavors(self) -> List[Type[Flavor]]: GCPContainerRegistryFlavor, GitHubContainerRegistryFlavor, ) + from zenml.deployers import DockerDeployerFlavor from zenml.image_builders import LocalImageBuilderFlavor from zenml.orchestrators import ( LocalDockerOrchestratorFlavor, @@ -83,6 +84,7 @@ def builtin_flavors(self) -> List[Type[Flavor]]: GCPContainerRegistryFlavor, GitHubContainerRegistryFlavor, LocalImageBuilderFlavor, + DockerDeployerFlavor, ] return flavors diff --git a/src/zenml/stack/stack.py b/src/zenml/stack/stack.py index ebceec343df..4651b50c9ca 100644 --- a/src/zenml/stack/stack.py +++ b/src/zenml/stack/stack.py @@ -57,6 +57,7 @@ from zenml.config.step_run_info import StepRunInfo from zenml.container_registries import BaseContainerRegistry from zenml.data_validators import BaseDataValidator + from zenml.deployers import BaseDeployer from zenml.experiment_trackers.base_experiment_tracker import ( BaseExperimentTracker, ) @@ -65,6 +66,7 @@ from zenml.model_deployers import BaseModelDeployer from zenml.model_registries import BaseModelRegistry from zenml.models import ( + DeploymentResponse, PipelineRunResponse, PipelineSnapshotBase, PipelineSnapshotResponse, @@ -109,6 +111,7 @@ def __init__( data_validator: Optional["BaseDataValidator"] = None, image_builder: Optional["BaseImageBuilder"] = None, model_registry: Optional["BaseModelRegistry"] = None, + deployer: Optional["BaseDeployer"] = None, ): """Initializes and validates a stack instance. @@ -131,6 +134,7 @@ def __init__( data_validator: Data validator component of the stack. image_builder: Image builder component of the stack. model_registry: Model registry component of the stack. + deployer: Deployer component of the stack. """ self._id = id self._name = name @@ -148,6 +152,7 @@ def __init__( self._data_validator = data_validator self._model_registry = model_registry self._image_builder = image_builder + self._deployer = deployer @classmethod def from_model(cls, stack_model: "StackResponse") -> "Stack": @@ -230,6 +235,7 @@ def from_components( from zenml.artifact_stores import BaseArtifactStore from zenml.container_registries import BaseContainerRegistry from zenml.data_validators import BaseDataValidator + from zenml.deployers import BaseDeployer from zenml.experiment_trackers import BaseExperimentTracker from zenml.feature_stores import BaseFeatureStore from zenml.image_builders import BaseImageBuilder @@ -324,6 +330,10 @@ def _raise_type_error( ): _raise_type_error(model_registry, BaseModelRegistry) + deployer = components.get(StackComponentType.DEPLOYER) + if deployer is not None and not isinstance(deployer, BaseDeployer): + _raise_type_error(deployer, BaseDeployer) + return Stack( id=id, name=name, @@ -341,6 +351,7 @@ def _raise_type_error( data_validator=data_validator, image_builder=image_builder, model_registry=model_registry, + deployer=deployer, ) @property @@ -365,6 +376,7 @@ def components(self) -> Dict[StackComponentType, "StackComponent"]: self.data_validator, self.image_builder, self.model_registry, + self.deployer, ] if component is not None } @@ -496,6 +508,15 @@ def model_registry(self) -> Optional["BaseModelRegistry"]: """ return self._model_registry + @property + def deployer(self) -> Optional["BaseDeployer"]: + """The deployer of the stack. + + Returns: + The deployer of the stack. + """ + return self._deployer + def dict(self) -> Dict[str, str]: """Converts the stack into a dictionary. @@ -764,6 +785,7 @@ def validate_image_builder(self) -> None: requires_image_builder = ( self.orchestrator.flavor != "local" or self.step_operator + or self.deployer or (self.model_deployer and self.model_deployer.flavor != "mlflow") ) skip_default_image_builder = handle_bool_env_var( @@ -859,6 +881,39 @@ def submit_pipeline( snapshot=snapshot, stack=self, placeholder_run=placeholder_run ) + def deploy_pipeline( + self, + snapshot: "PipelineSnapshotResponse", + deployment_name: str, + timeout: Optional[int] = None, + ) -> "DeploymentResponse": + """Deploys a pipeline on this stack. + + Args: + snapshot: The pipeline snapshot. + deployment_name: The name to use for the deployment. + timeout: The maximum time in seconds to wait for the pipeline to be + deployed. + + Returns: + The deployment response. + + Raises: + RuntimeError: If the stack does not have a deployer. + """ + if not self.deployer: + raise RuntimeError( + "The stack does not have a deployer. Please add a " + "deployer to the stack in order to deploy a pipeline." + ) + + return self.deployer.provision_deployment( + snapshot=snapshot, + stack=self, + deployment_name_or_id=deployment_name, + timeout=timeout, + ) + def _get_active_components_for_step( self, step_config: "StepConfiguration" ) -> Dict[StackComponentType, "StackComponent"]: diff --git a/src/zenml/steps/base_step.py b/src/zenml/steps/base_step.py index 549f94648a4..ef18b99603e 100644 --- a/src/zenml/steps/base_step.py +++ b/src/zenml/steps/base_step.py @@ -720,12 +720,14 @@ def _convert_to_tuple(value: Any) -> Tuple[Source, ...]: failure_hook_source = None if on_failure: # string of on_failure hook function to be used for this step - failure_hook_source = resolve_and_validate_hook(on_failure) + failure_hook_source, _ = resolve_and_validate_hook( + on_failure, allow_exception_arg=True + ) success_hook_source = None if on_success: # string of on_success hook function to be used for this step - success_hook_source = resolve_and_validate_hook(on_success) + success_hook_source, _ = resolve_and_validate_hook(on_success) if merge and secrets and self._configuration.secrets: secrets = self._configuration.secrets + list(secrets) diff --git a/src/zenml/steps/step_context.py b/src/zenml/steps/step_context.py index 9afd6331f5a..196727fc647 100644 --- a/src/zenml/steps/step_context.py +++ b/src/zenml/steps/step_context.py @@ -27,7 +27,7 @@ from zenml.exceptions import StepContextError from zenml.logger import get_logger from zenml.utils.callback_registry import CallbackRegistry -from zenml.utils.singleton import SingletonMetaClass +from zenml.utils.singleton import SingletonMetaClass, ThreadLocalSingleton if TYPE_CHECKING: from zenml.artifacts.artifact_config import ArtifactConfig @@ -61,7 +61,56 @@ def get_step_context() -> "StepContext": ) -class StepContext(metaclass=SingletonMetaClass): +def get_or_create_run_context() -> "RunContext": + """Get or create the context of the currently running pipeline. + + Returns: + The context of the currently running pipeline. + """ + return RunContext() + + +class RunContext(metaclass=SingletonMetaClass): + """Provides context shared between all steps in a pipeline run.""" + + def __init__(self) -> None: + """Create the run context.""" + self.initialized = False + self._state: Optional[Any] = None + + @property + def state(self) -> Optional[Any]: + """Returns the pipeline state. + + Returns: + The pipeline state or None. + + Raises: + RuntimeError: If the run context is not initialized. + """ + if not self.initialized: + raise RuntimeError( + "Run context not initialized. The run state is only available " + "in the context of a running pipeline." + ) + return self._state + + def initialize(self, state: Optional[Any]) -> None: + """Initialize the run context. + + Args: + state: Optional state for the pipeline run + + Raises: + RuntimeError: If the run context is already initialized. + """ + if self.initialized: + raise RuntimeError("Run context already initialized.") + self._state = state + self.initialized = True + + +class StepContext(metaclass=ThreadLocalSingleton): """Provides additional context inside a step function. This singleton class is used to access information about the current run, @@ -167,6 +216,15 @@ def pipeline(self) -> "PipelineResponse": f"a pipeline associated with it." ) + @property + def pipeline_state(self) -> Optional[Any]: + """Returns the pipeline state. + + Returns: + The pipeline state or None. + """ + return get_or_create_run_context().state + @property def model(self) -> "Model": """Returns configured Model. diff --git a/src/zenml/steps/utils.py b/src/zenml/steps/utils.py index 12c28b67aad..e6e058b6d84 100644 --- a/src/zenml/steps/utils.py +++ b/src/zenml/steps/utils.py @@ -18,6 +18,7 @@ import contextlib import inspect import textwrap +from collections import Counter from typing import ( TYPE_CHECKING, Any, @@ -25,6 +26,7 @@ Dict, Optional, Tuple, + TypeVar, Union, ) from uuid import UUID @@ -584,3 +586,42 @@ def single_step_pipeline() -> None: return outputs[0] else: return tuple(outputs) + + +T = TypeVar("T") + + +def get_unique_step_output_names( + step_outputs: Dict[Tuple[str, str], T], +) -> Dict[Tuple[str, str], Tuple[T, str]]: + """Get unique step output names. + + Given a dictionary of step outputs indexed by (invocation_id, output_name), + where the value is an arbitrary context object, return the same dictionary + complemented with unique step output names. + + If an output name is repeated, the context object is prepended to the + output_name to make it unique. + + Args: + step_outputs: The step outputs as arbitrary objects indexed by + (invocation_id, output_name). + + Returns: + The input dictionary complemented with unique step output names. + """ + output_name_count = Counter( + output_name for _, output_name in step_outputs.keys() + ) + + unique_step_output_mapping = {} + for invocation_id, output_name in step_outputs.keys(): + if output_name_count[output_name] > 1: + unique_step_output_name = f"{invocation_id}.{output_name}" + else: + unique_step_output_name = output_name + unique_step_output_mapping[invocation_id, output_name] = ( + step_outputs[invocation_id, output_name], + unique_step_output_name, + ) + return unique_step_output_mapping diff --git a/src/zenml/types.py b/src/zenml/types.py index 5c5e21313fa..65ea344221f 100644 --- a/src/zenml/types.py +++ b/src/zenml/types.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Custom ZenML types.""" -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING, Any, Callable, Union if TYPE_CHECKING: from types import FunctionType @@ -21,6 +21,9 @@ from zenml.config.source import Source HookSpecification = Union[str, Source, FunctionType, Callable[..., None]] + InitHookSpecification = Union[ + str, Source, FunctionType, Callable[..., Any] + ] class HTMLString(str): diff --git a/src/zenml/utils/networking_utils.py b/src/zenml/utils/networking_utils.py index 0d720275d1a..1b1cd28cade 100644 --- a/src/zenml/utils/networking_utils.py +++ b/src/zenml/utils/networking_utils.py @@ -14,7 +14,7 @@ """Utility functions for networking.""" import socket -from typing import Optional, cast +from typing import List, Optional, Tuple, cast from urllib.parse import urlparse from zenml.environment import Environment @@ -67,6 +67,43 @@ def find_available_port() -> int: return cast(int, port) +def lookup_preferred_or_free_port( + preferred_ports: List[int] = [], + allocate_port_if_busy: bool = True, + range: Tuple[int, int] = SCAN_PORT_RANGE, +) -> int: + """Find a preferred TCP port that is available or search for a free TCP port. + + If a list of preferred TCP port values is explicitly requested, they + will be checked in order. + + Args: + preferred_ports: A list of preferred TCP port values. + allocate_port_if_busy: If True, allocate a free port if all the + preferred ports are busy, otherwise an exception will be raised. + range: The range of ports to search for a free port. + + Returns: + An available TCP port number + + Raises: + IOError: if the preferred TCP port is busy and `allocate_port_if_busy` + is disabled, or if no free TCP port could be otherwise allocated. + """ + # If a port value is explicitly configured, attempt to use it first + if preferred_ports: + for port in preferred_ports: + if port_available(port): + return port + if not allocate_port_if_busy: + raise IOError(f"TCP port {preferred_ports} is not available.") + + available_port = scan_for_available_port(start=range[0], stop=range[1]) + if available_port: + return available_port + raise IOError(f"No free TCP ports found in range {range}") + + def scan_for_available_port( start: int = SCAN_PORT_RANGE[0], stop: int = SCAN_PORT_RANGE[1] ) -> Optional[int]: diff --git a/src/zenml/utils/settings_utils.py b/src/zenml/utils/settings_utils.py index ab7fca1e86f..c9254bd899d 100644 --- a/src/zenml/utils/settings_utils.py +++ b/src/zenml/utils/settings_utils.py @@ -16,7 +16,10 @@ import re from typing import TYPE_CHECKING, Dict, Sequence, Type -from zenml.config.constants import DOCKER_SETTINGS_KEY, RESOURCE_SETTINGS_KEY +from zenml.config.constants import ( + DOCKER_SETTINGS_KEY, + RESOURCE_SETTINGS_KEY, +) from zenml.enums import StackComponentType if TYPE_CHECKING: diff --git a/src/zenml/utils/singleton.py b/src/zenml/utils/singleton.py index 1b89c082449..b34075bffb5 100644 --- a/src/zenml/utils/singleton.py +++ b/src/zenml/utils/singleton.py @@ -13,6 +13,7 @@ # permissions and limitations under the License. """Utility class to turn classes into singleton classes.""" +import contextvars from typing import Any, Optional, cast @@ -87,3 +88,97 @@ def _exists(cls) -> bool: `True` if the singleton instance exists, `False` otherwise. """ return cls.__singleton_instance is not None + + +class ThreadLocalSingleton(type): + """Thread-local singleton metaclass using contextvars. + + This metaclass creates singleton instances that are isolated per execution + context (thread or asyncio task). Each context gets its own singleton + instance, allowing for thread-safe and coroutine-safe singleton behavior. + + Use this metaclass when you need singleton behavior but want isolation + between different execution contexts: + + ```python + class DatabaseConnection(metaclass=ContextVarSingleton): + def __init__(self, connection_string: str): + self._connection_string = connection_string + self._connected = False + + def connect(self): + if not self._connected: + # Connect to database + self._connected = True + + @property + def connection_string(self): + return self._connection_string + + # In context 1 (e.g., thread 1) + db1 = DatabaseConnection("postgres://localhost/db1") + db1.connect() + + # In context 2 (e.g., thread 2) + db2 = DatabaseConnection("postgres://localhost/db2") + # db2 is a different instance from db1, isolated by context + ``` + """ + + def __init__(cls, *args: Any, **kwargs: Any) -> None: + """Initialize a thread-local singleton class. + + Args: + *args: Additional arguments. + **kwargs: Additional keyword arguments. + """ + super().__init__(*args, **kwargs) + cls.__context_instance: contextvars.ContextVar[ + Optional["ThreadLocalSingleton"] + ] = contextvars.ContextVar(f"{cls.__name__}_instance", default=None) + + def __call__(cls, *args: Any, **kwargs: Any) -> "ThreadLocalSingleton": + """Create or return the singleton instance for the current context. + + Args: + *args: Additional arguments. + **kwargs: Additional keyword arguments. + + Returns: + The singleton instance for the current execution context. + """ + instance = cls.__context_instance.get() + if instance is None: + instance = cast( + "ThreadLocalSingleton", super().__call__(*args, **kwargs) + ) + cls.__context_instance.set(instance) + + return instance + + def _clear(cls, instance: Optional["ThreadLocalSingleton"] = None) -> None: + """Clear or replace the singleton instance in the current context. + + Args: + instance: The new singleton instance for the current context. + If None, clears the current instance. + """ + cls.__context_instance.set(instance) + + def _instance(cls) -> Optional["ThreadLocalSingleton"]: + """Get the singleton instance for the current context. + + Returns: + The singleton instance for the current execution context, + or None if no instance exists in this context. + """ + return cls.__context_instance.get() + + def _exists(cls) -> bool: + """Check if a singleton instance exists in the current context. + + Returns: + True if a singleton instance exists in the current context, + False otherwise. + """ + return cls.__context_instance.get() is not None diff --git a/src/zenml/utils/tag_utils.py b/src/zenml/utils/tag_utils.py index 9866644a448..7a67bc0893a 100644 --- a/src/zenml/utils/tag_utils.py +++ b/src/zenml/utils/tag_utils.py @@ -44,6 +44,9 @@ # Manual tagging of a snapshot add_tags(tags=[...], snapshot=...) +# Manual tagging of a deployment +add_tags(tags=[...], deployment=...) + # Manual tagging of an artifact add_tags(tags=[...], artifact=...) @@ -72,6 +75,9 @@ # Manual tag removal from a snapshot remove_tags(tags=[...], snapshot=...) +# Manual tag removal from a deployment +remove_tags(tags=[...], deployment=...) + # Manual tag removal from an artifact remove_tags(tags=[...], artifact=...) @@ -191,6 +197,14 @@ def add_tags( ) -> None: ... +@overload +def add_tags( + *, + tags: List[Union[str, Tag]], + deployment: Union[UUID, str], +) -> None: ... + + def add_tags( tags: List[Union[str, Tag]], # Pipelines @@ -201,6 +215,8 @@ def add_tags( run_template: Optional[Union[UUID, str]] = None, # Snapshots snapshot: Optional[Union[UUID, str]] = None, + # Deployments + deployment: Optional[Union[UUID, str]] = None, # Artifacts artifact: Optional[Union[UUID, str]] = None, # Artifact Versions @@ -217,6 +233,7 @@ def add_tags( run: The id, name or prefix of the run. run_template: The ID or the name of the run template. snapshot: The ID of the snapshot. + deployment: The ID or the name of the deployment. artifact: The ID or the name of the artifact. artifact_version_id: The ID of the artifact version. artifact_name: The name of the artifact. @@ -243,6 +260,7 @@ def add_tags( run, run_template, snapshot, + deployment, artifact, artifact_version_id, artifact_name, @@ -261,6 +279,7 @@ def add_tags( pipeline, run_template, snapshot, + deployment, artifact, artifact_version_id, artifact_name, @@ -279,6 +298,7 @@ def add_tags( pipeline, run, snapshot, + deployment, artifact, artifact_version_id, artifact_name, @@ -299,6 +319,7 @@ def add_tags( pipeline, run, run_template, + deployment, artifact, artifact_version_id, artifact_name, @@ -312,6 +333,25 @@ def add_tags( resource_id = snapshot_model.id resource_type = TaggableResourceTypes.PIPELINE_SNAPSHOT + # Tag a deployment + elif deployment is not None and all( + v is None + for v in [ + pipeline, + run, + run_template, + snapshot, + artifact, + artifact_version_id, + artifact_name, + artifact_version, + infer_artifact, + ] + ): + deployment_model = client.get_deployment(name_id_or_prefix=deployment) + resource_id = deployment_model.id + resource_type = TaggableResourceTypes.DEPLOYMENT + # Tag an artifact elif artifact is not None and all( v is None @@ -320,6 +360,7 @@ def add_tags( run, run_template, snapshot, + deployment, artifact_version_id, artifact_name, artifact_version, @@ -338,6 +379,7 @@ def add_tags( run, run_template, snapshot, + deployment, artifact, artifact_name, artifact_version, @@ -355,6 +397,7 @@ def add_tags( run, run_template, snapshot, + deployment, artifact, artifact_version_id, infer_artifact, @@ -374,6 +417,7 @@ def add_tags( run, run_template, snapshot, + deployment, artifact, artifact_version_id, artifact_version, @@ -431,6 +475,7 @@ def add_tags( run, run_template, snapshot, + deployment, artifact, artifact_version_id, artifact_name, @@ -546,6 +591,14 @@ def remove_tags( ) -> None: ... +@overload +def remove_tags( + *, + tags: List[str], + deployment: Union[UUID, str], +) -> None: ... + + @overload def remove_tags( *, @@ -590,6 +643,8 @@ def remove_tags( run_template: Optional[Union[UUID, str]] = None, # Snapshots snapshot: Optional[Union[UUID, str]] = None, + # Deployments + deployment: Optional[Union[UUID, str]] = None, # Artifacts artifact: Optional[Union[UUID, str]] = None, # Artifact Versions @@ -606,6 +661,7 @@ def remove_tags( run: The id, name or prefix of the run. run_template: The ID or the name of the run template. snapshot: The ID of the snapshot. + deployment: The ID or the name of the deployment. artifact: The ID or the name of the artifact. artifact_version_id: The ID of the artifact version. artifact_name: The name of the artifact. @@ -630,6 +686,7 @@ def remove_tags( for v in [ run_template, snapshot, + deployment, run, artifact, artifact_version_id, @@ -648,6 +705,7 @@ def remove_tags( for v in [ pipeline, snapshot, + deployment, run, artifact, artifact_version_id, @@ -668,6 +726,7 @@ def remove_tags( for v in [ pipeline, run_template, + deployment, run, artifact, artifact_version_id, @@ -682,6 +741,25 @@ def remove_tags( resource_id = snapshot_model.id resource_type = TaggableResourceTypes.PIPELINE_SNAPSHOT + # Remove tags from a deployment + elif deployment is not None and all( + v is None + for v in [ + pipeline, + run_template, + snapshot, + run, + artifact, + artifact_version_id, + artifact_version, + artifact_name, + infer_artifact, + ] + ): + deployment_model = client.get_deployment(name_id_or_prefix=deployment) + resource_id = deployment_model.id + resource_type = TaggableResourceTypes.DEPLOYMENT + # Remove tags from a run elif run is not None and all( v is None @@ -689,6 +767,7 @@ def remove_tags( pipeline, run_template, snapshot, + deployment, artifact, artifact_version_id, artifact_name, @@ -707,6 +786,7 @@ def remove_tags( pipeline, run_template, snapshot, + deployment, run, artifact_version_id, artifact_name, @@ -725,6 +805,7 @@ def remove_tags( pipeline, run_template, snapshot, + deployment, run, artifact, artifact_name, @@ -742,6 +823,7 @@ def remove_tags( pipeline, run_template, snapshot, + deployment, run, artifact, artifact_version_id, @@ -761,6 +843,7 @@ def remove_tags( pipeline, run_template, snapshot, + deployment, run, artifact, artifact_version_id, @@ -818,6 +901,7 @@ def remove_tags( run, run_template, snapshot, + deployment, artifact, artifact_version_id, artifact_name, diff --git a/src/zenml/zen_server/auth.py b/src/zenml/zen_server/auth.py index 1a4bcd69732..e58afe3bfc8 100644 --- a/src/zenml/zen_server/auth.py +++ b/src/zenml/zen_server/auth.py @@ -931,6 +931,7 @@ def generate_access_token( expires_in: Optional[int] = None, schedule_id: Optional[UUID] = None, pipeline_run_id: Optional[UUID] = None, + deployment_id: Optional[UUID] = None, ) -> OAuthTokenResponse: """Generates an access token for the given user. @@ -949,6 +950,7 @@ def generate_access_token( expire. schedule_id: The ID of the schedule to scope the token to. pipeline_run_id: The ID of the pipeline run to scope the token to. + deployment_id: The ID of the deployment to scope the token to. Returns: An authentication response with an access token. @@ -1024,6 +1026,7 @@ def generate_access_token( api_key_id=api_key.id if api_key else None, schedule_id=schedule_id, pipeline_run_id=pipeline_run_id, + deployment_id=deployment_id, # Set the session ID if this is a cross-site request session_id=session_id, ).encode(expires=expires) diff --git a/src/zenml/zen_server/jwt.py b/src/zenml/zen_server/jwt.py index c7c391bb43e..a95fb9673b2 100644 --- a/src/zenml/zen_server/jwt.py +++ b/src/zenml/zen_server/jwt.py @@ -43,6 +43,8 @@ class JWTToken(BaseModel): schedule_id: The id of the schedule for which the token was issued. pipeline_run_id: The id of the pipeline run for which the token was issued. + deployment_id: The id of the deployment for which the token was + issued. step_run_id: The id of the step run for which the token was issued. session_id: The id of the authenticated session (used for CSRF). @@ -54,6 +56,7 @@ class JWTToken(BaseModel): api_key_id: Optional[UUID] = None schedule_id: Optional[UUID] = None pipeline_run_id: Optional[UUID] = None + deployment_id: Optional[UUID] = None session_id: Optional[UUID] = None claims: Dict[str, Any] = {} @@ -147,6 +150,16 @@ def decode_token( "UUID" ) + deployment_id: Optional[UUID] = None + if "deployment_id" in claims: + try: + deployment_id = UUID(claims.pop("deployment_id")) + except ValueError: + raise CredentialsNotValid( + "Invalid JWT token: the deployment_id claim is not a valid " + "UUID" + ) + session_id: Optional[UUID] = None if "session_id" in claims: try: @@ -163,6 +176,7 @@ def decode_token( api_key_id=api_key_id, schedule_id=schedule_id, pipeline_run_id=pipeline_run_id, + deployment_id=deployment_id, session_id=session_id, claims=claims, ) @@ -200,6 +214,8 @@ def encode(self, expires: Optional[datetime] = None) -> str: claims["schedule_id"] = str(self.schedule_id) if self.pipeline_run_id: claims["pipeline_run_id"] = str(self.pipeline_run_id) + if self.deployment_id: + claims["deployment_id"] = str(self.deployment_id) if self.session_id: claims["session_id"] = str(self.session_id) diff --git a/src/zenml/zen_server/rbac/models.py b/src/zenml/zen_server/rbac/models.py index 0af1e833a5c..bc8984e403b 100644 --- a/src/zenml/zen_server/rbac/models.py +++ b/src/zenml/zen_server/rbac/models.py @@ -63,6 +63,7 @@ class ResourceType(StrEnum): # We keep this name for backwards compatibility PIPELINE_SNAPSHOT = "pipeline_deployment" PIPELINE_BUILD = "pipeline_build" + DEPLOYMENT = "deployment" SCHEDULE = "schedule" RUN_TEMPLATE = "run_template" SERVICE = "service" diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index 8bb2b6e2d96..d3db78fe409 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -440,6 +440,8 @@ def get_resource_type_for_model( CodeRepositoryResponse, ComponentRequest, ComponentResponse, + DeploymentRequest, + DeploymentResponse, EventSourceRequest, EventSourceResponse, FlavorRequest, @@ -505,6 +507,8 @@ def get_resource_type_for_model( ModelVersionResponse: ResourceType.MODEL_VERSION, PipelineBuildRequest: ResourceType.PIPELINE_BUILD, PipelineBuildResponse: ResourceType.PIPELINE_BUILD, + DeploymentRequest: ResourceType.DEPLOYMENT, + DeploymentResponse: ResourceType.DEPLOYMENT, PipelineSnapshotRequest: ResourceType.PIPELINE_SNAPSHOT, PipelineSnapshotResponse: ResourceType.PIPELINE_SNAPSHOT, PipelineRequest: ResourceType.PIPELINE, @@ -643,6 +647,7 @@ def get_schema_for_resource_type( ArtifactSchema, ArtifactVersionSchema, CodeRepositorySchema, + DeploymentSchema, EventSourceSchema, FlavorSchema, ModelSchema, @@ -682,6 +687,7 @@ def get_schema_for_resource_type( ResourceType.SERVICE_ACCOUNT: UserSchema, # ResourceType.PROJECT: ProjectSchema, ResourceType.PIPELINE_RUN: PipelineRunSchema, + ResourceType.DEPLOYMENT: DeploymentSchema, ResourceType.PIPELINE_SNAPSHOT: PipelineSnapshotSchema, ResourceType.PIPELINE_BUILD: PipelineBuildSchema, ResourceType.RUN_TEMPLATE: RunTemplateSchema, diff --git a/src/zenml/zen_server/routers/auth_endpoints.py b/src/zenml/zen_server/routers/auth_endpoints.py index fc9a86a02f7..fb3c60a2902 100644 --- a/src/zenml/zen_server/routers/auth_endpoints.py +++ b/src/zenml/zen_server/routers/auth_endpoints.py @@ -478,6 +478,7 @@ def api_token( expires_in: Optional[int] = None, schedule_id: Optional[UUID] = None, pipeline_run_id: Optional[UUID] = None, + deployment_id: Optional[UUID] = None, auth_context: AuthContext = Security(authorize), ) -> str: """Generate an API token for the current user. @@ -488,14 +489,14 @@ def api_token( * Generic API token: This token is short-lived and can be used for generic automation tasks. The expiration can be set by the user, but the server will impose a maximum expiration time. - * Workload API token: This token is scoped to a specific pipeline run, step - run or schedule and is used by pipeline workloads to authenticate with the - server. A pipeline run ID, step run ID or schedule ID must be provided and - the generated token will only be valid for the indicated pipeline run, step - run or schedule. No time limit is imposed on the validity of the token. - A workload API token can be used to authenticate and generate another - workload API token, but only for the same schedule, pipeline run ID or step - run ID, in that order. + * Workload API token: This token is scoped to a specific pipeline run, + schedule or deployment and is used by pipeline workloads to + authenticate with the server. A pipeline run ID, schedule ID or deployment + ID must be provided and the generated token will only be valid for the + indicated pipeline run, schedule or deployment. + No time limit is imposed on the validity of the token. A workload API token + can be used to authenticate and generate another workload API token, but + only for the same schedule, pipeline run ID or deployment ID, in that order. Args: token_type: The type of API token to generate. @@ -506,6 +507,8 @@ def api_token( schedule_id: The ID of the schedule to scope the workload API token to. pipeline_run_id: The ID of the pipeline run to scope the workload API token to. + deployment_id: The ID of the deployment to scope the workload + API token to. auth_context: The authentication context. Returns: @@ -521,10 +524,10 @@ def api_token( raise AuthorizationException("Not authenticated.") if token_type == APITokenType.GENERIC: - if schedule_id or pipeline_run_id: + if schedule_id or pipeline_run_id or deployment_id: raise ValueError( - "Generic API tokens cannot be scoped to a schedule or pipeline " - "run." + "Generic API tokens cannot be scoped to a schedule, pipeline " + "run or deployment." ) config = server_config() @@ -548,10 +551,12 @@ def api_token( schedule_id = schedule_id or token.schedule_id pipeline_run_id = pipeline_run_id or token.pipeline_run_id + deployment_id = deployment_id or token.deployment_id - if not pipeline_run_id and not schedule_id: + if not pipeline_run_id and not schedule_id and not deployment_id: raise ValueError( - "Workload API tokens must be scoped to a schedule or pipeline run." + "Workload API tokens must be scoped to a schedule, pipeline run or " + "deployment." ) if schedule_id and token.schedule_id and schedule_id != token.schedule_id: @@ -572,6 +577,17 @@ def api_token( f"pipeline run {token.pipeline_run_id}." ) + if ( + deployment_id + and token.deployment_id + and deployment_id != token.deployment_id + ): + raise AuthorizationException( + f"Unable to scope API token to deployment {deployment_id}. The " + f"token used to authorize this request is already scoped to " + f"deployment {token.deployment_id}." + ) + project_id: Optional[UUID] = None if schedule_id: @@ -614,6 +630,19 @@ def api_token( "for security reasons." ) + if deployment_id: + # The deployment must exist + try: + deployment = zen_store().get_deployment( + deployment_id, hydrate=False + ) + except KeyError: + raise ValueError( + f"Deployment {deployment_id} does not exist and API tokens cannot " + "be generated for non-existent deployments for security reasons." + ) + project_id = deployment.project_id + assert project_id is not None verify_permission( resource_type=ResourceType.PIPELINE_RUN, @@ -628,6 +657,7 @@ def api_token( device=auth_context.device, schedule_id=schedule_id, pipeline_run_id=pipeline_run_id, + deployment_id=deployment_id, # Don't include the access token as a cookie in the response response=None, # Never expire the token diff --git a/src/zenml/zen_server/routers/deployment_endpoints.py b/src/zenml/zen_server/routers/deployment_endpoints.py new file mode 100644 index 00000000000..7b0ec15e237 --- /dev/null +++ b/src/zenml/zen_server/routers/deployment_endpoints.py @@ -0,0 +1,185 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Endpoint definitions for deployments.""" + +from uuid import UUID + +from fastapi import ( + APIRouter, + Depends, + Security, +) + +from zenml.constants import ( + API, + DEPLOYMENTS, + VERSION_1, +) +from zenml.models import ( + DeploymentFilter, + DeploymentRequest, + DeploymentResponse, + DeploymentUpdate, +) +from zenml.models.v2.base.page import Page +from zenml.zen_server.auth import AuthContext, authorize +from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_create_entity, + verify_permissions_and_delete_entity, + verify_permissions_and_get_entity, + verify_permissions_and_list_entities, + verify_permissions_and_update_entity, +) +from zenml.zen_server.rbac.models import ResourceType +from zenml.zen_server.utils import ( + async_fastapi_endpoint_wrapper, + make_dependable, + zen_store, +) + +router = APIRouter( + prefix=API + VERSION_1 + DEPLOYMENTS, + tags=["deployments"], + responses={401: error_response, 403: error_response}, +) + + +@router.post( + "", + responses={401: error_response, 409: error_response, 422: error_response}, +) +@async_fastapi_endpoint_wrapper +def create_deployment( + deployment: DeploymentRequest, + _: AuthContext = Security(authorize), +) -> DeploymentResponse: + """Creates a deployment. + + Args: + deployment: Deployment to create. + + Returns: + The created deployment. + """ + return verify_permissions_and_create_entity( + request_model=deployment, + create_method=zen_store().create_deployment, + ) + + +@router.get( + "", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@async_fastapi_endpoint_wrapper(deduplicate=True) +def list_deployments( + deployment_filter_model: DeploymentFilter = Depends( + make_dependable(DeploymentFilter) + ), + hydrate: bool = False, + _: AuthContext = Security(authorize), +) -> Page[DeploymentResponse]: + """Gets a list of deployments. + + Args: + deployment_filter_model: Filter model used for pagination, sorting, + filtering. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + List of deployment objects matching the filter criteria. + """ + return verify_permissions_and_list_entities( + filter_model=deployment_filter_model, + resource_type=ResourceType.DEPLOYMENT, + list_method=zen_store().list_deployments, + hydrate=hydrate, + ) + + +@router.get( + "/{deployment_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@async_fastapi_endpoint_wrapper(deduplicate=True) +def get_deployment( + deployment_id: UUID, + hydrate: bool = True, + _: AuthContext = Security(authorize), +) -> DeploymentResponse: + """Gets a specific deployment using its unique id. + + Args: + deployment_id: ID of the deployment to get. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + A specific deployment object. + """ + return verify_permissions_and_get_entity( + id=deployment_id, + get_method=zen_store().get_deployment, + hydrate=hydrate, + ) + + +@router.put( + "/{deployment_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@async_fastapi_endpoint_wrapper(deduplicate=True) +def update_deployment( + deployment_id: UUID, + deployment_update: DeploymentUpdate, + _: AuthContext = Security(authorize), +) -> DeploymentResponse: + """Updates a specific deployment. + + Args: + deployment_id: ID of the deployment to update. + deployment_update: Update model for the deployment. + + Returns: + The updated deployment. + """ + return verify_permissions_and_update_entity( + id=deployment_id, + update_model=deployment_update, + get_method=zen_store().get_deployment, + update_method=zen_store().update_deployment, + ) + + +@router.delete( + "/{deployment_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@async_fastapi_endpoint_wrapper +def delete_deployment( + deployment_id: UUID, + _: AuthContext = Security(authorize), +) -> None: + """Deletes a specific deployment. + + Args: + deployment_id: ID of the deployment to delete. + """ + verify_permissions_and_delete_entity( + id=deployment_id, + get_method=zen_store().get_deployment, + delete_method=zen_store().delete_deployment, + ) diff --git a/src/zenml/zen_server/zen_server_api.py b/src/zenml/zen_server/zen_server_api.py index 0060df7aedc..50719982766 100644 --- a/src/zenml/zen_server/zen_server_api.py +++ b/src/zenml/zen_server/zen_server_api.py @@ -57,6 +57,7 @@ artifact_version_endpoints, auth_endpoints, code_repositories_endpoints, + deployment_endpoints, devices_endpoints, event_source_endpoints, flavors_endpoints, @@ -263,6 +264,7 @@ async def dashboard(request: Request) -> Any: app.include_router(auth_endpoints.router) app.include_router(devices_endpoints.router) app.include_router(code_repositories_endpoints.router) +app.include_router(deployment_endpoints.router) app.include_router(plugin_endpoints.plugin_router) app.include_router(event_source_endpoints.event_source_router) app.include_router(flavors_endpoints.router) diff --git a/src/zenml/zen_stores/migrations/versions/0d69e308846a_add_pipeline_deployments.py b/src/zenml/zen_stores/migrations/versions/0d69e308846a_add_pipeline_deployments.py new file mode 100644 index 00000000000..1e46142046b --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/0d69e308846a_add_pipeline_deployments.py @@ -0,0 +1,86 @@ +"""add pipeline deployment [0d69e308846a]. + +Revision ID: 0d69e308846a +Revises: 8ad841ad9bfe +Create Date: 2025-08-26 10:30:52.737833 + +""" + +import sqlalchemy as sa +import sqlmodel +from alembic import op +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision = "0d69e308846a" +down_revision = "8ad841ad9bfe" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "deployment", + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("created", sa.DateTime(), nullable=False), + sa.Column("updated", sa.DateTime(), nullable=False), + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("project_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.Column( + "status", sqlmodel.sql.sqltypes.AutoString(), nullable=False + ), + sa.Column("url", sa.TEXT(), nullable=True), + sa.Column("auth_key", sa.TEXT(), nullable=True), + sa.Column( + "deployment_metadata", + sa.String(length=16777215).with_variant(mysql.MEDIUMTEXT, "mysql"), + nullable=False, + ), + sa.Column( + "snapshot_id", + sqlmodel.sql.sqltypes.GUID(), + nullable=True, + ), + sa.Column("deployer_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.ForeignKeyConstraint( + ["snapshot_id"], + ["pipeline_snapshot.id"], + name="fk_deployment_snapshot_id_pipeline_snapshot", + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["deployer_id"], + ["stack_component.id"], + name="fk_deployment_deployer_id_stack_component", + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["project.id"], + name="fk_deployment_project_id_project", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + name="fk_deployment_user_id_user", + ondelete="SET NULL", + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint( + "name", + "project_id", + name="unique_deployment_name_in_project", + ), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("deployment") + # ### end Alembic commands ### diff --git a/src/zenml/zen_stores/migrations/versions/8ad841ad9bfe_pipeline_snapshots.py b/src/zenml/zen_stores/migrations/versions/8ad841ad9bfe_pipeline_snapshots.py index 1954e6514dd..e23ca4af6b4 100644 --- a/src/zenml/zen_stores/migrations/versions/8ad841ad9bfe_pipeline_snapshots.py +++ b/src/zenml/zen_stores/migrations/versions/8ad841ad9bfe_pipeline_snapshots.py @@ -236,10 +236,10 @@ def add_unlisted_pipeline_if_necessary() -> None: now = utc_now() - for project_id in projects_with_orphaned_snapshots: + for project_id_row in projects_with_orphaned_snapshots: existing_pipeline = connection.execute( sa.select(pipeline_table.c.id) - .where(pipeline_table.c.project_id == project_id) + .where(pipeline_table.c.project_id == project_id_row.project_id) .where(pipeline_table.c.name == "unlisted") ).fetchone() @@ -255,14 +255,17 @@ def add_unlisted_pipeline_if_necessary() -> None: name="unlisted", description="Auto-created pipeline for snapshots without " "pipeline reference", - project_id=project_id, + project_id=project_id_row.project_id, user_id=None, ) ) connection.execute( sa.update(pipeline_snapshot_table) - .where(pipeline_snapshot_table.c.project_id == project_id) + .where( + pipeline_snapshot_table.c.project_id + == project_id_row.project_id + ) .where(pipeline_snapshot_table.c.pipeline_id.is_(None)) .values(pipeline_id=unlisted_pipeline_id) ) @@ -366,7 +369,7 @@ def migrate_run_templates() -> None: snapshot_updates = [ { "id_": source_snapshot_id, - "version": template_name, + "name": template_name, "description": template_description, } for template_name, template_description, source_snapshot_id in connection.execute( @@ -378,7 +381,7 @@ def migrate_run_templates() -> None: sa.update(pipeline_snapshot_table) .where(pipeline_snapshot_table.c.id == sa.bindparam("id_")) .values( - version=sa.bindparam("version"), + name=sa.bindparam("name"), description=sa.bindparam("description"), ), snapshot_updates, diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 8dcf898944c..65e0bf91f19 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -68,6 +68,7 @@ CURRENT_USER, DEACTIVATE, DEFAULT_HTTP_TIMEOUT, + DEPLOYMENTS, DEVICES, DISABLE_CLIENT_SERVER_MISMATCH_WARNING, ENV_ZENML_DISABLE_CLIENT_SERVER_MISMATCH_WARNING, @@ -165,6 +166,10 @@ ComponentResponse, ComponentUpdate, DeployedStack, + DeploymentFilter, + DeploymentRequest, + DeploymentResponse, + DeploymentUpdate, EventSourceFilter, EventSourceRequest, EventSourceResponse, @@ -1759,6 +1764,98 @@ def run_snapshot( return PipelineRunResponse.model_validate(response_body) + # -------------------- Deployments -------------------- + + def create_deployment( + self, deployment: DeploymentRequest + ) -> DeploymentResponse: + """Create a new deployment. + + Args: + deployment: The deployment to create. + + Returns: + The newly created deployment. + """ + return self._create_resource( + resource=deployment, + route=DEPLOYMENTS, + response_model=DeploymentResponse, + ) + + def get_deployment( + self, deployment_id: UUID, hydrate: bool = True + ) -> DeploymentResponse: + """Get a deployment with a given ID. + + Args: + deployment_id: ID of the deployment. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + The deployment. + """ + return self._get_resource( + resource_id=deployment_id, + route=DEPLOYMENTS, + response_model=DeploymentResponse, + params={"hydrate": hydrate}, + ) + + def list_deployments( + self, + deployment_filter_model: DeploymentFilter, + hydrate: bool = False, + ) -> Page[DeploymentResponse]: + """List all deployments matching the given filter criteria. + + Args: + deployment_filter_model: All filter parameters including pagination + params. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + A page of all deployments matching the filter criteria. + """ + return self._list_paginated_resources( + route=DEPLOYMENTS, + response_model=DeploymentResponse, + filter_model=deployment_filter_model, + params={"hydrate": hydrate}, + ) + + def update_deployment( + self, deployment_id: UUID, deployment_update: DeploymentUpdate + ) -> DeploymentResponse: + """Update a deployment. + + Args: + deployment_id: The ID of the deployment to update. + deployment_update: The update to apply. + + Returns: + The updated deployment. + """ + return self._update_resource( + resource_id=deployment_id, + resource_update=deployment_update, + route=DEPLOYMENTS, + response_model=DeploymentResponse, + ) + + def delete_deployment(self, deployment_id: UUID) -> None: + """Delete a deployment. + + Args: + deployment_id: The ID of the deployment to delete. + """ + self._delete_resource( + resource_id=deployment_id, + route=DEPLOYMENTS, + ) + # -------------------- Run templates -------------------- def create_run_template( @@ -3973,7 +4070,7 @@ def get_api_token( expires_in: Optional[int] = None, schedule_id: Optional[UUID] = None, pipeline_run_id: Optional[UUID] = None, - step_run_id: Optional[UUID] = None, + deployment_id: Optional[UUID] = None, ) -> str: """Get an API token. @@ -3982,7 +4079,7 @@ def get_api_token( expires_in: The time in seconds until the token expires. schedule_id: The ID of the schedule to get a token for. pipeline_run_id: The ID of the pipeline run to get a token for. - step_run_id: The ID of the step run to get a token for. + deployment_id: The ID of the deployment to get a token for. Returns: The API token. @@ -3999,8 +4096,8 @@ def get_api_token( params["schedule_id"] = schedule_id if pipeline_run_id: params["pipeline_run_id"] = pipeline_run_id - if step_run_id: - params["step_run_id"] = step_run_id + if deployment_id: + params["deployment_id"] = deployment_id response_body = self.get(API_TOKEN, params=params) if not isinstance(response_body, str): raise ValueError( diff --git a/src/zenml/zen_stores/schemas/__init__.py b/src/zenml/zen_stores/schemas/__init__.py index 0fb08bc8de6..b98adfcfeac 100644 --- a/src/zenml/zen_stores/schemas/__init__.py +++ b/src/zenml/zen_stores/schemas/__init__.py @@ -30,6 +30,7 @@ from zenml.zen_stores.schemas.device_schemas import OAuthDeviceSchema from zenml.zen_stores.schemas.event_source_schemas import EventSourceSchema from zenml.zen_stores.schemas.pipeline_build_schemas import PipelineBuildSchema +from zenml.zen_stores.schemas.deployment_schemas import DeploymentSchema from zenml.zen_stores.schemas.component_schemas import StackComponentSchema from zenml.zen_stores.schemas.flavor_schemas import FlavorSchema from zenml.zen_stores.schemas.server_settings_schemas import ServerSettingsSchema @@ -86,6 +87,7 @@ "BaseSchema", "CodeReferenceSchema", "CodeRepositorySchema", + "DeploymentSchema", "EventSourceSchema", "FlavorSchema", "LogsSchema", diff --git a/src/zenml/zen_stores/schemas/deployment_schemas.py b/src/zenml/zen_stores/schemas/deployment_schemas.py new file mode 100644 index 00000000000..c6284f7aad5 --- /dev/null +++ b/src/zenml/zen_stores/schemas/deployment_schemas.py @@ -0,0 +1,274 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""SQLModel implementation of pipeline deployments table.""" + +import json +from typing import TYPE_CHECKING, Any, List, Optional, Sequence +from uuid import UUID + +from sqlalchemy import TEXT, Column, UniqueConstraint +from sqlalchemy.dialects.mysql import MEDIUMTEXT +from sqlalchemy.orm import joinedload, selectinload +from sqlalchemy.sql.base import ExecutableOption +from sqlmodel import Field, Relationship, String + +from zenml.constants import MEDIUMTEXT_MAX_LENGTH +from zenml.enums import DeploymentStatus, TaggableResourceTypes +from zenml.logger import get_logger +from zenml.models.v2.core.deployment import ( + DeploymentRequest, + DeploymentResponse, + DeploymentResponseBody, + DeploymentResponseMetadata, + DeploymentResponseResources, + DeploymentUpdate, +) +from zenml.utils.time_utils import utc_now +from zenml.zen_stores.schemas.base_schemas import NamedSchema +from zenml.zen_stores.schemas.component_schemas import StackComponentSchema +from zenml.zen_stores.schemas.pipeline_snapshot_schemas import ( + PipelineSnapshotSchema, +) +from zenml.zen_stores.schemas.project_schemas import ProjectSchema +from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field +from zenml.zen_stores.schemas.user_schemas import UserSchema +from zenml.zen_stores.schemas.utils import jl_arg + +if TYPE_CHECKING: + from zenml.zen_stores.schemas.tag_schemas import TagSchema + +logger = get_logger(__name__) + + +class DeploymentSchema(NamedSchema, table=True): + """SQL Model for pipeline deployment.""" + + __tablename__ = "deployment" + __table_args__ = ( + UniqueConstraint( + "name", + "project_id", + name="unique_deployment_name_in_project", + ), + ) + + project_id: UUID = build_foreign_key_field( + source=__tablename__, + target=ProjectSchema.__tablename__, + source_column="project_id", + target_column="id", + ondelete="CASCADE", + nullable=False, + ) + project: "ProjectSchema" = Relationship(back_populates="deployments") + + user_id: Optional[UUID] = build_foreign_key_field( + source=__tablename__, + target=UserSchema.__tablename__, + source_column="user_id", + target_column="id", + ondelete="SET NULL", + nullable=True, + ) + user: Optional["UserSchema"] = Relationship(back_populates="deployments") + + status: str + url: Optional[str] = Field( + default=None, + sa_column=Column(TEXT, nullable=True), + ) + auth_key: Optional[str] = Field( + default=None, + sa_column=Column(TEXT, nullable=True), + ) + deployment_metadata: str = Field( + default="{}", + sa_column=Column( + String(length=MEDIUMTEXT_MAX_LENGTH).with_variant( + MEDIUMTEXT, "mysql" + ), + nullable=False, + ), + ) + snapshot_id: Optional[UUID] = build_foreign_key_field( + source=__tablename__, + target=PipelineSnapshotSchema.__tablename__, + source_column="snapshot_id", + target_column="id", + ondelete="SET NULL", + nullable=True, + ) + snapshot: Optional["PipelineSnapshotSchema"] = Relationship( + back_populates="deployment", + ) + + deployer_id: Optional[UUID] = build_foreign_key_field( + source=__tablename__, + target=StackComponentSchema.__tablename__, + source_column="deployer_id", + target_column="id", + ondelete="SET NULL", + nullable=True, + ) + deployer: Optional["StackComponentSchema"] = Relationship() + + tags: List["TagSchema"] = Relationship( + sa_relationship_kwargs=dict( + primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.DEPLOYMENT.value}', foreign(TagResourceSchema.resource_id)==DeploymentSchema.id)", + secondary="tag_resource", + secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)", + order_by="TagSchema.name", + overlaps="tags", + ), + ) + + @classmethod + def get_query_options( + cls, + include_metadata: bool = False, + include_resources: bool = False, + **kwargs: Any, + ) -> Sequence[ExecutableOption]: + """Get the query options for the schema. + + Args: + include_metadata: Whether metadata will be included when converting + the schema to a model. + include_resources: Whether resources will be included when + converting the schema to a model. + **kwargs: Keyword arguments to allow schema specific logic + + Returns: + A list of query options. + """ + options = [] + + if include_resources: + options.extend( + [ + joinedload(jl_arg(DeploymentSchema.user)), + joinedload(jl_arg(DeploymentSchema.deployer)), + selectinload(jl_arg(DeploymentSchema.snapshot)).joinedload( + jl_arg(PipelineSnapshotSchema.pipeline) + ), + ] + ) + + return options + + def to_model( + self, + include_metadata: bool = False, + include_resources: bool = False, + **kwargs: Any, + ) -> DeploymentResponse: + """Convert a `DeploymentSchema` to a `DeploymentResponse`. + + Args: + include_metadata: Whether to include metadata in the response. + include_resources: Whether to include resources in the response. + kwargs: Additional keyword arguments. + + Returns: + The created `DeploymentResponse`. + """ + status: Optional[DeploymentStatus] = None + if self.status in DeploymentStatus.values(): + status = DeploymentStatus(self.status) + elif self.status is not None: + status = DeploymentStatus.UNKNOWN + logger.warning( + f"Deployment status '{self.status}' used for deployment " + f"{self.name} is not a valid DeploymentStatus value. " + "Using UNKNOWN instead." + ) + + body = DeploymentResponseBody( + user_id=self.user_id, + project_id=self.project_id, + created=self.created, + updated=self.updated, + url=self.url, + status=status, + ) + + metadata = None + if include_metadata: + metadata = DeploymentResponseMetadata( + deployment_metadata=json.loads(self.deployment_metadata), + auth_key=self.auth_key, + ) + + resources = None + if include_resources: + resources = DeploymentResponseResources( + user=self.user.to_model() if self.user else None, + tags=[tag.to_model() for tag in self.tags], + snapshot=self.snapshot.to_model() if self.snapshot else None, + deployer=self.deployer.to_model() if self.deployer else None, + pipeline=self.snapshot.pipeline.to_model() + if self.snapshot and self.snapshot.pipeline + else None, + ) + + return DeploymentResponse( + id=self.id, + name=self.name, + body=body, + metadata=metadata, + resources=resources, + ) + + def update( + self, + update: DeploymentUpdate, + ) -> "DeploymentSchema": + """Updates a `DeploymentSchema` from a `DeploymentUpdate`. + + Args: + update: The `DeploymentUpdate` to update from. + + Returns: + The updated `DeploymentSchema`. + """ + for field, value in update.model_dump( + exclude_unset=True, exclude_none=True + ).items(): + if field == "deployment_metadata": + setattr(self, field, json.dumps(value)) + elif hasattr(self, field): + setattr(self, field, value) + + self.updated = utc_now() + return self + + @classmethod + def from_request(cls, request: DeploymentRequest) -> "DeploymentSchema": + """Convert a `DeploymentRequest` to a `DeploymentSchema`. + + Args: + request: The request model to convert. + + Returns: + The converted schema. + """ + return cls( + name=request.name, + project_id=request.project, + user_id=request.user, + status=DeploymentStatus.UNKNOWN.value, + snapshot_id=request.snapshot_id, + deployer_id=request.deployer_id, + auth_key=request.auth_key, + ) diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index e6693c7c278..1a192d79762 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -346,6 +346,9 @@ def from_request( if request.trigger_info.step_run_id: triggered_by = request.trigger_info.step_run_id triggered_by_type = PipelineRunTriggeredByType.STEP_RUN.value + elif request.trigger_info.deployment_id: + triggered_by = request.trigger_info.deployment_id + triggered_by_type = PipelineRunTriggeredByType.DEPLOYMENT.value return cls( project_id=request.project, @@ -422,8 +425,8 @@ def get_upstream_steps(self) -> Dict[str, List[str]]: The list of upstream steps for each step. Raises: - RuntimeError: If the pipeline run has no deployment or - the deployment has no pipeline spec. + RuntimeError: If the pipeline run has no snapshot or + the snapshot has no pipeline spec. """ if self.snapshot and self.snapshot.pipeline_spec: pipeline_spec = PipelineSpec.model_validate_json( diff --git a/src/zenml/zen_stores/schemas/pipeline_snapshot_schemas.py b/src/zenml/zen_stores/schemas/pipeline_snapshot_schemas.py index 15f7de8144c..290dcec937b 100644 --- a/src/zenml/zen_stores/schemas/pipeline_snapshot_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_snapshot_schemas.py @@ -53,6 +53,9 @@ from zenml.zen_stores.schemas.utils import jl_arg if TYPE_CHECKING: + from zenml.zen_stores.schemas.deployment_schemas import ( + DeploymentSchema, + ) from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema @@ -202,6 +205,9 @@ class PipelineSnapshotSchema(BaseSchema, table=True): "order_by": "asc(StepConfigurationSchema.index)", } ) + deployment: Optional["DeploymentSchema"] = Relationship( + back_populates="snapshot" + ) step_count: int tags: List["TagSchema"] = Relationship( sa_relationship_kwargs=dict( @@ -448,7 +454,6 @@ def to_model( included. **kwargs: Keyword arguments to allow schema specific logic - Returns: The response. """ @@ -456,12 +461,17 @@ def to_model( if self.build and not self.build.is_local and self.build.stack_id: runnable = True + deployable = False + if self.build and self.stack and self.stack.has_deployer: + deployable = True + body = PipelineSnapshotResponseBody( user_id=self.user_id, project_id=self.project_id, created=self.created, updated=self.updated, runnable=runnable, + deployable=deployable, ) metadata = None if include_metadata: @@ -546,6 +556,9 @@ def to_model( code_reference=self.code_reference.to_model() if self.code_reference else None, + deployment=self.deployment.to_model() + if self.deployment + else None, tags=[tag.to_model() for tag in self.tags], latest_run_id=latest_run.id if latest_run else None, latest_run_status=latest_run.status if latest_run else None, diff --git a/src/zenml/zen_stores/schemas/project_schemas.py b/src/zenml/zen_stores/schemas/project_schemas.py index 68694b8dfe0..e639ba57c2f 100644 --- a/src/zenml/zen_stores/schemas/project_schemas.py +++ b/src/zenml/zen_stores/schemas/project_schemas.py @@ -33,6 +33,7 @@ ActionSchema, ArtifactVersionSchema, CodeRepositorySchema, + DeploymentSchema, EventSourceSchema, ModelSchema, ModelVersionSchema, @@ -122,6 +123,10 @@ class ProjectSchema(NamedSchema, table=True): back_populates="project", sa_relationship_kwargs={"cascade": "delete"}, ) + deployments: List["DeploymentSchema"] = Relationship( + back_populates="project", + sa_relationship_kwargs={"cascade": "delete"}, + ) @classmethod def from_request(cls, project: ProjectRequest) -> "ProjectSchema": diff --git a/src/zenml/zen_stores/schemas/stack_schemas.py b/src/zenml/zen_stores/schemas/stack_schemas.py index 845876637ea..c74227c21b9 100644 --- a/src/zenml/zen_stores/schemas/stack_schemas.py +++ b/src/zenml/zen_stores/schemas/stack_schemas.py @@ -19,11 +19,11 @@ from uuid import UUID from sqlalchemy import UniqueConstraint -from sqlalchemy.orm import joinedload +from sqlalchemy.orm import joinedload, object_session from sqlalchemy.sql.base import ExecutableOption -from sqlmodel import Field, Relationship, SQLModel +from sqlmodel import Field, Relationship, SQLModel, select -from zenml.enums import SecretResourceTypes +from zenml.enums import SecretResourceTypes, StackComponentType from zenml.models import ( StackRequest, StackResponse, @@ -123,6 +123,41 @@ class StackSchema(NamedSchema, table=True): ), ) + @property + def has_deployer(self) -> bool: + """If the stack has a deployer component. + + Returns: + If the stack has a deployer component. + + Raises: + RuntimeError: if the stack has no DB session. + """ + from zenml.zen_stores.schemas import ( + StackComponentSchema, + StackCompositionSchema, + ) + + if session := object_session(self): + query = ( + select(StackComponentSchema.id) + .where( + StackComponentSchema.type + == StackComponentType.DEPLOYER.value + ) + .where( + StackCompositionSchema.component_id + == StackComponentSchema.id + ) + .where(StackCompositionSchema.stack_id == self.id) + ) + + return session.execute(query).first() is not None + else: + raise RuntimeError( + "Missing DB session to check if stack has a deployer component." + ) + @classmethod def from_request( cls, diff --git a/src/zenml/zen_stores/schemas/user_schemas.py b/src/zenml/zen_stores/schemas/user_schemas.py index e6b1044e110..3699158849d 100644 --- a/src/zenml/zen_stores/schemas/user_schemas.py +++ b/src/zenml/zen_stores/schemas/user_schemas.py @@ -44,6 +44,7 @@ ArtifactSchema, ArtifactVersionSchema, CodeRepositorySchema, + DeploymentSchema, EventSourceSchema, FlavorSchema, ModelSchema, @@ -169,6 +170,9 @@ class UserSchema(NamedSchema, table=True): back_populates="service_account", sa_relationship_kwargs={"cascade": "delete"}, ) + deployments: List["DeploymentSchema"] = Relationship( + back_populates="user", + ) tags: List["TagSchema"] = Relationship( back_populates="user", ) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index dd8669d915e..831dca781d9 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -134,6 +134,7 @@ ArtifactSaveType, AuthScheme, DatabaseBackupStrategy, + DeploymentStatus, ExecutionMode, ExecutionStatus, LoggingLevels, @@ -201,6 +202,10 @@ DefaultComponentRequest, DefaultStackRequest, DeployedStack, + DeploymentFilter, + DeploymentRequest, + DeploymentResponse, + DeploymentUpdate, EventSourceFilter, EventSourceRequest, EventSourceResponse, @@ -355,6 +360,7 @@ BaseSchema, CodeReferenceSchema, CodeRepositorySchema, + DeploymentSchema, EventSourceSchema, FlavorSchema, ModelSchema, @@ -3949,6 +3955,22 @@ def delete_stack_component(self, component_id: UUID) -> None: f"The default {stack_component.type} cannot be deleted." ) + if stack_component.type == StackComponentType.DEPLOYER: + deployments = session.exec( + select(DeploymentSchema) + .where(DeploymentSchema.deployer_id == stack_component.id) + .where(DeploymentSchema.status != DeploymentStatus.ABSENT) + ).all() + + if len(deployments) > 0: + raise IllegalOperationError( + f"The {stack_component.name} deployer stack component " + f"cannot be deleted because there are still " + f"{len(deployments)} deployments being managed by it " + f"and this would result in orphaned resources." + f"Please deprovision or delete the deployments first." + ) + if len(stack_component.stacks) > 0: raise IllegalOperationError( f"Stack Component `{stack_component.name}` of type " @@ -5197,6 +5219,181 @@ def run_snapshot( "Running a snapshot is not possible with a local store." ) + # -------------------- Deployments -------------------- + + @track_decorator(AnalyticsEvent.CREATE_DEPLOYMENT) + def create_deployment( + self, deployment: DeploymentRequest + ) -> DeploymentResponse: + """Create a new deployment. + + Args: + deployment: The deployment to create. + + Returns: + The newly created deployment. + """ + with Session(self.engine) as session: + self._set_request_user_id( + request_model=deployment, session=session + ) + self._verify_name_uniqueness( + resource=deployment, + schema=DeploymentSchema, + session=session, + ) + self._get_reference_schema_by_id( + resource=deployment, + reference_schema=PipelineSnapshotSchema, + reference_id=deployment.snapshot_id, + session=session, + ) + self._get_reference_schema_by_id( + resource=deployment, + reference_schema=StackComponentSchema, + reference_id=deployment.deployer_id, + session=session, + reference_type="deployer", + ) + deployment_schema = DeploymentSchema.from_request(deployment) + session.add(deployment_schema) + session.commit() + + self._attach_tags_to_resources( + tags=deployment.tags, + resources=deployment_schema, + session=session, + ) + + session.refresh(deployment_schema) + return deployment_schema.to_model( + include_metadata=True, include_resources=True + ) + + def get_deployment( + self, deployment_id: UUID, hydrate: bool = True + ) -> DeploymentResponse: + """Get a deployment with a given ID. + + Args: + deployment_id: ID of the deployment. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + The deployment. + """ + with Session(self.engine) as session: + deployment = self._get_schema_by_id( + resource_id=deployment_id, + schema_class=DeploymentSchema, + session=session, + ) + return deployment.to_model( + include_metadata=hydrate, include_resources=True + ) + + def list_deployments( + self, + deployment_filter_model: DeploymentFilter, + hydrate: bool = False, + ) -> Page[DeploymentResponse]: + """List all deployments matching the given filter criteria. + + Args: + deployment_filter_model: All filter parameters including pagination + params. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + A page of all deployments matching the filter criteria. + """ + with Session(self.engine) as session: + self._set_filter_project_id( + filter_model=deployment_filter_model, + session=session, + ) + query = select(DeploymentSchema) + return self.filter_and_paginate( + session=session, + query=query, + table=DeploymentSchema, + filter_model=deployment_filter_model, + hydrate=hydrate, + ) + + def update_deployment( + self, + deployment_id: UUID, + deployment_update: DeploymentUpdate, + ) -> DeploymentResponse: + """Update a deployment. + + Args: + deployment_id: The ID of the deployment to update. + deployment_update: The update to apply. + + Returns: + The updated deployment. + """ + with Session(self.engine) as session: + deployment = self._get_schema_by_id( + resource_id=deployment_id, + schema_class=DeploymentSchema, + session=session, + ) + + self._verify_name_uniqueness( + resource=deployment_update, + schema=deployment, + session=session, + ) + self._get_reference_schema_by_id( + resource=deployment, + reference_schema=PipelineSnapshotSchema, + reference_id=deployment_update.snapshot_id, + session=session, + ) + + deployment.update(deployment_update) + session.add(deployment) + session.commit() + + self._attach_tags_to_resources( + tags=deployment_update.add_tags, + resources=deployment, + session=session, + ) + self._detach_tags_from_resources( + tags=deployment_update.remove_tags, + resources=deployment, + session=session, + ) + + session.refresh(deployment) + + return deployment.to_model( + include_metadata=True, include_resources=True + ) + + @track_decorator(AnalyticsEvent.DELETE_DEPLOYMENT) + def delete_deployment(self, deployment_id: UUID) -> None: + """Delete a deployment. + + Args: + deployment_id: The ID of the deployment to delete. + """ + with Session(self.engine) as session: + deployment = self._get_schema_by_id( + resource_id=deployment_id, + schema_class=DeploymentSchema, + session=session, + ) + + session.delete(deployment) + session.commit() + # -------------------- Run templates -------------------- @track_decorator(AnalyticsEvent.CREATED_RUN_TEMPLATE) @@ -12944,6 +13141,7 @@ def _get_taggable_resource_type( PipelineRunSchema: TaggableResourceTypes.PIPELINE_RUN, RunTemplateSchema: TaggableResourceTypes.RUN_TEMPLATE, PipelineSnapshotSchema: TaggableResourceTypes.PIPELINE_SNAPSHOT, + DeploymentSchema: TaggableResourceTypes.DEPLOYMENT, } if type(resource) not in resource_types: raise ValueError( @@ -12986,6 +13184,7 @@ def _get_schema_from_resource_type( TaggableResourceTypes.PIPELINE_RUN: PipelineRunSchema, TaggableResourceTypes.RUN_TEMPLATE: RunTemplateSchema, TaggableResourceTypes.PIPELINE_SNAPSHOT: PipelineSnapshotSchema, + TaggableResourceTypes.DEPLOYMENT: DeploymentSchema, } return resource_type_to_schema_mapping[resource_type] @@ -13654,6 +13853,32 @@ def _create_tag_resource_schemas( resource_type=TaggableResourceTypes.PIPELINE_SNAPSHOT, ) ) + elif isinstance(resource, DeploymentSchema): + if not resource.snapshot: + continue + scope_id = resource.snapshot.pipeline_id + scope_ids[TaggableResourceTypes.DEPLOYMENT].append( + scope_id + ) + + # TODO: This is very inefficient, we should use a + # better query + older_deployments = self.list_deployments( + DeploymentFilter( + id=f"notequals:{resource.id}", + project=resource.project.id, + pipeline=scope_id, + tags=[tag_schema.name], + ) + ) + if older_deployments.items: + detach_resources.append( + TagResourceRequest( + tag_id=tag_schema.id, + resource_id=older_deployments.items[0].id, + resource_type=TaggableResourceTypes.DEPLOYMENT, + ) + ) else: raise ValueError( "Can not attach exclusive tag to resource of type " diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index 13abfa23d3b..40a1a74adf1 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -49,6 +49,10 @@ ComponentResponse, ComponentUpdate, DeployedStack, + DeploymentFilter, + DeploymentRequest, + DeploymentResponse, + DeploymentUpdate, EventSourceFilter, EventSourceRequest, EventSourceResponse, @@ -1384,6 +1388,89 @@ def run_snapshot( The created pipeline run. """ + # -------------------- Deployments -------------------- + + @abstractmethod + def create_deployment( + self, deployment: DeploymentRequest + ) -> DeploymentResponse: + """Create a new deployment. + + Args: + deployment: The deployment to create. + + Returns: + The newly created deployment. + + Raises: + EntityExistsError: If a deployment with the same name already + exists in the same project. + """ + + @abstractmethod + def get_deployment( + self, deployment_id: UUID, hydrate: bool = True + ) -> DeploymentResponse: + """Get a deployment with a given ID. + + Args: + deployment_id: ID of the deployment. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + The deployment. + + Raises: + KeyError: If the deployment does not exist. + """ + + @abstractmethod + def list_deployments( + self, + deployment_filter_model: DeploymentFilter, + hydrate: bool = False, + ) -> Page[DeploymentResponse]: + """List all deployments matching the given filter criteria. + + Args: + deployment_filter_model: All filter parameters including pagination + params. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + A list of all deployments matching the filter criteria. + """ + + @abstractmethod + def update_deployment( + self, deployment_id: UUID, deployment_update: DeploymentUpdate + ) -> DeploymentResponse: + """Update a deployment. + + Args: + deployment_id: The ID of the deployment to update. + deployment_update: The update to apply. + + Returns: + The updated deployment. + + Raises: + KeyError: If the deployment does not exist. + """ + + @abstractmethod + def delete_deployment(self, deployment_id: UUID) -> None: + """Delete a deployment. + + Args: + deployment_id: The ID of the deployment to delete. + + Raises: + KeyError: If the deployment does not exist. + """ + # -------------------- Run templates -------------------- @abstractmethod diff --git a/tests/integration/functional/deployers/server/conftest.py b/tests/integration/functional/deployers/server/conftest.py new file mode 100644 index 00000000000..1695b7d2a20 --- /dev/null +++ b/tests/integration/functional/deployers/server/conftest.py @@ -0,0 +1,31 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Test-specific fixtures for deployment integration tests.""" + +from types import SimpleNamespace +from typing import Generator, Tuple + +import pytest + + +@pytest.fixture(scope="session", autouse=True) +def auto_environment() -> Generator[ + Tuple[SimpleNamespace, SimpleNamespace], None, None +]: + """Override the global auto_environment fixture with a lightweight stub. + + Yields: + The active environment and a client connected with it. + """ + yield SimpleNamespace(), SimpleNamespace() diff --git a/tests/integration/functional/deployers/server/test_app_endpoints.py b/tests/integration/functional/deployers/server/test_app_endpoints.py new file mode 100644 index 00000000000..4a4272e0650 --- /dev/null +++ b/tests/integration/functional/deployers/server/test_app_endpoints.py @@ -0,0 +1,400 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Integration tests for FastAPI deployment application endpoints.""" + +import importlib +from types import ModuleType, SimpleNamespace +from typing import Generator, Optional, Tuple +from uuid import uuid4 + +import pytest +from fastapi.testclient import TestClient +from pydantic import BaseModel + +import zenml.deployers.server.app as deployment_app +from zenml.deployers.server.models import ( + BaseDeploymentInvocationRequest, + BaseDeploymentInvocationResponse, + DeploymentInfo, + DeploymentInvocationResponseMetadata, + ExecutionMetrics, + PipelineInfo, + ServiceInfo, + SnapshotInfo, +) + + +class MockWeatherRequest(BaseModel): + """Mock Pydantic model for testing.""" + + city: str + temperature: int = 20 + + +class StubDeploymentService: + """Stub service implementing the interface used by the FastAPI app.""" + + def __init__(self, deployment_id: str) -> None: + """Initialize the stub service. + + Args: + snapshot_id: The ID of the snapshot to use for the service. + """ + self._healthy = True + self.initialized = False + self.cleaned_up = False + self.last_request: Optional[BaseDeploymentInvocationRequest] = None + self.input_schema = { + "type": "object", + "properties": {"city": {"type": "string"}}, + } + self.output_schema = { + "type": "object", + "properties": {"result": {"type": "string"}}, + } + self.snapshot = SimpleNamespace( + id=uuid4(), + name="snapshot", + pipeline_configuration=SimpleNamespace( + name="test_pipeline", + environment={}, + init_hook_source=None, + init_hook_kwargs=None, + cleanup_hook_source=None, + ), + pipeline_spec=SimpleNamespace( + parameters={"city": "London"}, + input_schema=self.input_schema, + output_schema=self.output_schema, + ), + ) + self.deployment = SimpleNamespace( + id=uuid4(), + name="deployment", + snapshot=self.snapshot, + ) + + @property + def input_model(self) -> type[BaseModel]: # noqa: D401 + """Expose the request model expected by the service. + + Returns: + The request model expected by the service. + """ + + return MockWeatherRequest + + def initialize(self) -> None: # noqa: D401 + """Mark the service as initialized for verification in tests.""" + + self.initialized = True + + def cleanup(self) -> None: # noqa: D401 + """Mark the service as cleaned up for shutdown assertions.""" + + self.cleaned_up = True + + def is_healthy(self) -> bool: # noqa: D401 + """Return the current health flag used by tests.""" + + return self._healthy + + def set_health(self, healthy: bool) -> None: # noqa: D401 + """Set the health of the service. + + Args: + healthy: The health of the service. + """ + self._healthy = healthy + + def get_service_info(self) -> ServiceInfo: # noqa: D401 + """Retrieve public metadata describing the stub deployment.""" + + return ServiceInfo( + deployment=DeploymentInfo( + id=self.deployment.id, + name=self.deployment.name, + ), + snapshot=SnapshotInfo( + id=self.snapshot.id, name=self.snapshot.name + ), + pipeline=PipelineInfo( + name=self.snapshot.pipeline_configuration.name, + parameters=self.snapshot.pipeline_spec.parameters, + input_schema=self.input_schema, + output_schema=self.output_schema, + ), + total_executions=1, + last_execution_time=None, + status="healthy" if self._healthy else "unhealthy", + uptime=1.0, + ) + + def get_execution_metrics(self) -> ExecutionMetrics: # noqa: D401 + """Return execution metrics describing recent pipeline activity.""" + + return ExecutionMetrics(total_executions=1, last_execution_time=None) + + def execute_pipeline( + self, request: BaseDeploymentInvocationRequest + ) -> BaseDeploymentInvocationResponse: # noqa: D401 + """Execute the pipeline. + + Args: + request: The request to execute the pipeline. + + Returns: + The response from the pipeline. + """ + self.last_request = request + return BaseDeploymentInvocationResponse( + success=True, + outputs={"result": "ok"}, + execution_time=0.5, + metadata=DeploymentInvocationResponseMetadata( + deployment_id=self.deployment.id, + deployment_name=self.deployment.name, + pipeline_name="test_pipeline", + run_id=None, + run_name=None, + parameters_used=request.parameters.model_dump(), + snapshot_id=self.snapshot.id, + snapshot_name=self.snapshot.name, + ), + error=None, + ) + + +@pytest.fixture +def client_service_pair( + monkeypatch: pytest.MonkeyPatch, +) -> Generator[ + Tuple[TestClient, StubDeploymentService, ModuleType], None, None +]: + """Provide a fresh FastAPI client and stub service per test. + + Args: + monkeypatch: The monkeypatch fixture. + + Yields: + A tuple containing the FastAPI client, the stub service, and the reloaded app. + """ + reloaded_app = importlib.reload(deployment_app) + service = StubDeploymentService(str(uuid4())) + + monkeypatch.setenv("ZENML_DEPLOYMENT_ID", str(service.deployment.id)) + monkeypatch.delenv("ZENML_DEPLOYMENT_TEST_MODE", raising=False) + + def _service_factory(_: str) -> StubDeploymentService: + """Factory function for creating a stub service. + + Args: + _: The snapshot ID to use for the service. + + Returns: + The stub service. + """ + return service + + monkeypatch.setattr( + reloaded_app, + "PipelineDeploymentService", + _service_factory, + ) + + with TestClient(reloaded_app.app) as client: + yield client, service, reloaded_app + + +class TestFastAPIAppEndpoints: + """Integration tests for FastAPI application endpoints.""" + + def test_root_endpoint( + self, + client_service_pair: Tuple[ + TestClient, StubDeploymentService, ModuleType + ], + ) -> None: + """Ensure the root endpoint renders the deployment overview.""" + client, service, _ = client_service_pair + response = client.get("/") + assert response.status_code == 200 + assert "ZenML Pipeline Deployment" in response.text + assert "test_pipeline" in response.text + assert service.initialized is True + + def test_health_endpoint_healthy( + self, + client_service_pair: Tuple[ + TestClient, StubDeploymentService, ModuleType + ], + ) -> None: + """Ensure the health endpoint returns OK for healthy services.""" + client, _, _ = client_service_pair + response = client.get("/health") + assert response.status_code == 200 + assert response.json() == "OK" + + def test_health_endpoint_unhealthy( + self, + client_service_pair: Tuple[ + TestClient, StubDeploymentService, ModuleType + ], + ) -> None: + """Return a 503 status when the service reports unhealthy.""" + client, service, _ = client_service_pair + service.set_health(False) + response = client.get("/health") + assert response.status_code == 503 + + def test_info_endpoint( + self, + client_service_pair: Tuple[ + TestClient, StubDeploymentService, ModuleType + ], + ) -> None: + """Expose pipeline and snapshot metadata via /info.""" + client, service, _ = client_service_pair + response = client.get("/info") + assert response.status_code == 200 + data = response.json() + assert data["pipeline"]["name"] == "test_pipeline" + assert data["pipeline"]["input_schema"] == service.input_schema + assert data["snapshot"]["name"] == "snapshot" + + def test_metrics_endpoint( + self, + client_service_pair: Tuple[ + TestClient, StubDeploymentService, ModuleType + ], + ) -> None: + """Surface execution metrics through the metrics endpoint.""" + client, _, _ = client_service_pair + response = client.get("/metrics") + assert response.status_code == 200 + data = response.json() + assert data["total_executions"] == 1 + assert data["last_execution_time"] is None + + def test_invoke_endpoint_success( + self, + client_service_pair: Tuple[ + TestClient, StubDeploymentService, ModuleType + ], + ) -> None: + """Propagate successful execution responses for valid payloads.""" + client, service, _ = client_service_pair + payload = {"parameters": {"city": "Paris", "temperature": 25}} + + response = client.post("/invoke", json=payload) + + assert response.status_code == 200 + body = response.json() + assert body["success"] is True + assert body["outputs"] == {"result": "ok"} + assert service.last_request.parameters.city == "Paris" + + def test_invoke_endpoint_execution_failure( + self, + client_service_pair: Tuple[ + TestClient, StubDeploymentService, ModuleType + ], + ) -> None: + """Propagate failure responses without raising errors.""" + client, service, module = client_service_pair + failure_response = BaseDeploymentInvocationResponse( + success=False, + outputs=None, + execution_time=0.1, + metadata=DeploymentInvocationResponseMetadata( + deployment_id=service.deployment.id, + deployment_name=service.deployment.name, + pipeline_name="test_pipeline", + run_id=None, + run_name=None, + parameters_used={}, + snapshot_id=service.snapshot.id, + snapshot_name=service.snapshot.name, + ), + error="Pipeline execution failed", + ) + + service.execute_pipeline = lambda request: failure_response + + response = client.post( + "/invoke", json={"parameters": {"city": "Paris"}} + ) + assert response.status_code == 200 + assert response.json()["success"] is False + + def test_cleanup_called_on_shutdown( + self, + monkeypatch: pytest.MonkeyPatch, + client_service_pair: Tuple[ + TestClient, StubDeploymentService, ModuleType + ], + ) -> None: + """Trigger service cleanup when the application shuts down.""" + reloaded_app = importlib.reload(deployment_app) + service = StubDeploymentService(str(uuid4())) + monkeypatch.setenv("ZENML_DEPLOYMENT_ID", str(service.deployment.id)) + monkeypatch.setattr( + reloaded_app, + "PipelineDeploymentService", + lambda deployment_id: service, + ) + with TestClient(reloaded_app.app): + pass + + assert service.initialized is True + assert service.cleaned_up is True + + +class TestOpenAPIIntegration: + """Integration tests for OpenAPI schema installation.""" + + def test_openapi_includes_invoke_models( + self, + client_service_pair: Tuple[ + TestClient, StubDeploymentService, ModuleType + ], + ) -> None: + """Include invoke request / response models within the OpenAPI schema.""" + client, service, module = client_service_pair + schema = client.get("/openapi.json").json() + operation = schema["paths"]["/invoke"]["post"] + + request_schema = operation["requestBody"]["content"][ + "application/json" + ]["schema"] + if "$ref" in request_schema: + ref = request_schema["$ref"].split("/")[-1] + request_schema = schema["components"]["schemas"][ref] + + parameters_schema = request_schema["properties"]["parameters"] + assert parameters_schema["properties"]["city"]["type"] == "string" + + response_schema = operation["responses"]["200"]["content"][ + "application/json" + ]["schema"] + if "$ref" in response_schema: + ref = response_schema["$ref"].split("/")[-1] + response_schema = schema["components"]["schemas"][ref] + + outputs_schema = response_schema["properties"]["outputs"] + if "$ref" in outputs_schema: + ref = outputs_schema["$ref"].split("/")[-1] + outputs_schema = schema["components"]["schemas"][ref] + + assert outputs_schema["properties"]["result"]["type"] == "string" diff --git a/tests/integration/integrations/aws/deployers/test_app_runner_deployer.py b/tests/integration/integrations/aws/deployers/test_app_runner_deployer.py new file mode 100644 index 00000000000..bdc7c16a3be --- /dev/null +++ b/tests/integration/integrations/aws/deployers/test_app_runner_deployer.py @@ -0,0 +1,99 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import re +from typing import List, Optional, Tuple + +import requests +from bs4 import BeautifulSoup + +from zenml.enums import StackComponentType +from zenml.integrations.aws.flavors import AWSDeployerFlavor +from zenml.integrations.aws.flavors.aws_deployer_flavor import ( + DEFAULT_RESOURCE_COMBINATIONS, +) + + +def test_aws_app_runner_deployer_flavor_attributes(): + """Tests that the basic attributes of the AWS App Runner deployer flavor are set correctly.""" + flavor = AWSDeployerFlavor() + assert flavor.type == StackComponentType.DEPLOYER + assert flavor.name == "aws" + + +DG_URL = "https://docs.aws.amazon.com/apprunner/latest/dg/architecture.html" # table with exact valid pairs + + +def _fetch_documented_supported_resource_combinations( + html: str, +) -> List[Tuple[float, float]]: + """Parse the 'App Runner supported configurations' table into (vCPU, GB) pairs.""" + soup = BeautifulSoup(html, "html.parser") + # Find the section that contains the supported configurations table + # Strategy: locate the heading text, then the first following table + heading = None + for tag in soup.find_all(["h1", "h2", "h3", "h4", "h5", "h6"]): + if "supported configurations" in tag.get_text(strip=True).lower(): + heading = tag + break + if not heading: + raise RuntimeError("Supported configurations heading not found") + + table = heading.find_next("table") + if not table: + raise RuntimeError("Supported configurations table not found") + + rows = [] + for tr in table.find_all("tr"): + cells = [c.get_text(strip=True) for c in tr.find_all(["td", "th"])] + if len(cells) != 2: # Expect 2 columns: CPU | Memory + continue + if cells[0].lower().startswith("cpu"): # skip header row + continue + cpu = _to_vcpu(cells[0]) + mem = _to_gb(cells[1]) + if cpu is not None and mem is not None: + rows.append((cpu, mem)) + + if not rows: + raise RuntimeError("No (CPU, Memory) pairs parsed from table") + # Deduplicate and sort + return sorted(set(rows), key=lambda x: (x[0], x[1])) + + +def _to_vcpu(s: str) -> Optional[float]: + m = re.search(r"([\d.]+)\s*v?CPU", s, re.IGNORECASE) + return float(m.group(1)) if m else None + + +def _to_gb(s: str) -> Optional[float]: + # Accept MB/GB but App Runner uses GB in the table + m = re.search(r"([\d.]+)\s*GB", s, re.IGNORECASE) + if m: + return float(m.group(1)) + m = re.search(r"(\d+)\s*MB", s, re.IGNORECASE) + return (float(m.group(1)) / 1024.0) if m else None + + +def test_aws_app_runner_deployer_flavor_resource_combinations(): + """Tests that the resource combinations of the AWS App Runner deployer flavor are set correctly.""" + dg_html = requests.get(DG_URL, timeout=15).text + supported_combinations = _fetch_documented_supported_resource_combinations( + dg_html + ) + # If this test fails, it is likely because the default resource combinations in the + # AWS App Runner deployer flavor are no longer up to date and need to be + # updated to match + assert DEFAULT_RESOURCE_COMBINATIONS == supported_combinations diff --git a/tests/unit/config/test_compiler.py b/tests/unit/config/test_compiler.py index c1ffd6679b7..6a3976ba5c2 100644 --- a/tests/unit/config/test_compiler.py +++ b/tests/unit/config/test_compiler.py @@ -349,10 +349,11 @@ def test_success_hook_merging( pipeline_instance.configure(on_success=pipeline_hook) step_instance_1.configure(on_success=step_hook) + success_hook_source, _ = resolve_and_validate_hook(step_hook) run_config = PipelineRunConfiguration( steps={ "_empty_step": StepConfigurationUpdate( - success_hook_source=resolve_and_validate_hook(step_hook) + success_hook_source=success_hook_source ) }, ) @@ -368,21 +369,20 @@ def test_success_hook_merging( compiled_pipeline_success_hook = ( snapshot.pipeline_configuration.success_hook_source ) - assert compiled_pipeline_success_hook == resolve_and_validate_hook( - pipeline_hook - ) + resolved_hook, _ = resolve_and_validate_hook(pipeline_hook) + assert compiled_pipeline_success_hook == resolved_hook compiled_step_1_success_hook = snapshot.step_configurations[ "_empty_step" ].config.success_hook_source - assert compiled_step_1_success_hook == resolve_and_validate_hook(step_hook) + resolved_hook, _ = resolve_and_validate_hook(step_hook) + assert compiled_step_1_success_hook == resolved_hook compiled_step_2_success_hook = snapshot.step_configurations[ "_empty_step_2" ].config.success_hook_source - assert compiled_step_2_success_hook == resolve_and_validate_hook( - pipeline_hook - ) + resolved_hook, _ = resolve_and_validate_hook(pipeline_hook) + assert compiled_step_2_success_hook == resolved_hook def test_failure_hook_merging( @@ -400,10 +400,11 @@ def test_failure_hook_merging( pipeline_instance.configure(on_failure=pipeline_hook) step_instance_1.configure(on_failure=step_hook) + failure_hook_source, _ = resolve_and_validate_hook(step_hook) run_config = PipelineRunConfiguration( steps={ "_empty_step": StepConfigurationUpdate( - failure_hook_source=resolve_and_validate_hook(step_hook) + failure_hook_source=failure_hook_source ) }, ) @@ -419,21 +420,20 @@ def test_failure_hook_merging( compiled_pipeline_failure_hook = ( snapshot.pipeline_configuration.failure_hook_source ) - assert compiled_pipeline_failure_hook == resolve_and_validate_hook( - pipeline_hook - ) + resolved_hook, _ = resolve_and_validate_hook(pipeline_hook) + assert compiled_pipeline_failure_hook == resolved_hook compiled_step_1_failure_hook = snapshot.step_configurations[ "_empty_step" ].config.failure_hook_source - assert compiled_step_1_failure_hook == resolve_and_validate_hook(step_hook) + resolved_hook, _ = resolve_and_validate_hook(step_hook) + assert compiled_step_1_failure_hook == resolved_hook compiled_step_2_failure_hook = snapshot.step_configurations[ "_empty_step_2" ].config.failure_hook_source - assert compiled_step_2_failure_hook == resolve_and_validate_hook( - pipeline_hook - ) + resolved_hook, _ = resolve_and_validate_hook(pipeline_hook) + assert compiled_step_2_failure_hook == resolved_hook def test_stack_component_settings_for_missing_component_are_ignored( diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 036ef4f37e8..bad0c00d28d 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -632,6 +632,7 @@ def sample_snapshot_response_model( created=datetime.now(), updated=datetime.now(), runnable=True, + deployable=True, ), metadata=PipelineSnapshotResponseMetadata( run_name_template="", diff --git a/tests/unit/deployers/__init__.py b/tests/unit/deployers/__init__.py new file mode 100644 index 00000000000..a3dbc650c37 --- /dev/null +++ b/tests/unit/deployers/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Unit tests for deployers.""" \ No newline at end of file diff --git a/tests/unit/deployers/server/__init__.py b/tests/unit/deployers/server/__init__.py new file mode 100644 index 00000000000..f1725e01594 --- /dev/null +++ b/tests/unit/deployers/server/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Unit tests for deployment functionality.""" \ No newline at end of file diff --git a/tests/unit/deployers/server/conftest.py b/tests/unit/deployers/server/conftest.py new file mode 100644 index 00000000000..64041fd320a --- /dev/null +++ b/tests/unit/deployers/server/conftest.py @@ -0,0 +1,16 @@ +"""Test-specific fixtures for deployment unit tests.""" + +from types import SimpleNamespace +from typing import Iterator, Tuple + +import pytest + + +@pytest.fixture(scope="session", autouse=True) +def auto_environment() -> Iterator[Tuple[SimpleNamespace, SimpleNamespace]]: + """Override the global auto_environment fixture with a lightweight stub. + + Yields: + The active environment and a connected client stub. + """ + yield SimpleNamespace(), SimpleNamespace() diff --git a/tests/unit/deployers/server/test_app.py b/tests/unit/deployers/server/test_app.py new file mode 100644 index 00000000000..eb22fc9ab41 --- /dev/null +++ b/tests/unit/deployers/server/test_app.py @@ -0,0 +1,400 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Unit tests for deployment app functionality.""" + +from __future__ import annotations + +import asyncio +import json +from typing import cast +from uuid import uuid4 + +import pytest +from fastapi import FastAPI, HTTPException, Request +from fastapi.security import HTTPAuthorizationCredentials +from fastapi.testclient import TestClient +from pydantic import BaseModel +from pytest_mock import MockerFixture + +from zenml.deployers.server.app import ( + _build_invoke_router, + app, + get_pipeline_service, + lifespan, + runtime_error_handler, + value_error_handler, + verify_token, +) +from zenml.deployers.server.models import ( + BaseDeploymentInvocationResponse, + DeploymentInfo, + DeploymentInvocationResponseMetadata, + ExecutionMetrics, + PipelineInfo, + ServiceInfo, + SnapshotInfo, +) +from zenml.deployers.server.service import PipelineDeploymentService + + +class MockWeatherRequest(BaseModel): + """Mock Pydantic model for testing.""" + + city: str + temperature: int = 20 + + +@pytest.fixture +def mock_service(mocker: MockerFixture) -> PipelineDeploymentService: + """Mock pipeline deployment service configured for the app tests.""" + + service = cast( + PipelineDeploymentService, + mocker.MagicMock(spec=PipelineDeploymentService), + ) + snapshot_id = uuid4() + deployment_id = uuid4() + + service.input_model = MockWeatherRequest + service.is_healthy.return_value = True + service.input_schema = { + "type": "object", + "properties": {"city": {"type": "string"}}, + } + service.output_schema = { + "type": "object", + "properties": {"result": {"type": "string"}}, + } + + service.get_service_info.return_value = ServiceInfo( + deployment=DeploymentInfo(id=deployment_id, name="deployment"), + snapshot=SnapshotInfo(id=snapshot_id, name="snapshot"), + pipeline=PipelineInfo( + name="test_pipeline", + parameters={"city": "London"}, + input_schema=service.input_schema, + output_schema=service.output_schema, + ), + total_executions=3, + last_execution_time=None, + status="healthy", + uptime=12.34, + ) + service.get_execution_metrics.return_value = ExecutionMetrics( + total_executions=3, + last_execution_time=None, + ) + service.execute_pipeline.return_value = BaseDeploymentInvocationResponse( + success=True, + outputs={"result": "ok"}, + execution_time=0.5, + metadata=DeploymentInvocationResponseMetadata( + deployment_id=deployment_id, + deployment_name="deployment", + pipeline_name="test_pipeline", + run_id=None, + run_name=None, + parameters_used={"city": "Paris", "temperature": 25}, + snapshot_id=snapshot_id, + snapshot_name="snapshot", + ), + error=None, + ) + return service + + +class TestDeploymentAppRoutes: + """Test FastAPI app routes.""" + + def test_root_endpoint( + self, + mock_service: PipelineDeploymentService, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Root endpoint returns HTML with pipeline information.""" + monkeypatch.setenv("ZENML_DEPLOYMENT_TEST_MODE", "true") + monkeypatch.setattr( + "zenml.deployers.server.app._service", mock_service + ) + with TestClient(app) as client: + response = client.get("/") + + assert response.status_code == 200 + assert response.headers["content-type"].startswith("text/html") + assert "ZenML Pipeline Deployment" in response.text + assert "test_pipeline" in response.text + + def test_health_endpoint( + self, + mock_service: PipelineDeploymentService, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Health endpoint returns OK when service is healthy.""" + monkeypatch.setenv("ZENML_DEPLOYMENT_TEST_MODE", "true") + monkeypatch.setattr( + "zenml.deployers.server.app._service", mock_service + ) + with TestClient(app) as client: + response = client.get("/health") + + assert response.status_code == 200 + assert response.json() == "OK" + + def test_health_endpoint_unhealthy( + self, + mock_service: PipelineDeploymentService, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Health endpoint raises when service reports unhealthy state.""" + mock_service.is_healthy.return_value = False + + monkeypatch.setenv("ZENML_DEPLOYMENT_TEST_MODE", "true") + monkeypatch.setattr( + "zenml.deployers.server.app._service", mock_service + ) + with TestClient(app) as client: + response = client.get("/health") + + assert response.status_code == 503 + assert response.json()["detail"] == "Service is unhealthy" + + def test_info_endpoint( + self, + mock_service: PipelineDeploymentService, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Info endpoint returns service metadata.""" + monkeypatch.setenv("ZENML_DEPLOYMENT_TEST_MODE", "true") + monkeypatch.setattr( + "zenml.deployers.server.app._service", mock_service + ) + with TestClient(app) as client: + response = client.get("/info") + + assert response.status_code == 200 + data = response.json() + assert data["pipeline"]["name"] == "test_pipeline" + assert data["pipeline"]["parameters"] == {"city": "London"} + assert data["status"] == "healthy" + assert data["snapshot"]["name"] == "snapshot" + + def test_metrics_endpoint( + self, + mock_service: PipelineDeploymentService, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Metrics endpoint exposes execution metrics.""" + monkeypatch.setenv("ZENML_DEPLOYMENT_TEST_MODE", "true") + monkeypatch.setattr( + "zenml.deployers.server.app._service", mock_service + ) + with TestClient(app) as client: + response = client.get("/metrics") + + assert response.status_code == 200 + data = response.json() + assert data["total_executions"] == 3 + assert data["last_execution_time"] is None + + def test_info_endpoint_includes_schemas( + self, + mock_service: PipelineDeploymentService, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Info endpoint includes input/output schemas.""" + monkeypatch.setenv("ZENML_DEPLOYMENT_TEST_MODE", "true") + monkeypatch.setattr( + "zenml.deployers.server.app._service", mock_service + ) + with TestClient(app) as client: + response = client.get("/info") + + data = response.json() + assert data["pipeline"]["input_schema"] == mock_service.input_schema + assert data["pipeline"]["output_schema"] == mock_service.output_schema + + def test_get_pipeline_service_returns_current_instance( + self, + mock_service: PipelineDeploymentService, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Ensure get_pipeline_service exposes the underlying instance.""" + monkeypatch.setattr( + "zenml.deployers.server.app._service", mock_service + ) + assert get_pipeline_service() is mock_service + + +class TestDeploymentAppInvoke: + """Test pipeline invocation via FastAPI.""" + + def test_invoke_endpoint_executes_service( + self, mock_service: PipelineDeploymentService + ) -> None: + """Invoke router validates payloads and calls the service.""" + fast_app = FastAPI() + fast_app.include_router(_build_invoke_router(mock_service)) + + with TestClient(fast_app) as client: + payload = {"parameters": {"city": "Paris", "temperature": 25}} + response = client.post("/invoke", json=payload) + + assert response.status_code == 200 + assert response.json()["success"] is True + mock_service.execute_pipeline.assert_called_once() + request_arg = mock_service.execute_pipeline.call_args.args[0] + assert request_arg.parameters.city == "Paris" + assert request_arg.skip_artifact_materialization is False + + def test_invoke_endpoint_validation_error( + self, mock_service: PipelineDeploymentService + ) -> None: + """Invalid payloads trigger validation errors.""" + fast_app = FastAPI() + fast_app.include_router(_build_invoke_router(mock_service)) + + with TestClient(fast_app) as client: + response = client.post("/invoke", json={"parameters": {}}) + + assert response.status_code == 422 + mock_service.execute_pipeline.assert_not_called() + + def test_verify_token_with_auth_enabled( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Token verification when authentication is enabled.""" + monkeypatch.setenv("ZENML_DEPLOYMENT_AUTH_KEY", "test-auth-key") + + credentials = HTTPAuthorizationCredentials( + scheme="Bearer", credentials="test-auth-key" + ) + assert verify_token(credentials) is None + + with pytest.raises(HTTPException): + verify_token( + HTTPAuthorizationCredentials( + scheme="Bearer", credentials="wrong" + ) + ) + + with pytest.raises(HTTPException): + verify_token(None) + + def test_verify_token_with_auth_disabled( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Token verification when authentication is disabled.""" + monkeypatch.delenv("ZENML_DEPLOYMENT_AUTH_KEY", raising=False) + assert verify_token(None) is None + + +class TestDeploymentAppLifecycle: + """Test app lifecycle management.""" + + def test_lifespan_test_mode(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Lifespan exits early in test mode.""" + monkeypatch.setenv("ZENML_DEPLOYMENT_TEST_MODE", "true") + + async def _run() -> None: + async with lifespan(app): + pass + + asyncio.run(_run()) + + def test_lifespan_normal_mode( + self, + monkeypatch: pytest.MonkeyPatch, + mocker: MockerFixture, + ) -> None: + """Lifespan initializes and cleans up service in normal mode.""" + monkeypatch.setenv("ZENML_DEPLOYMENT_ID", "test-deployment-id") + + mock_service = cast( + PipelineDeploymentService, + mocker.MagicMock(spec=PipelineDeploymentService), + ) + mock_service.input_model = MockWeatherRequest + mock_service.initialize = mocker.MagicMock() + mock_service.cleanup = mocker.MagicMock() + + mocker.patch( + "zenml.deployers.server.app.PipelineDeploymentService", + return_value=mock_service, + ) + mock_include = mocker.patch.object(app, "include_router") + + async def _run() -> None: + async with lifespan(app): + pass + + asyncio.run(_run()) + + mock_include.assert_called() + mock_service.initialize.assert_called_once() + mock_service.cleanup.assert_called_once() + + def test_lifespan_missing_snapshot_id( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Lifespan raises when no snapshot id is configured.""" + monkeypatch.delenv("ZENML_DEPLOYMENT_ID", raising=False) + + async def _run() -> None: + with pytest.raises(ValueError, match="ZENML_DEPLOYMENT_ID"): + async with lifespan(app): + pass + + asyncio.run(_run()) + + +class TestDeploymentAppErrorHandling: + """Test app error handling.""" + + def test_value_error_handler(self) -> None: + """ValueError exception handler returns 400 with message.""" + request = Request( + {"type": "http", "method": "POST", "url": "http://test"} + ) + error = ValueError("Test error") + + response = value_error_handler(request, error) + assert response.status_code == 400 + payload = json.loads(response.body) + assert payload["detail"] == "Test error" + + def test_runtime_error_handler(self) -> None: + """RuntimeError exception handler returns 500 with message.""" + request = Request( + {"type": "http", "method": "POST", "url": "http://test"} + ) + error = RuntimeError("Runtime error") + + response = runtime_error_handler(request, error) + assert response.status_code == 500 + payload = json.loads(response.body) + assert payload["detail"] == "Runtime error" + + +class TestBuildInvokeRouter: + """Test the invoke router building functionality.""" + + def test_build_invoke_router( + self, mock_service: PipelineDeploymentService + ) -> None: + """Building the invoke router exposes /invoke route.""" + router = _build_invoke_router(mock_service) + + assert router is not None + routes = [route.path for route in router.routes] + assert "/invoke" in routes diff --git a/tests/unit/deployers/server/test_parameter_flow.py b/tests/unit/deployers/server/test_parameter_flow.py new file mode 100644 index 00000000000..e65cc204555 --- /dev/null +++ b/tests/unit/deployers/server/test_parameter_flow.py @@ -0,0 +1,109 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Comprehensive test for parameter resolution and flow in deployment.""" + +from unittest.mock import MagicMock + +import pytest + +from zenml.deployers.server import runtime + + +class TestOutputRecording: + """Test output recording and retrieval functionality.""" + + @pytest.fixture(autouse=True) + def setup_deployment_state(self): + """Set up deployment state for each test.""" + runtime.stop() + yield + runtime.stop() + + def test_record_and_get_outputs(self): + """Test recording and retrieving step outputs.""" + snapshot = MagicMock() + snapshot.id = "test-snapshot" + + runtime.start( + request_id="test-request", + snapshot=snapshot, + parameters={"param": "value"}, + ) + + # Record some outputs + runtime.record_step_outputs( + "step1", {"result": "output1", "score": 0.95} + ) + runtime.record_step_outputs("step2", {"prediction": "class_a"}) + + # Retrieve all outputs + all_outputs = runtime.get_outputs() + + assert "step1" in all_outputs + assert "step2" in all_outputs + assert all_outputs["step1"]["result"] == "output1" + assert all_outputs["step1"]["score"] == 0.95 + assert all_outputs["step2"]["prediction"] == "class_a" + + def test_record_outputs_inactive_context(self): + """Test that recording does nothing when context is inactive.""" + # Don't start context + runtime.record_step_outputs("step1", {"result": "output1"}) + + # Should not record anything + snapshot = MagicMock() + snapshot.id = "test-snapshot" + + runtime.start( + request_id="test-request", snapshot=snapshot, parameters={} + ) + + outputs = runtime.get_outputs() + assert outputs == {} + + def test_record_empty_outputs(self): + """Test recording empty outputs.""" + snapshot = MagicMock() + snapshot.id = "test-snapshot" + + runtime.start( + request_id="test-request", snapshot=snapshot, parameters={} + ) + + # Record empty outputs + runtime.record_step_outputs("step1", {}) + runtime.record_step_outputs("step2", None) + + outputs = runtime.get_outputs() + assert outputs == {} + + def test_multiple_output_updates(self): + """Test multiple updates to same step outputs.""" + snapshot = MagicMock() + snapshot.id = "test-snapshot" + + runtime.start( + request_id="test-request", snapshot=snapshot, parameters={} + ) + + # Record outputs in multiple calls + runtime.record_step_outputs("step1", {"result": "first"}) + runtime.record_step_outputs("step1", {"score": 0.8}) + runtime.record_step_outputs( + "step1", {"result": "updated"} + ) # Should overwrite + + outputs = runtime.get_outputs() + assert outputs["step1"]["result"] == "updated" + assert outputs["step1"]["score"] == 0.8 diff --git a/tests/unit/deployers/server/test_runtime.py b/tests/unit/deployers/server/test_runtime.py new file mode 100644 index 00000000000..494604eb044 --- /dev/null +++ b/tests/unit/deployers/server/test_runtime.py @@ -0,0 +1,288 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Unit tests for deployment runtime context management.""" + +from unittest.mock import MagicMock + +import pytest + +from zenml.deployers.server import runtime + + +class TestDeploymentRuntimeContext: + """Test deployment runtime context management.""" + + @pytest.fixture(autouse=True) + def setup_runtime(self): + """Ensure clean runtime state before each test.""" + runtime.stop() + yield + runtime.stop() + + def test_context_lifecycle(self): + """Test basic context start/stop lifecycle.""" + assert not runtime.is_active() + + snapshot = MagicMock() + snapshot.id = "test-snapshot" + + # Start context + runtime.start( + request_id="test-request", + snapshot=snapshot, + parameters={"city": "Berlin", "temperature": 25}, + skip_artifact_materialization=True, + ) + + assert runtime.is_active() + + # Stop context + runtime.stop() + + assert not runtime.is_active() + + def test_step_outputs_recording(self): + """Test step outputs recording and retrieval.""" + snapshot = MagicMock() + snapshot.id = "test-snapshot" + + runtime.start( + request_id="test-request", + snapshot=snapshot, + parameters={}, + ) + + # Record some outputs + runtime.record_step_outputs( + "step1", {"result": "output1", "score": 0.95} + ) + runtime.record_step_outputs("step2", {"prediction": "class_a"}) + + # Retrieve all outputs + all_outputs = runtime.get_outputs() + + assert "step1" in all_outputs + assert "step2" in all_outputs + assert all_outputs["step1"]["result"] == "output1" + assert all_outputs["step1"]["score"] == 0.95 + assert all_outputs["step2"]["prediction"] == "class_a" + + def test_step_outputs_inactive_context(self): + """Test that recording does nothing when context is inactive.""" + # Don't start context + runtime.record_step_outputs("step1", {"result": "output1"}) + + # Start context and check - should be empty + snapshot = MagicMock() + snapshot.id = "test-snapshot" + + runtime.start( + request_id="test-request", + snapshot=snapshot, + parameters={}, + ) + + outputs = runtime.get_outputs() + assert outputs == {} + + def test_step_outputs_empty_data(self): + """Test recording empty outputs.""" + snapshot = MagicMock() + snapshot.id = "test-snapshot" + + runtime.start( + request_id="test-request", + snapshot=snapshot, + parameters={}, + ) + + # Record empty outputs + runtime.record_step_outputs("step1", {}) + runtime.record_step_outputs("step2", None) + + outputs = runtime.get_outputs() + assert outputs == {} + + def test_step_outputs_multiple_updates(self): + """Test multiple updates to same step outputs.""" + snapshot = MagicMock() + snapshot.id = "test-snapshot" + + runtime.start( + request_id="test-request", + snapshot=snapshot, + parameters={}, + ) + + # Record outputs in multiple calls + runtime.record_step_outputs("step1", {"result": "first"}) + runtime.record_step_outputs("step1", {"score": 0.8}) + runtime.record_step_outputs( + "step1", {"result": "updated"} + ) # Should overwrite + + outputs = runtime.get_outputs() + assert outputs["step1"]["result"] == "updated" + assert outputs["step1"]["score"] == 0.8 + + def test_in_memory_data_storage(self): + """Test in-memory data storage and retrieval.""" + snapshot = MagicMock() + snapshot.id = "test-snapshot" + + runtime.start( + request_id="test-request", + snapshot=snapshot, + parameters={}, + ) + + # Store some data + runtime.put_in_memory_data("memory://artifact/1", {"data": "value1"}) + runtime.put_in_memory_data("memory://artifact/2", "string_value") + + # Retrieve data + assert runtime.get_in_memory_data("memory://artifact/1") == { + "data": "value1" + } + assert ( + runtime.get_in_memory_data("memory://artifact/2") == "string_value" + ) + assert runtime.get_in_memory_data("memory://missing") is None + + def test_in_memory_data_inactive_context(self): + """Test in-memory data operations when context is inactive.""" + # Don't start context + runtime.put_in_memory_data("memory://artifact/1", {"data": "value"}) + + # Should not store anything + assert runtime.get_in_memory_data("memory://artifact/1") is None + + def test_context_isolation(self): + """Test that multiple contexts don't interfere with each other.""" + snapshot1 = MagicMock() + snapshot1.id = "snapshot-1" + + snapshot2 = MagicMock() + snapshot2.id = "snapshot-2" + + # Start first context + runtime.start( + request_id="request-1", + snapshot=snapshot1, + parameters={"city": "Berlin"}, + ) + + runtime.record_step_outputs("step1", {"result": "berlin_result"}) + runtime.put_in_memory_data("memory://artifact/1", "berlin_data") + + # Verify first context state + assert runtime.get_outputs()["step1"]["result"] == "berlin_result" + assert ( + runtime.get_in_memory_data("memory://artifact/1") == "berlin_data" + ) + + # Stop first context + runtime.stop() + + # Start second context + runtime.start( + request_id="request-2", + snapshot=snapshot2, + parameters={"city": "Munich"}, + ) + + # Should have clean state + assert runtime.get_outputs() == {} + assert runtime.get_in_memory_data("memory://artifact/1") is None + + def test_skip_artifact_materialization_setting(self): + """Test skip_artifact_materialization setting functionality.""" + snapshot = MagicMock() + snapshot.id = "test-snapshot" + + # Test with skip_artifact_materialization=True + runtime.start( + request_id="test-request", + snapshot=snapshot, + parameters={}, + skip_artifact_materialization=True, + ) + + assert runtime.should_skip_artifact_materialization() is True + assert runtime.should_skip_artifact_materialization() is True + + runtime.stop() + + # Test with skip_artifact_materialization=False + runtime.start( + request_id="test-request", + snapshot=snapshot, + parameters={}, + skip_artifact_materialization=False, + ) + + assert runtime.should_skip_artifact_materialization() is False + + runtime.stop() + + # Test with skip_artifact_materialization=False (default) + runtime.start( + request_id="test-request", + snapshot=snapshot, + parameters={}, + ) + + assert runtime.should_skip_artifact_materialization() is False + + def test_skip_artifact_materialization_inactive_context(self): + """Test skip_artifact_materialization functions when context is inactive.""" + assert runtime.should_skip_artifact_materialization() is False + + def test_context_reset_clears_all_data(self): + """Test that context reset clears all stored data.""" + snapshot = MagicMock() + snapshot.id = "test-snapshot" + + runtime.start( + request_id="test-request", + snapshot=snapshot, + parameters={"city": "Berlin"}, + skip_artifact_materialization=True, + ) + + # Store various types of data + runtime.record_step_outputs("step1", {"result": "output"}) + runtime.put_in_memory_data("memory://artifact/1", "data") + + # Verify data is stored + assert runtime.is_active() + assert runtime.get_outputs() != {} + assert runtime.should_skip_artifact_materialization() is True + + # Stop context (triggers reset) + runtime.stop() + + # Verify everything is cleared + assert not runtime.is_active() + + # Start new context to verify clean state + runtime.start( + request_id="new-request", + snapshot=snapshot, + parameters={}, + ) + + assert runtime.get_outputs() == {} + assert runtime.get_in_memory_data("memory://artifact/1") is None + assert runtime.should_skip_artifact_materialization() is False diff --git a/tests/unit/deployers/server/test_service.py b/tests/unit/deployers/server/test_service.py new file mode 100644 index 00000000000..7878aa1781e --- /dev/null +++ b/tests/unit/deployers/server/test_service.py @@ -0,0 +1,340 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Unit tests for the pipeline deployment service.""" + +from __future__ import annotations + +from contextlib import contextmanager +from types import SimpleNamespace +from typing import Dict, Iterator +from uuid import UUID, uuid4 + +import pytest +from pydantic import BaseModel +from pytest_mock import MockerFixture + +from zenml.deployers.server.models import BaseDeploymentInvocationRequest +from zenml.deployers.server.service import PipelineDeploymentService + + +class WeatherParams(BaseModel): + """Minimal parameter model used for service tests.""" + + city: str + temperature: int = 20 + + +def _make_snapshot() -> SimpleNamespace: + """Create a snapshot stub with the attributes accessed by the service.""" + + pipeline_configuration = SimpleNamespace( + name="test_pipeline", + environment={}, + init_hook_source=None, + init_hook_kwargs={}, + cleanup_hook_source=None, + ) + pipeline_spec = SimpleNamespace( + parameters={"city": "London"}, + input_schema={"type": "object"}, + output_schema={"type": "object"}, + outputs=[], + source="test.module.pipeline", + ) + stack = SimpleNamespace(name="test_stack") + + return SimpleNamespace( + id=uuid4(), + name="snapshot", + pipeline_configuration=pipeline_configuration, + pipeline_spec=pipeline_spec, + step_configurations={}, + stack=stack, + ) + + +def _make_deployment() -> SimpleNamespace: + """Create a deployment stub with the attributes accessed by the service.""" + return SimpleNamespace( + id=uuid4(), name="deployment", snapshot=_make_snapshot() + ) + + +def _make_service_stub(mocker: MockerFixture) -> PipelineDeploymentService: + """Create a service instance without running __init__ for isolated tests.""" + deployment = _make_deployment() + service = PipelineDeploymentService.__new__(PipelineDeploymentService) + service._client = mocker.MagicMock() + service._orchestrator = mocker.MagicMock() + mocker.patch.object( + type(service), + "input_model", + new_callable=mocker.PropertyMock, + return_value=WeatherParams, + ) + service.service_start_time = 100.0 + service.last_execution_time = None + service.total_executions = 0 + service.deployment = deployment + service.snapshot = deployment.snapshot + return service + + +def test_initialization_loads_deployment( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """__init__ should load the deployment from the store.""" + deployment = _make_deployment() + + class DummyZenStore: + """In-memory zen store stub that records requested snapshot IDs.""" + + def __init__(self) -> None: + self.requested_snapshot_id: UUID | None = None + self.requested_deployment_id: UUID | None = None + + def get_snapshot(self, snapshot_id: UUID) -> SimpleNamespace: # noqa: D401 + """Return the stored snapshot and remember the requested ID.""" + + self.requested_snapshot_id = snapshot_id + return deployment.snapshot + + def get_deployment(self, deployment_id: UUID) -> SimpleNamespace: # noqa: D401 + """Return the stored deployment and remember the requested ID.""" + + self.requested_deployment_id = deployment_id + return deployment + + dummy_store = DummyZenStore() + + class DummyClient: + """Client stub providing access to the dummy zen store.""" + + def __init__(self) -> None: + self.zen_store = dummy_store + + monkeypatch.setattr("zenml.deployers.server.service.Client", DummyClient) + + service = PipelineDeploymentService(deployment.id) + + assert service.deployment is deployment + assert service.snapshot is deployment.snapshot + assert dummy_store.requested_deployment_id == deployment.id + assert dummy_store.requested_snapshot_id is None + + +def test_initialize_sets_up_orchestrator( + monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture +) -> None: + """initialize should activate integrations and build orchestrator.""" + deployment = _make_deployment() + + class DummyZenStore: + """Zen store stub that supplies the prepared snapshot.""" + + def get_snapshot(self, snapshot_id: UUID) -> SimpleNamespace: # noqa: D401 + return deployment.snapshot + + def get_deployment(self, deployment_id: UUID) -> SimpleNamespace: # noqa: D401 + return deployment + + class DummyClient: + """Client stub exposing only the attributes required by the service.""" + + def __init__(self) -> None: + self.zen_store = DummyZenStore() + + monkeypatch.setattr("zenml.deployers.server.service.Client", DummyClient) + + mock_orchestrator = mocker.MagicMock() + monkeypatch.setattr( + "zenml.deployers.server.service.SharedLocalOrchestrator", + mocker.MagicMock(return_value=mock_orchestrator), + ) + + @contextmanager + def _noop_env(_: object) -> Iterator[None]: + """Provide a no-op temporary environment context manager for tests.""" + + yield + + monkeypatch.setattr( + "zenml.deployers.server.service.env_utils.temporary_environment", + _noop_env, + ) + + service = PipelineDeploymentService(uuid4()) + service.initialize() + + assert service._orchestrator is mock_orchestrator + + +def test_execute_pipeline_calls_subroutines(mocker: MockerFixture) -> None: + """execute_pipeline should orchestrate helper methods and return response.""" + service = _make_service_stub(mocker) + + placeholder_run = mocker.MagicMock() + deployment_snapshot = mocker.MagicMock() + captured_outputs: Dict[str, Dict[str, object]] = { + "step1": {"result": "value"} + } + mapped_outputs = {"result": "value"} + + service._prepare_execute_with_orchestrator = mocker.MagicMock( + return_value=(placeholder_run, deployment_snapshot) + ) + service._execute_with_orchestrator = mocker.MagicMock( + return_value=captured_outputs + ) + service._map_outputs = mocker.MagicMock(return_value=mapped_outputs) + service._build_response = mocker.MagicMock(return_value="response") + + request = BaseDeploymentInvocationRequest( + parameters=WeatherParams(city="Berlin") + ) + result = service.execute_pipeline(request) + + assert result == "response" + service._prepare_execute_with_orchestrator.assert_called_once_with( + resolved_params={"city": "Berlin", "temperature": 20} + ) + service._execute_with_orchestrator.assert_called_once_with( + placeholder_run=placeholder_run, + deployment_snapshot=deployment_snapshot, + resolved_params={"city": "Berlin", "temperature": 20}, + skip_artifact_materialization=False, + ) + service._map_outputs.assert_called_once_with(captured_outputs) + service._build_response.assert_called_once() + + +def test_map_outputs_returns_filtered_mapping(mocker: MockerFixture) -> None: + """_map_outputs should align runtime outputs to pipeline spec.""" + service = _make_service_stub(mocker) + service.snapshot.pipeline_spec.outputs = [ + SimpleNamespace(step_name="trainer", output_name="model"), + SimpleNamespace(step_name="trainer", output_name="metrics"), + SimpleNamespace(step_name="evaluator", output_name="report"), + ] + + runtime_outputs = { + "trainer": {"model": "model-artifact", "metrics": {"f1": 0.9}}, + "evaluator": {"report": "report-artifact"}, + } + + mapped = service._map_outputs(runtime_outputs) + assert mapped == { + "model": "model-artifact", + "metrics": {"f1": 0.9}, + "report": "report-artifact", + } + + +def test_map_outputs_handles_missing_data(mocker: MockerFixture) -> None: + """_map_outputs should return empty dict when no runtime outputs.""" + service = _make_service_stub(mocker) + + assert service._map_outputs(None) == {} + + +def test_build_response_success( + monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture +) -> None: + """_build_response should return a successful response payload.""" + service = _make_service_stub(mocker) + + placeholder_run = mocker.MagicMock() + placeholder_run.id = uuid4() + placeholder_run.name = "placeholder" + + resolved_params = {"city": "Berlin", "temperature": 20} + + concrete_run = mocker.MagicMock() + concrete_run.id = uuid4() + concrete_run.name = "run" + service._client.get_pipeline_run.return_value = concrete_run + + monkeypatch.setattr( + "zenml.deployers.server.service.time.time", lambda: 110.0 + ) + + response = service._build_response( + resolved_params=resolved_params, + start_time=100.0, + mapped_outputs={"result": "value"}, + placeholder_run=placeholder_run, + ) + + assert response.success is True + assert response.outputs == {"result": "value"} + assert response.metadata.pipeline_name == "test_pipeline" + assert response.metadata.run_id == concrete_run.id + assert service.total_executions == 1 + assert service.last_execution_time is not None + + +def test_build_response_error( + monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture +) -> None: + """_build_response should capture errors and omit outputs.""" + service = _make_service_stub(mocker) + + monkeypatch.setattr( + "zenml.deployers.server.service.time.time", lambda: 105.0 + ) + + response = service._build_response( + resolved_params={"city": "Berlin"}, + start_time=100.0, + mapped_outputs=None, + placeholder_run=None, + error=RuntimeError("failure"), + ) + + assert response.success is False + assert response.outputs is None + assert response.error == "failure" + + +def test_get_service_info_aggregates_snapshot( + mocker: MockerFixture, +) -> None: + """get_service_info should expose pipeline metadata and schemas.""" + service = _make_service_stub(mocker) + + info = service.get_service_info() + + assert info.pipeline.name == "test_pipeline" + assert info.snapshot.id == service.snapshot.id + assert info.pipeline.parameters == {"city": "London"} + assert info.pipeline.input_schema == {"type": "object"} + + +def test_execution_metrics_reflect_counters(mocker: MockerFixture) -> None: + """get_execution_metrics should return counters from service state.""" + service = _make_service_stub(mocker) + service.total_executions = 5 + service.last_execution_time = None + + metrics = service.get_execution_metrics() + assert metrics.total_executions == 5 + assert metrics.last_execution_time is None + + +def test_input_output_schema_properties(mocker: MockerFixture) -> None: + """input_schema and output_schema expose snapshot schemas.""" + service = _make_service_stub(mocker) + + assert service.input_schema == {"type": "object"} + assert service.output_schema == {"type": "object"} diff --git a/tests/unit/deployers/server/test_service_outputs.py b/tests/unit/deployers/server/test_service_outputs.py new file mode 100644 index 00000000000..3cc14811021 --- /dev/null +++ b/tests/unit/deployers/server/test_service_outputs.py @@ -0,0 +1,152 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Unit tests for PipelineDeploymentService output mapping with in-memory mode.""" + +from types import SimpleNamespace +from typing import Generator +from uuid import uuid4 + +import pytest +from pydantic import BaseModel +from pytest_mock import MockerFixture + +from zenml.deployers.server import runtime +from zenml.deployers.server.models import BaseDeploymentInvocationRequest +from zenml.deployers.server.service import PipelineDeploymentService + + +class _DummyParams(BaseModel): + """Minimal params model to bypass real pipeline loading.""" + + city: str = "Berlin" + + +class _DummySnapshot: + def __init__(self) -> None: + self.id = uuid4() + self.name = "snapshot" + self.pipeline_configuration = SimpleNamespace( + name="test-pipeline", + environment={}, + init_hook_source=None, + init_hook_kwargs=None, + cleanup_hook_source=None, + ) + self.pipeline_spec = SimpleNamespace( + parameters={}, + input_schema=None, + output_schema=None, + outputs=[SimpleNamespace(step_name="step1", output_name="result")], + ) + self.step_configurations = {} + self.stack = SimpleNamespace(name="test-stack") + + +class _DummyDeployment: + def __init__(self) -> None: + self.id = uuid4() + self.name = "test-deployment" + self.snapshot = _DummySnapshot() + + +class _DummyRun: + def __init__(self) -> None: + self.id = uuid4() + self.name = "test-run" + + +@pytest.fixture(autouse=True) +def clean_runtime_state() -> Generator[None, None, None]: + """Ensure runtime state is reset before and after each test.""" + + runtime.stop() + yield + runtime.stop() + + +def _make_service( + monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture +) -> PipelineDeploymentService: + """Construct a deployment service instance backed by dummy artifacts.""" + + deployment = _DummyDeployment() + + class DummyZenStore: + """Return the snapshot associated with the provided ID.""" + + def get_deployment(self, deployment_id: object) -> _DummyDeployment: # noqa: D401 + return deployment + + def get_snapshot(self, snapshot_id: object) -> _DummySnapshot: # noqa: D401 + return deployment.snapshot + + def create_snapshot(self, request: object) -> _DummySnapshot: # noqa: D401 + """Return the snapshot that would be created in the real store.""" + + return deployment.snapshot + + class DummyClient: + """Client stub exposing zen_store and active stack attributes.""" + + def __init__(self) -> None: + self.zen_store = DummyZenStore() + self.active_stack = mocker.MagicMock() + + def get_pipeline_run( + self, *args: object, **kwargs: object + ) -> _DummyRun: # noqa: D401 + """Return a dummy pipeline run.""" + + return _DummyRun() + + monkeypatch.setattr("zenml.deployers.server.service.Client", DummyClient) + + service = PipelineDeploymentService(uuid4()) + service.params_model = _DummyParams + return service + + +def test_service_captures_in_memory_outputs( + monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture +) -> None: + """Service should capture in-memory outputs before stopping runtime.""" + + service = _make_service(monkeypatch, mocker) + + placeholder_run = _DummyRun() + monkeypatch.setattr( + "zenml.deployers.server.service.run_utils.create_placeholder_run", + mocker.MagicMock(return_value=placeholder_run), + ) + monkeypatch.setattr( + "zenml.orchestrators.utils.deployment_snapshot_request_from_source_snapshot", + lambda source_snapshot, deployment_parameters: SimpleNamespace(), + ) + + class _DummyOrchestrator: + def run(self, snapshot, stack, placeholder_run): # noqa: D401 + runtime.record_step_outputs("step1", {"result": "fast_value"}) + + service._orchestrator = _DummyOrchestrator() + + request = BaseDeploymentInvocationRequest( + parameters=_DummyParams(), + skip_artifact_materialization=True, + ) + + response = service.execute_pipeline(request) + + assert response.success is True + assert response.outputs == {"result": "fast_value"} + assert service.total_executions == 1 diff --git a/tests/unit/materializers/test_in_memory_materializer.py b/tests/unit/materializers/test_in_memory_materializer.py new file mode 100644 index 00000000000..a600a21fa2d --- /dev/null +++ b/tests/unit/materializers/test_in_memory_materializer.py @@ -0,0 +1,74 @@ +"""Unit tests for the simple in-memory materializer.""" + +from typing import Any, Dict + +from zenml.materializers.in_memory_materializer import InMemoryMaterializer + + +def test_in_memory_materializer_uses_runtime(monkeypatch) -> None: + """Verify that the materializer stores and loads data via the runtime.""" + + stored: Dict[str, Any] = {} + + # Patch the deployment runtime helpers used by the materializer. + from zenml.deployers.server import runtime + + monkeypatch.setattr(runtime, "is_active", lambda: True) + monkeypatch.setattr( + runtime, "should_skip_artifact_materialization", lambda: True + ) + monkeypatch.setattr(runtime, "put_in_memory_data", stored.__setitem__) + monkeypatch.setattr(runtime, "get_in_memory_data", stored.get) + + # Simple approach - no wrapping needed + materializer = InMemoryMaterializer( + uri="s3://bucket/artifact", artifact_store=None + ) + + payload = {"foo": "bar"} + materializer.save(payload) + + # Data should be stored with original URI as key + assert stored["s3://bucket/artifact"] == payload + + loaded = materializer.load(dict) + assert loaded == payload + + +def test_in_memory_materializer_metadata_methods() -> None: + """Test that metadata methods return empty results in deployment mode.""" + + materializer = InMemoryMaterializer( + uri="s3://bucket/artifact", artifact_store=None + ) + + # All metadata methods should return empty/None in deployment mode + assert materializer.extract_full_metadata({}) == {} + assert materializer.compute_content_hash({}) is None + assert materializer.save_visualizations({}) == {} + + +def test_in_memory_materializer_missing_data() -> None: + """Test that loading missing data raises appropriate error.""" + + from zenml.deployers.server import runtime + + materializer = InMemoryMaterializer( + uri="s3://missing/artifact", artifact_store=None + ) + + # Mock runtime to return None for missing data + def mock_get_data(uri): + return None + + import unittest.mock + + with unittest.mock.patch.object( + runtime, "get_in_memory_data", mock_get_data + ): + try: + materializer.load(dict) + assert False, "Should have raised RuntimeError" + except RuntimeError as e: + assert "No data available" in str(e) + assert "s3://missing/artifact" in str(e) diff --git a/tests/unit/orchestrators/test_step_runner_outputs.py b/tests/unit/orchestrators/test_step_runner_outputs.py new file mode 100644 index 00000000000..ad0f815b373 --- /dev/null +++ b/tests/unit/orchestrators/test_step_runner_outputs.py @@ -0,0 +1,116 @@ +"""Unit tests for StepRunner output artifact persistence behavior.""" + +from __future__ import annotations + +from typing import Any, Dict, List, Tuple, Type +from unittest.mock import MagicMock + +import pytest + +from zenml.enums import ArtifactType +from zenml.materializers.base_materializer import BaseMaterializer +from zenml.orchestrators.step_runner import StepRunner + + +class _DummyMaterializer(BaseMaterializer): + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + ASSOCIATED_TYPES = (int,) + + def __init__(self, uri: str, artifact_store: Any) -> None: + self.uri = uri + + def validate_save_type_compatibility(self, data_type: type) -> None: # noqa: D401 + return None + + +def test_store_multiple_output_artifacts( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Ensure multiple outputs are persisted and mapped correctly. + + Args: + monkeypatch: The monkeypatch object. + """ + # Prepare a StepRunner with minimal dependencies + dummy_step = MagicMock() + dummy_step.config.outputs = {"out1": MagicMock(), "out2": MagicMock()} + runner = StepRunner(step=dummy_step, stack=MagicMock()) + + # Patch get_step_context to provide required properties/methods + class _Ctx: + class _PR: # pipeline_run + class _P: + name = "pipe" + + pipeline = _P() + + class _Cfg: + tags = [] + + config = _Cfg() + + class _SR: # step_run + name = "step" + + pipeline_run = _PR() + step_run = _SR() + + def get_output_metadata(self, name: str): # noqa: D401 + return {} + + def get_output_tags(self, name: str): # noqa: D401 + return [] + + monkeypatch.setattr( + "zenml.orchestrators.step_runner.get_step_context", lambda: _Ctx() + ) + + # Prepare inputs to _store_output_artifacts + output_data = {"out1": 1, "out2": 2} + output_materializers: Dict[str, Tuple[Type[_DummyMaterializer], ...]] = { + "out1": (_DummyMaterializer,), + "out2": (_DummyMaterializer,), + } + output_uris = {"out1": "memory://uri1", "out2": "memory://uri2"} + output_annotations = { + "out1": MagicMock(artifact_config=None), + "out2": MagicMock(artifact_config=None), + } + + # Patch artifact pre-store util to avoid I/O and return request objects + requests_created: List[Any] = [] + + def _fake_store(**kwargs: Any): # noqa: D401 + requests_created.append(kwargs) + return MagicMock() + + monkeypatch.setattr( + "zenml.orchestrators.step_runner._store_artifact_data_and_prepare_request", + lambda **kwargs: _fake_store(**kwargs), + ) + + # Patch batch_create_artifact_versions to return two distinct responses + resp1 = MagicMock(id="a1") + resp2 = MagicMock(id="a2") + monkeypatch.setattr( + "zenml.orchestrators.step_runner.Client", + lambda: MagicMock( + zen_store=MagicMock( + batch_create_artifact_versions=lambda reqs: [resp1, resp2] + ) + ), + ) + + result = runner._store_output_artifacts( + output_data=output_data, + output_materializers=output_materializers, + output_artifact_uris=output_uris, + output_annotations=output_annotations, + artifact_metadata_enabled=False, + artifact_visualization_enabled=False, + ) + + # Ensure both outputs are present and mapped correctly + assert set(result.keys()) == {"out1", "out2"} + assert result["out1"].id == "a1" + assert result["out2"].id == "a2" diff --git a/tests/unit/steps/test_base_step.py b/tests/unit/steps/test_base_step.py index a39a0e11120..54081071c33 100644 --- a/tests/unit/steps/test_base_step.py +++ b/tests/unit/steps/test_base_step.py @@ -21,7 +21,7 @@ from zenml import pipeline, step from zenml.enums import ExecutionMode -from zenml.exceptions import StepInterfaceError +from zenml.exceptions import HookValidationException, StepInterfaceError from zenml.materializers import BuiltInMaterializer from zenml.materializers.base_materializer import BaseMaterializer from zenml.models import ArtifactVersionResponse @@ -691,7 +691,7 @@ def test_configure_step_with_failure_hook(one_step_pipeline): # Test 2 is_hook_called = False - with pytest.raises(ValueError): + with pytest.raises(HookValidationException): one_step_pipeline( exception_step.with_options( on_failure=on_failure_with_wrong_params @@ -701,13 +701,13 @@ def test_configure_step_with_failure_hook(one_step_pipeline): # Test 3 is_hook_called = False - with pytest.raises(ValueError): + with pytest.raises(BaseException): one_step_pipeline( exception_step.with_options( on_failure=on_failure_with_not_annotated_params ) ).with_options(unlisted=True)() - assert not is_hook_called + assert is_hook_called # Test 4 is_hook_called = False @@ -750,7 +750,7 @@ def test_configure_step_with_success_hook(one_step_pipeline): # Test 1 is_hook_called = False - with pytest.raises(ValueError): + with pytest.raises(HookValidationException): one_step_pipeline( passing_step.with_options(on_success=on_success_with_wrong_params) ).with_options(unlisted=True)() @@ -758,7 +758,7 @@ def test_configure_step_with_success_hook(one_step_pipeline): # Test 2 is_hook_called = False - with pytest.raises(ValueError): + with pytest.raises(HookValidationException): one_step_pipeline( passing_step.with_options( on_success=on_success_with_not_annotated_params