diff --git a/doc/source/whatsnew/v0.21.0.txt b/doc/source/whatsnew/v0.21.0.txt index a5d4259480ba8..762107a261090 100644 --- a/doc/source/whatsnew/v0.21.0.txt +++ b/doc/source/whatsnew/v0.21.0.txt @@ -175,7 +175,7 @@ Groupby/Resample/Rolling - Bug in ``DataFrame.resample(...).size()`` where an empty ``DataFrame`` did not return a ``Series`` (:issue:`14962`) - Bug in :func:`infer_freq` causing indices with 2-day gaps during the working week to be wrongly inferred as business daily (:issue:`16624`) - Bug in ``.rolling(...).quantile()`` which incorrectly used different defaults than :func:`Series.quantile()` and :func:`DataFrame.quantile()` (:issue:`9413`, :issue:`16211`) - +- Bug in ``groupby.transform()`` that would coerce boolean dtypes back to float (:issue:`16875`) Sparse ^^^^^^ diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 16b0a5c8a74ca..6532e17695c86 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -110,9 +110,7 @@ def trans(x): # noqa np.prod(result.shape)): return result - if issubclass(dtype.type, np.floating): - return result.astype(dtype) - elif is_bool_dtype(dtype) or is_integer_dtype(dtype): + if is_bool_dtype(dtype) or is_integer_dtype(dtype): # if we don't have any elements, just astype it if not np.prod(result.shape): @@ -144,6 +142,9 @@ def trans(x): # noqa # hit here if (new_result == result).all(): return new_result + elif (issubclass(dtype.type, np.floating) and + not is_bool_dtype(result.dtype)): + return result.astype(dtype) # a datetimelike # GH12821, iNaT is casted to float diff --git a/pandas/tests/dtypes/test_cast.py b/pandas/tests/dtypes/test_cast.py index 767e99d98cf29..6e07487b3e04f 100644 --- a/pandas/tests/dtypes/test_cast.py +++ b/pandas/tests/dtypes/test_cast.py @@ -9,7 +9,7 @@ from datetime import datetime, timedelta, date import numpy as np -from pandas import Timedelta, Timestamp, DatetimeIndex, DataFrame, NaT +from pandas import Timedelta, Timestamp, DatetimeIndex, DataFrame, NaT, Series from pandas.core.dtypes.cast import ( maybe_downcast_to_dtype, @@ -45,6 +45,12 @@ def test_downcast_conv(self): expected = np.array([8, 8, 8, 8, 9]) assert (np.array_equal(result, expected)) + # GH16875 coercing of bools + ser = Series([True, True, False]) + result = maybe_downcast_to_dtype(ser, np.dtype(np.float64)) + expected = ser + tm.assert_series_equal(result, expected) + # conversions expected = np.array([1, 2]) diff --git a/pandas/tests/groupby/test_transform.py b/pandas/tests/groupby/test_transform.py index 40434ff510421..98839a17d6e0c 100644 --- a/pandas/tests/groupby/test_transform.py +++ b/pandas/tests/groupby/test_transform.py @@ -195,6 +195,19 @@ def test_transform_bug(self): expected = Series(np.arange(5, 0, step=-1), name='B') assert_series_equal(result, expected) + def test_transform_numeric_to_boolean(self): + # GH 16875 + # inconsistency in transforming boolean values + expected = pd.Series([True, True], name='A') + + df = pd.DataFrame({'A': [1.1, 2.2], 'B': [1, 2]}) + result = df.groupby('B').A.transform(lambda x: True) + assert_series_equal(result, expected) + + df = pd.DataFrame({'A': [1, 2], 'B': [1, 2]}) + result = df.groupby('B').A.transform(lambda x: True) + assert_series_equal(result, expected) + def test_transform_datetime_to_timedelta(self): # GH 15429 # transforming a datetime to timedelta