From 6cc6da16956f80067da5f635382800f5bd90612a Mon Sep 17 00:00:00 2001 From: Soren Macbeth Date: Wed, 15 Jan 2025 01:22:55 -0800 Subject: [PATCH 1/3] Make tensor dtypes `np.float32` for MPS devices numpy defaults to numpy.float64 when they should be numpy.float32 This caused training to fail on MPS devices but it works on my M1 with this. --- .../ssl_models/common/noise_generators.py | 2 +- src/pytorch_tabular/tabular_datamodule.py | 12 +++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/pytorch_tabular/ssl_models/common/noise_generators.py b/src/pytorch_tabular/ssl_models/common/noise_generators.py index 2da372b4..d9acfeca 100644 --- a/src/pytorch_tabular/ssl_models/common/noise_generators.py +++ b/src/pytorch_tabular/ssl_models/common/noise_generators.py @@ -18,7 +18,7 @@ class SwapNoiseCorrupter(nn.Module): def __init__(self, probas): super().__init__() - self.probas = torch.from_numpy(np.array(probas)) + self.probas = torch.from_numpy(np.array(probas, dtype=np.float32)) def forward(self, x): should_swap = torch.bernoulli(self.probas.to(x.device) * torch.ones(x.shape).to(x.device)) diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index 3d09bb2e..71fe5635 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -67,7 +67,7 @@ def __init__( if isinstance(target, str): self.y = self.y.reshape(-1, 1) # .astype(np.int64) else: - self.y = np.zeros((self.n, 1)) # .astype(np.int64) + self.y = np.zeros((self.n, 1), dtype=np.float32) # .astype(np.int64) if task == "classification": self.y = self.y.astype(np.int64) @@ -502,7 +502,7 @@ def _cache_dataset(self): def split_train_val(self, train): logger.debug( - "No validation data provided." f" Using {self.config.validation_split*100}% of train data as validation" + f"No validation data provided. Using {self.config.validation_split * 100}% of train data as validation" ) val_idx = train.sample( int(self.config.validation_split * len(train)), @@ -753,18 +753,16 @@ def _load_dataset_from_cache(self, tag: str = "train"): try: dataset = getattr(self, f"_{tag}_dataset") except AttributeError: - raise AttributeError( - f"{tag}_dataset not found in memory. Please provide the data for" f" {tag} dataloader" - ) + raise AttributeError(f"{tag}_dataset not found in memory. Please provide the data for {tag} dataloader") elif self.cache_mode is self.CACHE_MODES.DISK: try: dataset = torch.load(self.cache_dir / f"{tag}_dataset") except FileNotFoundError: raise FileNotFoundError( - f"{tag}_dataset not found in {self.cache_dir}. Please provide the" f" data for {tag} dataloader" + f"{tag}_dataset not found in {self.cache_dir}. Please provide the data for {tag} dataloader" ) elif self.cache_mode is self.CACHE_MODES.INFERENCE: - raise RuntimeError("Cannot load dataset in inference mode. Use" " `prepare_inference_dataloader` instead") + raise RuntimeError("Cannot load dataset in inference mode. Use `prepare_inference_dataloader` instead") else: raise ValueError(f"{self.cache_mode} is not a valid cache mode") return dataset From a196f80735f86304b0356375f7f7b16424351d21 Mon Sep 17 00:00:00 2001 From: Soren Macbeth Date: Wed, 19 Feb 2025 20:19:15 -0800 Subject: [PATCH 2/3] Add lr scheduler interval config --- src/pytorch_tabular/config/config.py | 20 ++++++++++++++------ src/pytorch_tabular/models/base_model.py | 7 +++++-- src/pytorch_tabular/ssl_models/base_model.py | 15 +++++++++------ 3 files changed, 28 insertions(+), 14 deletions(-) diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py index 55aa500b..ed8235b2 100644 --- a/src/pytorch_tabular/config/config.py +++ b/src/pytorch_tabular/config/config.py @@ -192,9 +192,9 @@ class DataConfig: ) def __post_init__(self): - assert ( - len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0 - ), "There should be at-least one feature defined in categorical, continuous, or date columns" + assert len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0, ( + "There should be at-least one feature defined in categorical, continuous, or date columns" + ) _validate_choices(self) if os.name == "nt" and self.num_workers != 0: print("Windows does not support num_workers > 0. Setting num_workers to 0") @@ -255,9 +255,9 @@ class InferredConfig: def __post_init__(self): if self.embedding_dims is not None: - assert all( - (isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims - ), "embedding_dims must be a list of tuples (cardinality, embedding_dim)" + assert all((isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims), ( + "embedding_dims must be a list of tuples (cardinality, embedding_dim)" + ) self.embedded_cat_dim = sum([t[1] for t in self.embedding_dims]) else: self.embedded_cat_dim = 0 @@ -677,6 +677,9 @@ class OptimizerConfig: lr_scheduler_monitor_metric (Optional[str]): Used with ReduceLROnPlateau, where the plateau is decided based on this metric + lr_scheduler_interval (Optional[str]): Interval at which to step the LR Scheduler, one of "epoch" + or "step". Defaults to `epoch`. + """ optimizer: str = field( @@ -709,6 +712,11 @@ class OptimizerConfig: metadata={"help": "Used with ReduceLROnPlateau, where the plateau is decided based on this metric"}, ) + lr_scheduler_interval: Optional[str] = field( + default="epoch", + metadata={"help": "Interval at which to step the LR Scheduler, one of `epoch` or `step`. Defaults to `epoch`."}, + ) + @staticmethod def read_from_yaml(filename: str = "config/optimizer_config.yml"): config = _read_yaml(filename) diff --git a/src/pytorch_tabular/models/base_model.py b/src/pytorch_tabular/models/base_model.py index 824eb710..b07d141b 100644 --- a/src/pytorch_tabular/models/base_model.py +++ b/src/pytorch_tabular/models/base_model.py @@ -588,8 +588,11 @@ def configure_optimizers(self): } return { "optimizer": opt, - "lr_scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params), - "monitor": self.hparams.lr_scheduler_monitor_metric, + "lr_scheduler": { + "scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params), + "monitor": self.hparams.lr_scheduler_monitor_metric, + "interval": self.hparams.lr_scheduler_interval, + }, } else: return opt diff --git a/src/pytorch_tabular/ssl_models/base_model.py b/src/pytorch_tabular/ssl_models/base_model.py index 03b31313..75773e75 100644 --- a/src/pytorch_tabular/ssl_models/base_model.py +++ b/src/pytorch_tabular/ssl_models/base_model.py @@ -85,9 +85,9 @@ def __init__( self._setup_metrics() def _setup_encoder_decoder(self, encoder, encoder_config, decoder, decoder_config, inferred_config): - assert (encoder is not None) or ( - encoder_config is not None - ), "Either encoder or encoder_config must be provided" + assert (encoder is not None) or (encoder_config is not None), ( + "Either encoder or encoder_config must be provided" + ) # assert (decoder is not None) or (decoder_config is not None), # "Either decoder or decoder_config must be provided" if encoder is not None: @@ -181,7 +181,7 @@ def test_step(self, batch, batch_idx): def on_validation_epoch_end(self) -> None: if hasattr(self.hparams, "log_logits") and self.hparams.log_logits: warnings.warn( - "Logging Logits is disabled for SSL tasks. Set `log_logits` to False" " to turn off this warning" + "Logging Logits is disabled for SSL tasks. Set `log_logits` to False to turn off this warning" ) super().on_validation_epoch_end() @@ -219,8 +219,11 @@ def configure_optimizers(self): } return { "optimizer": opt, - "lr_scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params), - "monitor": self.hparams.lr_scheduler_monitor_metric, + "lr_scheduler": { + "scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params), + "monitor": self.hparams.lr_scheduler_monitor_metric, + "interval": self.hparams.lr_scheduler_interval, + }, } else: return opt From bc50a909adb4736b24bc9425db6e3bad58b15f27 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Feb 2025 05:34:29 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_tabular/config/config.py | 12 ++++++------ src/pytorch_tabular/ssl_models/base_model.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py index ed8235b2..999c2c4a 100644 --- a/src/pytorch_tabular/config/config.py +++ b/src/pytorch_tabular/config/config.py @@ -192,9 +192,9 @@ class DataConfig: ) def __post_init__(self): - assert len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0, ( - "There should be at-least one feature defined in categorical, continuous, or date columns" - ) + assert ( + len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0 + ), "There should be at-least one feature defined in categorical, continuous, or date columns" _validate_choices(self) if os.name == "nt" and self.num_workers != 0: print("Windows does not support num_workers > 0. Setting num_workers to 0") @@ -255,9 +255,9 @@ class InferredConfig: def __post_init__(self): if self.embedding_dims is not None: - assert all((isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims), ( - "embedding_dims must be a list of tuples (cardinality, embedding_dim)" - ) + assert all( + (isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims + ), "embedding_dims must be a list of tuples (cardinality, embedding_dim)" self.embedded_cat_dim = sum([t[1] for t in self.embedding_dims]) else: self.embedded_cat_dim = 0 diff --git a/src/pytorch_tabular/ssl_models/base_model.py b/src/pytorch_tabular/ssl_models/base_model.py index 75773e75..6b9150a7 100644 --- a/src/pytorch_tabular/ssl_models/base_model.py +++ b/src/pytorch_tabular/ssl_models/base_model.py @@ -85,9 +85,9 @@ def __init__( self._setup_metrics() def _setup_encoder_decoder(self, encoder, encoder_config, decoder, decoder_config, inferred_config): - assert (encoder is not None) or (encoder_config is not None), ( - "Either encoder or encoder_config must be provided" - ) + assert (encoder is not None) or ( + encoder_config is not None + ), "Either encoder or encoder_config must be provided" # assert (decoder is not None) or (decoder_config is not None), # "Either decoder or decoder_config must be provided" if encoder is not None: