Skip to content

Commit

Permalink
Merge 5158a02 into 0541dd0
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Jun 2, 2019
2 parents 0541dd0 + 5158a02 commit d842b95
Showing 1 changed file with 23 additions and 18 deletions.
41 changes: 23 additions & 18 deletions mlxtend/feature_selection/tests/test_sequential_feature_selector.py
Expand Up @@ -38,22 +38,26 @@ def nan_roc_auc_score(y_true, y_score, average='macro', sample_weight=None):
average=average, sample_weight=sample_weight)


def dict_compare_utility(d1, d2, decimal=3):
assert d1.keys() == d2.keys(), "%s != %s" % (d1, d2)
for i in d1:
err_msg = ("d1[%s]['feature_idx']"
" != d2[%s]['feature_idx']" % (i, i))
assert d1[i]['feature_idx'] == d1[i]["feature_idx"], err_msg
assert_almost_equal(d1[i]['avg_score'],
d2[i]['avg_score'],
def dict_compare_utility(d_actual, d_desired, decimal=3):
assert d_actual.keys() == d_desired.keys(), "%s != %s" \
% (d_actual, d_desired)
for i in d_actual:
err_msg = ("d_actual[%s]['feature_idx']"
" != d_desired[%s]['feature_idx']" % (i, i))
assert d_actual[i]['feature_idx'] == d_desired[i]["feature_idx"],\
err_msg
assert_almost_equal(actual=d_actual[i]['avg_score'],
desired=d_desired[i]['avg_score'],
decimal=decimal,
err_msg=("d1[%s]['avg_score']"
" != d2[%s]['avg_score']" % (i, i)))
assert_almost_equal(d1[i]['cv_scores'],
d2[i]['cv_scores'],
err_msg=("d_actual[%s]['avg_score']"
" != d_desired[%s]['avg_score']"
% (i, i)))
assert_almost_equal(actual=d_actual[i]['cv_scores'],
desired=d_desired[i]['cv_scores'],
decimal=decimal,
err_msg=("d1[%s]['cv_scores']"
" != d2[%s]['cv_scores']" % (i, i)))
err_msg=("d_actual[%s]['cv_scores']"
" != d_desired[%s]['cv_scores']"
% (i, i)))


def test_run_default():
Expand Down Expand Up @@ -182,7 +186,7 @@ def test_knn_wo_cv():
3: {'avg_score': 0.97333333333333338,
'cv_scores': np.array([0.97333333]),
'feature_idx': (1, 2, 3)}}
dict_compare_utility(d1=expect, d2=sfs1.subsets_)
dict_compare_utility(d_actual=sfs1.subsets_, d_desired=expect)


def test_knn_cv3():
Expand Down Expand Up @@ -216,7 +220,7 @@ def test_knn_cv3():
0.94444444,
0.97222222]),
'feature_idx': (1, 2, 3)}}
dict_compare_utility(d1=expect, d2=sfs1.subsets_)
dict_compare_utility(d_actual=sfs1.subsets_, d_desired=expect)


def test_knn_cv3_groups():
Expand Down Expand Up @@ -244,7 +248,8 @@ def test_knn_cv3_groups():
3: {'cv_scores': np.array([0.97916667, 0.95918367, 0.94339623]),
'feature_idx': (1, 2, 3),
'avg_score': 0.9605821888503829}}
dict_compare_utility(d1=expect, d2=sfs1.subsets_, decimal=3)
dict_compare_utility(d_actual=sfs1.subsets_, d_desired=expect, decimal=3)


def test_knn_rbf_groupkfold():
nan_roc_auc_scorer = make_scorer(nan_roc_auc_score)
Expand Down Expand Up @@ -294,7 +299,7 @@ def test_knn_rbf_groupkfold():
'avg_score': 0.55,
'feature_idx': (1, 2, 3)}}

dict_compare_utility(d1=expect, d2=sfs1.subsets_, decimal=1)
dict_compare_utility(d_actual=sfs1.subsets_, d_desired=expect, decimal=1)


def test_knn_option_sfs():
Expand Down

0 comments on commit d842b95

Please sign in to comment.