From edab3f3c148d31e8dc7076c130f4ed4bfcebf6e4 Mon Sep 17 00:00:00 2001 From: Snehil Chatterjee Date: Tue, 8 Oct 2024 21:02:33 +0530 Subject: [PATCH 1/3] Categorical bug fix --- src/pytorch_tabular/categorical_encoders.py | 4 +++- src/pytorch_tabular/tabular_datamodule.py | 7 +++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/pytorch_tabular/categorical_encoders.py b/src/pytorch_tabular/categorical_encoders.py index 8e8006c6..62effb0a 100644 --- a/src/pytorch_tabular/categorical_encoders.py +++ b/src/pytorch_tabular/categorical_encoders.py @@ -62,6 +62,8 @@ def transform(self, X): not X[self.cols].isnull().any().any() ), "`handle_missing` = `error` and missing values found in columns to encode." X_encoded = X.copy(deep=True) + category_cols = X_encoded.select_dtypes(include='category').columns + X_encoded[category_cols] = X_encoded[category_cols].astype('object') for col, mapping in self._mapping.items(): X_encoded[col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"]) @@ -267,4 +269,4 @@ def save_as_object_file(self, path): def load_from_object_file(self, path): for k, v in pickle.load(open(path, "rb")).items(): - setattr(self, k, v) + setattr(self, k, v) \ No newline at end of file diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index 50849ae4..d6ecf34f 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -301,10 +301,14 @@ def _update_config(self, config) -> InferredConfig: else: raise ValueError(f"{config.task} is an unsupported task.") if self.train is not None: + category_cols = self.train[config.categorical_cols].select_dtypes(include='category').columns + self.train[category_cols] = self.train[category_cols].astype('object') categorical_cardinality = [ int(x) + 1 for x in list(self.train[config.categorical_cols].fillna("NA").nunique().values) ] else: + category_cols = self.train_dataset.data[config.categorical_cols].select_dtypes(include='category').columns + self.train_dataset.data[category_cols] = self.train_dataset.data[category_cols].astype('object') categorical_cardinality = [ int(x) + 1 for x in list(self.train_dataset.data[config.categorical_cols].nunique().values) ] @@ -805,6 +809,7 @@ def train_dataloader(self, batch_size: Optional[int] = None) -> DataLoader: num_workers=self.config.num_workers, sampler=self.train_sampler, pin_memory=self.config.pin_memory, + **self.config.dataloader_kwargs, ) def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader: @@ -823,6 +828,7 @@ def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader: shuffle=False, num_workers=self.config.num_workers, pin_memory=self.config.pin_memory, + **self.config.dataloader_kwargs, ) def _prepare_inference_data(self, df: DataFrame) -> DataFrame: @@ -865,6 +871,7 @@ def prepare_inference_dataloader( batch_size or self.batch_size, shuffle=False, num_workers=self.config.num_workers, + **self.config.dataloader_kwargs, ) def save_dataloader(self, path: Union[str, Path]) -> None: From f8f6c2c0e3e87013f2e8829cf07c782f78499d4f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Oct 2024 15:41:14 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_tabular/categorical_encoders.py | 6 +++--- src/pytorch_tabular/tabular_datamodule.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/pytorch_tabular/categorical_encoders.py b/src/pytorch_tabular/categorical_encoders.py index 62effb0a..35b771fe 100644 --- a/src/pytorch_tabular/categorical_encoders.py +++ b/src/pytorch_tabular/categorical_encoders.py @@ -62,8 +62,8 @@ def transform(self, X): not X[self.cols].isnull().any().any() ), "`handle_missing` = `error` and missing values found in columns to encode." X_encoded = X.copy(deep=True) - category_cols = X_encoded.select_dtypes(include='category').columns - X_encoded[category_cols] = X_encoded[category_cols].astype('object') + category_cols = X_encoded.select_dtypes(include="category").columns + X_encoded[category_cols] = X_encoded[category_cols].astype("object") for col, mapping in self._mapping.items(): X_encoded[col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"]) @@ -269,4 +269,4 @@ def save_as_object_file(self, path): def load_from_object_file(self, path): for k, v in pickle.load(open(path, "rb")).items(): - setattr(self, k, v) \ No newline at end of file + setattr(self, k, v) diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index d6ecf34f..81f13a69 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -301,14 +301,14 @@ def _update_config(self, config) -> InferredConfig: else: raise ValueError(f"{config.task} is an unsupported task.") if self.train is not None: - category_cols = self.train[config.categorical_cols].select_dtypes(include='category').columns - self.train[category_cols] = self.train[category_cols].astype('object') + category_cols = self.train[config.categorical_cols].select_dtypes(include="category").columns + self.train[category_cols] = self.train[category_cols].astype("object") categorical_cardinality = [ int(x) + 1 for x in list(self.train[config.categorical_cols].fillna("NA").nunique().values) ] else: - category_cols = self.train_dataset.data[config.categorical_cols].select_dtypes(include='category').columns - self.train_dataset.data[category_cols] = self.train_dataset.data[category_cols].astype('object') + category_cols = self.train_dataset.data[config.categorical_cols].select_dtypes(include="category").columns + self.train_dataset.data[category_cols] = self.train_dataset.data[category_cols].astype("object") categorical_cardinality = [ int(x) + 1 for x in list(self.train_dataset.data[config.categorical_cols].nunique().values) ] From 0610bd3cb38f3777a11e754558c8cca0ed7d7cfb Mon Sep 17 00:00:00 2001 From: Snehil Chatterjee Date: Tue, 8 Oct 2024 21:46:17 +0530 Subject: [PATCH 3/3] dataloader kwargs part removed --- src/pytorch_tabular/tabular_datamodule.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index 81f13a69..e59e85df 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -809,7 +809,6 @@ def train_dataloader(self, batch_size: Optional[int] = None) -> DataLoader: num_workers=self.config.num_workers, sampler=self.train_sampler, pin_memory=self.config.pin_memory, - **self.config.dataloader_kwargs, ) def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader: @@ -828,7 +827,6 @@ def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader: shuffle=False, num_workers=self.config.num_workers, pin_memory=self.config.pin_memory, - **self.config.dataloader_kwargs, ) def _prepare_inference_data(self, df: DataFrame) -> DataFrame: @@ -871,7 +869,6 @@ def prepare_inference_dataloader( batch_size or self.batch_size, shuffle=False, num_workers=self.config.num_workers, - **self.config.dataloader_kwargs, ) def save_dataloader(self, path: Union[str, Path]) -> None: