Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update has_custom_name for legacy artifacts #2384

Merged
4 changes: 4 additions & 0 deletions src/zenml/artifacts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ def save_artifact(
# Get or create the artifact
try:
artifact = client.list_artifacts(name=name)[0]
if artifact.has_custom_name != has_custom_name:
client.update_artifact(
name_id_or_prefix=artifact.id, has_custom_name=has_custom_name
)
except IndexError:
artifact = client.zen_store.create_artifact(
ArtifactRequest(
Expand Down
3 changes: 3 additions & 0 deletions src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2761,6 +2761,7 @@ def update_artifact(
new_name: Optional[str] = None,
add_tags: Optional[List[str]] = None,
remove_tags: Optional[List[str]] = None,
has_custom_name: Optional[bool] = None,
) -> ArtifactResponse:
"""Update an artifact.

Expand All @@ -2769,6 +2770,7 @@ def update_artifact(
new_name: The new name of the artifact.
add_tags: Tags to add to the artifact.
remove_tags: Tags to remove from the artifact.
has_custom_name: Whether the artifact has a custom name.

Returns:
The updated artifact.
Expand All @@ -2778,6 +2780,7 @@ def update_artifact(
name=new_name,
add_tags=add_tags,
remove_tags=remove_tags,
has_custom_name=has_custom_name,
)
return self.zen_store.update_artifact(
artifact_id=artifact.id, artifact_update=artifact_update
Expand Down
1 change: 1 addition & 0 deletions src/zenml/models/v2/core/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class ArtifactUpdate(BaseModel):
name: Optional[str] = None
add_tags: Optional[List[str]] = None
remove_tags: Optional[List[str]] = None
has_custom_name: Optional[bool] = None


# ------------------ Response Model ------------------
Expand Down
4 changes: 4 additions & 0 deletions src/zenml/orchestrators/output_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import os
from typing import TYPE_CHECKING, Dict, Sequence
from uuid import uuid4

from zenml.io import fileio
from zenml.logger import get_logger
Expand Down Expand Up @@ -44,11 +45,14 @@ def generate_artifact_uri(
Returns:
The URI of the output artifact.
"""
for banned_character in ["<", ">", ":", '"', "/", "\\", "|", "?", "*"]:
output_name = output_name.replace(banned_character, "_")
return os.path.join(
artifact_store.path,
step_run.name,
output_name,
str(step_run.id),
str(uuid4())[:8], # add random subfolder to avoid collisions
)


Expand Down
2 changes: 2 additions & 0 deletions src/zenml/zen_stores/schemas/artifact_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def update(self, artifact_update: ArtifactUpdate) -> "ArtifactSchema":
if artifact_update.name:
self.name = artifact_update.name
self.has_custom_name = True
if artifact_update.has_custom_name is not None:
self.has_custom_name = artifact_update.has_custom_name
return self


Expand Down
36 changes: 36 additions & 0 deletions tests/integration/functional/artifacts/test_artifact_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,3 +442,39 @@ def _inner_pipeline(force_disable_cache: bool = False):
assert (
len(mvrm.data_artifact_ids["cacheable"]) == 1
), f"Failed on {i} run"


@step
def standard_name_producer() -> str:
return "standard"


@step
def custom_name_producer() -> (
Annotated[str, "pipeline_::standard_name_producer::output"]
):
return "custom"


def test_update_of_has_custom_name(clean_client: "Client"):
"""Test that update of has_custom_name works."""

@pipeline(enable_cache=False)
def pipeline_():
standard_name_producer()

@pipeline(enable_cache=False)
def pipeline_2():
custom_name_producer()

# run 2 times to see both ways switching
for i in range(2):
pipeline_()
assert not clean_client.get_artifact(
"pipeline_::standard_name_producer::output"
).has_custom_name, f"Standard name validation failed in {i+1} run"

pipeline_2()
assert clean_client.get_artifact(
"pipeline_::standard_name_producer::output"
).has_custom_name, f"Custom name validation failed in {i+1} run"
11 changes: 3 additions & 8 deletions tests/unit/orchestrators/test_output_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

import os

import pytest

from zenml.config.step_configurations import Step
from zenml.orchestrators import output_utils

Expand All @@ -41,11 +39,8 @@ def test_output_artifact_preparation(create_step_run, local_stack):
"output_name",
str(step_run.id),
)
output_artifact_uris["output_name"] = "/".join(
output_artifact_uris["output_name"].split("/")[:-1]
)
avishniakov marked this conversation as resolved.
Show resolved Hide resolved
assert output_artifact_uris == {"output_name": expected_path}
assert os.path.isdir(expected_path)

# artifact directory already exists
with pytest.raises(RuntimeError):
output_utils.prepare_output_artifact_uris(
step_run=step_run, stack=local_stack, step=step
)
Loading