Skip to content

Commit

Permalink
speedup pairwise_distance (#256)
Browse files Browse the repository at this point in the history
* add faster individual pairwise distance functions

* Update whatsnew.rst

* Update test_utils_distance.py

* Apply suggestions from code review

Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>

* remove individual pairwise functions from api

* remove check for array_equal in distance functions

* make pairwise_distance functions private

* Implement code review

* correct typos and clean code

* set Y input to None if applicable (faster)

---------

Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Co-authored-by: qbarthelemy <q.barthelemy@gmail.com>
  • Loading branch information
3 people committed Aug 11, 2023
1 parent aa8fee7 commit 0f35b46
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 6 deletions.
2 changes: 2 additions & 0 deletions doc/whatsnew.rst
Expand Up @@ -12,6 +12,8 @@ v0.6.dev

- Update pyRiemann from Python 3.7 to 3.8. pr`254` by :user:`qbarthelemy`

- Speedup pairwise distance function :func:`pyriemann.utils.distance.pairwise_distance` by adding individual functions for 'euclid', 'harmonic', 'logeuclid' and 'riemann' metrics. :pr:`256` by :user:`gabelstein`


v0.5 (Jun 2023)
---------------
Expand Down
174 changes: 170 additions & 4 deletions pyriemann/utils/distance.py
Expand Up @@ -2,6 +2,7 @@

import numpy as np
from scipy.linalg import eigvalsh, solve
from sklearn.metrics import euclidean_distances

from .base import logm, sqrtm, invsqrtm

Expand Down Expand Up @@ -368,9 +369,6 @@ def distance_wasserstein(A, B, squared=False):
return d2 if squared else np.sqrt(d2)


###############################################################################


distance_functions = {
'euclid': distance_euclid,
'harmonic': distance_harmonic,
Expand Down Expand Up @@ -446,6 +444,165 @@ def distance(A, B, metric='riemann', squared=False):
return d


###############################################################################


def _pairwise_distance_euclid(X, Y=None, squared=False):
"""Pairwise Euclidean distance matrix.
Compute the matrix of Euclidean distances between pairs of elements of X
and Y.
Parameters
----------
X : ndarray, shape (n_matrices_X, n, n)
First set of matrices.
Y : None | ndarray, shape (n_matrices_Y, n, n), default=None
Second set of matrices. If None, Y is set to X.
squared : bool, default False
Return squared distances.
Returns
-------
dist : ndarray, shape (n_matrices_X, n_matrices_X) or (n_matrices_X, \
n_matrices_Y)
Euclidean distances between pairs of elements of X if Y is None, or
between elements of X and Y.
See Also
--------
pairwise_distance
distance_euclid
"""
if Y is None:
dist = euclidean_distances(X.reshape(len(X), -1), squared=squared)
else:
dist = euclidean_distances(X.reshape(len(X), -1),
Y.reshape(len(Y), -1),
squared=squared)
return dist


def _pairwise_distance_harmonic(X, Y=None, squared=False):
"""Pairwise harmonic distance matrix.
Compute the matrix of harmonic distances between pairs of elements of X and
Y.
Parameters
----------
X : ndarray, shape (n_matrices_X, n, n)
First set of matrices.
Y : None | ndarray, shape (n_matrices_Y, n, n), default=None
Second set of matrices. If None, Y is set to X.
squared : bool, default False
Return squared distances.
Returns
-------
dist : ndarray, shape (n_matrices_X, n_matrices_X) or (n_matrices_X, \
n_matrices_Y)
Harmonic distances between pairs of elements of X if Y is None, or
between elements of X and Y.
See Also
--------
pairwise_distance
distance_harmonic
"""
invX = np.linalg.inv(X)
if Y is None:
invY = None
else:
invY = np.linalg.inv(Y)

return _pairwise_distance_euclid(invX, invY, squared=squared)


def _pairwise_distance_logeuclid(X, Y=None, squared=False):
"""Pairwise Log-Euclidean distance matrix.
Compute the matrix of Log-Euclidean distances between pairs of elements of
X and Y.
Parameters
----------
X : ndarray, shape (n_matrices_X, n, n)
First set of matrices.
Y : None | ndarray, shape (n_matrices_Y, n, n), default=None
Second set of matrices. If None, Y is set to X.
squared : bool, default False
Return squared distances.
Returns
-------
dist : ndarray, shape (n_matrices_X, n_matrices_X) or (n_matrices_X, \
n_matrices_Y)
Log-Euclidean distances between pairs of elements of X if Y is None, or
between elements of X and Y.
See Also
--------
pairwise_distance
distance_logeuclid
"""
logX = logm(X)
if Y is None:
logY = None
else:
logY = logm(Y)

return _pairwise_distance_euclid(logX, logY, squared=squared)


def _pairwise_distance_riemann(X, Y=None, squared=False):
"""Pairwise Riemannian distance matrix.
Compute the matrix of Riemannian distances between pairs of elements of X
and Y.
Parameters
----------
X : ndarray, shape (n_matrices_X, n, n)
First set of matrices.
Y : None | ndarray, shape (n_matrices_Y, n, n), default=None
Second set of matrices. If None, Y is set to X.
squared : bool, default False
Return squared distances.
Returns
-------
dist : ndarray, shape (n_matrices_X, n_matrices_X) or (n_matrices_X, \
n_matrices_Y)
Riemannian distances between pairs of elements of X if Y is None, or
between elements of X and Y.
See Also
--------
pairwise_distance
distance_riemann
"""
XisY = False
if Y is None:
XisY = True
Y = X

n_matrices_X, n_matrices_Y = len(X), len(Y)
Xinv12 = invsqrtm(X)
dist = np.zeros((n_matrices_X, n_matrices_Y))

# row by row so it fits in memory
for i, x_ in enumerate(Xinv12):
evals_ = np.linalg.eigvalsh(x_ @ Y[i * XisY:] @ x_)
d2 = np.sum(np.log(evals_) ** 2, -1)
dist[i, i * XisY:] = d2

if XisY:
dist += dist.T

return dist if squared else np.sqrt(dist)


def pairwise_distance(X, Y=None, metric='riemann', squared=False):
"""Pairwise distance matrix.
Expand All @@ -462,7 +619,7 @@ def pairwise_distance(X, Y=None, metric='riemann', squared=False):
'kullback_right', 'kullback_sym', 'logdet', 'logeuclid', 'riemann',
'wasserstein', or a callable function.
squared : bool, default False
Return squared distance.
Return squared distances.
.. versionadded:: 0.5
Expand All @@ -477,6 +634,15 @@ def pairwise_distance(X, Y=None, metric='riemann', squared=False):
--------
distance
"""
if metric == 'euclid':
return _pairwise_distance_euclid(X, Y=Y, squared=squared)
elif metric == 'harmonic':
return _pairwise_distance_harmonic(X, Y=Y, squared=squared)
elif metric == 'logeuclid':
return _pairwise_distance_logeuclid(X, Y=Y, squared=squared)
elif metric == 'riemann':
return _pairwise_distance_riemann(X, Y=Y, squared=squared)

n_matrices_X, _, _ = X.shape

# compute full pairwise matrix for non-symmetric metrics
Expand Down
17 changes: 15 additions & 2 deletions tests/test_utils_distance.py
Expand Up @@ -244,17 +244,30 @@ def test_distance_wrapper_between_set_and_matrix(dist, get_covmats):

@pytest.mark.parametrize("dist", get_distances())
@pytest.mark.parametrize("Y", [None, True])
def test_pairwise_distance_matrix(get_covmats, dist, Y):
@pytest.mark.parametrize("squared", [False, True])
def test_pairwise_distance_matrix(get_covmats, dist, Y, squared):
n_matrices_X, n_matrices_Y, n_channels = 6, 4, 5
X = get_covmats(n_matrices_X, n_channels)
if Y is None:
n_matrices_Y = n_matrices_X
Y_ = X
else:
Y = get_covmats(n_matrices_Y, n_channels)
Y_ = Y

pdist = pairwise_distance(X, Y, metric=dist)
pdist = pairwise_distance(X, Y, metric=dist, squared=squared)
assert pdist.shape == (n_matrices_X, n_matrices_Y)

for i in range(n_matrices_X):
for j in range(n_matrices_Y):
assert np.isclose(pdist[i, j],
distance(X[i],
Y_[j],
metric=dist,
squared=squared),
atol=1e-5,
rtol=1e-5)

if Y is None and dist not in ["kullback", "kullback_right"]:
assert is_sym(pdist)
else:
Expand Down

0 comments on commit 0f35b46

Please sign in to comment.