Skip to content

Commit

Permalink
ENH: use warning.catch_warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
GaelVaroquaux committed Apr 22, 2012
1 parent f18a5f4 commit 0040a41
Showing 1 changed file with 9 additions and 15 deletions.
24 changes: 9 additions & 15 deletions sklearn/tests/test_cross_validation.py
Expand Up @@ -164,28 +164,22 @@ def test_train_test_split_errors():

def test_shuffle_split_warnings():
# change warnings.warn to catch the message
warn_queue = []
warnings_warn = warnings.warn
warnings.warn = lambda msg: warn_queue.append(msg)

expected_message = ("test_fraction is deprecated in 0.11 and scheduled "
"for removal in 0.12, use test_size instead",
"train_fraction is deprecated in 0.11 and scheduled "
"for removal in 0.12, use train_size instead")

cval.ShuffleSplit(10, 3, test_fraction=0.1)
cval.ShuffleSplit(10, 3, train_fraction=0.1)
cval.train_test_split(range(3), test_fraction=0.1)
cval.train_test_split(range(3), train_fraction=0.1)
with warnings.catch_warnings(record=True) as warn_queue:
cval.ShuffleSplit(10, 3, test_fraction=0.1)
cval.ShuffleSplit(10, 3, train_fraction=0.1)
cval.train_test_split(range(3), test_fraction=0.1)
cval.train_test_split(range(3), train_fraction=0.1)

assert_equal(len(warn_queue), 4)
assert_equal(warn_queue[0], expected_message[0])
assert_equal(warn_queue[1], expected_message[1])
assert_equal(warn_queue[2], expected_message[0])
assert_equal(warn_queue[3], expected_message[1])

# restore default behavior
warnings.warn = warnings_warn
assert_equal(warn_queue[0].message.message, expected_message[0])
assert_equal(warn_queue[1].message.message, expected_message[1])
assert_equal(warn_queue[2].message.message, expected_message[0])
assert_equal(warn_queue[3].message.message, expected_message[1])


def test_train_test_split():
Expand Down

0 comments on commit 0040a41

Please sign in to comment.