diff --git a/src/zenml/artifacts/utils.py b/src/zenml/artifacts/utils.py index 6ec3403c2a..41d7e1ce0a 100644 --- a/src/zenml/artifacts/utils.py +++ b/src/zenml/artifacts/utils.py @@ -189,6 +189,10 @@ def save_artifact( # Get or create the artifact try: artifact = client.list_artifacts(name=name)[0] + if artifact.has_custom_name != has_custom_name: + client.update_artifact( + name_id_or_prefix=artifact.id, has_custom_name=has_custom_name + ) except IndexError: artifact = client.zen_store.create_artifact( ArtifactRequest( diff --git a/src/zenml/client.py b/src/zenml/client.py index ef5d9e3a81..9830fc25a7 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -2761,6 +2761,7 @@ def update_artifact( new_name: Optional[str] = None, add_tags: Optional[List[str]] = None, remove_tags: Optional[List[str]] = None, + has_custom_name: Optional[bool] = None, ) -> ArtifactResponse: """Update an artifact. @@ -2769,6 +2770,7 @@ def update_artifact( new_name: The new name of the artifact. add_tags: Tags to add to the artifact. remove_tags: Tags to remove from the artifact. + has_custom_name: Whether the artifact has a custom name. Returns: The updated artifact. @@ -2778,6 +2780,7 @@ def update_artifact( name=new_name, add_tags=add_tags, remove_tags=remove_tags, + has_custom_name=has_custom_name, ) return self.zen_store.update_artifact( artifact_id=artifact.id, artifact_update=artifact_update diff --git a/src/zenml/models/v2/core/artifact.py b/src/zenml/models/v2/core/artifact.py index 3938dbff1f..8fb548924d 100644 --- a/src/zenml/models/v2/core/artifact.py +++ b/src/zenml/models/v2/core/artifact.py @@ -60,6 +60,7 @@ class ArtifactUpdate(BaseModel): name: Optional[str] = None add_tags: Optional[List[str]] = None remove_tags: Optional[List[str]] = None + has_custom_name: Optional[bool] = None # ------------------ Response Model ------------------ diff --git a/src/zenml/orchestrators/output_utils.py b/src/zenml/orchestrators/output_utils.py index 0bf5385cf6..21f8b74ccb 100644 --- a/src/zenml/orchestrators/output_utils.py +++ b/src/zenml/orchestrators/output_utils.py @@ -15,6 +15,7 @@ import os from typing import TYPE_CHECKING, Dict, Sequence +from uuid import uuid4 from zenml.io import fileio from zenml.logger import get_logger @@ -44,11 +45,14 @@ def generate_artifact_uri( Returns: The URI of the output artifact. """ + for banned_character in ["<", ">", ":", '"', "/", "\\", "|", "?", "*"]: + output_name = output_name.replace(banned_character, "_") return os.path.join( artifact_store.path, step_run.name, output_name, str(step_run.id), + str(uuid4())[:8], # add random subfolder to avoid collisions ) diff --git a/src/zenml/zen_stores/schemas/artifact_schemas.py b/src/zenml/zen_stores/schemas/artifact_schemas.py index a7a60574bd..93a1a81394 100644 --- a/src/zenml/zen_stores/schemas/artifact_schemas.py +++ b/src/zenml/zen_stores/schemas/artifact_schemas.py @@ -143,6 +143,8 @@ def update(self, artifact_update: ArtifactUpdate) -> "ArtifactSchema": if artifact_update.name: self.name = artifact_update.name self.has_custom_name = True + if artifact_update.has_custom_name is not None: + self.has_custom_name = artifact_update.has_custom_name return self diff --git a/tests/integration/functional/artifacts/test_artifact_config.py b/tests/integration/functional/artifacts/test_artifact_config.py index 5d018e9ad5..68990ab768 100644 --- a/tests/integration/functional/artifacts/test_artifact_config.py +++ b/tests/integration/functional/artifacts/test_artifact_config.py @@ -442,3 +442,39 @@ def _inner_pipeline(force_disable_cache: bool = False): assert ( len(mvrm.data_artifact_ids["cacheable"]) == 1 ), f"Failed on {i} run" + + +@step +def standard_name_producer() -> str: + return "standard" + + +@step +def custom_name_producer() -> ( + Annotated[str, "pipeline_::standard_name_producer::output"] +): + return "custom" + + +def test_update_of_has_custom_name(clean_client: "Client"): + """Test that update of has_custom_name works.""" + + @pipeline(enable_cache=False) + def pipeline_(): + standard_name_producer() + + @pipeline(enable_cache=False) + def pipeline_2(): + custom_name_producer() + + # run 2 times to see both ways switching + for i in range(2): + pipeline_() + assert not clean_client.get_artifact( + "pipeline_::standard_name_producer::output" + ).has_custom_name, f"Standard name validation failed in {i+1} run" + + pipeline_2() + assert clean_client.get_artifact( + "pipeline_::standard_name_producer::output" + ).has_custom_name, f"Custom name validation failed in {i+1} run" diff --git a/tests/unit/orchestrators/test_output_utils.py b/tests/unit/orchestrators/test_output_utils.py index e63c4d4607..6bb4d12d48 100644 --- a/tests/unit/orchestrators/test_output_utils.py +++ b/tests/unit/orchestrators/test_output_utils.py @@ -14,8 +14,6 @@ import os -import pytest - from zenml.config.step_configurations import Step from zenml.orchestrators import output_utils @@ -41,11 +39,9 @@ def test_output_artifact_preparation(create_step_run, local_stack): "output_name", str(step_run.id), ) + output_artifact_uris["output_name"] = os.path.split( + output_artifact_uris["output_name"] + )[0] + assert output_artifact_uris == {"output_name": expected_path} assert os.path.isdir(expected_path) - - # artifact directory already exists - with pytest.raises(RuntimeError): - output_utils.prepare_output_artifact_uris( - step_run=step_run, stack=local_stack, step=step - )