Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create tags table #2036

Merged
merged 37 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
c9729d9
working option 2
avishniakov Nov 9, 2023
caecabc
working option 2
avishniakov Nov 9, 2023
27624c6
working option 2
avishniakov Nov 9, 2023
15df83c
fix mysql
avishniakov Nov 9, 2023
a813455
Merge branch 'develop' into feature/OSS-2610-create-tags-table
avishniakov Nov 9, 2023
331e1a1
fix bug + add tests
avishniakov Nov 9, 2023
7ab1a22
rename
avishniakov Nov 9, 2023
200776e
Merge branch 'develop' into feature/OSS-2610-create-tags-table
avishniakov Nov 10, 2023
a265b64
add alembic branches output on divergence
avishniakov Nov 10, 2023
4ff8b26
add client functions
avishniakov Nov 10, 2023
4438a5f
strenums for types/colors
avishniakov Nov 10, 2023
9f7daf6
add tags cli
avishniakov Nov 10, 2023
e9b539a
add tags cli
avishniakov Nov 10, 2023
da14b2b
try bypass alembic branching
avishniakov Nov 10, 2023
fb26bd4
remove tag<>resource endpoints
avishniakov Nov 10, 2023
df82b6e
rely on sql for tag links
avishniakov Nov 10, 2023
1dd232b
Merge branch 'develop' into feature/OSS-2610-create-tags-table
avishniakov Nov 10, 2023
17bced5
fix migration bug with uuids
avishniakov Nov 10, 2023
3bb6f82
remove `tagged`
avishniakov Nov 10, 2023
a8f508e
calm down branching check on zenml import
avishniakov Nov 10, 2023
3206946
update signature in tests
avishniakov Nov 10, 2023
994b59a
Merge branch 'develop' into feature/OSS-2610-create-tags-table
avishniakov Nov 10, 2023
ee57e87
update signature in tests
avishniakov Nov 10, 2023
7033f23
Merge branch 'develop' into feature/OSS-2610-create-tags-table
avishniakov Nov 10, 2023
4b51589
resolve branching
avishniakov Nov 10, 2023
81a0390
Auto-update of E2E template
actions-user Nov 10, 2023
de13f1e
move tagging code to sql store
avishniakov Nov 13, 2023
4adcd0b
resolve branching
avishniakov Nov 13, 2023
dce6439
Merge branch 'develop' into feature/OSS-2610-create-tags-table
avishniakov Nov 13, 2023
3f79f50
resolve alembic
avishniakov Nov 13, 2023
08649c8
stabilize test case
avishniakov Nov 13, 2023
005adcb
better cleanups in tests
avishniakov Nov 13, 2023
d6e76aa
workaround fix for quickstart
avishniakov Nov 13, 2023
5324f2b
revert hard cleanup
avishniakov Nov 13, 2023
1355e7c
explicit asserts in cli
avishniakov Nov 13, 2023
f388ae0
revert workaround fix for quickstart
avishniakov Nov 13, 2023
ed33c9f
Temporarily fix quickstart until the certificate is renewed
stefannica Nov 13, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/setup-python-environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ jobs:
if: ${{ inputs.os == 'ubuntu-latest' && inputs.python-version == '3.8' }}

- name: Check for alembic branch divergence
env:
ZENML_DEBUG: 0
run: |
bash scripts/check-alembic-branches.sh

Expand Down
2 changes: 1 addition & 1 deletion examples/quickstart/steps/training_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def training_data_loader() -> (
):
"""Load the Census Income dataset as tuple of Pandas DataFrame / Series."""
# Load the dataset
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
url = "http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
column_names = [
"age",
"workclass",
Expand Down
1 change: 1 addition & 0 deletions scripts/check-alembic-branches.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ output=$(alembic branches)

# Check if there's any output
if [[ -n "$output" ]]; then
echo $output
echo "Warning: Diverging Alembic branches detected."
exit 1 # Exit with failure status
else
Expand Down
1 change: 1 addition & 0 deletions src/zenml/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1578,3 +1578,4 @@ def my_pipeline(...):
from zenml.cli.stack_recipes import * # noqa
from zenml.cli.user_management import * # noqa
from zenml.cli.workspace import * # noqa
from zenml.cli.tag import * # noqa
184 changes: 100 additions & 84 deletions src/zenml/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# permissions and limitations under the License.
"""CLI functionality to interact with Model Control Plane."""
# from functools import partial
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional

import click

Expand All @@ -23,22 +23,58 @@
from zenml.cli.cli import TagGroup, cli
from zenml.client import Client
from zenml.enums import CliCategories, ModelStages
from zenml.exceptions import EntityExistsError
from zenml.logger import get_logger
from zenml.models.model_models import (
ModelFilterModel,
ModelRequestModel,
ModelResponseModel,
ModelUpdateModel,
ModelVersionArtifactFilterModel,
ModelVersionFilterModel,
ModelVersionPipelineRunFilterModel,
ModelVersionResponseModel,
ModelVersionUpdateModel,
)

# from zenml.utils.pagination_utils import depaginate
from zenml.utils.dict_utils import remove_none_values

logger = get_logger(__name__)


def _model_to_print(model: ModelResponseModel) -> Dict[str, Any]:
return {
"id": model.id,
"name": model.name,
"latest_version": model.latest_version,
"description": model.description,
"tags": [t.name for t in model.tags],
"use_cases": model.use_cases,
"audience": model.audience,
"limitations": model.limitations,
"trade_offs": model.trade_offs,
"ethics": model.ethics,
"license": model.license,
"updated": model.updated.date(),
}


def _model_version_to_print(
model_version: ModelVersionResponseModel,
) -> Dict[str, Any]:
return {
"id": model_version.id,
"name": model_version.name,
"number": model_version.number,
"description": model_version.description,
"stage": model_version.stage,
"artifact_objects_count": len(model_version.artifact_object_ids),
"model_objects_count": len(model_version.model_object_ids),
"deployments_count": len(model_version.deployment_ids),
"pipeline_runs_count": len(model_version.pipeline_run_ids),
"updated": model_version.updated.date(),
}


@cli.group(cls=TagGroup, tag=CliCategories.MODEL_CONTROL_PLANE)
def model() -> None:
"""Interact with models and model versions in the Model Control Plane."""
Expand All @@ -57,11 +93,10 @@ def list_models(**kwargs: Any) -> None:
if not models:
cli_utils.declare("No models found.")
return

cli_utils.print_pydantic_models(
models,
exclude_columns=["user", "workspace"],
)
to_print = []
for model in models:
to_print.append(_model_to_print(model))
cli_utils.print_table(to_print)


@model.command("register", help="Register a new model.")
Expand Down Expand Up @@ -151,28 +186,26 @@ def register_model(
limitations: The know limitations of the model.
tag: Tags associated with the model.
"""
model = Client().create_model(
ModelRequestModel(
name=name,
license=license,
description=description,
audience=audience,
use_cases=use_cases,
trade_offs=tradeoffs,
ethics=ethical,
limitations=limitations,
tags=tag,
user=Client().active_user.id,
workspace=Client().active_workspace.id,
try:
model = Client().create_model(
ModelRequestModel(
name=name,
license=license,
description=description,
audience=audience,
use_cases=use_cases,
trade_offs=tradeoffs,
ethics=ethical,
limitations=limitations,
tags=tag,
user=Client().active_user.id,
workspace=Client().active_workspace.id,
)
)
)
except (EntityExistsError, ValueError) as e:
cli_utils.error(str(e))

cli_utils.print_pydantic_models(
[
model,
],
exclude_columns=["user", "workspace"],
)
cli_utils.print_table([_model_to_print(model)])


@model.command("update", help="Update an existing model.")
Expand Down Expand Up @@ -227,7 +260,15 @@ def register_model(
@click.option(
"--tag",
"-t",
help="Tags associated with the model.",
help="Tags to be added to the model.",
type=str,
required=False,
multiple=True,
)
@click.option(
"--remove-tag",
"-r",
help="Tags to be removed from the model.",
type=str,
required=False,
multiple=True,
Expand All @@ -242,6 +283,7 @@ def update_model(
ethical: Optional[str],
limitations: Optional[str],
tag: Optional[List[str]],
remove_tag: Optional[List[str]],
) -> None:
"""Register a new model in the Model Control Plane.

Expand All @@ -254,31 +296,31 @@ def update_model(
tradeoffs: The tradeoffs of the model.
ethical: The ethical implications of the model.
limitations: The know limitations of the model.
tag: Tags associated with the model.
tag: Tags to be added to the model.
remove_tag: Tags to be removed from the model.
"""
model_id = Client().get_model(model_name_or_id=model_name_or_id).id
model = Client().update_model(
model_id=model_id,
model_update=ModelUpdateModel(
update_dict = remove_none_values(
dict(
license=license,
description=description,
audience=audience,
use_cases=use_cases,
trade_offs=tradeoffs,
ethics=ethical,
limitations=limitations,
tags=tag,
add_tags=tag,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect this to take a list of strings not just one string

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a list, just the limitation of click - the name of arg must match the name of arg in the function and the user pass it as -t a -t b ...

remove_tags=remove_tag,
user=Client().active_user.id,
workspace=Client().active_workspace.id,
),
)
)

cli_utils.print_pydantic_models(
[
model,
],
exclude_columns=["user", "workspace"],
model = Client().update_model(
model_id=model_id,
model_update=ModelUpdateModel(**update_dict),
)
cli_utils.print_table([_model_to_print(model)])


@model.command("delete", help="Delete an existing model.")
Expand Down Expand Up @@ -342,29 +384,11 @@ def list_model_versions(model_name_or_id: str, **kwargs: Any) -> None:
cli_utils.declare("No model versions found.")
return

to_print = []
for model_version in model_versions:
model_version.artifact_objects_count = len( # type: ignore[attr-defined]
model_version.artifact_object_ids
)
model_version.model_objects_count = len(model_version.model_object_ids) # type: ignore[attr-defined]
model_version.deployments_count = len(model_version.deployment_ids) # type: ignore[attr-defined]
model_version.pipeline_runs_count = len(model_version.pipeline_run_ids) # type: ignore[attr-defined]
to_print.append(_model_version_to_print(model_version))

cli_utils.print_pydantic_models(
model_versions,
columns=[
"id",
"name",
"number",
"description",
"stage",
"artifact_objects_count",
"model_objects_count",
"deployments_count",
"pipeline_runs_count",
"updated",
],
)
cli_utils.print_table(to_print)


@version.command("update", help="Update an existing model version stage.")
Expand Down Expand Up @@ -409,31 +433,23 @@ def update_model_version(
)
except RuntimeError:
if not force:
cli_utils.print_pydantic_models(
Client().list_model_versions(
model_name_or_id=model_version.model.id,
model_version_filter_model=ModelVersionFilterModel(
stage=stage
),
),
columns=[
"id",
"name",
"number",
"description",
"stage",
"artifact_objects_count",
"model_objects_count",
"deployments_count",
"pipeline_runs_count",
"updated",
],
cli_utils.print_table(
[
_model_version_to_print(
Client().get_model_version(
model_name_or_id=model_version.model.id,
model_version_name_or_number_or_id=stage,
)
)
]
)

confirmation = cli_utils.confirmation(
"Are you sure you want to change the status of this model "
f"version to '{stage}'? This stage is already taken by "
"another model version and if you will proceed the current "
"model version in this stage will be archived."
"Are you sure you want to change the status of model "
f"version '{model_version_name_or_number_or_id}' to "
f"'{stage}'?\nThis stage is already taken by "
"model version shown above and if you will proceed this "
"model version will get into archived stage."
)
if not confirmation:
cli_utils.declare("Model version stage update canceled.")
Expand Down
Loading
Loading