Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create tags table #2036

Merged
merged 37 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
c9729d9
working option 2
avishniakov Nov 9, 2023
caecabc
working option 2
avishniakov Nov 9, 2023
27624c6
working option 2
avishniakov Nov 9, 2023
15df83c
fix mysql
avishniakov Nov 9, 2023
a813455
Merge branch 'develop' into feature/OSS-2610-create-tags-table
avishniakov Nov 9, 2023
331e1a1
fix bug + add tests
avishniakov Nov 9, 2023
7ab1a22
rename
avishniakov Nov 9, 2023
200776e
Merge branch 'develop' into feature/OSS-2610-create-tags-table
avishniakov Nov 10, 2023
a265b64
add alembic branches output on divergence
avishniakov Nov 10, 2023
4ff8b26
add client functions
avishniakov Nov 10, 2023
4438a5f
strenums for types/colors
avishniakov Nov 10, 2023
9f7daf6
add tags cli
avishniakov Nov 10, 2023
e9b539a
add tags cli
avishniakov Nov 10, 2023
da14b2b
try bypass alembic branching
avishniakov Nov 10, 2023
fb26bd4
remove tag<>resource endpoints
avishniakov Nov 10, 2023
df82b6e
rely on sql for tag links
avishniakov Nov 10, 2023
1dd232b
Merge branch 'develop' into feature/OSS-2610-create-tags-table
avishniakov Nov 10, 2023
17bced5
fix migration bug with uuids
avishniakov Nov 10, 2023
3bb6f82
remove `tagged`
avishniakov Nov 10, 2023
a8f508e
calm down branching check on zenml import
avishniakov Nov 10, 2023
3206946
update signature in tests
avishniakov Nov 10, 2023
994b59a
Merge branch 'develop' into feature/OSS-2610-create-tags-table
avishniakov Nov 10, 2023
ee57e87
update signature in tests
avishniakov Nov 10, 2023
7033f23
Merge branch 'develop' into feature/OSS-2610-create-tags-table
avishniakov Nov 10, 2023
4b51589
resolve branching
avishniakov Nov 10, 2023
81a0390
Auto-update of E2E template
actions-user Nov 10, 2023
de13f1e
move tagging code to sql store
avishniakov Nov 13, 2023
4adcd0b
resolve branching
avishniakov Nov 13, 2023
dce6439
Merge branch 'develop' into feature/OSS-2610-create-tags-table
avishniakov Nov 13, 2023
3f79f50
resolve alembic
avishniakov Nov 13, 2023
08649c8
stabilize test case
avishniakov Nov 13, 2023
005adcb
better cleanups in tests
avishniakov Nov 13, 2023
d6e76aa
workaround fix for quickstart
avishniakov Nov 13, 2023
5324f2b
revert hard cleanup
avishniakov Nov 13, 2023
1355e7c
explicit asserts in cli
avishniakov Nov 13, 2023
f388ae0
revert workaround fix for quickstart
avishniakov Nov 13, 2023
ed33c9f
Temporarily fix quickstart until the certificate is renewed
stefannica Nov 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,6 @@ zenml_tutorial/
mlstacks_reset.sh

.local/

# exclude installed dashboard folder
src/zenml/zen_server/dashboard
2 changes: 1 addition & 1 deletion src/zenml/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def update_model(
trade_offs=tradeoffs,
ethics=ethical,
limitations=limitations,
tags=tag,
add_tags=tag,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect this to take a list of strings not just one string

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a list, just the limitation of click - the name of arg must match the name of arg in the function and the user pass it as -t a -t b ...

user=Client().active_user.id,
workspace=Client().active_workspace.id,
),
Expand Down
2 changes: 2 additions & 0 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
DEVICE_AUTHORIZATION = "/device_authorization"
DEVICE_VERIFY = "/verify"
API_TOKEN = "/api_token"
TAGS = "/tags"
TAG_RESOURCES = "/tag_resources"
avishniakov marked this conversation as resolved.
Show resolved Hide resolved

# model metadata yaml file name
MODEL_METADATA_YAML_FILE_NAME = "model_metadata.yaml"
Expand Down
22 changes: 22 additions & 0 deletions src/zenml/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,25 @@ class ModelStages(StrEnum):
STAGING = "staging"
PRODUCTION = "production"
ARCHIVED = "archived"


class ColorVariants(Enum):
avishniakov marked this conversation as resolved.
Show resolved Hide resolved
"""All possible color variants for frontend."""

GREY = 1
PURPLE = 2
RED = 3
GREEN = 4
YELLOW = 5
ORANGE = 6
LIME = 7
TEAL = 8
TURQUOISE = 9
MAGENTA = 10
BLUE = 11


class TaggableResourceTypes(Enum):
"""All possible resource types for tagging."""

MODEL = 1
avishniakov marked this conversation as resolved.
Show resolved Hide resolved
11 changes: 10 additions & 1 deletion src/zenml/model/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pydantic import BaseModel, root_validator

from zenml.constants import RUNNING_MODEL_VERSION
from zenml.enums import ExecutionStatus, ModelStages
from zenml.enums import ExecutionStatus, ModelStages, TaggableResourceTypes
from zenml.exceptions import EntityExistsError
from zenml.logger import get_logger

Expand Down Expand Up @@ -177,6 +177,7 @@ def get_or_create_model(self) -> "ModelResponseModel":
"""
from zenml.client import Client
from zenml.models.model_models import ModelRequestModel
from zenml.utils.tag_utils import create_links

zenml_client = Client()
try:
Expand All @@ -198,9 +199,17 @@ def get_or_create_model(self) -> "ModelResponseModel":
model_request = ModelRequestModel.parse_obj(model_request)
try:
model = zenml_client.create_model(model=model_request)
if model_request.tags:
create_links(
model_request.tags,
model.id,
TaggableResourceTypes.MODEL,
)
logger.info(f"New model `{self.name}` was created implicitly.")
except EntityExistsError:
# this is backup logic, if model was created somehow in between get and create calls
pass
finally:
model = zenml_client.get_model(model_name_or_id=self.name)

return model
Expand Down
15 changes: 14 additions & 1 deletion src/zenml/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,14 @@
ModelVersionFilterModel,
ModelVersionUpdateModel,
)

from zenml.models.tag_models import (
TagFilterModel,
TagResourceResponseModel,
TagResourceRequestModel,
TagResponseModel,
TagRequestModel,
TagUpdateModel,
)

ComponentResponseModel.update_forward_refs(
UserResponseModel=UserResponseModel,
Expand Down Expand Up @@ -445,6 +452,12 @@
"StepRunRequestModel",
"StepRunResponseModel",
"StepRunUpdateModel",
"TagFilterModel",
"TagResourceResponseModel",
"TagResourceRequestModel",
"TagResponseModel",
"TagRequestModel",
"TagUpdateModel",
"TeamFilterModel",
"TeamRequestModel",
"TeamResponseModel",
Expand Down
5 changes: 1 addition & 4 deletions src/zenml/models/model_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# permissions and limitations under the License.
"""Model base model to support Model Control Plane feature."""

from typing import List, Optional
from typing import Optional

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -58,6 +58,3 @@ class ModelBaseModel(BaseModel):
title="The ethical implications of the model",
max_length=TEXT_FIELD_MAX_LENGTH,
)
tags: Optional[List[str]] = Field(
title="Tags associated with the model",
)
11 changes: 9 additions & 2 deletions src/zenml/models/model_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from zenml.models.filter_models import WorkspaceScopedFilterModel
from zenml.models.model_base_model import ModelBaseModel
from zenml.models.pipeline_run_models import PipelineRunResponseModel
from zenml.models.tag_models import TagResponseModel

if TYPE_CHECKING:
from sqlmodel.sql.expression import Select, SelectOfScalar
Expand Down Expand Up @@ -624,7 +625,9 @@ class ModelRequestModel(
):
"""Model request model."""

pass
tags: Optional[List[str]] = Field(
title="Tags associated with the model",
)


class ModelResponseModel(
Expand All @@ -636,6 +639,9 @@ class ModelResponseModel(
latest_version: name of latest version, if any
"""

tags: Optional[List[TagResponseModel]] = Field(
title="Tags associated with the model",
)
latest_version: Optional[str]

@property
Expand Down Expand Up @@ -707,4 +713,5 @@ class ModelUpdateModel(BaseModel):
limitations: Optional[str]
trade_offs: Optional[str]
ethics: Optional[str]
tags: Optional[List[str]]
add_tags: Optional[List[str]]
remove_tags: Optional[List[str]]
143 changes: 143 additions & 0 deletions src/zenml/models/tag_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright (c) ZenML GmbH 2022. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Models representing tags."""


import random
from typing import Any, Dict, Optional
from uuid import UUID

from pydantic import BaseModel, Field, root_validator, validator

from zenml.enums import ColorVariants, TaggableResourceTypes
from zenml.models.base_models import (
BaseRequestModel,
BaseResponseModel,
)
from zenml.models.constants import STR_FIELD_MAX_LENGTH
from zenml.models.filter_models import BaseFilterModel

# Tags


def _validate_color(color: str) -> str:
try:
if str(color).isdigit():
color = str(ColorVariants(int(color)).value)
else:
color = str(getattr(ColorVariants, color.upper()).value)
except NameError:
raise ValueError(
f"Given color value `{color}` does not "
"match any of defined ColorVariants "
f"`{list(ColorVariants.__members__.keys())}`."
)
return color


class TagBaseModel(BaseModel):
"""Base model for tags."""

name: str = Field(
description="The unique title of the tag.",
max_length=STR_FIELD_MAX_LENGTH,
)
color: Optional[str] = Field(
description="The color variant assigned to the tag.",
max_length=STR_FIELD_MAX_LENGTH,
)

@root_validator(pre=True)
def _set_random_color_if_none(
cls, values: Dict[str, Any]
) -> Dict[str, Any]:
if not values.get("color", None):
values["color"] = random.choice(list(ColorVariants)).name.lower()
else:
_validate_color(values["color"])
return values


class TagResponseModel(TagBaseModel, BaseResponseModel):
"""Response model for tags."""

tagged_count: int = Field(
fa9r marked this conversation as resolved.
Show resolved Hide resolved
description="The count of resources tagged with this tag."
)


class TagFilterModel(BaseFilterModel):
"""Model to enable advanced filtering of all tags."""

name: Optional[str]
color: Optional[str]

@validator("color", pre=True)
def _translate_color_to_integer(
cls, color: Optional[str]
) -> Optional[str]:
if not color:
return None
else:
return _validate_color(color)


class TagRequestModel(TagBaseModel, BaseRequestModel):
"""Request model for tags."""


class TagUpdateModel(BaseModel):
"""Update model for tags."""

name: Optional[str]
color: Optional[str]

@validator("color", pre=True)
def _translate_color_to_integer(
cls, color: Optional[str]
) -> Optional[str]:
if not color:
return None
else:
return _validate_color(color)


# Tags <> Resources


class TagResourceBaseModel(BaseModel):
"""Base model for tag resource relationships."""

tag_id: UUID
resource_id: UUID
resource_type: TaggableResourceTypes

@property
def tag_resource_id(self) -> UUID:
"""Get stable ID from tag_id and resource_id.

Returns:
The generated stable ID.
"""
from zenml.utils.tag_utils import _get_tag_resource_id

return _get_tag_resource_id(self.tag_id, self.resource_id)


class TagResourceResponseModel(TagResourceBaseModel, BaseResponseModel):
"""Response model for tag resource relationships."""


class TagResourceRequestModel(TagResourceBaseModel, BaseRequestModel):
"""Request model for tag resource relationships."""
79 changes: 79 additions & 0 deletions src/zenml/utils/tag_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) ZenML GmbH 2022. All Rights Reserved.
avishniakov marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Utility functions for handling tags."""

from typing import List
from uuid import UUID

from zenml.enums import TaggableResourceTypes
from zenml.exceptions import EntityExistsError
from zenml.models.tag_models import TagRequestModel, TagResourceRequestModel
from zenml.utils.uuid_utils import generate_uuid_from_string


def _get_tag_resource_id(tag_id: UUID, resource_id: UUID) -> UUID:
return generate_uuid_from_string(str(tag_id) + str(resource_id))


def create_links(
avishniakov marked this conversation as resolved.
Show resolved Hide resolved
tag_names: List[str],
resource_id: UUID,
resource_type: TaggableResourceTypes,
) -> None:
"""Creates a tag<>resource link if not present.

Args:
tag_names: The list of names of the tags.
resource_id: The id of the resource.
resource_type: The type of the resource to create link with
"""
from zenml.client import Client

zs = Client().zen_store
for tag_name in tag_names:
try:
tag = zs.get_tag(tag_name)
except KeyError:
tag = zs.create_tag(TagRequestModel(name=tag_name))
try:
zs.create_tag_resource(
TagResourceRequestModel(
tag_id=tag.id,
resource_id=resource_id,
resource_type=resource_type,
)
)
except EntityExistsError:
pass


def delete_links(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea where to write this - but do we have any way of implementing cascading delete? Like if I delete resource 1 - will all entries for tags on resource 1 be deleted

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be cascaded - this is need for removal of tags via update, I'll rename it for clarity

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This folk is cascading (set in Models and other taggable entities going forward)

tags: List["TagResourceSchema"] = Relationship(
        back_populates="model",
        sa_relationship_kwargs=dict(
            primaryjoin=f"and_(TagResourceSchema.resource_type=={TaggableResourceTypes.MODEL.value}, foreign(TagResourceSchema.resource_id)==ModelSchema.id)",
            cascade="delete",
        ),
    )

tag_names: List[str],
resource_id: UUID,
) -> None:
"""Deletes tag<>resource link if present.

Args:
tag_names: The list of names of the tags.
resource_id: The id of the resource.
"""
from zenml.client import Client

zs = Client().zen_store
for tag_name in tag_names:
try:
tag = zs.get_tag(tag_name)
zs.delete_tag_resource(_get_tag_resource_id(tag.id, resource_id))
except KeyError:
pass
Loading
Loading