Skip to content

Commit

Permalink
Making log more concise
Browse files Browse the repository at this point in the history
Making the log parameters method more concise as suggested.
  • Loading branch information
cargecla1 committed Apr 6, 2024
1 parent ce635bc commit f798244
Showing 1 changed file with 15 additions and 28 deletions.
43 changes: 15 additions & 28 deletions docs/userguide/torch_forecasting_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -512,33 +512,20 @@ with mlflow.start_run(nested=True) as run:
# dataset is used for model training
mlflow.log_input(dataset, context="training")

mlflow.log_param("model_type", "Darts_Pytorch_model")
mlflow.log_param("input_chunk_length", 24)
mlflow.log_param("output_chunk_length", 12)
mlflow.log_param("n_epochs", 500)
mlflow.log_param("model_name", 'NBEATS_MLflow')
mlflow.log_param("log_tensorboard", True)
mlflow.log_param("torch_metrics", "torchmetrics.regression.MeanAbsolutePercentageError()")
mlflow.log_param("nr_epochs_val_period", 1)
mlflow.log_param("pl_trainer_kwargs", "{callbacks: [loss_logger]}")


from pytorch_lightning.callbacks import Callback

class LossLogger(Callback):
def __init__(self):
self.train_loss = []
self.val_loss = []

# will automatically be called at the end of each epoch
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.train_loss.append(float(trainer.callback_metrics["train_loss"]))

def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.val_loss.append(float(trainer.callback_metrics["val_loss"]))


loss_logger = LossLogger()
# Define model hyperparameters to log
params = {
"model_type": "Darts_Pytorch_model",
"input_chunk_length": 24,
"output_chunk_length": 12,
"n_epochs": 500,
"model_name": "NBEATS_MLflow",
"log_tensorboard": True,
"torch_metrics": "torchmetrics.regression.MeanAbsolutePercentageError()",
"nr_epochs_val_period": 1,
}

# Log hyperparameters
mlflow.log_params(params)

# create the model
model = NBEATSModel(
Expand All @@ -549,7 +536,7 @@ with mlflow.start_run(nested=True) as run:
log_tensorboard=True,
torch_metrics=torch_metrics,
nr_epochs_val_period=1,
pl_trainer_kwargs={"callbacks": [loss_logger]})
)

# use validation dataset
model.fit(
Expand Down

0 comments on commit f798244

Please sign in to comment.