Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Detect precision issues in euclidean distance calculations #12142

Closed
wants to merge 3 commits into from

Conversation

rth
Copy link
Member

@rth rth commented Sep 24, 2018

Following the discussion in #9354
this is an attempt to warn the user when numerical precision issues may occur in the euclidean distances computed with quadratic expansion. This affects both 32 bit and 64bit.

The overall idea is that we need to detect when the distance between two vectors is much less then the vector norms (by a factor of 1e7 in 64bit).

Here we take a more simplified approach and instead only consider the global distribution of compared vectors. If all considered vectors are within a very thin shell such a the shell thinkness divided by the mean vector norm is below a given threshold, a warning is raised (cf image below).

euclidean_distances_accuracy_illustration

  • If all vectors are within the cluster A, a warning is raised as it should. For instance, an example of this would [[100000001, 100000000]] and [[100000000, 100000001]], as reported by @kno10 in Numerical precision of euclidean_distances with float32 #9354 (comment)
  • If points belong to both cluster A and B (e.g. by adding a minus sign to one of the above vectors), a warning will be raised while it's shouldn't be (false postive)
  • if points belong to very tight clusters A and C, no warning will be raised even though it should be.

Here the cost of checking for precision issues is only O(n_samples_a + n_samples_b) which is negligible to the cost of computing pairwise distances O(n_samples_a*n_samples_b*n_features).

This it's more a proof of concept, more work would be needed to make this useful.

TODO

  • check how this solution is addressed in other libraries. It seem to be a common issue, e.g. from pytorch discussion forum,

    This seems to be a well documented issue with this approach and existing libraries will use thresholding to default to the direct computation when needed.

  • contact the scipy community about it
  • generalize the above approach. I think, it might be possible to get an exhaustive list of problematic points fairly efficiently (possibly either with one pass through sorted vectors of ‖A‖₂, ‖B‖₂, or using e.g. BallTree.query_radius), then compute them accurately. The main idea is that comparing only vector norms, may give an idea when precision is problematic, with low computational effort, even though there are false positives.

kno10 and others added 2 commits September 23, 2018 11:13
Surprisingly bad precision, isn't it?

Note that the traditional computation sqrt(sum((x-y)**2)) gets the results exact.
warning_message += (
"Consider standardizing features by removing the mean, "
"or setting globally "
"euclidean_distances_algorithm='exact' for slower but "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we not just back off to using the exact algorithm in this case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that would probably be the best outcome. But I would like to explore more fine-grained detection of problematic points, instead of considering just the min/max of the distribution for this to actually be useful.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Standardizing, centering etc. will not help in all cases. Consider the 1d data set -10001,-10000,10000,10001 with FP32; essentially the symmetric version of the example I provided before. The mean is zero; and scaling it will not resolve the relative closeness of the pairs of values.

Copy link
Member Author

@rth rth Oct 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, sample statistics is not the solution here. This was merely an attempt at prototyping. Theoretically though since we already have the norms, we could use comparison between the norms as a proxy to determine potentially problematic points at the cost of O(n_samples_a*n_samples_b) as compared to the overall cost O(n_samples_a*n_samples_b*n_features) for the full calculation. Or less if we use some tree structure. Then recompute those points exactly, though it would have its own limitations...

@jnothman
Copy link
Member

jnothman commented Sep 27, 2018 via email

@kno10
Copy link
Contributor

kno10 commented Oct 6, 2018

Some additional good values for testing implementations include (note that these aim at a different thing - at causing the distances to become negative due to rounding - they are derived from test values for verifying variance implementations, and the values are 1 ULP apart, while the numeric accuracy of the dot-based equation used here is supposedly sqrt(1 ULP)):

fp32 = numpy.array([[2.8],[2.800001]], numpy.float32)
fp64 = numpy.array([[1.2],[1.2-2e-16]], numpy.float64)

These will cause havoc to a naive implementation:

print numpy.sqrt((fp32**2).sum(axis=1)-2*numpy.dot(fp32,fp32.transpose())+(fp32**2).sum(axis=1).reshape((-1,1)))
print numpy.sqrt((fp64**2).sum(axis=1)-2*numpy.dot(fp64,fp64.transpose())+(fp64**2).sum(axis=1).reshape((-1,1)))

Notice the NaNs:

cc.py:54: RuntimeWarning: invalid value encountered in sqrt
  print numpy.sqrt((fp32**2).sum(axis=1)-2*numpy.dot(fp32,fp32.transpose())+(fp32**2).sum(axis=1).reshape((-1,1)))
[[ 0. nan]
 [nan  0.]]
cc.py:55: RuntimeWarning: invalid value encountered in sqrt
  print numpy.sqrt((fp64**2).sum(axis=1)-2*numpy.dot(fp64,fp64.transpose())+(fp64**2).sum(axis=1).reshape((-1,1)))
[[ 0. nan]
 [nan  0.]]

The current released code should pass this test because of the np.maximum(0, hack; but when working on improved versions, make sure to not break this.

@jnothman
Copy link
Member

Happy to close?

@rth
Copy link
Member Author

rth commented Feb 28, 2019

Agreed, closing following discussion in #9354 (comment)

@rth rth closed this Feb 28, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants