Skip to content

Commit

Permalink
Resolving torcmetrics
Browse files Browse the repository at this point in the history
Resolving torchmetrics issue.
  • Loading branch information
cargecla1 committed Apr 29, 2024
1 parent f1a923b commit 65dbd48
Showing 1 changed file with 1 addition and 6 deletions.
7 changes: 1 addition & 6 deletions docs/userguide/torch_forecasting_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -485,9 +485,6 @@ transformer = Scaler()
train = transformer.fit_transform(train)
val = transformer.transform(val)

# any TorchMetric or val_loss can be used as the monitor
torch_metrics = torchmetrics.regression.MeanAbsolutePercentageError()

# MLflow setup
## Run this command with environment activated: mlflow ui --port xxxx (e.g. 5000, 5001, 5002)
# Copy and paste url from command line to web browser
Expand Down Expand Up @@ -515,13 +512,12 @@ with mlflow.start_run(nested=True) as run:

# Define model hyperparameters to log
params = {
"model_type": "Darts_Pytorch_model",
"input_chunk_length": 24,
"output_chunk_length": 12,
"n_epochs": 500,
"model_name": "NBEATS_MLflow",
"log_tensorboard": True,
"torch_metrics": "torchmetrics.regression.MeanAbsolutePercentageError()",
"torch_metrics": MeanAbsolutePercentageError(),
"nr_epochs_val_period": 1,
}

Expand All @@ -531,7 +527,6 @@ with mlflow.start_run(nested=True) as run:
# create the model
model = NBEATSModel(
**params,
torch_metrics=torch_metrics,
)

# use validation dataset
Expand Down

0 comments on commit 65dbd48

Please sign in to comment.