Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speedup pairwise_distance #256

Merged
merged 11 commits into from Aug 11, 2023
4 changes: 4 additions & 0 deletions doc/api.rst
Expand Up @@ -232,6 +232,10 @@ Distances
distance_riemann
distance_wasserstein
pairwise_distance
pairwise_distance_euclid
pairwise_distance_harmonic
pairwise_distance_logeuclid
pairwise_distance_riemann
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not make these functions public to avoid over crowding the public namespace.


distance_mahalanobis

Expand Down
1 change: 1 addition & 0 deletions doc/whatsnew.rst
Expand Up @@ -10,6 +10,7 @@ A catalog of new features, improvements, and bug-fixes in each release.
v0.6.dev
--------

- Speedup pairwise distance function :func:`pyriemann.utils.distance.pairwise_distance` by adding individual functions. :pr:`256` by :user:`gabelstein`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Say for what metric it's faster


v0.5 (Jun 2023)
---------------
Expand Down
182 changes: 182 additions & 0 deletions pyriemann/utils/distance.py
Expand Up @@ -2,6 +2,7 @@

import numpy as np
from scipy.linalg import eigvalsh, solve
from sklearn.metrics import euclidean_distances

from .base import logm, sqrtm, invsqrtm

Expand Down Expand Up @@ -446,6 +447,178 @@ def distance(A, B, metric='riemann', squared=False):
return d


def pairwise_distance_euclid(X, Y=None, squared=False):
"""Pairwise Euclidean distance matrix.

Compute the matrix of Euclidan distances between pairs of elements of X and
Y.

Parameters
----------
X : ndarray, shape (n_matrices_X, n, n)
First set of matrices.
Y : None | ndarray, shape (n_matrices_Y, n, n), default=None
Second set of matrices. If None, Y is set to X.
squared : bool, default False
Return squared distance.

Returns
-------
dist : ndarray, shape (n_matrices_X, n_matrices_X) or (n_matrices_X, \
n_matrices_Y)
Euclidean Distances between pairs of elements of X if Y is None, or
between elements of X and Y.

Notes
-----
.. versionadded:: 0.6

See Also
--------
pairwise_distance
distance_euclid
"""
if isinstance(Y, type(None)) or np.array_equal(X, Y):
gabelstein marked this conversation as resolved.
Show resolved Hide resolved
dist = euclidean_distances(X.reshape(len(X), -1), squared=squared)
np.fill_diagonal(dist, 0) # diagonal is not exactly 0
gabelstein marked this conversation as resolved.
Show resolved Hide resolved
else:
dist = euclidean_distances(X.reshape(len(X), -1),
Y.reshape(len(Y), -1),
squared=squared)
return dist


def pairwise_distance_harmonic(X, Y=None, squared=False):
"""Pairwise harmonic distance matrix.

Compute the matrix of harmonic distances between pairs of elements of X and
Y.

Parameters
----------
X : ndarray, shape (n_matrices_X, n, n)
First set of matrices.
Y : None | ndarray, shape (n_matrices_Y, n, n), default=None
Second set of matrices. If None, Y is set to X.
squared : bool, default False
Return squared distance.

Returns
-------
dist : ndarray, shape (n_matrices_X, n_matrices_X) or (n_matrices_X, \
n_matrices_Y)
Harmonic Distances between pairs of elements of X if Y is None, or
between elements of X and Y.
Notes
-----
.. versionadded:: 0.6
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need these now that it's private


See Also
--------
pairwise_distance
distance_harmonic
"""
if isinstance(Y, type(None)) or np.array_equal(X, Y):
invY = invX = np.linalg.inv(X)
else:
invX = np.linalg.inv(X)
invY = np.linalg.inv(Y)

return pairwise_distance_euclid(invX, invY, squared=squared)


def pairwise_distance_logeuclid(X, Y=None, squared=False):
"""Pairwise Log-Euclidean distance matrix.

Compute the matrix of Log-Euclidan distances between pairs of elements of X
and Y.

Parameters
----------
X : ndarray, shape (n_matrices_X, n, n)
First set of matrices.
Y : None | ndarray, shape (n_matrices_Y, n, n), default=None
Second set of matrices. If None, Y is set to X.
squared : bool, default False
Return squared distance.

Returns
-------
dist : ndarray, shape (n_matrices_X, n_matrices_X) or (n_matrices_X, \
n_matrices_Y)
Log-Euclidean Distances between pairs of elements of X if Y is None, or
between elements of X and Y.

Notes
-----
.. versionadded:: 0.6

See Also
--------
pairwise_distance
distance_logeuclid
"""
if isinstance(Y, type(None)) or np.array_equal(X, Y):
logY = logX = logm(X)
else:
logX = logm(X)
logY = logm(Y)

return pairwise_distance_euclid(logX, logY, squared=squared)


def pairwise_distance_riemann(X, Y=None, squared=False):
"""Pairwise Riemann distance matrix.

Compute the matrix of Riemann distances between pairs of elements of X
and Y.

Parameters
----------
X : ndarray, shape (n_matrices_X, n, n)
First set of matrices.
Y : None | ndarray, shape (n_matrices_Y, n, n), default=None
Second set of matrices. If None, Y is set to X.
squared : bool, default False
Return squared distance.

Returns
-------
dist : ndarray, shape (n_matrices_X, n_matrices_X) or (n_matrices_X, \
n_matrices_Y)
Riemann Distances between pairs of elements of X if Y is None, or
between elements of X and Y.

Notes
-----
.. versionadded:: 0.6

See Also
--------
pairwise_distance
distance_riemann
"""
XisY = False
if isinstance(Y, type(None)) or np.array_equal(X, Y):
XisY = True
Y = X

n_matrices_X, n_matrices_Y = len(X), len(Y)
Xinv12 = invsqrtm(X)
res = np.zeros((n_matrices_X, n_matrices_Y))

# row by row so it fits in memory
for i, x_ in enumerate(Xinv12):
evals_ = np.linalg.eigvalsh(x_ @ Y[i * XisY:] @ x_)
res_ = np.sum(np.log(evals_) ** 2, -1)
res[i, i * XisY:] = res_

if XisY:
res = res + res.T

return res if squared else np.sqrt(res)


def pairwise_distance(X, Y=None, metric='riemann', squared=False):
"""Pairwise distance matrix.

Expand Down Expand Up @@ -477,6 +650,15 @@ def pairwise_distance(X, Y=None, metric='riemann', squared=False):
--------
distance
"""
if metric == 'euclid':
return pairwise_distance_euclid(X, Y=Y, squared=squared)
elif metric == 'harmonic':
return pairwise_distance_harmonic(X, Y=Y, squared=squared)
elif metric == 'logeuclid':
return pairwise_distance_logeuclid(X, Y=Y, squared=squared)
elif metric == 'riemann':
return pairwise_distance_riemann(X, Y=Y, squared=squared)

n_matrices_X, _, _ = X.shape

# compute full pairwise matrix for non-symmetric metrics
Expand Down
17 changes: 15 additions & 2 deletions tests/test_utils_distance.py
Expand Up @@ -244,17 +244,30 @@ def test_distance_wrapper_between_set_and_matrix(dist, get_covmats):

@pytest.mark.parametrize("dist", get_distances())
@pytest.mark.parametrize("Y", [None, True])
def test_pairwise_distance_matrix(get_covmats, dist, Y):
@pytest.mark.parametrize("squared", [False, True])
def test_pairwise_distance_matrix(get_covmats, dist, Y, squared):
n_matrices_X, n_matrices_Y, n_channels = 6, 4, 5
X = get_covmats(n_matrices_X, n_channels)
if Y is None:
n_matrices_Y = n_matrices_X
Y_ = X
else:
Y = get_covmats(n_matrices_Y, n_channels)
Y_ = Y

pdist = pairwise_distance(X, Y, metric=dist)
pdist = pairwise_distance(X, Y, metric=dist, squared=squared)
assert pdist.shape == (n_matrices_X, n_matrices_Y)

for i in range(n_matrices_X):
for j in range(n_matrices_Y):
assert np.isclose(pdist[i, j],
distance(X[i],
Y_[j],
metric=dist,
squared=squared),
atol=1e-5,
rtol=1e-5)

if Y is None and dist not in ["kullback", "kullback_right"]:
assert is_sym(pdist)
else:
Expand Down