Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

ENH: use warning.catch_warnings

  • Loading branch information...
commit 0040a41d870f1785ac65219074727ac8ece1609a 1 parent f18a5f4
@GaelVaroquaux GaelVaroquaux authored
Showing with 9 additions and 15 deletions.
  1. +9 −15 sklearn/tests/test_cross_validation.py
View
24 sklearn/tests/test_cross_validation.py
@@ -164,28 +164,22 @@ def test_train_test_split_errors():
def test_shuffle_split_warnings():
# change warnings.warn to catch the message
- warn_queue = []
- warnings_warn = warnings.warn
- warnings.warn = lambda msg: warn_queue.append(msg)
-
expected_message = ("test_fraction is deprecated in 0.11 and scheduled "
"for removal in 0.12, use test_size instead",
"train_fraction is deprecated in 0.11 and scheduled "
"for removal in 0.12, use train_size instead")
- cval.ShuffleSplit(10, 3, test_fraction=0.1)
- cval.ShuffleSplit(10, 3, train_fraction=0.1)
- cval.train_test_split(range(3), test_fraction=0.1)
- cval.train_test_split(range(3), train_fraction=0.1)
+ with warnings.catch_warnings(record=True) as warn_queue:
+ cval.ShuffleSplit(10, 3, test_fraction=0.1)
+ cval.ShuffleSplit(10, 3, train_fraction=0.1)
+ cval.train_test_split(range(3), test_fraction=0.1)
+ cval.train_test_split(range(3), train_fraction=0.1)
assert_equal(len(warn_queue), 4)
- assert_equal(warn_queue[0], expected_message[0])
- assert_equal(warn_queue[1], expected_message[1])
- assert_equal(warn_queue[2], expected_message[0])
- assert_equal(warn_queue[3], expected_message[1])
-
- # restore default behavior
- warnings.warn = warnings_warn
+ assert_equal(warn_queue[0].message.message, expected_message[0])
+ assert_equal(warn_queue[1].message.message, expected_message[1])
+ assert_equal(warn_queue[2].message.message, expected_message[0])
+ assert_equal(warn_queue[3].message.message, expected_message[1])
def test_train_test_split():
Please sign in to comment.
Something went wrong with that request. Please try again.