Skip to content

Commit

Permalink
Fix sklearn v0.18 cross_validation compat
Browse files Browse the repository at this point in the history
  • Loading branch information
sinhrks committed Oct 11, 2016
1 parent b9d1d50 commit 63751a4
Show file tree
Hide file tree
Showing 9 changed files with 54 additions and 43 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Example
[5 rows x 65 columns]
# split to training and test data
>>> train_df, test_df = df.cross_validation.train_test_split()
>>> train_df, test_df = df.model_selection.train_test_split()
# create estimator (accessor is mapped to sklearn namespace)
>>> estimator = df.svm.LinearSVC()
Expand Down
52 changes: 22 additions & 30 deletions doc/source/sklearn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,36 +42,41 @@ You can create ``ModelFrame`` instance from ``scikit-learn`` datasets directly.
Following table shows ``scikit-learn`` module and corresponding ``ModelFrame`` module. Some accessors has its abbreviated versions.

================================ ==========================================
================================ ======================================================
``scikit-learn`` ``ModelFrame`` accessor
================================ ==========================================
================================ ======================================================
``sklearn.calibration`` ``ModelFrame.calibration``
``sklearn.cluster`` ``ModelFrame.cluster``
``sklearn.covariance`` ``ModelFrame.covariance``
``sklearn.cross_decomposition`` ``ModelFrame.cross_decomposition``
``sklearn.cross_validation`` ``ModelFrame.cross_validation``, ``crv``
``sklearn.cross_validation`` ``ModelFrame.cross_validation``, ``crv`` (deprecated)
``sklearn.datasets`` (not accesible from accessor)
``sklearn.decomposition`` ``ModelFrame.decomposition``
``sklearn.discriminant_analysis`` ``ModelFrame.discriminant_analysis``, ``da``
``sklearn.dummy`` ``ModelFrame.dummy``
``sklearn.ensemble`` ``ModelFrame.ensemble``
``sklearn.feature_extraction`` ``ModelFrame.feature_extraction``
``sklearn.feature_selection`` ``ModelFrame.feature_selection``
``sklearn.gaussian_process`` ``ModelFrame.gaussian_process``
``sklearn.grid_search`` ``ModelFrame.grid_search``
``sklearn.gaussian_process`` ``ModelFrame.gaussian_process``, ``gp``
``sklearn.grid_search`` ``ModelFrame.grid_search`` (deprecated)
``sklearn.isotonic`` ``ModelFrame.isotonic``
``sklearn.kernel_approximation`` ``ModelFrame.kernel_approximation``
``sklearn.lda`` ``ModelFrame.lda``
``sklearn.learning_curve`` ``ModelFrame.learning_curve``
``sklearn.kernel_ridge`` ``ModelFrame.kernel_ridge``
``sklearn.lda`` ``ModelFrame.lda`` (deprecated)
``sklearn.learning_curve`` ``ModelFrame.learning_curve`` (deprecated)
``sklearn.linear_model`` ``ModelFrame.linear_model``, ``lm``
``sklearn.manifold`` ``ModelFrame.manifold``
``sklearn.metrics`` ``ModelFrame.metrics``
``sklearn.mixture`` ``ModelFrame.mixture``
``sklearn.model_selection`` ``ModelFrame.model_selection``, ``ms``
``sklearn.multiclass`` ``ModelFrame.multiclass``
``sklearn.multioutput`` ``ModelFrame.multioutput``
``sklearn.naive_bayes`` ``ModelFrame.naive_bayes``
``sklearn.neighbors`` ``ModelFrame.neighbors``
``sklearn.neural_network`` ``ModelFrame.neural_network``
``sklearn.pipeline`` ``ModelFrame.pipeline``
``sklearn.preprocessing`` ``ModelFrame.preprocessing``, ``pp``
``sklearn.qda`` ``ModelFrame.qda``
``sklearn.qda`` ``ModelFrame.qda`` (deprecated)
``sklearn.semi_supervised`` ``ModelFrame.semi_supervised``
``sklearn.svm`` ``ModelFrame.svm``
``sklearn.tree`` ``ModelFrame.tree``
Expand Down Expand Up @@ -277,7 +282,7 @@ Cross Validation

.. code-block:: python
>>> train_df, test_df = df.cross_validation.train_test_split()
>>> train_df, test_df = df.model_selection.train_test_split()
>>> train_df
.target sepal length sepal width petal length petal width
124 2 6.7 3.3 5.7 2.1
Expand Down Expand Up @@ -312,30 +317,17 @@ Cross Validation
[38 rows x 5 columns]
Also, there are some iterative classes which returns indexes for training sets and test sets. You can slice ``ModelFrame`` using these indexes.

.. code-block:: python
>>> kf = df.cross_validation.KFold(n=150, n_folds=3)
>>> for train_index, test_index in kf:
... print('training set shape: ', df.iloc[train_index, :].shape,
... 'test set shape: ', df.iloc[test_index, :].shape)
('training set shape: ', (100, 5), 'test set shape: ', (50, 5))
('training set shape: ', (100, 5), 'test set shape: ', (50, 5))
('training set shape: ', (100, 5), 'test set shape: ', (50, 5))
For further simplification, ``ModelFrame.cross_validation.iterate`` can accept such iterators and returns ``ModelFrame`` corresponding to training and test data.
You can iterate over Splitter classes via ``ModelFrame.model_selection.iterate`` which returns ``ModelFrame`` corresponding to training and test data.

.. code-block:: python
>>> kf = df.cross_validation.KFold(n=150, n_folds=3)
>>> kf = df.model_selection.KFold(n_splits=3)
>>> for train_df, test_df in df.cross_validation.iterate(kf):
... print('training set shape: ', train_df.shape,
... 'test set shape: ', test_df.shape)
('training set shape: ', (100, 5), 'test set shape: ', (50, 5))
('training set shape: ', (100, 5), 'test set shape: ', (50, 5))
('training set shape: ', (100, 5), 'test set shape: ', (50, 5))
training set shape: (112, 5) test set shape: (38, 5)
training set shape: (112, 5) test set shape: (38, 5)
training set shape: (112, 5) test set shape: (38, 5)
Grid Search
-----------
Expand All @@ -349,8 +341,8 @@ You can perform grid search using ``ModelFrame.fit``.
... {'kernel': ['linear'], 'C': [1, 10, 100]}]
>>> df = pdml.ModelFrame(datasets.load_digits())
>>> cv = df.grid_search.GridSearchCV(df.svm.SVC(C=1), tuned_parameters,
... cv=5, scoring='precision')
>>> cv = df.model_selection.GridSearchCV(df.svm.SVC(C=1), tuned_parameters,
... cv=5)
>>> df.fit(cv)
Expand All @@ -363,7 +355,7 @@ In addition, ``ModelFrame.grid_search`` has a ``describe`` function to organize

.. code-block:: python
>>> df.grid_search.describe(cv)
>>> df.model_selection.describe(cv)
mean std C gamma kernel
0 0.974108 0.013139 1 0.0010 rbf
1 0.951416 0.020010 1 0.0001 rbf
Expand Down
2 changes: 1 addition & 1 deletion doc/source/whatsnew.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ v0.4.0
Enhancement
^^^^^^^^^^^

- Support scikit-learn v0.17.x.
- Support scikit-learn v0.17.x and v0.18.0.
- Support imbalanced-learn via ``.imbalance`` accessor.

Bug Fix
Expand Down
8 changes: 8 additions & 0 deletions pandas_ml/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,11 +575,15 @@ def _cross_decomposition(self):
@property
@Appender(_shared_docs['skaccessor'] % dict(module='cross_validation'))
def cross_validation(self):
msg = '.cross_validation is deprecated. Use .ms or .model_selection'
warnings.warn(msg, DeprecationWarning, stacklevel=2)
return self._cross_validation

@property
@Appender(_shared_docs['skaccessor'] % dict(module='cross_validation'))
def crv(self):
msg = '.crv is deprecated. Use .ms or .model_selection'
warnings.warn(msg, DeprecationWarning, stacklevel=2)
return self._cross_validation

@cache_readonly
Expand Down Expand Up @@ -666,6 +670,8 @@ def _gaussian_process(self):
@property
@Appender(_shared_docs['skaccessor'] % dict(module='grid_search'))
def grid_search(self):
msg = '.grid_search is deprecated. Use .ms or .model_selection'
warnings.warn(msg, DeprecationWarning, stacklevel=2)
return self._grid_search

@cache_readonly
Expand Down Expand Up @@ -722,6 +728,8 @@ def lda(self):
@property
@Appender(_shared_docs['skaccessor'] % dict(module='learning_curve'))
def learning_curve(self):
msg = '.learning_curve is deprecated. Use .ms or .model_selection'
warnings.warn(msg, DeprecationWarning, stacklevel=2)
return self._learning_curve

@cache_readonly
Expand Down
2 changes: 1 addition & 1 deletion pandas_ml/skaccessors/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class CrossValidationMethods(_AccessorMethods):
"""
Accessor to ``sklearn.cross_validation``.
Deprecated. Accessor to ``sklearn.cross_validation``.
"""

_module_name = 'sklearn.cross_validation'
Expand Down
2 changes: 1 addition & 1 deletion pandas_ml/skaccessors/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class GridSearchMethods(_AccessorMethods):
"""
Accessor to ``sklearn.grid_search``.
Deprecated. Accessor to ``sklearn.grid_search``.
"""

_module_name = 'sklearn.grid_search'
Expand Down
2 changes: 1 addition & 1 deletion pandas_ml/skaccessors/learning_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class LearningCurveMethods(_AccessorMethods):
"""
Accessor to ``sklearn.learning_curve``.
Deprecated. Accessor to ``sklearn.learning_curve``.
"""

_module_name = 'sklearn.learning_curve'
Expand Down
14 changes: 11 additions & 3 deletions pandas_ml/skaccessors/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ def StratifiedShuffleSplit(self, *args, **kwargs):
- ``y``: ``ModelFrame.target``
"""
target = self._target
return self._module.StratifiedShuffleSplit(target.values, *args, **kwargs)
if _SKLEARN_ge_018:
return self._module.StratifiedShuffleSplit(*args, **kwargs)
else:
target = self._target
return self._module.StratifiedShuffleSplit(target.values, *args, **kwargs)

def iterate(self, cv, reset_index=False):
"""
Expand All @@ -45,7 +48,12 @@ def iterate(self, cv, reset_index=False):
msg = "{0} is not a subclass of BaseCrossValidator"
warnings.warn(msg.format(cv.__class__.__name__))

for train_index, test_index in cv.split(self._df.index):
if isinstance(cv, self._module.StratifiedShuffleSplit):
gen = cv.split(self._df.data.values, self._df.target.values)
else:
gen = cv.split(self._df.index)

for train_index, test_index in gen:
train_df = self._df.iloc[train_index, :]
test_df = self._df.iloc[test_index, :]
if reset_index:
Expand Down
13 changes: 8 additions & 5 deletions pandas_ml/skaccessors/test/test_model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,19 +239,22 @@ def test_check_cv(self):
self.assertIsInstance(result, ms.KFold)

def test_StratifiedShuffleSplit(self):
return

iris = datasets.load_iris()
df = pdml.ModelFrame(iris)
sf1 = df.model_selection.StratifiedShuffleSplit(random_state=self.random_state)
sf2 = ms.StratifiedShuffleSplit(iris.target, random_state=self.random_state)
sf2 = ms.StratifiedShuffleSplit(random_state=self.random_state)

# consume generator
ind1 = [x for x in sf1.split(df.data.values, df.target.values)]
ind2 = [x for x in sf2.split(iris.data, iris.target)]

self.assert_numpy_array_equal(ind1[0], ind1[0])
self.assert_numpy_array_equal(ind1[1], ind2[1])
for i1, i2 in zip(ind1, ind2):
self.assertIsInstance(i1, tuple)
self.assertEqual(len(i1), 2)
self.assertIsInstance(i2, tuple)
self.assertEqual(len(i2), 2)
self.assert_numpy_array_equal(i1[0], i1[0])
self.assert_numpy_array_equal(i1[1], i2[1])

sf1 = df.model_selection.StratifiedShuffleSplit(random_state=self.random_state)
with tm.assert_produces_warning(UserWarning):
Expand Down

0 comments on commit 63751a4

Please sign in to comment.