Skip to content

Commit

Permalink
Typing distances.
Browse files Browse the repository at this point in the history
  • Loading branch information
vnmabus committed Mar 3, 2022
1 parent d9bc83e commit 1b4366e
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 67 deletions.
5 changes: 4 additions & 1 deletion dcor/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _transform_to_2d(t: T) -> T:
return t


def _can_be_double(x: np.typing.NDArray[Any]) -> bool:
def _can_be_numpy_double(x: ArrayType) -> bool:
"""
Return if the array can be safely converted to double.
Expand All @@ -131,6 +131,9 @@ def _can_be_double(x: np.typing.NDArray[Any]) -> bool:
converted to double (if the roundtrip conversion works).
"""
if get_namespace(x) != np:
return False

return (
(
np.issubdtype(x.dtype, np.floating)
Expand Down
148 changes: 82 additions & 66 deletions dcor/distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,59 +6,77 @@
a double precision floating point number will not cause loss of precision.
"""

import numpy as _np
import scipy.spatial as _spatial
from __future__ import annotations

from dcor._utils import _transform_to_2d
from typing import TypeVar

from ._utils import _can_be_double
import numpy as np
import scipy.spatial as spatial

from dcor._utils import ArrayType, _sqrt, _transform_to_2d, get_namespace

def _cdist_naive(x, y, exponent=1):
from ._utils import _can_be_numpy_double

T = TypeVar("T", bound=ArrayType)


def _cdist_naive(x: T, y: T, exponent: float = 1) -> T:
"""Pairwise distance, custom implementation."""
squared_norms = ((x[_np.newaxis, :, :] - y[:, _np.newaxis, :]) ** 2).sum(2)
xp = get_namespace(x, y)

exponent = exponent / 2
try:
exponent = squared_norms.take(0).from_float(exponent)
except AttributeError:
pass
x = xp.asarray(x)
y = xp.asarray(y)

x = xp.expand_dims(x, axis=0)
y = xp.expand_dims(y, axis=1)

return squared_norms ** exponent
squared_norms = xp.sum(((x - y) ** 2), axis=-1)

try:
return squared_norms ** (exponent / 2)
except TypeError:
return _sqrt(squared_norms ** exponent)

def _pdist_scipy(x, exponent=1):

def _pdist_scipy(
x: np.typing.NDArray[float],
exponent: float = 1,
) -> np.typing.NDArray[float]:
"""Pairwise distance between points in a set."""
metric = 'euclidean'

if exponent != 1:
metric = 'sqeuclidean'

distances = _spatial.distance.pdist(x, metric=metric)
distances = _spatial.distance.squareform(distances)
distances = spatial.distance.pdist(x, metric=metric)
distances = spatial.distance.squareform(distances)

if exponent != 1:
distances **= exponent / 2

return distances


def _cdist_scipy(x, y, exponent=1):
def _cdist_scipy(
x: np.typing.NDArray[float],
y: np.typing.NDArray[float],
exponent: float = 1,
) -> np.typing.NDArray[float]:
"""Pairwise distance between the points in two sets."""
metric = 'euclidean'

if exponent != 1:
metric = 'sqeuclidean'

distances = _spatial.distance.cdist(x, y, metric=metric)
distances = spatial.distance.cdist(x, y, metric=metric)

if exponent != 1:
distances **= exponent / 2

return distances


def _pdist(x, exponent=1):
def _pdist(x: T, exponent: float = 1) -> T:
"""
Pairwise distance between points in a set.
Expand All @@ -67,13 +85,13 @@ def _pdist(x, exponent=1):
can not be converted to double.
"""
if _can_be_double(x):
if _can_be_numpy_double(x):
return _pdist_scipy(x, exponent)
else:
return _cdist_naive(x, x, exponent)

return _cdist_naive(x, x, exponent)


def _cdist(x, y, exponent=1):
def _cdist(x: T, y: T, exponent: float = 1) -> T:
"""
Pairwise distance between points in two sets.
Expand All @@ -82,65 +100,63 @@ def _cdist(x, y, exponent=1):
can not be converted to double.
"""
if _can_be_double(x) and _can_be_double(y):
if _can_be_numpy_double(x) and _can_be_numpy_double(y):
return _cdist_scipy(x, y, exponent)
else:
return _cdist_naive(x, y, exponent)

return _cdist_naive(x, y, exponent)


def pairwise_distances(x, y=None, *, exponent=1):
def pairwise_distances(
x: T,
y: T | None = None,
*,
exponent: float = 1,
) -> T:
r"""
Pairwise distance between points.
Return the pairwise distance between points in two sets, or
in the same set if only one set is passed.
Parameters
----------
x: array_like
An :math:`n \times m` array of :math:`n` observations in
a :math:`m`-dimensional space.
y: array_like
An :math:`l \times m` array of :math:`l` observations in
a :math:`m`-dimensional space. If None, the distances will
be computed between the points in :math:`x`.
exponent: float
Exponent of the Euclidean distance.
Returns
-------
numpy ndarray
Args:
x: An :math:`n \times m` array of :math:`n` observations in
a :math:`m`-dimensional space.
y: An :math:`l \times m` array of :math:`l` observations in
a :math:`m`-dimensional space. If None, the distances will
be computed between the points in :math:`x`.
exponent: Exponent of the Euclidean distance.
Returns:
A :math:`n \times l` matrix where the :math:`(i, j)`-th entry is the
distance between :math:`x[i]` and :math:`y[j]`.
Examples
--------
>>> import numpy as np
>>> import dcor
>>> a = np.array([[1, 2, 3, 4],
... [5, 6, 7, 8],
... [9, 10, 11, 12],
... [13, 14, 15, 16]])
>>> b = np.array([[16, 15, 14, 13],
... [12, 11, 10, 9],
... [8, 7, 6, 5],
... [4, 3, 2, 1]])
>>> dcor.distances.pairwise_distances(a)
array([[ 0., 8., 16., 24.],
[ 8., 0., 8., 16.],
[16., 8., 0., 8.],
[24., 16., 8., 0.]])
>>> dcor.distances.pairwise_distances(a, b)
array([[24.41311123, 16.61324773, 9.16515139, 4.47213595],
[16.61324773, 9.16515139, 4.47213595, 9.16515139],
[ 9.16515139, 4.47213595, 9.16515139, 16.61324773],
[ 4.47213595, 9.16515139, 16.61324773, 24.41311123]])
Examples:
>>> import numpy as np
>>> import dcor
>>> a = np.array([[1, 2, 3, 4],
... [5, 6, 7, 8],
... [9, 10, 11, 12],
... [13, 14, 15, 16]])
>>> b = np.array([[16, 15, 14, 13],
... [12, 11, 10, 9],
... [8, 7, 6, 5],
... [4, 3, 2, 1]])
>>> dcor.distances.pairwise_distances(a)
array([[ 0., 8., 16., 24.],
[ 8., 0., 8., 16.],
[16., 8., 0., 8.],
[24., 16., 8., 0.]])
>>> dcor.distances.pairwise_distances(a, b)
array([[24.41311123, 16.61324773, 9.16515139, 4.47213595],
[16.61324773, 9.16515139, 4.47213595, 9.16515139],
[ 9.16515139, 4.47213595, 9.16515139, 16.61324773],
[ 4.47213595, 9.16515139, 16.61324773, 24.41311123]])
"""
x = _transform_to_2d(x)

if y is None or y is x:
return _pdist(x, exponent=exponent)
else:
y = _transform_to_2d(y)
return _cdist(x, y, exponent=exponent)

y = _transform_to_2d(y)
return _cdist(x, y, exponent=exponent)

0 comments on commit 1b4366e

Please sign in to comment.