Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/pytorch_tabular/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,9 @@ class OptimizerConfig:
lr_scheduler_monitor_metric (Optional[str]): Used with ReduceLROnPlateau, where the plateau is
decided based on this metric

lr_scheduler_interval (Optional[str]): Interval at which to step the LR Scheduler, one of "epoch"
or "step". Defaults to `epoch`.

"""

optimizer: str = field(
Expand Down Expand Up @@ -709,6 +712,11 @@ class OptimizerConfig:
metadata={"help": "Used with ReduceLROnPlateau, where the plateau is decided based on this metric"},
)

lr_scheduler_interval: Optional[str] = field(
default="epoch",
metadata={"help": "Interval at which to step the LR Scheduler, one of `epoch` or `step`. Defaults to `epoch`."},
)

@staticmethod
def read_from_yaml(filename: str = "config/optimizer_config.yml"):
config = _read_yaml(filename)
Expand Down
7 changes: 5 additions & 2 deletions src/pytorch_tabular/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,8 +588,11 @@ def configure_optimizers(self):
}
return {
"optimizer": opt,
"lr_scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params),
"monitor": self.hparams.lr_scheduler_monitor_metric,
"lr_scheduler": {
"scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params),
"monitor": self.hparams.lr_scheduler_monitor_metric,
"interval": self.hparams.lr_scheduler_interval,
},
}
else:
return opt
Expand Down
9 changes: 6 additions & 3 deletions src/pytorch_tabular/ssl_models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_step(self, batch, batch_idx):
def on_validation_epoch_end(self) -> None:
if hasattr(self.hparams, "log_logits") and self.hparams.log_logits:
warnings.warn(
"Logging Logits is disabled for SSL tasks. Set `log_logits` to False" " to turn off this warning"
"Logging Logits is disabled for SSL tasks. Set `log_logits` to False to turn off this warning"
)
super().on_validation_epoch_end()

Expand Down Expand Up @@ -219,8 +219,11 @@ def configure_optimizers(self):
}
return {
"optimizer": opt,
"lr_scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params),
"monitor": self.hparams.lr_scheduler_monitor_metric,
"lr_scheduler": {
"scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params),
"monitor": self.hparams.lr_scheduler_monitor_metric,
"interval": self.hparams.lr_scheduler_interval,
},
}
else:
return opt
Expand Down
Loading