Skip to content
Browse files

Backport some performance improvements from master.

Stripped down version of the real stuff.
  • Loading branch information...
1 parent 965ecb8 commit 39087109d235c562ae335aa3864f7e5ef2076a4f @fabianp fabianp committed Feb 23, 2011
Showing with 25 additions and 12 deletions.
  1. +25 −12 scikits/learn/neighbors.py
View
37 scikits/learn/neighbors.py
@@ -37,7 +37,7 @@ class Neighbors(BaseEstimator, ClassifierMixin):
>>> neigh.fit(samples, labels)
Neighbors(n_neighbors=3, window_size=1)
>>> print neigh.predict([[0,0,0]])
- [ 0.]
+ [0]
Notes
-----
@@ -54,9 +54,16 @@ def __init__(self, n_neighbors=5, window_size=1):
def fit(self, X, Y=()):
# we need Y to be an integer, because after we'll use it an index
self.Y = np.asanyarray(Y, dtype=np.int)
- self.ball_tree = BallTree(X, self.window_size)
+ X = np.asanyarray(X)
+
+ if X.shape[1] < 20:
+ self.ball_tree = BallTree(X, self.window_size)
+ else:
+ self.ball_tree = None
+ self._fit_X = X
return self
+
def kneighbors(self, data, n_neighbors=None):
"""Finds the K-neighbors of a point.
@@ -135,19 +142,25 @@ def predict(self, T, n_neighbors=None):
T = np.asanyarray(T)
if n_neighbors is None:
n_neighbors = self.n_neighbors
- return _predict_from_BallTree(self.ball_tree, self.Y, T, n_neighbors)
+ if self.ball_tree is None:
+ from .metrics.pairwise import euclidian_distances
+ dist = euclidian_distances(
+ T, self._fit_X)
+ neigh_ind = dist.argsort(axis=1)[:, :self.n_neighbors]
+ pred_labels = self.Y[neigh_ind]
+ else:
+ neigh_ind = self.ball_tree.query(
+ T, self.n_neighbors, return_distance=False)
+ pred_labels = self.Y[neigh_ind]
-def _predict_from_BallTree(ball_tree, Y, test, n_neighbors):
- """Predict target from BallTree object containing the data points.
+ if n_neighbors == 1:
+ return pred_labels.flatten()
+
+ from scipy import stats
+ mode, _ = stats.mode(pred_labels, axis=1)
+ return mode.flatten().astype(np.int)
- This is a helper method, not meant to be used directly. It will
- not check that input is of the correct type.
- """
- Y_ = Y[ball_tree.query(test, k=n_neighbors, return_distance=False)]
- if n_neighbors == 1:
- return Y_
- return (stats.mode(Y_, axis=1)[0]).ravel()
###############################################################################
# Neighbors Barycenter class for regression problems

0 comments on commit 3908710

Please sign in to comment.
Something went wrong with that request. Please try again.