Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename model version to a model #13

Merged
merged 2 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,4 @@ dmypy.json

*.zen
.vscode
.local
2 changes: 1 addition & 1 deletion template/configs/inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ settings:
- pyarrow

# configuration of the Model Control Plane
model_version:
model:
name: "breast_cancer_classifier"
version: "production"
license: Apache 2.0
Expand Down
2 changes: 1 addition & 1 deletion template/configs/training_rf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ settings:
- pyarrow

# configuration of the Model Control Plane
model_version:
model:
name: breast_cancer_classifier
version: rf
license: Apache 2.0
Expand Down
2 changes: 1 addition & 1 deletion template/configs/training_sgd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ settings:
- pyarrow

# configuration of the Model Control Plane
model_version:
model:
name: breast_cancer_classifier
version: sgd
license: Apache 2.0
Expand Down
4 changes: 2 additions & 2 deletions template/pipelines/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def inference(random_state: str, target: str):
target: Name of target column in dataset.
"""
# Get the production model artifact
model = get_pipeline_context().model_version.get_artifact("sklearn_classifier")
model = get_pipeline_context().model.get_artifact("sklearn_classifier")

# Get the preprocess pipeline artifact associated with this version
preprocess_pipeline = get_pipeline_context().model_version.get_artifact(
preprocess_pipeline = get_pipeline_context().model.get_artifact(
"preprocess_pipeline"
)

Expand Down
20 changes: 10 additions & 10 deletions template/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@
"\n",
"import random\n",
"import pandas as pd\n",
"from zenml import step, ExternalArtifact, pipeline, ModelVersion, get_step_context\n",
"from zenml import step, ExternalArtifact, pipeline, Model, get_step_context\n",
"from zenml.client import Client\n",
"from zenml.logger import get_logger\n",
"from uuid import UUID\n",
Expand Down Expand Up @@ -729,7 +729,7 @@
"all the models produced as you develop your experiments and use-cases. Luckily, ZenML offers a *Model Control Plane*,\n",
"which is a central register of all your ML models.\n",
"\n",
"You can easily create a ZenML `Model` and associate it with your pipelines using the `ModelVersion` object:"
"You can easily create a ZenML Model and associate it with your pipelines using the `Model` object:"
]
},
{
Expand All @@ -742,7 +742,7 @@
"pipeline_settings = {}\n",
"\n",
"# Lets add some metadata to the model to make it identifiable\n",
"pipeline_settings[\"model_version\"] = ModelVersion(\n",
"pipeline_settings[\"model\"] = Model(\n",
" name=\"breast_cancer_classifier\",\n",
" license=\"Apache 2.0\",\n",
" description=\"A breast cancer classifier\",\n",
Expand All @@ -758,7 +758,7 @@
"outputs": [],
"source": [
"# Let's train the SGD model and set the version name to \"sgd\"\n",
"pipeline_settings[\"model_version\"].version = \"sgd\"\n",
"pipeline_settings[\"model\"].version = \"sgd\"\n",
"\n",
"# the `with_options` method allows us to pass in pipeline settings\n",
"# and returns a configured pipeline\n",
Expand All @@ -780,7 +780,7 @@
"outputs": [],
"source": [
"# Let's train the RF model and set the version name to \"rf\"\n",
"pipeline_settings[\"model_version\"].version = \"rf\"\n",
"pipeline_settings[\"model\"].version = \"rf\"\n",
"\n",
"# the `with_options` method allows us to pass in pipeline settings\n",
"# and returns a configured pipeline\n",
Expand Down Expand Up @@ -939,11 +939,11 @@
"@step\n",
"def inference_predict(dataset_inf: pd.DataFrame) -> Annotated[pd.Series, \"predictions\"]:\n",
" \"\"\"Predictions step\"\"\"\n",
" # Get the model_version\n",
" model_version = get_step_context().model_version\n",
" # Get the model\n",
" model = get_step_context().model\n",
"\n",
" # run prediction from memory\n",
" predictor = model_version.load_artifact(\"sklearn_classifier\")\n",
" predictor = model.load_artifact(\"sklearn_classifier\")\n",
" predictions = predictor.predict(dataset_inf)\n",
"\n",
" predictions = pd.Series(predictions, name=\"predicted\")\n",
Expand Down Expand Up @@ -994,7 +994,7 @@
"id": "c7afe7be",
"metadata": {},
"source": [
"The way to load the right model is to pass in the `production` stage into the `ModelVersion` config this time.\n",
"The way to load the right model is to pass in the `production` stage into the `Model` config this time.\n",
"This will ensure to always load the production model, decoupled from all other pipelines:"
]
},
Expand All @@ -1008,7 +1008,7 @@
"pipeline_settings = {\"enable_cache\": False}\n",
"\n",
"# Lets add some metadata to the model to make it identifiable\n",
"pipeline_settings[\"model_version\"] = ModelVersion(\n",
"pipeline_settings[\"model\"] = Model(\n",
" name=\"breast_cancer_classifier\",\n",
" version=\"production\", # We can pass in the stage name here!\n",
" license=\"Apache 2.0\",\n",
Expand Down
2 changes: 1 addition & 1 deletion template/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def main(
with open(pipeline_args["config_path"], "r") as f:
config = yaml.load(f, Loader=yaml.SafeLoader)
zenml_model = client.get_model_version(
config["model_version"]["name"], config["model_version"]["version"]
config["model"]["name"], config["model"]["version"]
)
preprocess_pipeline_artifact = zenml_model.get_artifact("preprocess_pipeline")

Expand Down
12 changes: 6 additions & 6 deletions template/steps/model_promoter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,26 @@ def model_promoter(accuracy: float, stage: str = "production") -> bool:
is_promoted = True

# Get the model in the current context
current_model_version = get_step_context().model_version
current_model = get_step_context().model

# Get the model that is in the production stage
client = Client()
try:
stage_model_version = client.get_model_version(
current_model_version.name, stage
stage_model = client.get_model_version(
current_model.name, stage
)
# We compare their metrics
prod_accuracy = (
stage_model_version.get_artifact("sklearn_classifier")
stage_model.get_artifact("sklearn_classifier")
.run_metadata["test_accuracy"]
.value
)
if float(accuracy) > float(prod_accuracy):
# If current model has better metrics, we promote it
is_promoted = True
current_model_version.set_stage(stage, force=True)
current_model.set_stage(stage, force=True)
except KeyError:
# If no such model exists, current one is promoted
is_promoted = True
current_model_version.set_stage(stage, force=True)
current_model.set_stage(stage, force=True)
return is_promoted
Loading