diff --git a/src/pytorch_tabular/categorical_encoders.py b/src/pytorch_tabular/categorical_encoders.py index 8e8006c6..35b771fe 100644 --- a/src/pytorch_tabular/categorical_encoders.py +++ b/src/pytorch_tabular/categorical_encoders.py @@ -62,6 +62,8 @@ def transform(self, X): not X[self.cols].isnull().any().any() ), "`handle_missing` = `error` and missing values found in columns to encode." X_encoded = X.copy(deep=True) + category_cols = X_encoded.select_dtypes(include="category").columns + X_encoded[category_cols] = X_encoded[category_cols].astype("object") for col, mapping in self._mapping.items(): X_encoded[col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"]) diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index 8cdcdaef..81f13a69 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -301,10 +301,14 @@ def _update_config(self, config) -> InferredConfig: else: raise ValueError(f"{config.task} is an unsupported task.") if self.train is not None: + category_cols = self.train[config.categorical_cols].select_dtypes(include="category").columns + self.train[category_cols] = self.train[category_cols].astype("object") categorical_cardinality = [ int(x) + 1 for x in list(self.train[config.categorical_cols].fillna("NA").nunique().values) ] else: + category_cols = self.train_dataset.data[config.categorical_cols].select_dtypes(include="category").columns + self.train_dataset.data[category_cols] = self.train_dataset.data[category_cols].astype("object") categorical_cardinality = [ int(x) + 1 for x in list(self.train_dataset.data[config.categorical_cols].nunique().values) ]