Skip to content

Commit

Permalink
FIX Make DistanceMetrics support readonly buffers attributes (#21694)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjerphan committed Nov 19, 2021
1 parent 432ae47 commit 33c39e5
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 7 deletions.
9 changes: 9 additions & 0 deletions doc/whats_new/v1.0.rst
Expand Up @@ -28,6 +28,15 @@ Changelog
and :class:`decomposition.MiniBatchSparsePCA` to be convex and match the referenced
article. :pr:`19210` by :user:`Jérémie du Boisberranger <jeremiedbb>`.

:mod:`sklearn.metrics`
......................

- |Fix| All :class:`sklearn.metrics.DistanceMetric` subclasses now correctly support
read-only buffer attributes.
This fixes a regression introduced in 1.0.0 with respect to 0.24.2.
:pr:`21694` by :user:`Julien Jerphanion <jjerphan>`.


:mod:`sklearn.preprocessing`
............................

Expand Down
12 changes: 6 additions & 6 deletions sklearn/metrics/_dist_metrics.pyx
Expand Up @@ -29,7 +29,7 @@ cdef DTYPE_t INF = np.inf

from ..utils._typedefs cimport DTYPE_t, ITYPE_t, DITYPE_t, DTYPECODE
from ..utils._typedefs import DTYPE, ITYPE

from ..utils._readonly_array_wrapper import ReadonlyArrayWrapper

######################################################################
# newObj function
Expand Down Expand Up @@ -214,8 +214,8 @@ cdef class DistanceMetric:
set state for pickling
"""
self.p = state[0]
self.vec = state[1]
self.mat = state[2]
self.vec = ReadonlyArrayWrapper(state[1])
self.mat = ReadonlyArrayWrapper(state[2])
if self.__class__.__name__ == "PyFuncDistance":
self.func = state[3]
self.kwargs = state[4]
Expand Down Expand Up @@ -444,7 +444,7 @@ cdef class SEuclideanDistance(DistanceMetric):
D(x, y) = \sqrt{ \sum_i \frac{ (x_i - y_i) ^ 2}{V_i} }
"""
def __init__(self, V):
self.vec = np.asarray(V, dtype=DTYPE)
self.vec = ReadonlyArrayWrapper(np.asarray(V, dtype=DTYPE))
self.size = self.vec.shape[0]
self.p = 2

Expand Down Expand Up @@ -605,7 +605,7 @@ cdef class WMinkowskiDistance(DistanceMetric):
raise ValueError("WMinkowskiDistance requires finite p. "
"For p=inf, use ChebyshevDistance.")
self.p = p
self.vec = np.asarray(w, dtype=DTYPE)
self.vec = ReadonlyArrayWrapper(np.asarray(w, dtype=DTYPE))
self.size = self.vec.shape[0]

def _validate_data(self, X):
Expand Down Expand Up @@ -665,7 +665,7 @@ cdef class MahalanobisDistance(DistanceMetric):
if VI.ndim != 2 or VI.shape[0] != VI.shape[1]:
raise ValueError("V/VI must be square")

self.mat = np.asarray(VI, dtype=float, order='C')
self.mat = ReadonlyArrayWrapper(np.asarray(VI, dtype=float, order='C'))

self.size = self.mat.shape[0]

Expand Down
24 changes: 23 additions & 1 deletion sklearn/metrics/tests/test_dist_metrics.py
Expand Up @@ -158,11 +158,16 @@ def check_pdist_bool(metric, D_true):
assert_array_almost_equal(D12, D_true)


@pytest.mark.parametrize("use_read_only_kwargs", [True, False])
@pytest.mark.parametrize("metric", METRICS_DEFAULT_PARAMS)
def test_pickle(metric):
def test_pickle(use_read_only_kwargs, metric):
argdict = METRICS_DEFAULT_PARAMS[metric]
keys = argdict.keys()
for vals in itertools.product(*argdict.values()):
if use_read_only_kwargs:
for val in vals:
if isinstance(val, np.ndarray):
val.setflags(write=False)
kwargs = dict(zip(keys, vals))
check_pickle(metric, kwargs)

Expand Down Expand Up @@ -242,3 +247,20 @@ def custom_metric(x, y):
pyfunc = DistanceMetric.get_metric("pyfunc", func=custom_metric)
eucl = DistanceMetric.get_metric("euclidean")
assert_array_almost_equal(pyfunc.pairwise(X), eucl.pairwise(X) ** 2)


def test_readonly_kwargs():
# Non-regression test for:
# https://github.com/scikit-learn/scikit-learn/issues/21685

rng = check_random_state(0)

weights = rng.rand(100)
VI = rng.rand(10, 10)
weights.setflags(write=False)
VI.setflags(write=False)

# Those distances metrics have to support readonly buffers.
DistanceMetric.get_metric("seuclidean", V=weights)
DistanceMetric.get_metric("wminkowski", p=1, w=weights)
DistanceMetric.get_metric("mahalanobis", VI=VI)

0 comments on commit 33c39e5

Please sign in to comment.