Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+ ½] Fixing #7155 in stochastic_gradient.py #7159

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
5da25f5
Fix #7155
yl565 Aug 7, 2016
dcbe371
Fix test errors
yl565 Aug 7, 2016
3aa26ac
add_tests
yl565 Aug 8, 2016
db59c22
fix_py26_error
yl565 Aug 8, 2016
9aed2b5
add spaces
yl565 Aug 8, 2016
bc1a700
use decorator, do not set fit params at init
yl565 Aug 9, 2016
91b648e
fix coef_ initialization error raised in passive_aggressive.py
yl565 Aug 9, 2016
899d0d1
fix pep8 and redundent _CheckProba in predict_log_proba
yl565 Aug 9, 2016
b2c5865
Revert to master
yl565 Aug 18, 2016
e2fe921
Implement solution suggested by @jnothman
yl565 Aug 18, 2016
60f2ed1
Revert estimator_checks.py to master
yl565 Aug 18, 2016
80ac0e4
Fix pep8 and modify comments as @raghavrv suggested
yl565 Aug 18, 2016
d84f5f3
Simplified docstring of `hasattr_nested`and renamed it `_hasattr_nested`
yl565 Aug 19, 2016
9e3eacb
Moving tests and ignore other delegates if the attribute has already …
yl565 Aug 22, 2016
cb274d4
Changed from `tuple([delegate])` to `(delegate,)`
yl565 Aug 22, 2016
5594f05
Check first if the first item in delegate is an attribute of object
yl565 Aug 25, 2016
0b3634c
Add support for list. Use attrgetter instead of hasattr
yl565 Aug 25, 2016
be9ae6f
Modify docstring, simplify docstring example
yl565 Aug 26, 2016
e1770fa
Move doctests to unit tests, add more tests, rename `method_name`, fi…
yl565 Sep 3, 2016
265fc70
remove unused import
yl565 Sep 3, 2016
23b325b
Alternative solution with check_is_fitted
yl565 Sep 3, 2016
2524e3c
Merge branch 'fix_7155_alt_check_fitted' into fix_7155
yl565 Sep 3, 2016
7afe3b9
Use new function if_fitted_delegate_has_method
yl565 Sep 4, 2016
62de9c3
Fix pep8
yl565 Sep 4, 2016
fa6b974
revert back to 265fc70
yl565 Sep 9, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 6 additions & 6 deletions sklearn/grid_search.py
Expand Up @@ -425,7 +425,7 @@ def score(self, X, y=None):
ChangedBehaviorWarning)
return self.scorer_(self.best_estimator_, X, y)

@if_delegate_has_method(delegate='estimator')
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def predict(self, X):
"""Call predict on the estimator with the best found parameters.

Expand All @@ -441,7 +441,7 @@ def predict(self, X):
"""
return self.best_estimator_.predict(X)

@if_delegate_has_method(delegate='estimator')
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def predict_proba(self, X):
"""Call predict_proba on the estimator with the best found parameters.

Expand All @@ -457,7 +457,7 @@ def predict_proba(self, X):
"""
return self.best_estimator_.predict_proba(X)

@if_delegate_has_method(delegate='estimator')
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def predict_log_proba(self, X):
"""Call predict_log_proba on the estimator with the best found parameters.

Expand All @@ -473,7 +473,7 @@ def predict_log_proba(self, X):
"""
return self.best_estimator_.predict_log_proba(X)

@if_delegate_has_method(delegate='estimator')
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def decision_function(self, X):
"""Call decision_function on the estimator with the best found parameters.

Expand All @@ -489,7 +489,7 @@ def decision_function(self, X):
"""
return self.best_estimator_.decision_function(X)

@if_delegate_has_method(delegate='estimator')
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def transform(self, X):
"""Call transform on the estimator with the best found parameters.

Expand All @@ -505,7 +505,7 @@ def transform(self, X):
"""
return self.best_estimator_.transform(X)

@if_delegate_has_method(delegate='estimator')
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def inverse_transform(self, Xt):
"""Call inverse_transform on the estimator with the best found parameters.

Expand Down
63 changes: 35 additions & 28 deletions sklearn/utils/metaestimators.py
Expand Up @@ -14,16 +14,22 @@ class _IffHasAttrDescriptor(object):
"""Implements a conditional property using the descriptor protocol.

Using this class to create a decorator will raise an ``AttributeError``
if the ``attribute_name`` is not present on the base object.
if none of the delegates (specified in ``delegate_names``) is an attribute
of the base object or none of the delegates has an attribute
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"none of the delegates has" -> "the first found delegate does not have"

``attribute_name``.

This allows ducktyping of the decorated method based on ``attribute_name``.
This allows ducktyping of the decorated method based on
``delegate.attribute_name`` where ``delegate`` is the first item in
``delegate_names`` that is an attribute of the base object.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this sentence hard to understand. maybe make it two sentences. And say "Here delegate is the first item in delegate_names for which hasattr(object, delegate) is true? not sure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about the following:

Using this class to create a decorator for a class method object.attribute_name.
An AttributeError will be raised if none of the delegates (specified in
delegate_names) is an attribute of object (i.e. hasattr(object, delegate )
is False) or none of the delegates has an attribute attribute_name (i.e.
hasattr(delegate, attribute_name) is False).

This allows ducktyping of the decorated method object.attribute_name based on
object.delegate.attribute_name where delegate is the first item in
delegate_names for which hasattr(object, delegate ) is True).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like " for a class method object.attribute_name":

  • attribute_name can be any attribute, not necessarily a method
  • a class method named "object.attribute_name" is weird or even contradictory.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just spit the second sentence as @amueller suggested.


See https://docs.python.org/3/howto/descriptor.html for an explanation of
descriptors.
"""
def __init__(self, fn, attribute_name):
def __init__(self, fn, delegate_names, attribute_name):
self.fn = fn
self.get_attribute = attrgetter(attribute_name)
self.delegate_names = delegate_names
self.attribute_name = attribute_name
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now a tuple. The variable name should reflect this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I.e. should be attr_names or similar


# update the docstring of the descriptor
update_wrapper(self, fn)

Expand All @@ -32,7 +38,17 @@ def __get__(self, obj, type=None):
if obj is not None:
# delegate only on instances, not the classes.
# this is to allow access to the docstrings.
self.get_attribute(obj)
for delegate_name in self.delegate_names:
try:
delegate = attrgetter(delegate_name)(obj)
except AttributeError:
continue
else:
getattr(delegate, self.attribute_name)
break
else:
attrgetter(self.delegate_names[-1])(obj)

# lambda, but not partial, allows help() to work with update_wrapper
out = lambda *args, **kwargs: self.fn(obj, *args, **kwargs)
# update the docstring of the returned function
Expand All @@ -46,27 +62,18 @@ def if_delegate_has_method(delegate):
This enables ducktyping by hasattr returning True according to the
sub-estimator.

>>> from sklearn.utils.metaestimators import if_delegate_has_method
>>>
>>>
>>> class MetaEst(object):
... def __init__(self, sub_est):
... self.sub_est = sub_est
...
... @if_delegate_has_method(delegate='sub_est')
... def predict(self, X):
... return self.sub_est.predict(X)
...
>>> class HasPredict(object):
... def predict(self, X):
... return X.sum(axis=1)
...
>>> class HasNoPredict(object):
... pass
...
>>> hasattr(MetaEst(HasPredict()), 'predict')
True
>>> hasattr(MetaEst(HasNoPredict()), 'predict')
False
Parameters
----------
delegate : string, list of strings or tuple of strings
Name of the sub-estimator that can be accessed as an attribute of the
base object. If a list or a tuple of names are provided, the first
sub-estimator that is an attribute of the base object will be used.

Copy link
Member

@lesteve lesteve Sep 6, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would leave the doctests as they were or add them as tests to sklearn/utils/tests/test_metaestimators.py.

You could also add additional tests in the same file as you did at some point of this PR. Adding tests for if_fitted_delegate_has_method would be good too.

"""
return lambda fn: _IffHasAttrDescriptor(fn, '%s.%s' % (delegate, fn.__name__))
if isinstance(delegate, list):
delegate = tuple(delegate)
if not isinstance(delegate, tuple):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think list should also be supported.

delegate = (delegate,)

return lambda fn: _IffHasAttrDescriptor(fn, delegate,
attribute_name=fn.__name__)
92 changes: 91 additions & 1 deletion sklearn/utils/tests/test_metaestimators.py
@@ -1,5 +1,11 @@
import warnings
import numpy as np
from nose.tools import assert_true, assert_false
from sklearn.utils.metaestimators import if_delegate_has_method
from nose.tools import assert_true
from sklearn.linear_model import SGDClassifier
with warnings.catch_warnings():
warnings.simplefilter('ignore')
from sklearn.grid_search import GridSearchCV


class Prefix(object):
Expand All @@ -24,3 +30,87 @@ def test_delegated_docstring():
in str(MockMetaEstimator.func.__doc__))
assert_true("This is a mock delegated function"
in str(MockMetaEstimator().func.__doc__))


def test_stochastic_gradient_loss_param():
Copy link
Member

@lesteve lesteve Sep 6, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test doesn't really belong to test_metaestimators.py maybe somewhere like sklearn/tests/test_grid_search.py would be better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for moving it to sklearn/model_selection/tests/test_search.py

# Make sure the predict_proba works when loss is specified
# as one of the parameters in the param_grid.
param_grid = {
'loss': ['log'],
}
X = np.arange(20).reshape(5, -1)
y = [0, 0, 1, 1, 1]
clf = GridSearchCV(estimator=SGDClassifier(loss='hinge'),
param_grid=param_grid)

# When the estimator is not fitted, `predict_proba` is not available as the
# loss is 'hinge'.
assert_false(hasattr(clf, "predict_proba"))
clf.fit(X, y)
clf.predict_proba(X)
clf.predict_log_proba(X)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be sure, please assert that when param_grid={'loss':['hinge']}, clf.predict_proba is not available.


# Make sure `predict_proba` is not available when setting loss=['hinge']
# in param_grid
param_grid = {
'loss': ['hinge'],
}
clf = GridSearchCV(estimator=SGDClassifier(loss='hinge'),
param_grid=param_grid)
assert_false(hasattr(clf, "predict_proba"))
clf.fit(X, y)
assert_false(hasattr(clf, "predict_proba"))


class MetaEst(object):
"""A mock meta estimator"""
def __init__(self, sub_est, better_sub_est=None):
self.sub_est = sub_est
self.better_sub_est = better_sub_est

@if_delegate_has_method(delegate='sub_est')
def predict(self):
pass


class MetaEstTestTuple(MetaEst):
"""A mock meta estimator to test passing a tuple of delegates"""

@if_delegate_has_method(delegate=('sub_est', 'better_sub_est'))
def predict(self):
pass


class MetaEstTestList(MetaEst):
"""A mock meta estimator to test passing a list of delegates"""

@if_delegate_has_method(delegate=['sub_est', 'better_sub_est'])
def predict(self):
pass


class HasPredict(object):
"""A mock sub-estimator with predict method"""

def predict(self):
pass


class HasNoPredict(object):
"""A mock sub-estimator with no predict method"""
pass


def test_if_delegate_has_method():
assert_true(hasattr(MetaEst(HasPredict()), 'predict'))
assert_false(hasattr(MetaEst(HasNoPredict()), 'predict'))
assert_false(
hasattr(MetaEstTestTuple(HasNoPredict(), HasNoPredict()), 'predict'))
assert_true(
hasattr(MetaEstTestTuple(HasPredict(), HasNoPredict()), 'predict'))
assert_false(
hasattr(MetaEstTestTuple(HasNoPredict(), HasPredict()), 'predict'))
assert_false(
hasattr(MetaEstTestList(HasNoPredict(), HasPredict()), 'predict'))
assert_true(
hasattr(MetaEstTestList(HasPredict(), HasPredict()), 'predict'))