Permalink
Browse files

an simple test case for KNeighborsClassifier.predict_proba()

  • Loading branch information...
1 parent 7bb3c6d commit 8b5142a58a734129ffa65937abce030c732cc604 @kernc kernc committed Jun 10, 2012
Showing with 23 additions and 0 deletions.
  1. +23 −0 sklearn/neighbors/tests/test_neighbors.py
View
23 sklearn/neighbors/tests/test_neighbors.py
@@ -189,6 +189,29 @@ def test_kneighbors_classifier(n_samples=40,
assert_array_equal(y_pred, y[:n_test_pts])
+def test_kneighbors_classifier_predict_proba():
+ """Test KNeighborsClassifier.predict_proba() method"""
+ X = np.array([[0,2,0],
+ [0,2,1],
+ [2,0,0],
+ [2,2,0],
+ [0,0,2],
+ [0,0,1]])
+ y = np.array([4, 4, 5, 5, 1, 1])
+ cls = neighbors.KNeighborsClassifier(n_neighbors=3, p=1) # cityblock dist
+ cls.fit(X, y)
+ y_prob = cls.predict_proba(X)
+ real_prob = np.array([[0, 2./3, 1./3],
+ [1./3, 2./3, 0],
+ [1./3, 0, 2./3],
+ [0, 1./3, 2./3],
+ [2./3, 1./3, 0],
+ [2./3, 1./3, 0]])
+ assert_array_equal(real_prob, y_prob)
+
+
+
+
def test_radius_neighbors_classifier(n_samples=40,
n_features=5,
n_test_pts=10,

6 comments on commit 8b5142a

@amueller
scikit-learn member

pep8?

@agramfort
scikit-learn member

shame ... feel free to fix it if not already done

@amueller
scikit-learn member

did it ;)

@kernc

argh, sorry. I forgot. :<
I don't get what's wrong with spaces on blank lines.

@agramfort
scikit-learn member
@mblondel
scikit-learn member

When 40 people work together to make each release, it's important to have a little bit of style consistency...

Please sign in to comment.