diff --git a/docs/sources/CHANGELOG.md b/docs/sources/CHANGELOG.md index 0fb0ce86e..be7f5b0ff 100755 --- a/docs/sources/CHANGELOG.md +++ b/docs/sources/CHANGELOG.md @@ -21,10 +21,11 @@ The CHANGELOG for the current development version is available at - The `'support'` column returned by `frequent_patterns.association_rules` was changed to compute the support of "antecedant union consequent", and new `antecedant support'` and `'consequent support'` column were added to avoid ambiguity. [#245](https://github.com/rasbt/mlxtend/pull/245) - Added `'leverage'` and `'conviction` as evaluation metrics to the `frequent_patterns.association_rules` function. [#246](https://github.com/rasbt/mlxtend/pull/246) & [#247](https://github.com/rasbt/mlxtend/pull/247) +- Allow the `OnehotTransactions` to be cloned via scikit-learn's `clone` function, which is required by e.g., scikit-learn's `FeatureUnion` or `GridSearchCV` (via [Iaroslav Shcherbatyi](https://github.com/iaroslav-ai)). [#249](https://github.com/rasbt/mlxtend/pull/249) ##### Bug Fixes -- +- Allow `OneHot` ### Version 0.8.0 (2017-09-09) diff --git a/mlxtend/preprocessing/onehot.py b/mlxtend/preprocessing/onehot.py index bb4cc1bc6..698972a76 100644 --- a/mlxtend/preprocessing/onehot.py +++ b/mlxtend/preprocessing/onehot.py @@ -5,6 +5,7 @@ # License: BSD 3 clause import numpy as np +from sklearn.base import BaseEstimator, TransformerMixin def one_hot(y, num_labels='auto', dtype='float'): @@ -51,7 +52,7 @@ def one_hot(y, num_labels='auto', dtype='float'): return ary.astype(dtype) -class OnehotTransactions(object): +class OnehotTransactions(BaseEstimator, TransformerMixin): """One-hot encoder class for transaction data in Python lists Parameters diff --git a/mlxtend/preprocessing/tests/test_onehot_transactions.py b/mlxtend/preprocessing/tests/test_onehot_transactions.py index 2c2fd870d..8fe9571f2 100644 --- a/mlxtend/preprocessing/tests/test_onehot_transactions.py +++ b/mlxtend/preprocessing/tests/test_onehot_transactions.py @@ -6,6 +6,8 @@ import numpy as np from mlxtend.preprocessing import OnehotTransactions +from sklearn.base import clone +from mlxtend.utils import assert_raises dataset = [['Apple', 'Beer', 'Rice', 'Chicken'], @@ -63,3 +65,19 @@ def test_inverse_transform(): oht.fit(dataset) np.testing.assert_array_equal(np.array(data_sorted), np.array(oht.inverse_transform(expect))) + + +def test_cloning(): + + oht = OnehotTransactions() + oht.fit(dataset) + oht2 = clone(oht) + + msg = ("'OnehotTransactions' object has no attribute 'columns_'") + assert_raises(AttributeError, + msg, + oht2.transform, + dataset) + + trans = oht2.fit_transform(dataset) + np.testing.assert_array_equal(expect, trans)