Skip to content

Commit

Permalink
Update MLFlow saving method
Browse files Browse the repository at this point in the history
Updating the MLFlow saving method including tracking uri.
  • Loading branch information
cargecla1 committed Apr 29, 2024
1 parent 65dbd48 commit 3593454
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions docs/userguide/torch_forecasting_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ mlflow.pytorch.autolog(log_every_n_epoch=1, log_every_n_step=None,
import mlflow.pytorch
from mlflow.client import MlflowClient

model_name = "Darts"
model_name = "darts-NBEATS"

with mlflow.start_run(nested=True) as run:

Expand Down Expand Up @@ -545,14 +545,13 @@ mlflow.pytorch.get_default_conda_env()
mlflow.pytorch.get_default_pip_requirements()

# Set tracking uri
mlflow.set_tracking_uri("sqlite:///mlruns.db")
model_uri = f"runs:/{run.info.run_id}/darts-NBEATS"

# Save Darts model (this need to be added via new cell)
mlflow.log_artifact("NBeatsModel.pickle")
# Save Darts model as an artifact
model_path = 'nbeats_air_passengers'
mlflow.sklearn.save_model(model, model_path)

# Registering model
model_name = "NBEATS"
model_uri = f"runs:/{run.info.run_id}/darts-NBEATS"
mlflow.register_model(model_uri=model_uri, name=model_name)
```

Expand Down

0 comments on commit 3593454

Please sign in to comment.