Skip to content

Commit

Permalink
API: dont infer dtype for object-dtype groupby reductions (#51205)
Browse files Browse the repository at this point in the history
* API: dont infer dtype for object-dtype groupby reductions

* GH ref
  • Loading branch information
jbrockmendel committed Feb 10, 2023
1 parent f33105f commit 9eec5bf
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 9 deletions.
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v2.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,9 @@ Other API changes
- The levels of the index of the :class:`Series` returned from ``Series.sparse.from_coo`` now always have dtype ``int32``. Previously they had dtype ``int64`` (:issue:`50926`)
- :func:`to_datetime` with ``unit`` of either "Y" or "M" will now raise if a sequence contains a non-round ``float`` value, matching the ``Timestamp`` behavior (:issue:`50301`)
- The methods :meth:`Series.round`, :meth:`DataFrame.__invert__`, :meth:`Series.__invert__`, :meth:`DataFrame.swapaxes`, :meth:`DataFrame.first`, :meth:`DataFrame.last`, :meth:`Series.first`, :meth:`Series.last` and :meth:`DataFrame.align` will now always return new objects (:issue:`51032`)
- :class:`DataFrameGroupBy` aggregations (e.g. "sum") with object-dtype columns no longer infer non-object dtypes for their results, explicitly call ``result.infer_objects(copy=False)`` on the result to obtain the old behavior (:issue:`51205`)
- Added :func:`pandas.api.types.is_any_real_numeric_dtype` to check for real numeric dtypes (:issue:`51152`)
-

.. ---------------------------------------------------------------------------
.. _whatsnew_200.deprecations:
Expand Down
6 changes: 4 additions & 2 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,9 @@ def _agg_py_fallback(
# TODO: if we ever get "rank" working, exclude it here.
res_values = type(values)._from_sequence(res_values, dtype=values.dtype)

elif ser.dtype == object:
res_values = res_values.astype(object, copy=False)

# If we are DataFrameGroupBy and went through a SeriesGroupByPath
# then we need to reshape
# GH#32223 includes case with IntegerArray values, ndarray res_values
Expand Down Expand Up @@ -1537,8 +1540,7 @@ def array_func(values: ArrayLike) -> ArrayLike:
new_mgr = data.grouped_reduce(array_func)
res = self._wrap_agged_manager(new_mgr)
out = self._wrap_aggregated_output(res)
if data.ndim == 2:
# TODO: don't special-case DataFrame vs Series
if self.axis == 1:
out = out.infer_objects(copy=False)
return out

Expand Down
2 changes: 2 additions & 0 deletions pandas/tests/groupby/aggregate/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def test_multiindex_groupby_mixed_cols_axis1(func, expected, dtype, result_dtype
expected = DataFrame([expected] * 3, columns=["i", "j", "k"]).astype(
result_dtype_dict
)

tm.assert_frame_equal(result, expected)


Expand Down Expand Up @@ -675,6 +676,7 @@ def test_agg_split_object_part_datetime():
"F": [1],
},
index=np.array([0]),
dtype=object,
)
tm.assert_frame_equal(result, expected)

Expand Down
1 change: 1 addition & 0 deletions pandas/tests/groupby/aggregate/test_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ def test_sum_uint64_overflow():
expected = DataFrame(
{1: [9223372036854775809, 9223372036854775811, 9223372036854775813]},
index=index,
dtype=object,
)

expected.index.name = 0
Expand Down
6 changes: 6 additions & 0 deletions pandas/tests/groupby/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,6 +1509,12 @@ def test_deprecate_numeric_only_series(dtype, groupby_func, request):
"sum",
"diff",
"pct_change",
"var",
"mean",
"median",
"min",
"max",
"prod",
)

# Test default behavior; kernels that fail may be enabled in the future but kernels
Expand Down
4 changes: 3 additions & 1 deletion pandas/tests/groupby/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2380,7 +2380,9 @@ def test_groupby_duplicate_columns():
).astype(object)
df.columns = ["A", "B", "B"]
result = df.groupby([0, 0, 0, 0]).min()
expected = DataFrame([["e", "a", 1]], index=np.array([0]), columns=["A", "B", "B"])
expected = DataFrame(
[["e", "a", 1]], index=np.array([0]), columns=["A", "B", "B"], dtype=object
)
tm.assert_frame_equal(result, expected)


Expand Down
18 changes: 12 additions & 6 deletions pandas/tests/groupby/test_min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,19 +148,25 @@ def test_aggregate_numeric_object_dtype():
{"key": ["A", "A", "B", "B"], "col1": list("abcd"), "col2": [np.nan] * 4},
).astype(object)
result = df.groupby("key").min()
expected = DataFrame(
{"key": ["A", "B"], "col1": ["a", "c"], "col2": [np.nan, np.nan]}
).set_index("key")
expected = (
DataFrame(
{"key": ["A", "B"], "col1": ["a", "c"], "col2": [np.nan, np.nan]},
)
.set_index("key")
.astype(object)
)
tm.assert_frame_equal(result, expected)

# same but with numbers
df = DataFrame(
{"key": ["A", "A", "B", "B"], "col1": list("abcd"), "col2": range(4)},
).astype(object)
result = df.groupby("key").min()
expected = DataFrame(
{"key": ["A", "B"], "col1": ["a", "c"], "col2": [0, 2]}
).set_index("key")
expected = (
DataFrame({"key": ["A", "B"], "col1": ["a", "c"], "col2": [0, 2]})
.set_index("key")
.astype(object)
)
tm.assert_frame_equal(result, expected)


Expand Down

0 comments on commit 9eec5bf

Please sign in to comment.