Skip to content

Commit

Permalink
PERF: Avoid re-computing mask in nanmedian (#50838)
Browse files Browse the repository at this point in the history
* PERF: Avoid re-computing mask in nanmedian

* Add gh ref

* Fix
  • Loading branch information
phofl committed Jan 19, 2023
1 parent b5abe5d commit 4d0cc6f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,7 @@ Performance improvements
- Performance improvement in :class:`Period` constructor when constructing from a string or integer (:issue:`38312`)
- Performance improvement in :func:`to_datetime` when using ``'%Y%m%d'`` format (:issue:`17410`)
- Performance improvement in :func:`to_datetime` when format is given or can be inferred (:issue:`50465`)
- Performance improvement in :meth:`Series.median` for nullable dtypes (:issue:`50838`)
- Performance improvement in :func:`read_csv` when passing :func:`to_datetime` lambda-function to ``date_parser`` and inputs have mixed timezone offsetes (:issue:`35296`)
- Performance improvement in :func:`isna` and :func:`isnull` (:issue:`50658`)
- Performance improvement in :meth:`.SeriesGroupBy.value_counts` with categorical dtype (:issue:`46202`)
Expand Down
13 changes: 8 additions & 5 deletions pandas/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,16 +746,19 @@ def nanmedian(values, *, axis: AxisInt | None = None, skipna: bool = True, mask=
2.0
"""

def get_median(x):
mask = notna(x)
if not skipna and not mask.all():
def get_median(x, _mask=None):
if _mask is None:
_mask = notna(x)
else:
_mask = ~_mask
if not skipna and not _mask.all():
return np.nan
with warnings.catch_warnings():
# Suppress RuntimeWarning about All-NaN slice
warnings.filterwarnings(
"ignore", "All-NaN slice encountered", RuntimeWarning
)
res = np.nanmedian(x[mask])
res = np.nanmedian(x[_mask])
return res

values, mask, dtype, _, _ = _get_values(values, skipna, mask=mask)
Expand Down Expand Up @@ -796,7 +799,7 @@ def get_median(x):

else:
# otherwise return a scalar value
res = get_median(values) if notempty else np.nan
res = get_median(values, mask) if notempty else np.nan
return _wrap_results(res, dtype)


Expand Down

0 comments on commit 4d0cc6f

Please sign in to comment.