Skip to content

Commit

Permalink
Replace parse_obj (#2623)
Browse files Browse the repository at this point in the history
* Replace parse_obj

* use hydrated artifact version

* update test signature

* not link artifact to the step in save of test

---------

Co-authored-by: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com>
  • Loading branch information
AlexejPenner and avishniakov committed Apr 23, 2024
1 parent 429f067 commit 097839e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
4 changes: 3 additions & 1 deletion src/zenml/orchestrators/step_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,9 @@ def _load_input_artifact(
"""
# Skip materialization for `UnmaterializedArtifact`.
if data_type == UnmaterializedArtifact:
return UnmaterializedArtifact.parse_obj(artifact)
return UnmaterializedArtifact(
**artifact.get_hydrated_version().dict()
)

if data_type is Any or is_union(get_origin(data_type)):
# Entrypoint function does not define a specific type for the input,
Expand Down
11 changes: 6 additions & 5 deletions tests/unit/orchestrators/test_step_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import pytest

from zenml import save_artifact
from zenml.artifacts.unmaterialized_artifact import UnmaterializedArtifact
from zenml.config.pipeline_configurations import PipelineConfiguration
from zenml.config.step_configurations import Step
Expand Down Expand Up @@ -147,11 +148,13 @@ def test_running_a_failing_step(
mock_publish_successful_step_run.assert_not_called()


def test_loading_unmaterialized_input_artifact(
local_stack, sample_artifact_version_model
):
def test_loading_unmaterialized_input_artifact(local_stack, clean_client):
"""Tests that having an input of type `UnmaterializedArtifact` does not
materialize the artifact but instead returns the response model."""
artifact_response = save_artifact(
42, "main_answer", manual_save=False
).get_hydrated_version()

step = Step.parse_obj(
{
"spec": {
Expand All @@ -164,8 +167,6 @@ def test_loading_unmaterialized_input_artifact(
}
)
runner = StepRunner(step=step, stack=local_stack)
artifact_response = sample_artifact_version_model

artifact = runner._load_input_artifact(
artifact=artifact_response, data_type=UnmaterializedArtifact
)
Expand Down

0 comments on commit 097839e

Please sign in to comment.