Skip to content
55 changes: 54 additions & 1 deletion pandas/tests/groupby/test_transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" test with the .transform """
from datetime import timedelta
from io import StringIO

import numpy as np
Expand All @@ -23,6 +24,17 @@
from pandas.core.groupby.groupby import DataError


@pytest.fixture
def df_for_transformation_func():
return DataFrame(
{
"A": [121, 121, 121, 121, 231, 231, 676],
"B": [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0],
"C": pd.date_range("2013-11-03", periods=7, freq="D"),
}
)


def assert_fp_equal(a, b):
assert (np.abs(a - b) < 1e-12).all()

Expand Down Expand Up @@ -318,7 +330,7 @@ def test_dispatch_transform(tsframe):
tm.assert_frame_equal(filled, expected)


def test_transform_transformation_func(transformation_func):
def test_transform_transformation_func(transformation_func, df_for_transformation_func):
# GH 30918
df = DataFrame(
{
Expand Down Expand Up @@ -346,6 +358,47 @@ def test_transform_transformation_func(transformation_func):
tm.assert_frame_equal(result, expected)


def test_groupby_transform_corrwith(df_for_transformation_func):

# GH 27905
df = df_for_transformation_func
g = df.groupby("A")

result = g.corrwith(df)
expected = pd.DataFrame(dict(B=[1, np.nan, np.nan], A=[np.nan] * 3))
expected.index = pd.Index([121, 231, 676], name="A")
tm.assert_frame_equal(result, expected)

msg = "'Series' object has no attribute 'corrwith'"

with pytest.raises(AttributeError, match=msg):
g.transform("corrwith", df)


def test_groupby_transform_tshift(df_for_transformation_func):

# GH 27905
df = df_for_transformation_func
g = df.set_index("C").groupby("A")
result = g.tshift(2, "D")
df["C"] = df["C"] + timedelta(days=2)
expected = df
tm.assert_frame_equal(
result.reset_index().reindex(columns=["A", "B", "C"]), expected
)

op1 = g.transform(lambda x: x.tshift(2, "D"))
op2 = g.transform("tshift", *[2, "D"])

for result in [op1, op2]:
pytest.xfail(
"The output of groupby.transform with tshift is wrong, see GH 32344"
)
tm.assert_frame_equal(
result.reset_index().reindex(columns=["A", "B", "C"]), expected
)


def test_transform_select_columns(df):
f = lambda x: x.mean()
result = df.groupby("A")[["C", "D"]].transform(f)
Expand Down