-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
Open
Description
Describe the bug
When using SimpleImputer from sklearn with pyarrow string types, the imputer fails with an error. This issue occurs when attempting to impute missing values in a DataFrame containing pyarrow string columns.
Steps/Code to Reproduce
import pandas as pd
import numpy as np
from sklearn.impute import SimpleImputer
# Create a DataFrame with pyarrow string types
data = {
'mpg': [18, 25, 120, 120],
'make': ['Ford', 'Chevy', np.nan, 'Tesla']
}
df = (pd.DataFrame(data)
.astype({'make':'string[pyarrow]'})
)
# Initialize SimpleImputer
imputer = SimpleImputer(strategy='most_frequent')
# Attempt to fit and transform the DataFrame
imputer.fit_transform(df[["make"]])
Expected Results
The SimpleImputer should handle pyarrow string types and impute the missing values without raising an error.
Actual Results
AttributeError Traceback (most recent call last)
Cell In[158], line 18
15 imputer = SimpleImputer(strategy='most_frequent')
17 # Attempt to fit and transform the DataFrame
---> 18 imputer.fit_transform(df[[\"make\"]])
File ~/.envs/menv/lib/python3.10/site-packages/sklearn/utils/_set_output.py:295, in _wrap_method_output.<locals>.wrapped(self, X, *args, **kwargs)
293 @wraps(f)
294 def wrapped(self, X, *args, **kwargs):
--> 295 data_to_wrap = f(self, X, *args, **kwargs)
296 if isinstance(data_to_wrap, tuple):
297 # only wrap the first output for cross decomposition
298 return_tuple = (
299 _wrap_data_with_container(method, data_to_wrap[0], X, self),
300 *data_to_wrap[1:],
301 )
File ~/.envs/menv/lib/python3.10/site-packages/sklearn/base.py:1098, in TransformerMixin.fit_transform(self, X, y, **fit_params)
1083 warnings.warn(
1084 (
1085 f\"This object ({self.__class__.__name__}) has a `transform`\"
(...)
1093 UserWarning,
1094 )
1096 if y is None:
1097 # fit method of arity 1 (unsupervised transformation)
-> 1098 return self.fit(X, **fit_params).transform(X)
1099 else:
1100 # fit method of arity 2 (supervised transformation)
1101 return self.fit(X, y, **fit_params).transform(X)
File ~/.envs/menv/lib/python3.10/site-packages/sklearn/base.py:1474, in _fit_context.<locals>.decorator.<locals>.wrapper(estimator, *args, **kwargs)
1467 estimator._validate_params()
1469 with config_context(
1470 skip_parameter_validation=(
1471 prefer_skip_nested_validation or global_skip_validation
1472 )
1473 ):
-> 1474 return fit_method(estimator, *args, **kwargs)
File ~/.envs/menv/lib/python3.10/site-packages/sklearn/impute/_base.py:427, in SimpleImputer.fit(self, X, y)
423 self.statistics_ = self._sparse_fit(
424 X, self.strategy, self.missing_values, fill_value
425 )
426 else:
--> 427 self.statistics_ = self._dense_fit(
428 X, self.strategy, self.missing_values, fill_value
429 )
431 return self
File ~/.envs/menv/lib/python3.10/site-packages/sklearn/impute/_base.py:510, in SimpleImputer._dense_fit(self, X, strategy, missing_values, fill_value)
503 elif strategy == \"most_frequent\":
504 # Avoid use of scipy.stats.mstats.mode due to the required
505 # additional overhead and slow benchmarking performance.
506 # See Issue 14325 and PR 14399 for full discussion.
507
508 # To be able access the elements by columns
509 X = X.transpose()
--> 510 mask = missing_mask.transpose()
512 if X.dtype.kind == \"O\":
513 most_frequent = np.empty(X.shape[0], dtype=object)
AttributeError: 'bool' object has no attribute 'transpose'"
Versions
System:
python: 3.10.14 (main, Mar 19 2024, 21:46:16) [Clang 15.0.0 (clang-1500.3.9.4)]
executable: /Users/matt/.envs/menv/bin/python
machine: macOS-14.2.1-arm64-arm-64bit
Python dependencies:
sklearn: 1.4.1.post1
pip: 24.0
setuptools: 67.6.1
numpy: 1.23.5
scipy: 1.10.0
Cython: 0.29.37
pandas: 2.2.2
matplotlib: 3.6.2
joblib: 1.2.0
threadpoolctl: 3.1.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
prefix: libopenblas
filepath: /Users/matt/.envs/menv/lib/python3.10/site-packages/numpy/.dylibs/libopenblas64_.0.dylib
version: 0.3.20
threading_layer: pthreads
architecture: armv8
num_threads: 8
user_api: openmp
internal_api: openmp
prefix: libomp
filepath: /Users/matt/.envs/menv/lib/python3.10/site-packages/sklearn/.dylibs/libomp.dylib
version: None
num_threads: 8
user_api: blas
internal_api: openblas
prefix: libopenblas
filepath: /Users/matt/.envs/menv/lib/python3.10/site-packages/scipy/.dylibs/libopenblas.0.dylib
version: 0.3.18
threading_layer: pthreads
architecture: armv8
num_threads: 8