diff --git a/.gitignore b/.gitignore index d951f3fb9cbad..03f035db527fe 100644 --- a/.gitignore +++ b/.gitignore @@ -141,3 +141,6 @@ doc/source/savefig/ # Pyodide/WASM related files # ############################## /.pyodide-xbuildenv-* + +local.py +.venv/ diff --git a/doc/source/reference/extensions.rst b/doc/source/reference/extensions.rst index e412793a328a3..ff93a3fd25104 100644 --- a/doc/source/reference/extensions.rst +++ b/doc/source/reference/extensions.rst @@ -58,6 +58,7 @@ objects. api.extensions.ExtensionArray.isin api.extensions.ExtensionArray.isna api.extensions.ExtensionArray.ravel + api.extensions.ExtensionArray.map api.extensions.ExtensionArray.repeat api.extensions.ExtensionArray.searchsorted api.extensions.ExtensionArray.shift diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 653a900fbfe45..bb6a19536f427 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1612,15 +1612,20 @@ def to_numpy( result[~mask] = data[~mask]._pa_array.to_numpy() return result - def map(self, mapper, na_action: Literal["ignore"] | None = None): + def map(self, mapper, + na_action: Literal["ignore"] | None = None, + preserve_dtype: bool = False): if is_numeric_dtype(self.dtype): - return map_array(self.to_numpy(), mapper, na_action=na_action) + result = map_array(self.to_numpy(), mapper, na_action=na_action) + if preserve_dtype: + result = self._cast_pointwise_result(result) + return result else: # For "mM" cases, the super() method passes `self` without the # to_numpy call, which inside map_array casts to ndarray[object]. # Without the to_numpy() call, NA is preserved instead of changed # to None. - return super().map(mapper, na_action) + return super().map(mapper, na_action, preserve_dtype=preserve_dtype) @doc(ExtensionArray.duplicated) def duplicated( diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 07c297b2c15ff..c37ee6c98079c 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -2516,7 +2516,12 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): return arraylike.default_array_ufunc(self, ufunc, method, *inputs, **kwargs) - def map(self, mapper, na_action: Literal["ignore"] | None = None): + def map( + self, + mapper, + na_action: Literal["ignore"] | None = None, + preserve_dtype: bool = False, + ): """ Map values using an input mapping or function. @@ -2528,6 +2533,12 @@ def map(self, mapper, na_action: Literal["ignore"] | None = None): If 'ignore', propagate NA values, without passing them to the mapping correspondence. If 'ignore' is not supported, a ``NotImplementedError`` should be raised. + preserve_dtype : bool, default False + If True, attempt to cast the elementwise result back to the + original ExtensionArray type (and dtype) when possible. This is + primarily intended for identity or dtype-preserving mappings. + If False, the result of the mapping is returned as produced by + the underlying implementation (typically a NumPy ndarray). Returns ------- @@ -2536,7 +2547,10 @@ def map(self, mapper, na_action: Literal["ignore"] | None = None): If the function returns a tuple with more than one element a MultiIndex will be returned. """ - return map_array(self, mapper, na_action=na_action) + results = map_array(self, mapper, na_action=na_action) + if preserve_dtype: + results = self._cast_pointwise_result(results) + return results # ------------------------------------------------------------------------ # GroupBy Methods diff --git a/pandas/core/arrays/masked.py b/pandas/core/arrays/masked.py index b00e362e1309a..9d0ee3f833218 100644 --- a/pandas/core/arrays/masked.py +++ b/pandas/core/arrays/masked.py @@ -1395,8 +1395,17 @@ def max(self, *, skipna: bool = True, axis: AxisInt | None = 0, **kwargs): ) return self._wrap_reduction_result("max", result, skipna=skipna, axis=axis) - def map(self, mapper, na_action: Literal["ignore"] | None = None): - return map_array(self.to_numpy(), mapper, na_action=na_action) + def map( + self, + mapper, + na_action: Literal["ignore"] | None = None, + preserve_dtype: bool = False, + ): + """See ExtensionArray.map.""" + result = map_array(self.to_numpy(), mapper, na_action=na_action) + if preserve_dtype: + result = self._cast_pointwise_result(result) + return result @overload def any(