diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py index 55aa500b..999c2c4a 100644 --- a/src/pytorch_tabular/config/config.py +++ b/src/pytorch_tabular/config/config.py @@ -677,6 +677,9 @@ class OptimizerConfig: lr_scheduler_monitor_metric (Optional[str]): Used with ReduceLROnPlateau, where the plateau is decided based on this metric + lr_scheduler_interval (Optional[str]): Interval at which to step the LR Scheduler, one of "epoch" + or "step". Defaults to `epoch`. + """ optimizer: str = field( @@ -709,6 +712,11 @@ class OptimizerConfig: metadata={"help": "Used with ReduceLROnPlateau, where the plateau is decided based on this metric"}, ) + lr_scheduler_interval: Optional[str] = field( + default="epoch", + metadata={"help": "Interval at which to step the LR Scheduler, one of `epoch` or `step`. Defaults to `epoch`."}, + ) + @staticmethod def read_from_yaml(filename: str = "config/optimizer_config.yml"): config = _read_yaml(filename) diff --git a/src/pytorch_tabular/models/base_model.py b/src/pytorch_tabular/models/base_model.py index 824eb710..b07d141b 100644 --- a/src/pytorch_tabular/models/base_model.py +++ b/src/pytorch_tabular/models/base_model.py @@ -588,8 +588,11 @@ def configure_optimizers(self): } return { "optimizer": opt, - "lr_scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params), - "monitor": self.hparams.lr_scheduler_monitor_metric, + "lr_scheduler": { + "scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params), + "monitor": self.hparams.lr_scheduler_monitor_metric, + "interval": self.hparams.lr_scheduler_interval, + }, } else: return opt diff --git a/src/pytorch_tabular/ssl_models/base_model.py b/src/pytorch_tabular/ssl_models/base_model.py index 03b31313..6b9150a7 100644 --- a/src/pytorch_tabular/ssl_models/base_model.py +++ b/src/pytorch_tabular/ssl_models/base_model.py @@ -181,7 +181,7 @@ def test_step(self, batch, batch_idx): def on_validation_epoch_end(self) -> None: if hasattr(self.hparams, "log_logits") and self.hparams.log_logits: warnings.warn( - "Logging Logits is disabled for SSL tasks. Set `log_logits` to False" " to turn off this warning" + "Logging Logits is disabled for SSL tasks. Set `log_logits` to False to turn off this warning" ) super().on_validation_epoch_end() @@ -219,8 +219,11 @@ def configure_optimizers(self): } return { "optimizer": opt, - "lr_scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params), - "monitor": self.hparams.lr_scheduler_monitor_metric, + "lr_scheduler": { + "scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params), + "monitor": self.hparams.lr_scheduler_monitor_metric, + "interval": self.hparams.lr_scheduler_interval, + }, } else: return opt