-
Notifications
You must be signed in to change notification settings - Fork 409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create tags table #2036
Create tags table #2036
Changes from 5 commits
c9729d9
caecabc
27624c6
15df83c
a813455
331e1a1
7ab1a22
200776e
a265b64
4ff8b26
4438a5f
9f7daf6
e9b539a
da14b2b
fb26bd4
df82b6e
1dd232b
17bced5
3bb6f82
a8f508e
3206946
994b59a
ee57e87
7033f23
4b51589
81a0390
de13f1e
4adcd0b
dce6439
3f79f50
08649c8
005adcb
d6e76aa
5324f2b
1355e7c
f388ae0
ed33c9f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -197,3 +197,6 @@ zenml_tutorial/ | |
mlstacks_reset.sh | ||
|
||
.local/ | ||
|
||
# exclude installed dashboard folder | ||
src/zenml/zen_server/dashboard |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
# Copyright (c) ZenML GmbH 2022. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at: | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express | ||
# or implied. See the License for the specific language governing | ||
# permissions and limitations under the License. | ||
"""Models representing tags.""" | ||
|
||
|
||
import random | ||
from typing import Any, Dict, Optional | ||
from uuid import UUID | ||
|
||
from pydantic import BaseModel, Field, root_validator, validator | ||
|
||
from zenml.enums import ColorVariants, TaggableResourceTypes | ||
from zenml.models.base_models import ( | ||
BaseRequestModel, | ||
BaseResponseModel, | ||
) | ||
from zenml.models.constants import STR_FIELD_MAX_LENGTH | ||
from zenml.models.filter_models import BaseFilterModel | ||
|
||
# Tags | ||
|
||
|
||
def _validate_color(color: str) -> str: | ||
try: | ||
if str(color).isdigit(): | ||
color = str(ColorVariants(int(color)).value) | ||
else: | ||
color = str(getattr(ColorVariants, color.upper()).value) | ||
except NameError: | ||
raise ValueError( | ||
f"Given color value `{color}` does not " | ||
"match any of defined ColorVariants " | ||
f"`{list(ColorVariants.__members__.keys())}`." | ||
) | ||
return color | ||
|
||
|
||
class TagBaseModel(BaseModel): | ||
"""Base model for tags.""" | ||
|
||
name: str = Field( | ||
description="The unique title of the tag.", | ||
max_length=STR_FIELD_MAX_LENGTH, | ||
) | ||
color: Optional[str] = Field( | ||
description="The color variant assigned to the tag.", | ||
max_length=STR_FIELD_MAX_LENGTH, | ||
) | ||
|
||
@root_validator(pre=True) | ||
def _set_random_color_if_none( | ||
cls, values: Dict[str, Any] | ||
) -> Dict[str, Any]: | ||
if not values.get("color", None): | ||
values["color"] = random.choice(list(ColorVariants)).name.lower() | ||
else: | ||
_validate_color(values["color"]) | ||
return values | ||
|
||
|
||
class TagResponseModel(TagBaseModel, BaseResponseModel): | ||
"""Response model for tags.""" | ||
|
||
tagged_count: int = Field( | ||
fa9r marked this conversation as resolved.
Show resolved
Hide resolved
|
||
description="The count of resources tagged with this tag." | ||
) | ||
|
||
|
||
class TagFilterModel(BaseFilterModel): | ||
"""Model to enable advanced filtering of all tags.""" | ||
|
||
name: Optional[str] | ||
color: Optional[str] | ||
|
||
@validator("color", pre=True) | ||
def _translate_color_to_integer( | ||
cls, color: Optional[str] | ||
) -> Optional[str]: | ||
if not color: | ||
return None | ||
else: | ||
return _validate_color(color) | ||
|
||
|
||
class TagRequestModel(TagBaseModel, BaseRequestModel): | ||
"""Request model for tags.""" | ||
|
||
|
||
class TagUpdateModel(BaseModel): | ||
"""Update model for tags.""" | ||
|
||
name: Optional[str] | ||
color: Optional[str] | ||
|
||
@validator("color", pre=True) | ||
def _translate_color_to_integer( | ||
cls, color: Optional[str] | ||
) -> Optional[str]: | ||
if not color: | ||
return None | ||
else: | ||
return _validate_color(color) | ||
|
||
|
||
# Tags <> Resources | ||
|
||
|
||
class TagResourceBaseModel(BaseModel): | ||
"""Base model for tag resource relationships.""" | ||
|
||
tag_id: UUID | ||
resource_id: UUID | ||
resource_type: TaggableResourceTypes | ||
|
||
@property | ||
def tag_resource_id(self) -> UUID: | ||
"""Get stable ID from tag_id and resource_id. | ||
|
||
Returns: | ||
The generated stable ID. | ||
""" | ||
from zenml.utils.tag_utils import _get_tag_resource_id | ||
|
||
return _get_tag_resource_id(self.tag_id, self.resource_id) | ||
|
||
|
||
class TagResourceResponseModel(TagResourceBaseModel, BaseResponseModel): | ||
"""Response model for tag resource relationships.""" | ||
|
||
|
||
class TagResourceRequestModel(TagResourceBaseModel, BaseRequestModel): | ||
"""Request model for tag resource relationships.""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright (c) ZenML GmbH 2022. All Rights Reserved. | ||
avishniakov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at: | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express | ||
# or implied. See the License for the specific language governing | ||
# permissions and limitations under the License. | ||
"""Utility functions for handling tags.""" | ||
|
||
from typing import List | ||
from uuid import UUID | ||
|
||
from zenml.enums import TaggableResourceTypes | ||
from zenml.exceptions import EntityExistsError | ||
from zenml.models.tag_models import TagRequestModel, TagResourceRequestModel | ||
from zenml.utils.uuid_utils import generate_uuid_from_string | ||
|
||
|
||
def _get_tag_resource_id(tag_id: UUID, resource_id: UUID) -> UUID: | ||
return generate_uuid_from_string(str(tag_id) + str(resource_id)) | ||
|
||
|
||
def create_links( | ||
avishniakov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tag_names: List[str], | ||
resource_id: UUID, | ||
resource_type: TaggableResourceTypes, | ||
) -> None: | ||
"""Creates a tag<>resource link if not present. | ||
|
||
Args: | ||
tag_names: The list of names of the tags. | ||
resource_id: The id of the resource. | ||
resource_type: The type of the resource to create link with | ||
""" | ||
from zenml.client import Client | ||
|
||
zs = Client().zen_store | ||
for tag_name in tag_names: | ||
try: | ||
tag = zs.get_tag(tag_name) | ||
except KeyError: | ||
tag = zs.create_tag(TagRequestModel(name=tag_name)) | ||
try: | ||
zs.create_tag_resource( | ||
TagResourceRequestModel( | ||
tag_id=tag.id, | ||
resource_id=resource_id, | ||
resource_type=resource_type, | ||
) | ||
) | ||
except EntityExistsError: | ||
pass | ||
|
||
|
||
def delete_links( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No idea where to write this - but do we have any way of implementing cascading delete? Like if I delete resource 1 - will all entries for tags on resource 1 be deleted There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will be cascaded - this is need for removal of tags via update, I'll rename it for clarity There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This folk is cascading (set in Models and other taggable entities going forward) tags: List["TagResourceSchema"] = Relationship(
back_populates="model",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=={TaggableResourceTypes.MODEL.value}, foreign(TagResourceSchema.resource_id)==ModelSchema.id)",
cascade="delete",
),
) |
||
tag_names: List[str], | ||
resource_id: UUID, | ||
) -> None: | ||
"""Deletes tag<>resource link if present. | ||
|
||
Args: | ||
tag_names: The list of names of the tags. | ||
resource_id: The id of the resource. | ||
""" | ||
from zenml.client import Client | ||
|
||
zs = Client().zen_store | ||
for tag_name in tag_names: | ||
try: | ||
tag = zs.get_tag(tag_name) | ||
zs.delete_tag_resource(_get_tag_resource_id(tag.id, resource_id)) | ||
except KeyError: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would expect this to take a list of strings not just one string
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a list, just the limitation of click - the name of arg must match the name of arg in the function and the user pass it as
-t a -t b ...