Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

possible SciKit Learn version issue #26

Closed
AruniRC opened this issue Apr 6, 2017 · 3 comments
Closed

possible SciKit Learn version issue #26

AruniRC opened this issue Apr 6, 2017 · 3 comments

Comments

@AruniRC
Copy link

AruniRC commented Apr 6, 2017

When training the network from scratch using ./train/train_model.sh 0, the following error happens at the "Solving LtoAB" step of the training:

caffe_traininglayers.py", line 331, in encode_points_mtx_nd
    (dists,inds) = self.nbrs.kneighbors(pts_flt)

....

  File "sklearn/neighbors/binary_tree.pxi", line 1309, in sklearn.neighbors.ball_tree.BinaryTree.query (sklearn/neighbors/ball_tree.c:11514)
  File "sklearn/neighbors/binary_tree.pxi", line 587, in sklearn.neighbors.ball_tree.NeighborsHeap.__init__ (sklearn/neighbors/ball_tree.c:5582)
TypeError: 'float' object cannot be interpreted as an index

The data pts_flt is a float32 numpy ndarray. Could this be due to a version problem in sklearn itself (I am using 0.18.1)?
Please let me know what version of scikit is used for this codebase and I'll match that and try training again.

thanks,
Aruni

@AruniRC
Copy link
Author

AruniRC commented Apr 7, 2017

Issue solved:

In colorization/resources/caffe_traininglayers.py, the original line was

self.nbrs = nn.NearestNeighbors(n_neighbors=NN, algorithm='ball_tree').fit(self.cc)

This would cause n_neighbors to be 10.0 and not 10, which leads to the indexing error. I believe later versions of numpy and sklearn have become more strict about not allowing index variables to be integer-valued float dtypes. n_neighbors=NN shold read n_neighbors=self.NN.

That line should be changed to:

self.nbrs = nn.NearestNeighbors(n_neighbors=self.NN, algorithm='ball_tree').fit(self.cc)

Hope this helps anyone else with this issue.

@AruniRC AruniRC closed this as completed Apr 7, 2017
@AruniRC AruniRC reopened this Apr 7, 2017
@crazyzsy
Copy link

Thank you ,I solved.

@youyingyin
Copy link

@AruniRC oh man you save me!thank you so much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants