From 0f35b4666d96462033f927fb05294c3dd50c74e9 Mon Sep 17 00:00:00 2001 From: Gabriel Wagner vom Berg <42769425+gabelstein@users.noreply.github.com> Date: Fri, 11 Aug 2023 14:11:01 +0200 Subject: [PATCH] speedup pairwise_distance (#256) * add faster individual pairwise distance functions * Update whatsnew.rst * Update test_utils_distance.py * Apply suggestions from code review Co-authored-by: Alexandre Gramfort * remove individual pairwise functions from api * remove check for array_equal in distance functions * make pairwise_distance functions private * Implement code review * correct typos and clean code * set Y input to None if applicable (faster) --------- Co-authored-by: Alexandre Gramfort Co-authored-by: qbarthelemy --- doc/whatsnew.rst | 2 + pyriemann/utils/distance.py | 174 ++++++++++++++++++++++++++++++++++- tests/test_utils_distance.py | 17 +++- 3 files changed, 187 insertions(+), 6 deletions(-) diff --git a/doc/whatsnew.rst b/doc/whatsnew.rst index d29e4d9c..a2b09cec 100644 --- a/doc/whatsnew.rst +++ b/doc/whatsnew.rst @@ -12,6 +12,8 @@ v0.6.dev - Update pyRiemann from Python 3.7 to 3.8. pr`254` by :user:`qbarthelemy` +- Speedup pairwise distance function :func:`pyriemann.utils.distance.pairwise_distance` by adding individual functions for 'euclid', 'harmonic', 'logeuclid' and 'riemann' metrics. :pr:`256` by :user:`gabelstein` + v0.5 (Jun 2023) --------------- diff --git a/pyriemann/utils/distance.py b/pyriemann/utils/distance.py index c7e5bdf0..3b33969a 100644 --- a/pyriemann/utils/distance.py +++ b/pyriemann/utils/distance.py @@ -2,6 +2,7 @@ import numpy as np from scipy.linalg import eigvalsh, solve +from sklearn.metrics import euclidean_distances from .base import logm, sqrtm, invsqrtm @@ -368,9 +369,6 @@ def distance_wasserstein(A, B, squared=False): return d2 if squared else np.sqrt(d2) -############################################################################### - - distance_functions = { 'euclid': distance_euclid, 'harmonic': distance_harmonic, @@ -446,6 +444,165 @@ def distance(A, B, metric='riemann', squared=False): return d +############################################################################### + + +def _pairwise_distance_euclid(X, Y=None, squared=False): + """Pairwise Euclidean distance matrix. + + Compute the matrix of Euclidean distances between pairs of elements of X + and Y. + + Parameters + ---------- + X : ndarray, shape (n_matrices_X, n, n) + First set of matrices. + Y : None | ndarray, shape (n_matrices_Y, n, n), default=None + Second set of matrices. If None, Y is set to X. + squared : bool, default False + Return squared distances. + + Returns + ------- + dist : ndarray, shape (n_matrices_X, n_matrices_X) or (n_matrices_X, \ + n_matrices_Y) + Euclidean distances between pairs of elements of X if Y is None, or + between elements of X and Y. + + See Also + -------- + pairwise_distance + distance_euclid + """ + if Y is None: + dist = euclidean_distances(X.reshape(len(X), -1), squared=squared) + else: + dist = euclidean_distances(X.reshape(len(X), -1), + Y.reshape(len(Y), -1), + squared=squared) + return dist + + +def _pairwise_distance_harmonic(X, Y=None, squared=False): + """Pairwise harmonic distance matrix. + + Compute the matrix of harmonic distances between pairs of elements of X and + Y. + + Parameters + ---------- + X : ndarray, shape (n_matrices_X, n, n) + First set of matrices. + Y : None | ndarray, shape (n_matrices_Y, n, n), default=None + Second set of matrices. If None, Y is set to X. + squared : bool, default False + Return squared distances. + + Returns + ------- + dist : ndarray, shape (n_matrices_X, n_matrices_X) or (n_matrices_X, \ + n_matrices_Y) + Harmonic distances between pairs of elements of X if Y is None, or + between elements of X and Y. + + See Also + -------- + pairwise_distance + distance_harmonic + """ + invX = np.linalg.inv(X) + if Y is None: + invY = None + else: + invY = np.linalg.inv(Y) + + return _pairwise_distance_euclid(invX, invY, squared=squared) + + +def _pairwise_distance_logeuclid(X, Y=None, squared=False): + """Pairwise Log-Euclidean distance matrix. + + Compute the matrix of Log-Euclidean distances between pairs of elements of + X and Y. + + Parameters + ---------- + X : ndarray, shape (n_matrices_X, n, n) + First set of matrices. + Y : None | ndarray, shape (n_matrices_Y, n, n), default=None + Second set of matrices. If None, Y is set to X. + squared : bool, default False + Return squared distances. + + Returns + ------- + dist : ndarray, shape (n_matrices_X, n_matrices_X) or (n_matrices_X, \ + n_matrices_Y) + Log-Euclidean distances between pairs of elements of X if Y is None, or + between elements of X and Y. + + See Also + -------- + pairwise_distance + distance_logeuclid + """ + logX = logm(X) + if Y is None: + logY = None + else: + logY = logm(Y) + + return _pairwise_distance_euclid(logX, logY, squared=squared) + + +def _pairwise_distance_riemann(X, Y=None, squared=False): + """Pairwise Riemannian distance matrix. + + Compute the matrix of Riemannian distances between pairs of elements of X + and Y. + + Parameters + ---------- + X : ndarray, shape (n_matrices_X, n, n) + First set of matrices. + Y : None | ndarray, shape (n_matrices_Y, n, n), default=None + Second set of matrices. If None, Y is set to X. + squared : bool, default False + Return squared distances. + + Returns + ------- + dist : ndarray, shape (n_matrices_X, n_matrices_X) or (n_matrices_X, \ + n_matrices_Y) + Riemannian distances between pairs of elements of X if Y is None, or + between elements of X and Y. + + See Also + -------- + pairwise_distance + distance_riemann + """ + XisY = False + if Y is None: + XisY = True + Y = X + + n_matrices_X, n_matrices_Y = len(X), len(Y) + Xinv12 = invsqrtm(X) + dist = np.zeros((n_matrices_X, n_matrices_Y)) + + # row by row so it fits in memory + for i, x_ in enumerate(Xinv12): + evals_ = np.linalg.eigvalsh(x_ @ Y[i * XisY:] @ x_) + d2 = np.sum(np.log(evals_) ** 2, -1) + dist[i, i * XisY:] = d2 + + if XisY: + dist += dist.T + + return dist if squared else np.sqrt(dist) + + def pairwise_distance(X, Y=None, metric='riemann', squared=False): """Pairwise distance matrix. @@ -462,7 +619,7 @@ def pairwise_distance(X, Y=None, metric='riemann', squared=False): 'kullback_right', 'kullback_sym', 'logdet', 'logeuclid', 'riemann', 'wasserstein', or a callable function. squared : bool, default False - Return squared distance. + Return squared distances. .. versionadded:: 0.5 @@ -477,6 +634,15 @@ def pairwise_distance(X, Y=None, metric='riemann', squared=False): -------- distance """ + if metric == 'euclid': + return _pairwise_distance_euclid(X, Y=Y, squared=squared) + elif metric == 'harmonic': + return _pairwise_distance_harmonic(X, Y=Y, squared=squared) + elif metric == 'logeuclid': + return _pairwise_distance_logeuclid(X, Y=Y, squared=squared) + elif metric == 'riemann': + return _pairwise_distance_riemann(X, Y=Y, squared=squared) + n_matrices_X, _, _ = X.shape # compute full pairwise matrix for non-symmetric metrics diff --git a/tests/test_utils_distance.py b/tests/test_utils_distance.py index 15402979..2ce14266 100644 --- a/tests/test_utils_distance.py +++ b/tests/test_utils_distance.py @@ -244,17 +244,30 @@ def test_distance_wrapper_between_set_and_matrix(dist, get_covmats): @pytest.mark.parametrize("dist", get_distances()) @pytest.mark.parametrize("Y", [None, True]) -def test_pairwise_distance_matrix(get_covmats, dist, Y): +@pytest.mark.parametrize("squared", [False, True]) +def test_pairwise_distance_matrix(get_covmats, dist, Y, squared): n_matrices_X, n_matrices_Y, n_channels = 6, 4, 5 X = get_covmats(n_matrices_X, n_channels) if Y is None: n_matrices_Y = n_matrices_X + Y_ = X else: Y = get_covmats(n_matrices_Y, n_channels) + Y_ = Y - pdist = pairwise_distance(X, Y, metric=dist) + pdist = pairwise_distance(X, Y, metric=dist, squared=squared) assert pdist.shape == (n_matrices_X, n_matrices_Y) + for i in range(n_matrices_X): + for j in range(n_matrices_Y): + assert np.isclose(pdist[i, j], + distance(X[i], + Y_[j], + metric=dist, + squared=squared), + atol=1e-5, + rtol=1e-5) + if Y is None and dist not in ["kullback", "kullback_right"]: assert is_sym(pdist) else: