From 580588928b18425decbea828efa1ac93f0486623 Mon Sep 17 00:00:00 2001 From: Manuel Riel Date: Mon, 4 May 2015 19:06:33 +0700 Subject: [PATCH] Return correct subclass when slicing DataFrame. --- doc/source/whatsnew/v0.16.1.txt | 1 + pandas/core/frame.py | 4 +-- pandas/tests/test_frame.py | 53 +++++++++++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 2 deletions(-) diff --git a/doc/source/whatsnew/v0.16.1.txt b/doc/source/whatsnew/v0.16.1.txt index 56d940031119d..d422e7815a5a3 100755 --- a/doc/source/whatsnew/v0.16.1.txt +++ b/doc/source/whatsnew/v0.16.1.txt @@ -300,3 +300,4 @@ Bug Fixes - Bug in ``transform`` when groups are equal in number and dtype to the input index (:issue:`9700`) - Google BigQuery connector now imports dependencies on a per-method basis.(:issue:`9713`) - Updated BigQuery connector to no longer use deprecated ``oauth2client.tools.run()`` (:issue:`8327`) +- Bug in subclassed ``DataFrame``. It may not return the correct class, when slicing or subsetting it. (:issue:`9632`) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 01b0d65e055df..cf676b81388a2 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -1839,7 +1839,7 @@ def _getitem_multilevel(self, key): result.columns = result_columns else: new_values = self.values[:, loc] - result = DataFrame(new_values, index=self.index, + result = self._constructor(new_values, index=self.index, columns=result_columns).__finalize__(self) if len(result.columns) == 1: top = result.columns[0] @@ -1847,7 +1847,7 @@ def _getitem_multilevel(self, key): (type(top) == tuple and top[0] == '')): result = result[''] if isinstance(result, Series): - result = Series(result, index=self.index, name=key) + result = self._constructor_sliced(result, index=self.index, name=key) result._set_is_copy(self) return result diff --git a/pandas/tests/test_frame.py b/pandas/tests/test_frame.py index 3f60f10e81013..4964d13f7ac28 100644 --- a/pandas/tests/test_frame.py +++ b/pandas/tests/test_frame.py @@ -2791,6 +2791,59 @@ def test_insert_error_msmgs(self): with assertRaisesRegexp(TypeError, msg): df['gr'] = df.groupby(['b', 'c']).count() + def test_frame_subclassing_and_slicing(self): + # Subclass frame and ensure it returns the right class on slicing it + # In reference to PR 9632 + + class CustomSeries(Series): + @property + def _constructor(self): + return CustomSeries + + def custom_series_function(self): + return 'OK' + + class CustomDataFrame(DataFrame): + "Subclasses pandas DF, fills DF with simulation results, adds some custom plotting functions." + + def __init__(self, *args, **kw): + super(CustomDataFrame, self).__init__(*args, **kw) + + @property + def _constructor(self): + return CustomDataFrame + + _constructor_sliced = CustomSeries + + def custom_frame_function(self): + return 'OK' + + data = {'col1': range(10), + 'col2': range(10)} + cdf = CustomDataFrame(data) + + # Did we get back our own DF class? + self.assertTrue(isinstance(cdf, CustomDataFrame)) + + # Do we get back our own Series class after selecting a column? + cdf_series = cdf.col1 + self.assertTrue(isinstance(cdf_series, CustomSeries)) + self.assertEqual(cdf_series.custom_series_function(), 'OK') + + # Do we get back our own DF class after slicing row-wise? + cdf_rows = cdf[1:5] + self.assertTrue(isinstance(cdf_rows, CustomDataFrame)) + self.assertEqual(cdf_rows.custom_frame_function(), 'OK') + + # Make sure sliced part of multi-index frame is custom class + mcol = pd.MultiIndex.from_tuples([('A', 'A'), ('A', 'B')]) + cdf_multi = CustomDataFrame([[0, 1], [2, 3]], columns=mcol) + self.assertTrue(isinstance(cdf_multi['A'], CustomDataFrame)) + + mcol = pd.MultiIndex.from_tuples([('A', ''), ('B', '')]) + cdf_multi2 = CustomDataFrame([[0, 1], [2, 3]], columns=mcol) + self.assertTrue(isinstance(cdf_multi2['A'], CustomSeries)) + def test_constructor_subclass_dict(self): # Test for passing dict subclass to constructor data = {'col1': tm.TestSubDict((x, 10.0 * x) for x in range(10)),