Skip to content

Commit

Permalink
Merge pull request #334 from rasbt/dropaxis2
Browse files Browse the repository at this point in the history
Also drop axis in column_selector if a tuple is provided and drop_axis=True
  • Loading branch information
rasbt committed Feb 20, 2018
2 parents b30749f + d2ca89b commit eb0f665
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
3 changes: 3 additions & 0 deletions mlxtend/feature_selection/column_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def transform(self, X, y=None):
"""
t = X[:, self.cols]

if t.shape[-1] == 1 and self.drop_axis:
t = t.reshape(-1)
if len(t.shape) == 1 and not self.drop_axis:
t = t[:, np.newaxis]
return t
Expand Down
10 changes: 8 additions & 2 deletions mlxtend/feature_selection/tests/test_column_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,16 @@ def test_ColumnSelector():

def test_ColumnSelector_drop_axis():
X1_in = np.ones((4, 8))
X1_out = ColumnSelector(cols=(1), drop_axis=True).transform(X1_in)
X1_out = ColumnSelector(cols=1, drop_axis=True).transform(X1_in)
assert X1_out.shape == (4,)

X1_out = ColumnSelector(cols=(1)).transform(X1_in)
X1_out = ColumnSelector(cols=(1,), drop_axis=True).transform(X1_in)
assert X1_out.shape == (4,)

X1_out = ColumnSelector(cols=1).transform(X1_in)
assert X1_out.shape == (4, 1)

X1_out = ColumnSelector(cols=(1,)).transform(X1_in)
assert X1_out.shape == (4, 1)


Expand Down

0 comments on commit eb0f665

Please sign in to comment.