Skip to content

Commit

Permalink
MNT Move DistanceMetric under metrics (#21177)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjerphan authored and glemaitre committed Dec 25, 2021
1 parent 1b97932 commit 71ca247
Show file tree
Hide file tree
Showing 27 changed files with 166 additions and 107 deletions.
7 changes: 3 additions & 4 deletions doc/glossary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -644,9 +644,8 @@ General Concepts

Note that for most distance metrics, we rely on implementations from
:mod:`scipy.spatial.distance`, but may reimplement for efficiency in
our context. The :mod:`neighbors` module also duplicates some metric
implementations for integration with efficient binary tree search data
structures.
our context. The :class:`metrics.DistanceMetric` interface is used to implement
distance metrics for integration with efficient neighbors search.

pd
A shorthand for `Pandas <https://pandas.pydata.org>`_ due to the
Expand Down Expand Up @@ -1023,7 +1022,7 @@ such as:

Further examples:

* :class:`neighbors.DistanceMetric`
* :class:`metrics.DistanceMetric`
* :class:`gaussian_process.kernels.Kernel`
* ``tree.Criterion``

Expand Down
11 changes: 10 additions & 1 deletion doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,16 @@ further details.

metrics.consensus_score

Distance metrics
----------------

.. currentmodule:: sklearn

.. autosummary::
:toctree: generated/
:template: class.rst

metrics.DistanceMetric

Pairwise metrics
----------------
Expand Down Expand Up @@ -1317,7 +1327,6 @@ Model validation
:template: class.rst

neighbors.BallTree
neighbors.DistanceMetric
neighbors.KDTree
neighbors.KernelDensity
neighbors.KNeighborsClassifier
Expand Down
6 changes: 3 additions & 3 deletions doc/modules/density.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ The form of these kernels is as follows:
:math:`K(x; h) \propto \cos(\frac{\pi x}{2h})` if :math:`x < h`

The kernel density estimator can be used with any of the valid distance
metrics (see :class:`~sklearn.neighbors.DistanceMetric` for a list of available metrics), though
the results are properly normalized only for the Euclidean metric. One
particularly useful metric is the
metrics (see :class:`~sklearn.metrics.DistanceMetric` for a list of
available metrics), though the results are properly normalized only
for the Euclidean metric. One particularly useful metric is the
`Haversine distance <https://en.wikipedia.org/wiki/Haversine_formula>`_
which measures the angular distance between points on a sphere. Here
is an example of using a kernel density estimate for a visualization
Expand Down
23 changes: 23 additions & 0 deletions doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,29 @@ Changelog
message when the solver does not support sparse matrices with int64 indices.
:pr:`21093` by `Tom Dupre la Tour`_.

:mod:`sklearn.metrics`
......................

- |API| :class:`metrics.DistanceMetric` has been moved from
:mod:`sklearn.neighbors` to :mod:`sklearn.metric`.
Using `neighbors.DistanceMetric` for imports is still valid for
backward compatibility, but this alias will be removed in 1.3.
:pr:`21177` by :user:`Julien Jerphanion <jjerphan>`.

:mod:`sklearn.model_selection`
..............................

- |Enhancement| raise an error during cross-validation when the fits for all the
splits failed. Similarly raise an error during grid-search when the fits for
all the models and all the splits failed. :pr:`21026` by :user:`Loïc Estève <lesteve>`.

:mod:`sklearn.pipeline`
.......................

- |Enhancement| Added support for "passthrough" in :class:`FeatureUnion`.
Setting a transformer to "passthrough" will pass the features unchanged.
:pr:`20860` by :user:`Shubhraneel Pal <shubhraneel>`.

:mod:`sklearn.utils`
....................

Expand Down
4 changes: 2 additions & 2 deletions sklearn/cluster/_agglomerative.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

from ..base import BaseEstimator, ClusterMixin
from ..metrics.pairwise import paired_distances
from ..neighbors import DistanceMetric
from ..neighbors._dist_metrics import METRIC_MAPPING
from ..metrics import DistanceMetric
from ..metrics._dist_metrics import METRIC_MAPPING
from ..utils import check_array
from ..utils._fast_dict import IntFloatDict
from ..utils.fixes import _astype_copy_false
Expand Down
15 changes: 7 additions & 8 deletions sklearn/cluster/_hierarchical_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ ctypedef np.int8_t INT8

np.import_array()

from ..neighbors._dist_metrics cimport DistanceMetric
from ..metrics._dist_metrics cimport DistanceMetric
from ..utils._fast_dict cimport IntFloatDict

# C++
Expand Down Expand Up @@ -236,8 +236,8 @@ def max_merge(IntFloatDict a, IntFloatDict b,
def average_merge(IntFloatDict a, IntFloatDict b,
np.ndarray[ITYPE_t, ndim=1] mask,
ITYPE_t n_a, ITYPE_t n_b):
"""Merge two IntFloatDicts with the average strategy: when the
same key is present in the two dicts, the weighted average of the two
"""Merge two IntFloatDicts with the average strategy: when the
same key is present in the two dicts, the weighted average of the two
values is used.
Parameters
Expand Down Expand Up @@ -290,13 +290,13 @@ def average_merge(IntFloatDict a, IntFloatDict b,


###############################################################################
# An edge object for fast comparisons
# An edge object for fast comparisons

cdef class WeightedEdge:
cdef public ITYPE_t a
cdef public ITYPE_t b
cdef public DTYPE_t weight

def __init__(self, DTYPE_t weight, ITYPE_t a, ITYPE_t b):
self.weight = weight
self.a = a
Expand Down Expand Up @@ -326,7 +326,7 @@ cdef class WeightedEdge:
return self.weight > other.weight
elif op == 5:
return self.weight >= other.weight

def __repr__(self):
return "%s(weight=%f, a=%i, b=%i)" % (self.__class__.__name__,
self.weight,
Expand Down Expand Up @@ -475,7 +475,7 @@ def mst_linkage_core(
dist_metric: DistanceMetric
A DistanceMetric object conforming to the API from
``sklearn.neighbors._dist_metrics.pxd`` that will be
``sklearn.metrics._dist_metrics.pxd`` that will be
used to compute distances.
Returns
Expand Down Expand Up @@ -534,4 +534,3 @@ def mst_linkage_core(
current_node = new_node

return np.array(result)

5 changes: 3 additions & 2 deletions sklearn/cluster/tests/test_hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from scipy.sparse.csgraph import connected_components

from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.neighbors.tests.test_dist_metrics import METRICS_DEFAULT_PARAMS
from sklearn.metrics.tests.test_dist_metrics import METRICS_DEFAULT_PARAMS
from sklearn.utils._testing import assert_almost_equal, create_memmap_backed_data
from sklearn.utils._testing import assert_array_almost_equal
from sklearn.utils._testing import ignore_warnings
Expand All @@ -31,14 +31,15 @@
_fix_connectivity,
)
from sklearn.feature_extraction.image import grid_to_graph
from sklearn.metrics import DistanceMetric
from sklearn.metrics.pairwise import (
PAIRED_DISTANCES,
cosine_distances,
manhattan_distances,
pairwise_distances,
)
from sklearn.metrics.cluster import normalized_mutual_info_score
from sklearn.neighbors import kneighbors_graph, DistanceMetric
from sklearn.neighbors import kneighbors_graph
from sklearn.cluster._hierarchical_fast import (
average_merge,
max_merge,
Expand Down
3 changes: 3 additions & 0 deletions sklearn/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from ._classification import brier_score_loss
from ._classification import multilabel_confusion_matrix

from ._dist_metrics import DistanceMetric

from . import cluster
from .cluster import adjusted_mutual_info_score
from .cluster import adjusted_rand_score
Expand Down Expand Up @@ -115,6 +117,7 @@
"davies_bouldin_score",
"DetCurveDisplay",
"det_curve",
"DistanceMetric",
"euclidean_distances",
"explained_variance_score",
"f1_score",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
#!python
#cython: boundscheck=False
#cython: wraparound=False
#cython: cdivision=True
# cython: boundscheck=False
# cython: cdivision=True
# cython: initializedcheck=False
# cython: wraparound=False

cimport cython
cimport numpy as np
from libc.math cimport fabs, sqrt, exp, cos, pow
from libc.math cimport sqrt, exp

from ._typedefs cimport DTYPE_t, ITYPE_t, DITYPE_t
from ._typedefs import DTYPE, ITYPE
from ..utils._typedefs cimport DTYPE_t, ITYPE_t

######################################################################
# Inline distance functions
Expand Down Expand Up @@ -60,7 +58,7 @@ cdef class DistanceMetric:
cdef DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
ITYPE_t size) nogil except -1

cdef DTYPE_t rdist(self, DTYPE_t* x1, DTYPE_t* x2,
cdef DTYPE_t rdist(self, const DTYPE_t* x1, const DTYPE_t* x2,
ITYPE_t size) nogil except -1

cdef int pdist(self, const DTYPE_t[:, ::1] X, DTYPE_t[:, ::1] D) except -1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#!python
#cython: boundscheck=False
#cython: wraparound=False
#cython: initializedcheck=False
#cython: cdivision=True
# cython: boundscheck=False
# cython: cdivision=True
# cython: initializedcheck=False
# cython: wraparound=False

# By Jake Vanderplas (2013) <jakevdp@cs.washington.edu>
# written for the scikit-learn project
Expand All @@ -19,7 +18,7 @@ cdef extern from "arrayobject.h":
int typenum, void* data)


cdef inline np.ndarray _buffer_to_ndarray(DTYPE_t* x, np.npy_intp n):
cdef inline np.ndarray _buffer_to_ndarray(const DTYPE_t* x, np.npy_intp n):
# Wrap a memory buffer with an ndarray. Warning: this is not robust.
# In particular, if x is deallocated before the returned array goes
# out of scope, this could cause memory errors. Since there is not
Expand All @@ -33,8 +32,8 @@ cdef inline np.ndarray _buffer_to_ndarray(DTYPE_t* x, np.npy_intp n):
from libc.math cimport fabs, sqrt, exp, pow, cos, sin, asin
cdef DTYPE_t INF = np.inf

from ._typedefs cimport DTYPE_t, ITYPE_t, DITYPE_t, DTYPECODE
from ._typedefs import DTYPE, ITYPE
from ..utils._typedefs cimport DTYPE_t, ITYPE_t, DITYPE_t, DTYPECODE
from ..utils._typedefs import DTYPE, ITYPE


######################################################################
Expand Down Expand Up @@ -98,7 +97,7 @@ cdef class DistanceMetric:
Examples
--------
>>> from sklearn.neighbors import DistanceMetric
>>> from sklearn.metrics import DistanceMetric
>>> dist = DistanceMetric.get_metric('euclidean')
>>> X = [[0, 1, 2],
[3, 4, 5]]
Expand Down Expand Up @@ -291,14 +290,13 @@ cdef class DistanceMetric:

cdef DTYPE_t rdist(self, const DTYPE_t* x1, const DTYPE_t* x2,
ITYPE_t size) nogil except -1:
"""Compute the reduced distance between vectors x1 and x2.
"""Compute the rank-preserving surrogate distance between vectors x1 and x2.
This can optionally be overridden in a base class.
The reduced distance is any measure that yields the same rank as the
distance, but is more efficient to compute. For example, for the
Euclidean metric, the reduced distance is the squared-euclidean
distance.
The rank-preserving surrogate distance is any measure that yields the same
rank as the distance, but is more efficient to compute. For example, for the
Euclidean metric, the surrogate distance is the squared-euclidean distance.
"""
return self.dist(x1, x2, size)

Expand All @@ -323,25 +321,24 @@ cdef class DistanceMetric:
return 0

cdef DTYPE_t _rdist_to_dist(self, DTYPE_t rdist) nogil except -1:
"""Convert the reduced distance to the distance"""
"""Convert the rank-preserving surrogate distance to the distance"""
return rdist

cdef DTYPE_t _dist_to_rdist(self, DTYPE_t dist) nogil except -1:
"""Convert the distance to the reduced distance"""
"""Convert the distance to the rank-preserving surrogate distance"""
return dist

def rdist_to_dist(self, rdist):
"""Convert the Reduced distance to the true distance.
"""Convert the rank-preserving surrogate distance to the distance.
The reduced distance, defined for some metrics, is a computationally
more efficient measure which preserves the rank of the true distance.
For example, in the Euclidean distance metric, the reduced distance
is the squared-euclidean distance.
The surrogate distance is any measure that yields the same rank as the
distance, but is more efficient to compute. For example, for the
Euclidean metric, the surrogate distance is the squared-euclidean distance.
Parameters
----------
rdist : double
Reduced distance.
Surrogate distance.
Returns
-------
Expand All @@ -351,12 +348,11 @@ cdef class DistanceMetric:
return rdist

def dist_to_rdist(self, dist):
"""Convert the true distance to the reduced distance.
"""Convert the true distance to the rank-preserving surrogate distance.
The reduced distance, defined for some metrics, is a computationally
more efficient measure which preserves the rank of the true distance.
For example, in the Euclidean distance metric, the reduced distance
is the squared-euclidean distance.
The surrogate distance is any measure that yields the same rank as the
distance, but is more efficient to compute. For example, for the
Euclidean metric, the surrogate distance is the squared-euclidean distance.
Parameters
----------
Expand All @@ -366,7 +362,7 @@ cdef class DistanceMetric:
Returns
-------
double
Reduced distance.
Surrogate distance.
"""
return dist

Expand Down Expand Up @@ -519,7 +515,7 @@ cdef class ChebyshevDistance(DistanceMetric):
Examples
--------
>>> from sklearn.neighbors.dist_metrics import DistanceMetric
>>> from sklearn.metrics.dist_metrics import DistanceMetric
>>> dist = DistanceMetric.get_metric('chebyshev')
>>> X = [[0, 1, 2],
... [3, 4, 5]]
Expand Down
2 changes: 1 addition & 1 deletion sklearn/metrics/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ def haversine_distances(X, Y=None):
array([[ 0. , 11099.54035582],
[11099.54035582, 0. ]])
"""
from ..neighbors import DistanceMetric
from ..metrics import DistanceMetric

return DistanceMetric.get_metric("haversine").pairwise(X, Y)

Expand Down
8 changes: 8 additions & 0 deletions sklearn/metrics/setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import numpy as np

from numpy.distutils.misc_util import Configuration

Expand All @@ -18,6 +19,13 @@ def configuration(parent_package="", top_path=None):
"_pairwise_fast", sources=["_pairwise_fast.pyx"], libraries=libraries
)

config.add_extension(
"_dist_metrics",
sources=["_dist_metrics.pyx"],
include_dirs=[np.get_include(), os.path.join(np.get_include(), "numpy")],
libraries=libraries,
)

config.add_subpackage("tests")

return config
Expand Down

0 comments on commit 71ca247

Please sign in to comment.