Skip to content

Commit

Permalink
Densify one column at a time to do slicing, revert changes to sparsef…
Browse files Browse the repository at this point in the history
…uncs_fast
  • Loading branch information
hamsal committed Aug 14, 2014
1 parent e82770e commit 72f5cdd
Showing 1 changed file with 3 additions and 19 deletions.
22 changes: 3 additions & 19 deletions sklearn/neighbors/classification.py
Expand Up @@ -13,7 +13,6 @@

from scipy import stats
from sklearn.utils.extmath import weighted_mode
from sklearn.utils.sparsefuncs_fast import csr_row_mode

from sklearn.neighbors.base import \
_check_weights, _get_weights, \
Expand Down Expand Up @@ -170,28 +169,13 @@ def predict(self, X):
indptr = array.array('i', [0])

for k, classes_k in enumerate(classes_):
# Using _y[neigh_ind, k] is not supported with scipy <0.13
# so we recreate fancy indexing using numpy functions
_y.sum_duplicates()
_y_data_k = _y.data[_y.indptr[k]:_y.indptr[k+1]]
_y_indices_k = _y.indices[_y.indptr[k]:_y.indptr[k+1]]

# Find the neigh_ind in _y.data using _y.indices as a guide
data_index = np.searchsorted(_y_indices_k, neigh_ind)
data_index[data_index == _y_data_k.shape[0]] = 0
if _y_data_k.size == 0:
neigh_lbls_k = np.zeros(shape=data_index.shape)
else:
neigh_lbls_k = _y_data_k[data_index]
# Replace incorrect nonzero elements with correct zeros
neigh_lbls_k[_y_indices_k[data_index] != neigh_ind] = 0
neigh_lbls_k = _y.getcol(k).toarray().ravel()[neigh_ind]

if weights is None:
mode = csr_row_mode(sp.csr_matrix(neigh_lbls_k))
mode, _ = stats.mode(neigh_lbls_k, axis=1)
mode = sp.csc_matrix(mode, dtype=np.intp)
else:
mode = mode = csr_row_mode(sp.csr_matrix(neigh_lbls_k),
weights)
mode, _ = weighted_mode(neigh_lbls_k, weights, axis=1)
mode = sp.csc_matrix(mode, dtype=np.intp)

data.extend(mode.data)
Expand Down

0 comments on commit 72f5cdd

Please sign in to comment.