Skip to content

Commit

Permalink
Actually test equality in assert_groupby_results_equal (#8272)
Browse files Browse the repository at this point in the history
Authors:
  - Ashwin Srinath (https://github.com/shwina)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)
  - Keith Kraus (https://github.com/kkraus14)
  - Michael Wang (https://github.com/isVoid)
  - Christopher Harris (https://github.com/cwharris)
  - Gera Shegalov (https://github.com/gerashegalov)

URL: #8272
  • Loading branch information
shwina committed May 20, 2021
1 parent 944e932 commit 75e12d1
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 11 deletions.
8 changes: 5 additions & 3 deletions python/cudf/cudf/core/groupby/groupby.py
Expand Up @@ -110,6 +110,7 @@ def cumcount(self):
)
.groupby(self.grouping, sort=self._sort)
.agg("cumcount")
.reset_index(drop=True)
)

@cached_property
Expand Down Expand Up @@ -225,9 +226,10 @@ def nth(self, n):
"""
Return the nth row from each group.
"""
result = self.agg(lambda x: x.nth(n))
sizes = self.size()
return result[n < sizes]
result = self.agg(lambda x: x.nth(n)).sort_index()
sizes = self.size().sort_index()

return result[sizes > n]

def serialize(self):
header = {}
Expand Down
46 changes: 38 additions & 8 deletions python/cudf/cudf/tests/test_groupby.py
Expand Up @@ -30,14 +30,28 @@
_index_type_aggs = {"count", "idxmin", "idxmax", "cumcount"}


def assert_groupby_results_equal(expect, got, sort=True, **kwargs):
def assert_groupby_results_equal(
expect, got, sort=True, as_index=True, by=None, **kwargs
):
# Because we don't sort by index by default in groupby,
# sort expect and got by index before comparing
if sort:
expect = expect.sort_index()
got = got.sort_index()
else:
assert_eq(expect.sort_index(), got.sort_index(), **kwargs)
if as_index:
expect = expect.sort_index()
got = got.sort_index()
else:
assert by is not None
if isinstance(expect, (pd.DataFrame, cudf.DataFrame)):
expect = expect.sort_values(by=by).reset_index(drop=True)
else:
expect = expect.sort_values().reset_index(drop=True)

if isinstance(got, cudf.DataFrame):
got = got.sort_values(by=by).reset_index(drop=True)
else:
got = got.sort_values().reset_index(drop=True)

assert_eq(expect, got, **kwargs)


def make_frame(
Expand Down Expand Up @@ -201,17 +215,25 @@ def test_groupby_getitem_getattr(as_index):
pdf = pd.DataFrame({"x": [1, 3, 1], "y": [1, 2, 3], "z": [1, 4, 5]})
gdf = cudf.from_pandas(pdf)
assert_groupby_results_equal(
pdf.groupby("x")["y"].sum(), gdf.groupby("x")["y"].sum(),
pdf.groupby("x")["y"].sum(),
gdf.groupby("x")["y"].sum(),
as_index=as_index,
by="x",
)
assert_groupby_results_equal(
pdf.groupby("x").y.sum(), gdf.groupby("x").y.sum(),
pdf.groupby("x").y.sum(),
gdf.groupby("x").y.sum(),
as_index=as_index,
by="x",
)
assert_groupby_results_equal(
pdf.groupby("x")[["y"]].sum(), gdf.groupby("x")[["y"]].sum(),
)
assert_groupby_results_equal(
pdf.groupby(["x", "y"], as_index=as_index).sum(),
gdf.groupby(["x", "y"], as_index=as_index).sum(),
as_index=as_index,
by=["x", "y"],
)


Expand Down Expand Up @@ -1088,7 +1110,13 @@ def test_groupby_datetime(nelem, as_index, agg):
else:
pdres = pdg.agg({"datetime": agg})
gdres = gdg.agg({"datetime": agg})
assert_groupby_results_equal(pdres, gdres, check_dtype=check_dtype)
assert_groupby_results_equal(
pdres,
gdres,
check_dtype=check_dtype,
as_index=as_index,
by=["datetime"],
)


def test_groupby_dropna():
Expand Down Expand Up @@ -1349,6 +1377,8 @@ def test_reset_index_after_empty_groupby():
assert_groupby_results_equal(
pdf.groupby("a").sum().reset_index(),
gdf.groupby("a").sum().reset_index(),
as_index=False,
by="a",
)


Expand Down

0 comments on commit 75e12d1

Please sign in to comment.