diff --git a/docs/userguide/torch_forecasting_models.md b/docs/userguide/torch_forecasting_models.md index b777aa1d03..2988295df6 100644 --- a/docs/userguide/torch_forecasting_models.md +++ b/docs/userguide/torch_forecasting_models.md @@ -500,7 +500,7 @@ mlflow.pytorch.autolog(log_every_n_epoch=1, log_every_n_step=None, import mlflow.pytorch from mlflow.client import MlflowClient -model_name = "Darts" +model_name = "darts-NBEATS" with mlflow.start_run(nested=True) as run: @@ -545,14 +545,13 @@ mlflow.pytorch.get_default_conda_env() mlflow.pytorch.get_default_pip_requirements() # Set tracking uri -mlflow.set_tracking_uri("sqlite:///mlruns.db") +model_uri = f"runs:/{run.info.run_id}/darts-NBEATS" -# Save Darts model (this need to be added via new cell) -mlflow.log_artifact("NBeatsModel.pickle") +# Save Darts model as an artifact +model_path = 'nbeats_air_passengers' +mlflow.sklearn.save_model(model, model_path) # Registering model -model_name = "NBEATS" -model_uri = f"runs:/{run.info.run_id}/darts-NBEATS" mlflow.register_model(model_uri=model_uri, name=model_name) ```