Skip to content

Commit

Permalink
Create tags table (#2036)
Browse files Browse the repository at this point in the history
  • Loading branch information
avishniakov committed Nov 14, 2023
1 parent dcec2d0 commit d847b31
Show file tree
Hide file tree
Showing 31 changed files with 2,184 additions and 146 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/setup-python-environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ jobs:
if: ${{ inputs.os == 'ubuntu-latest' && inputs.python-version == '3.8' }}

- name: Check for alembic branch divergence
env:
ZENML_DEBUG: 0
run: |
bash scripts/check-alembic-branches.sh
Expand Down
2 changes: 1 addition & 1 deletion examples/quickstart/steps/training_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def training_data_loader() -> (
):
"""Load the Census Income dataset as tuple of Pandas DataFrame / Series."""
# Load the dataset
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
url = "http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
column_names = [
"age",
"workclass",
Expand Down
1 change: 1 addition & 0 deletions scripts/check-alembic-branches.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ output=$(alembic branches)

# Check if there's any output
if [[ -n "$output" ]]; then
echo $output
echo "Warning: Diverging Alembic branches detected."
exit 1 # Exit with failure status
else
Expand Down
1 change: 1 addition & 0 deletions src/zenml/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1578,3 +1578,4 @@ def my_pipeline(...):
from zenml.cli.stack_recipes import * # noqa
from zenml.cli.user_management import * # noqa
from zenml.cli.workspace import * # noqa
from zenml.cli.tag import * # noqa
184 changes: 100 additions & 84 deletions src/zenml/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# permissions and limitations under the License.
"""CLI functionality to interact with Model Control Plane."""
# from functools import partial
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional

import click

Expand All @@ -23,22 +23,58 @@
from zenml.cli.cli import TagGroup, cli
from zenml.client import Client
from zenml.enums import CliCategories, ModelStages
from zenml.exceptions import EntityExistsError
from zenml.logger import get_logger
from zenml.models.model_models import (
ModelFilterModel,
ModelRequestModel,
ModelResponseModel,
ModelUpdateModel,
ModelVersionArtifactFilterModel,
ModelVersionFilterModel,
ModelVersionPipelineRunFilterModel,
ModelVersionResponseModel,
ModelVersionUpdateModel,
)

# from zenml.utils.pagination_utils import depaginate
from zenml.utils.dict_utils import remove_none_values

logger = get_logger(__name__)


def _model_to_print(model: ModelResponseModel) -> Dict[str, Any]:
return {
"id": model.id,
"name": model.name,
"latest_version": model.latest_version,
"description": model.description,
"tags": [t.name for t in model.tags],
"use_cases": model.use_cases,
"audience": model.audience,
"limitations": model.limitations,
"trade_offs": model.trade_offs,
"ethics": model.ethics,
"license": model.license,
"updated": model.updated.date(),
}


def _model_version_to_print(
model_version: ModelVersionResponseModel,
) -> Dict[str, Any]:
return {
"id": model_version.id,
"name": model_version.name,
"number": model_version.number,
"description": model_version.description,
"stage": model_version.stage,
"artifact_objects_count": len(model_version.artifact_object_ids),
"model_objects_count": len(model_version.model_object_ids),
"deployments_count": len(model_version.deployment_ids),
"pipeline_runs_count": len(model_version.pipeline_run_ids),
"updated": model_version.updated.date(),
}


@cli.group(cls=TagGroup, tag=CliCategories.MODEL_CONTROL_PLANE)
def model() -> None:
"""Interact with models and model versions in the Model Control Plane."""
Expand All @@ -57,11 +93,10 @@ def list_models(**kwargs: Any) -> None:
if not models:
cli_utils.declare("No models found.")
return

cli_utils.print_pydantic_models(
models,
exclude_columns=["user", "workspace"],
)
to_print = []
for model in models:
to_print.append(_model_to_print(model))
cli_utils.print_table(to_print)


@model.command("register", help="Register a new model.")
Expand Down Expand Up @@ -151,28 +186,26 @@ def register_model(
limitations: The know limitations of the model.
tag: Tags associated with the model.
"""
model = Client().create_model(
ModelRequestModel(
name=name,
license=license,
description=description,
audience=audience,
use_cases=use_cases,
trade_offs=tradeoffs,
ethics=ethical,
limitations=limitations,
tags=tag,
user=Client().active_user.id,
workspace=Client().active_workspace.id,
try:
model = Client().create_model(
ModelRequestModel(
name=name,
license=license,
description=description,
audience=audience,
use_cases=use_cases,
trade_offs=tradeoffs,
ethics=ethical,
limitations=limitations,
tags=tag,
user=Client().active_user.id,
workspace=Client().active_workspace.id,
)
)
)
except (EntityExistsError, ValueError) as e:
cli_utils.error(str(e))

cli_utils.print_pydantic_models(
[
model,
],
exclude_columns=["user", "workspace"],
)
cli_utils.print_table([_model_to_print(model)])


@model.command("update", help="Update an existing model.")
Expand Down Expand Up @@ -227,7 +260,15 @@ def register_model(
@click.option(
"--tag",
"-t",
help="Tags associated with the model.",
help="Tags to be added to the model.",
type=str,
required=False,
multiple=True,
)
@click.option(
"--remove-tag",
"-r",
help="Tags to be removed from the model.",
type=str,
required=False,
multiple=True,
Expand All @@ -242,6 +283,7 @@ def update_model(
ethical: Optional[str],
limitations: Optional[str],
tag: Optional[List[str]],
remove_tag: Optional[List[str]],
) -> None:
"""Register a new model in the Model Control Plane.
Expand All @@ -254,31 +296,31 @@ def update_model(
tradeoffs: The tradeoffs of the model.
ethical: The ethical implications of the model.
limitations: The know limitations of the model.
tag: Tags associated with the model.
tag: Tags to be added to the model.
remove_tag: Tags to be removed from the model.
"""
model_id = Client().get_model(model_name_or_id=model_name_or_id).id
model = Client().update_model(
model_id=model_id,
model_update=ModelUpdateModel(
update_dict = remove_none_values(
dict(
license=license,
description=description,
audience=audience,
use_cases=use_cases,
trade_offs=tradeoffs,
ethics=ethical,
limitations=limitations,
tags=tag,
add_tags=tag,
remove_tags=remove_tag,
user=Client().active_user.id,
workspace=Client().active_workspace.id,
),
)
)

cli_utils.print_pydantic_models(
[
model,
],
exclude_columns=["user", "workspace"],
model = Client().update_model(
model_id=model_id,
model_update=ModelUpdateModel(**update_dict),
)
cli_utils.print_table([_model_to_print(model)])


@model.command("delete", help="Delete an existing model.")
Expand Down Expand Up @@ -342,29 +384,11 @@ def list_model_versions(model_name_or_id: str, **kwargs: Any) -> None:
cli_utils.declare("No model versions found.")
return

to_print = []
for model_version in model_versions:
model_version.artifact_objects_count = len( # type: ignore[attr-defined]
model_version.artifact_object_ids
)
model_version.model_objects_count = len(model_version.model_object_ids) # type: ignore[attr-defined]
model_version.deployments_count = len(model_version.deployment_ids) # type: ignore[attr-defined]
model_version.pipeline_runs_count = len(model_version.pipeline_run_ids) # type: ignore[attr-defined]
to_print.append(_model_version_to_print(model_version))

cli_utils.print_pydantic_models(
model_versions,
columns=[
"id",
"name",
"number",
"description",
"stage",
"artifact_objects_count",
"model_objects_count",
"deployments_count",
"pipeline_runs_count",
"updated",
],
)
cli_utils.print_table(to_print)


@version.command("update", help="Update an existing model version stage.")
Expand Down Expand Up @@ -409,31 +433,23 @@ def update_model_version(
)
except RuntimeError:
if not force:
cli_utils.print_pydantic_models(
Client().list_model_versions(
model_name_or_id=model_version.model.id,
model_version_filter_model=ModelVersionFilterModel(
stage=stage
),
),
columns=[
"id",
"name",
"number",
"description",
"stage",
"artifact_objects_count",
"model_objects_count",
"deployments_count",
"pipeline_runs_count",
"updated",
],
cli_utils.print_table(
[
_model_version_to_print(
Client().get_model_version(
model_name_or_id=model_version.model.id,
model_version_name_or_number_or_id=stage,
)
)
]
)

confirmation = cli_utils.confirmation(
"Are you sure you want to change the status of this model "
f"version to '{stage}'? This stage is already taken by "
"another model version and if you will proceed the current "
"model version in this stage will be archived."
"Are you sure you want to change the status of model "
f"version '{model_version_name_or_number_or_id}' to "
f"'{stage}'?\nThis stage is already taken by "
"model version shown above and if you will proceed this "
"model version will get into archived stage."
)
if not confirmation:
cli_utils.declare("Model version stage update canceled.")
Expand Down
Loading

0 comments on commit d847b31

Please sign in to comment.