Skip to content

Commit

Permalink
generic Client() getters lazy loading (#2323)
Browse files Browse the repository at this point in the history
* generic `Client()` getters lazy loading

* Auto-update of Starter template

* coderabbitai

* Auto-update of E2E template

* lint, after coderabbitai

* Auto-update of E2E template

* Auto-update of NLP template

* update test signatures

* add `get_model` and `get_model_version` to lazy loaders

* update test signature

* add `evaluate_all_lazy_load_args`

* lint up again

* make IDE great again

* Auto-update of Starter template

* lint again with new ruff

* DOCS!

* update toc

* update link

* Apply suggestions from code review

Co-authored-by: Alex Strick van Linschoten <strickvl@users.noreply.github.com>

* add MCP link

* fix misuse of static methods

* fix wrapping/evaluation

* fix misuse of static methods

* gentle handle static methods

* check for ClientLazyLoader instances

---------

Co-authored-by: GitHub Actions <actions@github.com>
Co-authored-by: Alex Strick van Linschoten <strickvl@users.noreply.github.com>
  • Loading branch information
3 people committed Jan 31, 2024
1 parent c203cd5 commit 38c0246
Show file tree
Hide file tree
Showing 16 changed files with 545 additions and 18 deletions.
1 change: 1 addition & 0 deletions docs/book/toc.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
* [Visualize artifacts](user-guide/advanced-guide/data-management/visualize-artifacts.md)
* [Organize data with tags](user-guide/advanced-guide/data-management/tagging.md)
* [Model management](user-guide/advanced-guide/data-management/model-management.md)
* [Late materialization in pipelines](user-guide/advanced-guide/data-management/late-materialization.md)
* [Managing infrastructure](user-guide/advanced-guide/infrastructure-management/infrastructure-management.md)
* [Understanding environments](user-guide/advanced-guide/infrastructure-management/understanding-environments.md)
* [Containerize your pipeline](user-guide/advanced-guide/infrastructure-management/containerize-your-pipeline.md)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
---
description: Always use up-to-date data in ZenML pipelines.
---

# Late materialization in pipelines

Often ZenML pipeline steps consume artifacts produced by one another directly in the pipeline code, but there are scenarios where you need to pull external data into your steps. Such external data could be artifacts produced by non-ZenML codes. For those cases it is advised to use [ExternalArtifact](../../../user-guide/starter-guide/manage-artifacts.md#consuming-external-artifacts-within-a-pipeline), but what if we plan to exchange data created with other ZenML pipelines?

ZenML pipelines are first compiled and only executed at some later point. During the compilation phase all function calls are executed and this data is fixed as step input parameters. Given all this, the late materialization of dynamic objects, like data artifacts, is crucial. Without late materialization it would not be possible to pass not-yet-existing artifacts as step inputs, or their metadata, which is often the case in a multi-pipeline setting.

We identify two major use cases for exchanging artifacts between pipelines:
* You semantically group your data products using [ZenML Models](model-management.md#linking-artifacts-to-models)
* You prefer to use [ZenML Client](../../../reference/python-client.md#client-methods) to bring all the pieces together

In the sections below we will dive deeper into these use cases from the pipeline perspective.

## Use ZenML Models to exchange artifacts

The ZenML Model is an entity introduced together with [the Model Control Plane](model-management.md). The Model Control Plane is how you manage your models through a unified interface. It allows you to combine the logic of your pipelines, artifacts and crucial business data along with the actual 'technical model'.

Documentation for [ZenML Models](model-management.md#linking-artifacts-to-models) describes in great detail how you can link various artifacts produced within pipelines to the model. Here we will focus more on the part that relates to consumption.

First, let's have a look at a two-pipeline project, where the first pipeline is running training logic and the second runs batch inference leveraging trained model artifact(s):

```python
from typing_extensions import Annotated
from zenml import get_pipeline_context, pipeline, Model
from zenml.enums import ModelStages
import pandas as pd
from sklearn.base import ClassifierMixin


@step
def predict(
model: ClassifierMixin,
data: pd.DataFrame,
) -> Annotated[pd.Series, "predictions"]:
predictions = pd.Series(model.predict(data))
return predictions

@pipeline(
model=Model(
name="iris_classifier",
# Using the production stage
version=ModelStages.PRODUCTION,
),
)
def do_predictions():
# model name and version are derived from pipeline context
model = get_pipeline_context().model
inference_data = load_data()
predict(
# Here, we load in the `trained_model` from a trainer step
model=model.get_model_artifact("trained_model"),
data=inference_data,
)


if __name__ == "__main__":
do_predictions()
```

In the example above we used `get_pipeline_context().model` property to acquire the model context in which the pipeline is running. During pipeline compilation this context will not yet have been evaluated, because `Production` model version is not a stable version name and another model version can become `Production` before it comes to the actual step execution. The same applies to calls like `model.get_model_artifact("trained_model")`; it will get stored in the step configuration for delayed materialization which will only happen during the step run itself.

It is also possible to achieve the same using bare `Client` methods reworking the pipeline code as follows:

```python
from zenml.client import Client

@pipeline
def do_predictions():
# model name and version are directly passed into client method
model = Client().get_model_version("iris_classifier", ModelStages.PRODUCTION)
inference_data = load_data()
predict(
# Here, we load in the `trained_model` from a trainer step
model=model.get_model_artifact("trained_model"),
data=inference_data,
)
```

In this case the evaluation of the actual artifact will happen only when the step is actually running.

## Use client methods to exchange artifacts

If you don't yet use the Model Control Plane you can still exchange data between pipelines with late materialization. Let's rework the `do_predictions` pipeline code once again as follows:

```python
from typing_extensions import Annotated
from zenml import get_pipeline_context, pipeline, Model
from zenml.client import Client
from zenml.enums import ModelStages
import pandas as pd
from sklearn.base import ClassifierMixin


@step
def predict(
model1: ClassifierMixin,
model2: ClassifierMixin,
model1_metric: float,
model2_metric: float,
data: pd.DataFrame,
) -> Annotated[pd.Series, "predictions"]:
# compare which models performs better on the fly
if model1_metric < model2_metric:
predictions = pd.Series(model1.predict(data))
else:
predictions = pd.Series(model2.predict(data))
return predictions

@pipeline
def do_predictions():
# get specific artifact version
model_42 = Client().get_artifact_version("trained_model", version="42")
metric_42 = model_42.run_metadata["MSE"].value

# get latest artifact version
model_latest = Client().get_artifact_version("trained_model")
metric_latest = model_latest.run_metadata["MSE"].value

inference_data = load_data()
predict(
model1=model_42,
model2=model_latest,
model1_metric=metric_42,
model2_metric=metric_latest,
data=inference_data,
)

if __name__ == "__main__":
do_predictions()
```

Here we also enriched the `predict` step logic with a metric comparison by MSE metric, so predictions are done on the best possible model. As before, calls like `Client().get_artifact_version("trained_model", version="42")` or `model_latest.run_metadata["MSE"].value` are not evaluating the actual objects behind them at pipeline compilation time. Rather, they do so only at the point of step execution. By doing so we ensure that latest version is actually the latest at the moment and not just the latest at the point of pipeline compilation.

<!-- For scarf -->
<figure><img alt="ZenML Scarf" referrerpolicy="no-referrer-when-downgrade" src="https://static.scarf.sh/a.png?x-pxid=f0b4f458-0a54-4fcd-aa95-d5ee424815bc" /></figure>
17 changes: 9 additions & 8 deletions docs/book/user-guide/starter-guide/structuring-a-project.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@ The lines between these pipelines can often get blurry: Some use cases call for

No matter how you design these pipelines, one thing stays consistent: you will often need to transfer or share information (in particular artifacts, models, and metadata) between pipelines. Here are some common patterns that you can use to help facilitate such an exchange:

## Pattern 1: Artifact exchange between pipelines through `ExternalArtifact`
## Pattern 1: Artifact exchange between pipelines through `Client`

Let's say we have a feature engineering pipeline and a training pipeline. The feature engineering pipeline is like a factory, pumping out many different datasets. Only a few of these datasets should be selected to be sent to the training pipeline to train an actual model.

<figure><img src="../../.gitbook/assets/artifact_exchange.png" alt=""><figcaption><p>A simple artifact exchange between two pipelines</p></figcaption></figure>

In this scenario, the [ExternalArtifact](manage-artifacts.md#consuming-artifacts-produced-by-other-pipelines) can be used to facilitate such an exchange:
In this scenario, the [ZenML Client](../../reference/python-client.md#client-methods) can be used to facilitate such an exchange:

```python
from zenml import pipeline, ExternalArtifact
from zenml import pipeline
from zenml.client import Client

@pipeline
def feature_engineering_pipeline():
Expand All @@ -40,10 +41,11 @@ def feature_engineering_pipeline():

@pipeline
def training_pipeline():
client = Client()
# Fetch by name alone - uses the latest version of this artifact
train_data = ExternalArtifact(name="iris_training_dataset")
train_data = client.get_artifact_version(name="iris_training_dataset")
# For test, we want a particular version
test_data = ExternalArtifact(name="iris_testing_dataset", version="raw_2023")
test_data = client.get_artifact_version(name="iris_testing_dataset", version="raw_2023")

# We can now send these directly into ZenML steps
sklearn_classifier = model_trainer(train_data)
Expand Down Expand Up @@ -91,9 +93,8 @@ However, this approach has the downside that if the step is cached, then it coul

```python
from typing_extensions import Annotated
from zenml import get_pipeline_context, pipeline, ExternalArtifact
from zenml import get_pipeline_context, pipeline, Model
from zenml.enums import ModelStages
from zenml.model import Model
import pandas as pd
from sklearn.base import ClassifierMixin

Expand All @@ -107,7 +108,7 @@ def predict(
return predictions

@pipeline(
model_config=Model(
model=Model(
name="iris_classifier",
# Using the production stage
version=ModelStages.PRODUCTION,
Expand Down
32 changes: 28 additions & 4 deletions src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@

from pydantic import SecretStr

from zenml.client_lazy_loader import (
client_lazy_loader,
evaluate_all_lazy_load_args_in_client_methods,
)
from zenml.config.global_config import GlobalConfiguration
from zenml.config.source import Source
from zenml.constants import (
Expand Down Expand Up @@ -281,6 +285,7 @@ def __call__(cls, *args: Any, **kwargs: Any) -> "Client":
return cls._global_client


@evaluate_all_lazy_load_args_in_client_methods
class Client(metaclass=ClientMetaClass):
"""ZenML client class.
Expand Down Expand Up @@ -2840,6 +2845,13 @@ def get_artifact_version(
Returns:
The artifact version.
"""
if cll := client_lazy_loader(
method_name="get_artifact_version",
name_id_or_prefix=name_id_or_prefix,
version=version,
hydrate=hydrate,
):
return cll # type: ignore[return-value]
return self._get_entity_version_by_id_or_name_or_prefix(
get_method=self.zen_store.get_artifact_version,
list_method=self.list_artifact_versions,
Expand Down Expand Up @@ -4782,6 +4794,10 @@ def get_model(
Returns:
The model of interest.
"""
if cll := client_lazy_loader(
"get_model", model_name_or_id=model_name_or_id, hydrate=hydrate
):
return cll # type: ignore[return-value]
return self.zen_store.get_model(
model_name_or_id=model_name_or_id,
hydrate=hydrate,
Expand Down Expand Up @@ -4906,6 +4922,14 @@ def get_model_version(
RuntimeError: In case method inputs don't adhere to restrictions.
KeyError: In case no model version with the identifiers exists.
"""
if cll := client_lazy_loader(
"get_model_version",
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
hydrate=hydrate,
):
return cll # type: ignore[return-value]

if model_version_name_or_number_or_id is None:
model_version_name_or_number_or_id = ModelStages.LATEST

Expand Down Expand Up @@ -5397,8 +5421,8 @@ def delete_authorized_device(

# ---- utility prefix matching get functions -----

@staticmethod
def _get_entity_by_id_or_name_or_prefix(
self,
get_method: Callable[..., AnyResponse],
list_method: Callable[..., Page[AnyResponse]],
name_id_or_prefix: Union[str, UUID],
Expand Down Expand Up @@ -5444,7 +5468,7 @@ def _get_entity_by_id_or_name_or_prefix(

# If still no match, try with prefix now
if entity.total == 0:
return Client._get_entity_by_prefix(
return self._get_entity_by_prefix(
get_method=get_method,
list_method=list_method,
partial_id_or_name=name_id_or_prefix,
Expand All @@ -5468,8 +5492,8 @@ def _get_entity_by_id_or_name_or_prefix(
f"only one of the {entity_label}s."
)

@staticmethod
def _get_entity_version_by_id_or_name_or_prefix(
self,
get_method: Callable[..., AnyResponse],
list_method: Callable[..., Page[AnyResponse]],
name_id_or_prefix: Union[str, UUID],
Expand Down Expand Up @@ -5533,8 +5557,8 @@ def _get_entity_version_by_id_or_name_or_prefix(
f"only one of the {entity_label}s."
)

@staticmethod
def _get_entity_by_prefix(
self,
get_method: Callable[..., AnyResponse],
list_method: Callable[..., Page[AnyResponse]],
partial_id_or_name: str,
Expand Down
Loading

0 comments on commit 38c0246

Please sign in to comment.