-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
[MRG + 1] BUG: remove checks from PyFunc distance metric (fixes #6287) #6288
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
LGTM |
Hello @jakevdp Once this pilot check is deleted, users may receive ambiguous error messages. For example,
What if just change this check a little bit such as It will raise
in the above case. |
I'd initially put the check where it is because it happens only once. A check such as #6289 will happen for every evaluation, and I'm afraid the impact on performance will be quite large (though admittedly, the user-defined function is not particularly performant as-is). |
@jakevdp Thanks for your useful opinion! What do you think if change d = self.func(x1arr, x2arr, **self.kwargs)
try:
return d
except TypeError:
raise TypeError("Customize function must return a float") Since the usual case (i.e. if user didn't do something silly) is no exception, I think |
I've tested it with the following script: import numpy as np
from sklearn.neighbors import BallTree
import timeit
n_samples = 10 ** 5
n_dim = 100
X = np.asarray(range(n_samples * n_dim)).reshape(n_samples, n_dim)
def correct_distance(x, y):
return np.sum((x - y) ** 2)
def balltree():
b = BallTree(X, metric=correct_distance)
time = timeit.Timer(balltree)
print min(time.repeat(number=10)) The no which means adding |
@yenchenlin1994 – great idea! I added that to the PR. |
d = self.func(x1arr, x2arr, **self.kwargs) | ||
try: | ||
return d | ||
except TypeError: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's my knowledge of Python that is lacking, but I am unsure that the exception will be really captured in the way I think.
Indeed, if I understand things correctly, the exception that we are expecting will be raised outside this function. Hence, I suspect that the try/except will not trigger. It seems to me that, whether it triggers or not is depends on the semantics of in which frame the exception is raised.
I think that it would be great to have a test that shows us that the exception is indeed raised, given that it is not trivial.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @GaelVaroquaux
I think test can be something like the following script,
it can show that this exception is indeed raised when customize distance function returns a non-float.
from sklearn.neighbors import BallTree
import timeit
import numpy as np
def wrong_distance(x, y):
return "1"
n_samples = 10 ** 3
n_dim = 10
X = np.asarray(range(n_samples * n_dim)).reshape(n_samples, n_dim)
b = BallTree(X, metric=wrong_distance)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @yenchenlin1994, that looks good. Would you like to submit a PR with that test? You'll have to put it in a function, probably in sklearn/neighbors/tests/test_dist_metrics.pyx
. You can use numpy.testing.assert_raises_regexp
to assert that the expected exception is being raised. One more comment: to make the test faster, you could do much fewer than 1000x10 points: even something like 5x2 would probably do it.
Let me know if you need help putting that test together!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, and @GaelVaroquaux – I had the same thought! I had to run a script like the one @yenchenlin1994 suggested to convince myself that it would catch the exception. I suspect the reason it's caught is because Cython produces code which does the type checking in the same block as the return
statement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's slightly crazy ^^
@jakevdp Thanks, I'll send a PR right after this PR get merged. |
We'll probably want the tests before merging this PR. You could either write a PR to my branch, or write a PR to master now and I'll cherry-pick the commit. |
@jakevdp I've written a PR to your branch. |
I think this can be merged. |
@jakevdp can you please rebase? |
It would be kinda nice to add a regression test against the original issue, i.e. have a metric that fails on 10d data but works on 3d data and test it with 3d data? |
Rebased. Let me add a couple tests... |
Regression test and @yenchenlin's test added. If all tests pass, I think this can be merged. |
Tests failed on old numpy versions... switched to using sklearn's backport of |
Flake8 error due to my over-zealous copy-paste |
LGTM if everything passes. @agramfort still looks good to you? |
Tests all pass. Good to merge? |
LGTM |
Added what's new in 0948ce9 |
…t-learn#6287) (scikit-learn#6288) # Conflicts: # sklearn/neighbors/tests/test_dist_metrics.py
# Conflicts: # doc/whats_new.rst
This check makes too many assumptions about the user-defined distance, and should probably be removed. (fixes #6287)