Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: spatial: improve dtype stability for all distance metrics #14909

Closed
wants to merge 9 commits into from

Conversation

AtsushiSakai
Copy link
Member

Reference issue

fix #9961

What does this implement/fix?

As reported in #9961, current distance metric functions might have different results with different dtypes (integer vs float).
So, I moved the type conversion codes from sqeuclidean to _validate_vector in order to apply it for every distance metric.
I added a test for it.

@AtsushiSakai AtsushiSakai added defect A clear bug or issue that prevents SciPy from being installed or used as expected scipy.spatial and removed scipy.spatial labels Oct 23, 2021
# or the dtype is not string.
if not (hasattr(u, "dtype") and
(np.issubdtype(u.dtype, np.inexact)
or (u.dtype.kind == 'S'))):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to use i.e., np.character subdtype check here in case of the use of another flexible dtype like Unicode?

# for stability only when the dtype is not a subtype of inexact
# or the dtype is not string.
if not (hasattr(u, "dtype") and
(np.issubdtype(u.dtype, np.inexact)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we always convert array-like inputs (even a list of "characters") as long as they are not a NumPy array proper?

For proper NumPy arrays, we do coerce to a floating point type when it makes sense, and that includes additional checking for things like characters that can't easily be converted to floats (which we don't check for with other array-like).

I don't know if it matters too much--perhaps attempting an asarray() temporary conversion for the array-like values just for this check might be overkill anyway.

Copy link
Member Author

@AtsushiSakai AtsushiSakai Oct 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. Right, distance.hamming might receive a string list. I think only a NumPy array should be converted to a floating array. I fixed it and add some tests for it. PTAL.

elif metric == seuclidean:
assert_almost_equal(metric(ai, bi, vo), metric(af, bf, vo))
else:
assert_almost_equal(metric(ai, bi), metric(af, bf))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here and above I suppose we may wish to use the recommended assert_allclose when adding new code per: https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_almost_equal.html

@@ -298,6 +298,15 @@ def _validate_seuclidean_kwargs(X, m, n, **kwargs):


def _validate_vector(u, dtype=None):

if dtype is None and type(u) is np.ndarray:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this change the behavior for sqeuclidean() from what it was before? I'm not sure if we need to worry about it too much, but I'll just outline this below with an example.

This would mean two changes I think:

  • we would no longer coerce subtypes of ndarray to floats because of the new is np.ndarray check
  • we would no longer coerce array-like to float64 arrays, which it looks like we did before because of the short-circuit on hasattr(u, "dtype")
import numpy as np
from scipy.spatial.distance import sqeuclidean

np.random.seed(31415)

a = np.random.randint(100, size=100, dtype='uint8')
b = np.random.randint(100, size=100, dtype='uint8')

class C(np.ndarray):
    pass

c_arr_a = a.view(C)
c_arr_b = b.view(C)
print(sqeuclidean(c_arr_a, c_arr_b))
print(sqeuclidean(c_arr_a.astype(float), c_arr_b.astype(float)))
print(sqeuclidean(c_arr_a.tolist(), c_arr_b.tolist()))
print(type(sqeuclidean(c_arr_a.tolist(), c_arr_b.tolist())))

# issue9961 feature branch:
# 248
# 166136.0
# 166136
# <class 'numpy.int64'>
# master:
# 166136.0
# 166136.0
# 166136.0
# <class 'numpy.float64'>

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you again. I missed the a.view(C) case and the type mismatch case. I extended these test cases.
The only solution I could come up with is to convert the input u to ndarray beforehand. Do you have any other better idea?

# or the dtype is not string.
u = np.asarray(u)
if not (np.issubdtype(u.dtype, np.inexact)
or (np.issubdtype(u.dtype, np.character))):
Copy link
Member

@peterbell10 peterbell10 Oct 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is a good idea because it effects metrics expecting boolean input, where floating point input isn't necessary or helpful. Instead, how about creating a new function _validate_real_vector and call it in only the functions that require real inputs.

@AtsushiSakai
Copy link
Member Author

The jensenshannon does not use _validate_vector now. I'm not sure this is intended or not.

p = np.asarray(p)
q = np.asarray(q)

@peterbell10
Copy link
Member

It can't use _validate_vector because it supports an axis argument and higher-dimensional input.

@AtsushiSakai
Copy link
Member Author

I rebased to fix CI failures.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
defect A clear bug or issue that prevents SciPy from being installed or used as expected scipy.spatial
Projects
None yet
Development

Successfully merging this pull request may close these issues.

unexpected behavior from scipy.spatial.distances with different dtypes (uint8)
3 participants