Skip to content

Commit

Permalink
ENH: address discussion
Browse files Browse the repository at this point in the history
  • Loading branch information
dengemann committed Nov 29, 2013
1 parent 163e659 commit 4c10c82
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 30 deletions.
13 changes: 9 additions & 4 deletions sklearn/feature_extraction/tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from numpy.testing import assert_array_equal
from numpy.testing import assert_raises
from sklearn.utils.testing import (assert_in, assert_less, assert_greater,
assert_warns)
assert_warns_message)

from collections import defaultdict, Mapping
from functools import partial
Expand Down Expand Up @@ -180,8 +180,11 @@ def test_unicode_decode_error():
assert_raises(UnicodeDecodeError, ca, text_bytes)

# Check the old interface
ca = assert_warns(DeprecationWarning, CountVectorizer, analyzer='char',
ngram_range=(3, 6), charset='ascii').build_analyzer()
in_warning_message = 'charset'
ca = assert_warns_message(DeprecationWarning, in_warning_message,
CountVectorizer, analyzer='char',
ngram_range=(3, 6),
charset='ascii').build_analyzer()
assert_raises(UnicodeDecodeError, ca, text_bytes)


Expand Down Expand Up @@ -349,7 +352,9 @@ def test_tfidf_no_smoothing():
1. / np.array([0.])
numpy_provides_div0_warning = len(w) == 1

tfidf = assert_warns(RuntimeWarning,tr.fit_transform, X).toarray()
in_warning_message = 'divide by zero'
tfidf = assert_warns_message(RuntimeWarning, in_warning_message,
tr.fit_transform, X).toarray()
if not numpy_provides_div0_warning:
raise SkipTest("Numpy does not provide div 0 warnings.")

Expand Down
38 changes: 23 additions & 15 deletions sklearn/linear_model/tests/test_least_angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sklearn.utils.testing import assert_less
from sklearn.utils.testing import assert_greater
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import ignore_warnings
from sklearn.utils.testing import ignore_warnings, assert_warns_message
from sklearn import linear_model, datasets

diabetes = datasets.load_diabetes()
Expand Down Expand Up @@ -182,7 +182,10 @@ def test_singular_matrix():
# to give a good answer
X1 = np.array([[1, 1.], [1., 1.]])
y1 = np.array([1, 1])
alphas, active, coef_path = ignore_warnings(linear_model.lars_path)(X1, y1)
in_warn_message = 'Dropping a regressor'
f = assert_warns_message
alphas, active, coef_path = f(UserWarning, in_warn_message,
linear_model.lars_path, X1, y1)
assert_array_almost_equal(coef_path.T, [[0, 0], [1, 0]])


Expand Down Expand Up @@ -315,22 +318,27 @@ def test_lasso_lars_vs_lasso_cd_ill_conditioned():
y += sigma * rng.rand(*y.shape)
y = y.squeeze()

f = ignore_warnings
lars_alphas, _, lars_coef = f(linear_model.lars_path)(X, y,
method='lasso')

_, lasso_coef2, _ = f(linear_model.lasso_path)(X, y,
alphas=lars_alphas,
tol=1e-6,
fit_intercept=False)

lasso_coef = np.zeros((w.shape[0], len(lars_alphas)))
for i, model in enumerate(f(linear_model.lasso_path)(X, y,
f = assert_warns_message
def in_warn_message(msg):
return 'Early stopping' in msg or 'Dropping regressor' in msg
lars_alphas, _, lars_coef = f(UserWarning,
in_warn_message,
linear_model.lars_path, X, y, method='lasso')

with ignore_warnings():
_, lasso_coef2, _ = linear_model.lasso_path(X, y,
alphas=lars_alphas,
tol=1e-6,
fit_intercept=False)

lasso_coef = np.zeros((w.shape[0], len(lars_alphas)))
iter_models = enumerate(linear_model.lasso_path(X, y,
alphas=lars_alphas,
tol=1e-6,
return_models=True,
fit_intercept=False)):
lasso_coef[:, i] = model.coef_
fit_intercept=False))
for i, model in iter_models:
lasso_coef[:, i] = model.coef_

np.testing.assert_array_almost_equal(lars_coef, lasso_coef, decimal=1)
np.testing.assert_array_almost_equal(lars_coef, lasso_coef2, decimal=1)
Expand Down
16 changes: 8 additions & 8 deletions sklearn/metrics/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1826,42 +1826,42 @@ def test_prf_warnings():
for average in [None, 'weighted', 'macro']:
msg = ('Precision and F-score are ill-defined and '
'being set to 0.0 in labels with no predicted samples.')
my_assert(w, f, msg, [0, 1, 2], [1, 1, 2], average=average)
my_assert(w, msg, f, [0, 1, 2], [1, 1, 2], average=average)

msg = ('Recall and F-score are ill-defined and '
'being set to 0.0 in labels with no true samples.')
my_assert(w, f, msg, [1, 1, 2], [0, 1, 2], average=average)
my_assert(w, msg, f, [1, 1, 2], [0, 1, 2], average=average)

# average of per-sample scores
msg = ('Precision and F-score are ill-defined and '
'being set to 0.0 in samples with no predicted labels.')
my_assert(w, f, msg, np.array([[1, 0], [1, 0]]),
my_assert(w, msg, f, np.array([[1, 0], [1, 0]]),
np.array([[1, 0], [0, 0]]), average='samples')

msg = ('Recall and F-score are ill-defined and '
'being set to 0.0 in samples with no true labels.')
my_assert(w, f, msg, np.array([[1, 0], [0, 0]]), np.array([[1, 0], [1, 0]]),
my_assert(w, msg, f, np.array([[1, 0], [0, 0]]), np.array([[1, 0], [1, 0]]),
average='samples')

# single score: micro-average
msg = ('Precision and F-score are ill-defined and '
'being set to 0.0 due to no predicted samples.')
my_assert(w, f, msg, np.array([[1, 1], [1, 1]]),
my_assert(w, msg, f, np.array([[1, 1], [1, 1]]),
np.array([[0, 0], [0, 0]]), average='micro')

msg =('Recall and F-score are ill-defined and '
'being set to 0.0 due to no true samples.')
my_assert(w, f, msg, np.array([[0, 0], [0, 0]]),
my_assert(w, msg, f, np.array([[0, 0], [0, 0]]),
np.array([[1, 1], [1, 1]]), average='micro')

# single postive label
msg = ('Precision and F-score are ill-defined and '
'being set to 0.0 due to no predicted samples.')
my_assert(w, f, msg, [1, 1], [-1, -1], average='macro')
my_assert(w, msg, f, [1, 1], [-1, -1], average='macro')

msg = ('Recall and F-score are ill-defined and '
'being set to 0.0 due to no true samples.')
my_assert(w, f, msg, [-1, -1], [1, 1], average='macro')
my_assert(w, msg, f, [-1, -1], [1, 1], average='macro')


def test__check_clf_targets():
Expand Down
60 changes: 57 additions & 3 deletions sklearn/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,27 @@ def _assert_greater(a, b, msg=None):

# To remove when we support numpy 1.7
def assert_warns(warning_class, func, *args, **kw):
"""Test that a certain warning occurs.
Parameters
----------
warning_class : the warning class
The class to test for, e.g. UserWarning.
func : callable
Calable object to trigger warnings.
*args : the positional arguments to `func`.
**kw : the keyword arguments to `func`
Returns
-------
result : the return value of `func`
"""

# very important to avoid uncontrolled state propagation
clean_warning_registry()
with warnings.catch_warnings(record=True) as w:
Expand All @@ -102,8 +123,34 @@ def assert_warns(warning_class, func, *args, **kw):

return result

def assert_warns_message(warning_class, func, message, *args, **kw):

def assert_warns_message(warning_class, message, func, *args, **kw):
# very important to avoid uncontrolled state propagation
"""Test that a certain warning occurs and with a certain message.
Parameters
----------
warning_class : the warning class
The class to test for, e.g. UserWarning.
message : str | callable
The entire message or a substring to test for. If callable,
it takes a string as argument and will trigger an assertion error
if it returns `False`.
func : callable
Calable object to trigger warnings.
*args : the positional arguments to `func`.
**kw : the keyword arguments to `func`.
Returns
-------
result : the return value of `func`
"""
clean_warning_registry()
with warnings.catch_warnings(record=True) as w:
# Cause all warnings to always be triggered.
Expand All @@ -119,8 +166,15 @@ def assert_warns_message(warning_class, func, message, *args, **kw):
raise AssertionError("First warning for %s is not a "
"%s( is %s)"
% (func.__name__, warning_class, w[0]))
msg = str(w[0].message)
if msg != message:

# substring will match, the entire message with typo won't
msg = w[0].message # For Python 3 compatibility
msg = str(msg.args[0] if hasattr(msg, 'args') else msg)
if callable(message): # add support for certain tests
check_in_message = message
else:
check_in_message = lambda msg : message in msg
if not check_in_message(msg):
raise AssertionError("The message received ('%s') for <%s> is "
"not the one you expected ('%s')"
% (msg, func.__name__, message
Expand Down

0 comments on commit 4c10c82

Please sign in to comment.