Skip to content

Conversation

jakevdp
Copy link
Member

@jakevdp jakevdp commented Feb 5, 2016

This check makes too many assumptions about the user-defined distance, and should probably be removed. (fixes #6287)

@agramfort
Copy link
Member

LGTM

@yenchenlin
Copy link
Contributor

Hello @jakevdp

Once this pilot check is deleted, users may receive ambiguous error messages.

For example,
if customize distance function now return a value of type string instead of type float

File "sklearn/neighbors/dist_metrics.pyx", line 1114, in sklearn.neighbors.dist_metrics.PyFuncDistance.dist (sklearn/neighbors/dist_metrics.c:11202) TypeError: a float is required

What if just change this check a little bit such as #6289 ?

It will raise

ValueError("Customize distance function must return a float")

in the above case.

@jakevdp
Copy link
Member Author

jakevdp commented Feb 5, 2016

I'd initially put the check where it is because it happens only once. A check such as #6289 will happen for every evaluation, and I'm afraid the impact on performance will be quite large (though admittedly, the user-defined function is not particularly performant as-is).

@yenchenlin
Copy link
Contributor

@jakevdp Thanks for your useful opinion!

What do you think if change
https://github.com/jakevdp/scikit-learn/blob/fix6287/sklearn/neighbors/dist_metrics.pyx#L1103
to

d = self.func(x1arr, x2arr, **self.kwargs)
try:
    return d
except TypeError:
    raise TypeError("Customize function must return a float")

Since the usual case (i.e. if user didn't do something silly) is no exception, I think try-except will be very efficient.
What do you think?

@yenchenlin
Copy link
Contributor

I've tested it with the following script:

import numpy as np
from sklearn.neighbors import BallTree
import timeit

n_samples = 10 ** 5
n_dim = 100
X = np.asarray(range(n_samples * n_dim)).reshape(n_samples, n_dim)

def correct_distance(x, y):
    return np.sum((x - y) ** 2)

def balltree():
    b = BallTree(X, metric=correct_distance)

time = timeit.Timer(balltree)
print min(time.repeat(number=10))

The no try-except version prints 80.8509781361 s,
while try-except version prints 80.3812861443 s,

which means adding try-except here barely affect the performance.

@jakevdp
Copy link
Member Author

jakevdp commented Feb 6, 2016

@yenchenlin1994 – great idea! I added that to the PR.

d = self.func(x1arr, x2arr, **self.kwargs)
try:
return d
except TypeError:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's my knowledge of Python that is lacking, but I am unsure that the exception will be really captured in the way I think.

Indeed, if I understand things correctly, the exception that we are expecting will be raised outside this function. Hence, I suspect that the try/except will not trigger. It seems to me that, whether it triggers or not is depends on the semantics of in which frame the exception is raised.

I think that it would be great to have a test that shows us that the exception is indeed raised, given that it is not trivial.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @GaelVaroquaux

I think test can be something like the following script,
it can show that this exception is indeed raised when customize distance function returns a non-float.

from sklearn.neighbors import BallTree
import timeit
import numpy as np

def wrong_distance(x, y):
    return "1"

n_samples = 10 ** 3
n_dim = 10
X = np.asarray(range(n_samples * n_dim)).reshape(n_samples, n_dim)
b = BallTree(X, metric=wrong_distance)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yenchenlin1994, that looks good. Would you like to submit a PR with that test? You'll have to put it in a function, probably in sklearn/neighbors/tests/test_dist_metrics.pyx. You can use numpy.testing.assert_raises_regexp to assert that the expected exception is being raised. One more comment: to make the test faster, you could do much fewer than 1000x10 points: even something like 5x2 would probably do it.

Let me know if you need help putting that test together!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, and @GaelVaroquaux – I had the same thought! I had to run a script like the one @yenchenlin1994 suggested to convince myself that it would catch the exception. I suspect the reason it's caught is because Cython produces code which does the type checking in the same block as the return statement.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's slightly crazy ^^

@yenchenlin
Copy link
Contributor

@jakevdp Thanks, I'll send a PR right after this PR get merged.

@jakevdp
Copy link
Member Author

jakevdp commented Feb 8, 2016

We'll probably want the tests before merging this PR. You could either write a PR to my branch, or write a PR to master now and I'll cherry-pick the commit.

@yenchenlin
Copy link
Contributor

@jakevdp I've written a PR to your branch.
Please notify me if I do it wrong.
Thanks!

@jakevdp jakevdp changed the title BUG: remove checks from PyFunc distance metric (fixes #6287) [MRG] BUG: remove checks from PyFunc distance metric (fixes #6287) Feb 18, 2016
@jakevdp
Copy link
Member Author

jakevdp commented Feb 18, 2016

I think this can be merged.

@amueller
Copy link
Member

amueller commented Oct 7, 2016

@jakevdp can you please rebase?

@amueller
Copy link
Member

amueller commented Oct 7, 2016

It would be kinda nice to add a regression test against the original issue, i.e. have a metric that fails on 10d data but works on 3d data and test it with 3d data?

@amueller amueller added this to the 0.19 milestone Oct 8, 2016
@jakevdp
Copy link
Member Author

jakevdp commented Oct 10, 2016

Rebased. Let me add a couple tests...

@jakevdp
Copy link
Member Author

jakevdp commented Oct 10, 2016

Regression test and @yenchenlin's test added. If all tests pass, I think this can be merged.

@jakevdp
Copy link
Member Author

jakevdp commented Oct 10, 2016

Tests failed on old numpy versions... switched to using sklearn's backport of assert_raises_regex. We'll see if that does the trick

@jakevdp
Copy link
Member Author

jakevdp commented Oct 10, 2016

Flake8 error due to my over-zealous copy-paste

@amueller
Copy link
Member

LGTM if everything passes. @agramfort still looks good to you?

@amueller amueller changed the title [MRG] BUG: remove checks from PyFunc distance metric (fixes #6287) [MRG + 1] BUG: remove checks from PyFunc distance metric (fixes #6287) Oct 10, 2016
@jakevdp
Copy link
Member Author

jakevdp commented Oct 10, 2016

Tests all pass. Good to merge?

@jnothman
Copy link
Member

LGTM

@jnothman jnothman merged commit cbd3bca into scikit-learn:master Oct 11, 2016
@jnothman
Copy link
Member

Added what's new in 0948ce9

jnothman added a commit that referenced this pull request Oct 11, 2016
amueller added a commit to amueller/scikit-learn that referenced this pull request Oct 14, 2016
amueller added a commit to amueller/scikit-learn that referenced this pull request Oct 14, 2016
# Conflicts:
#	doc/whats_new.rst
Sundrique pushed a commit to Sundrique/scikit-learn that referenced this pull request Jun 14, 2017
paulha pushed a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BallTree calls custom functions with non desired values
6 participants