From 7fdc39ae3aed9dc1b3c762284f29d1669c041e8a Mon Sep 17 00:00:00 2001 From: Kristof Van Engeland Date: Wed, 15 Aug 2018 14:42:01 +0200 Subject: [PATCH] Fix column naming for DataFrames with MultiIndex columns (#166) --- .gitignore | 4 ++- sklearn_pandas/dataframe_mapper.py | 2 +- tests/test_dataframe_mapper.py | 50 ++++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 941a393..250c89a 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,6 @@ .tox/ build/ dist/ -.cache/ \ No newline at end of file +.cache/ +.idea/ +.pytest_cache/ diff --git a/sklearn_pandas/dataframe_mapper.py b/sklearn_pandas/dataframe_mapper.py index 57b27f4..f51a911 100644 --- a/sklearn_pandas/dataframe_mapper.py +++ b/sklearn_pandas/dataframe_mapper.py @@ -233,7 +233,7 @@ def get_names(self, columns, transformer, x, alias=None): if alias is not None: name = alias elif isinstance(columns, list): - name = '_'.join(columns) + name = '_'.join(map(str, columns)) else: name = columns num_cols = x.shape[1] if len(x.shape) > 1 else 1 diff --git a/tests/test_dataframe_mapper.py b/tests/test_dataframe_mapper.py index 004e809..5df1149 100644 --- a/tests/test_dataframe_mapper.py +++ b/tests/test_dataframe_mapper.py @@ -108,6 +108,29 @@ def complex_dataframe(): 'feat2': [1, 2, 3, 2, 3, 4]}) +@pytest.fixture +def multiindex_dataframe(): + """Example MultiIndex DataFrame, taken from pandas documentation + """ + iterables = [['bar', 'baz', 'foo', 'qux'], ['one', 'two']] + index = pd.MultiIndex.from_product(iterables, names=['first', 'second']) + df = pd.DataFrame(np.random.randn(10, 8), columns=index) + return df + + +@pytest.fixture +def multiindex_dataframe_incomplete(multiindex_dataframe): + """Example MultiIndex DataFrame with missing entries + """ + df = multiindex_dataframe + mask_array = np.zeros(df.size) + mask_array[:20] = 1 + np.random.shuffle(mask_array) + mask = mask_array.reshape(df.shape).astype(bool) + df.mask(mask, inplace=True) + return df + + def test_transformed_names_simple(simple_dataframe): """ Get transformed names of features in `transformed_names` attribute @@ -234,6 +257,33 @@ def test_complex_df(complex_dataframe): assert len(transformed[c]) == len(df[c]) +def test_numeric_column_names(complex_dataframe): + """ + Get a dataframe from a complex mapped dataframe with numeric column names + """ + df = complex_dataframe + df.columns = [0, 1, 2] + mapper = DataFrameMapper( + [(0, None), (1, None), (2, None)], df_out=True) + transformed = mapper.fit_transform(df) + assert len(transformed) == len(complex_dataframe) + for c in df.columns: + assert len(transformed[c]) == len(df[c]) + + +def test_multiindex_df(multiindex_dataframe_incomplete): + """ + Get a dataframe from a multiindex dataframe with missing data + """ + df = multiindex_dataframe_incomplete + mapper = DataFrameMapper([([c], Imputer()) for c in df.columns], + df_out=True) + transformed = mapper.fit_transform(df) + assert len(transformed) == len(multiindex_dataframe_incomplete) + for c in df.columns: + assert len(transformed[str(c)]) == len(df[c]) + + def test_binarizer_df(): """ Check level names from LabelBinarizer