Skip to content

Commit

Permalink
Changed safe_asarray to check_arrary.
Browse files Browse the repository at this point in the history
Updated the insert function.
  • Loading branch information
maheshakya committed Jul 30, 2014
1 parent 37286a3 commit 06ce72a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 14 deletions.
19 changes: 8 additions & 11 deletions sklearn/neighbors/lsh_forest.py
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import itertools
from ..base import BaseEstimator
from ..utils.validation import safe_asarray
from ..utils.validation import check_array
from ..utils import check_random_state

from ..random_projection import GaussianRandomProjection
Expand Down Expand Up @@ -314,7 +314,7 @@ def fit(self, X=None):
if X is None:
raise ValueError("X cannot be None")

self._input_array = safe_asarray(X)
self._input_array = check_array(X)
self._n_dim = self._input_array.shape[1]

self.max_label_length = 32
Expand Down Expand Up @@ -405,7 +405,7 @@ def kneighbors(self, X, n_neighbors=None, return_distance=False):
if n_neighbors is not None:
self.n_neighbors = n_neighbors

X = safe_asarray(X)
X = check_array(X)
x_dim = X.ndim

if x_dim == 1:
Expand Down Expand Up @@ -452,7 +452,7 @@ def radius_neighbors(self, X, radius=None, return_distance=False):
if radius is not None:
self.radius = radius

X = safe_asarray(X)
X = check_array(X)
x_dim = X.ndim

if x_dim == 1:
Expand All @@ -477,7 +477,8 @@ def radius_neighbors(self, X, radius=None, return_distance=False):

def insert(self, X):
"""
Inserts new data into the LSH Forest.
Inserts new data into the LSH Forest. Cost is proportional
to new total size, so additions should be batched.
Parameters
----------
Expand All @@ -488,20 +489,16 @@ def insert(self, X):
raise ValueError("estimator should be fitted before"
" inserting.")

X = safe_asarray(X)
X = check_array(X)

if X.ndim != 2:
raise ValueError("X should be a 2-D matrix")
if X.shape[1] != self._input_array.shape[1]:
raise ValueError("Number of features in X and"
" fitted array does not match.")
n_samples = X.shape[0]
input_array_size = self._input_array.shape[0]

for i in range(self.n_trees):
bin_X = []
for j in range(n_samples):
bin_X.append(self._convert_to_hash(X[j], i))
bin_X = [self._convert_to_hash(X[j], i) for j in range(n_samples)]
# gets the position to be added in the tree.
positions = self._trees[i].searchsorted(bin_X)
# adds the hashed value into the tree.
Expand Down
3 changes: 0 additions & 3 deletions sklearn/neighbors/tests/test_lsh_forest.py
Expand Up @@ -222,9 +222,6 @@ def test_insert():

lshf.fit(X)

# Insert 1D array
assert_raises(ValueError, lshf.insert,
np.random.randn(dim))
# Insert wrong dimension
assert_raises(ValueError, lshf.insert,
np.random.randn(samples_insert, dim-1))
Expand Down

0 comments on commit 06ce72a

Please sign in to comment.