From 3593454143f61d42b132c1413a40b34fd30d2d09 Mon Sep 17 00:00:00 2001 From: cargecla1 <138342606+cargecla1@users.noreply.github.com> Date: Mon, 29 Apr 2024 14:48:56 +1000 Subject: [PATCH] Update MLFlow saving method Updating the MLFlow saving method including tracking uri. --- docs/userguide/torch_forecasting_models.md | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/docs/userguide/torch_forecasting_models.md b/docs/userguide/torch_forecasting_models.md index b777aa1d03..2988295df6 100644 --- a/docs/userguide/torch_forecasting_models.md +++ b/docs/userguide/torch_forecasting_models.md @@ -500,7 +500,7 @@ mlflow.pytorch.autolog(log_every_n_epoch=1, log_every_n_step=None, import mlflow.pytorch from mlflow.client import MlflowClient -model_name = "Darts" +model_name = "darts-NBEATS" with mlflow.start_run(nested=True) as run: @@ -545,14 +545,13 @@ mlflow.pytorch.get_default_conda_env() mlflow.pytorch.get_default_pip_requirements() # Set tracking uri -mlflow.set_tracking_uri("sqlite:///mlruns.db") +model_uri = f"runs:/{run.info.run_id}/darts-NBEATS" -# Save Darts model (this need to be added via new cell) -mlflow.log_artifact("NBeatsModel.pickle") +# Save Darts model as an artifact +model_path = 'nbeats_air_passengers' +mlflow.sklearn.save_model(model, model_path) # Registering model -model_name = "NBEATS" -model_uri = f"runs:/{run.info.run_id}/darts-NBEATS" mlflow.register_model(model_uri=model_uri, name=model_name) ```