Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warm restart to AJDC #196

Merged
merged 5 commits into from Aug 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whatsnew.rst
Expand Up @@ -16,6 +16,9 @@ v0.3.1.dev
The different metrics for tangent space mapping can now be defined into :class:`pyriemann.tangentspace.TangentSpace`,
then used for ``transform()`` as well as for ``inverse_transform()``. :pr:`195` by :user:`qbarthelemy`

- Enhance AJD: add ``init`` to :func:`pyriemann.utils.ajd.ajd_pham` and :func:`pyriemann.utils.ajd.rjd`,
add ``warm_restart`` to :class:`pyriemann.spatialfilters.AJDC`. :pr:`196` by :user:`qbarthelemy`

v0.3 (July 2022)
----------------

Expand Down
18 changes: 8 additions & 10 deletions pyriemann/preprocessing.py
Expand Up @@ -19,8 +19,8 @@ class Whitening(BaseEstimator, TransformerMixin):
Parameters
----------
metric : str, default='euclid'
The metric for the estimation of mean covariance matrix used for
whitening and dimension reduction.
The metric for the estimation of mean matrix used for whitening and
dimension reduction.
dim_red : None | dict, default=None
If ``None`` :
no dimension reduction during whitening.
Expand All @@ -34,8 +34,7 @@ class Whitening(BaseEstimator, TransformerMixin):
``val`` must be a float in (0,1], typically ``0.99``.
If ``{'max_cond': val}`` :
dimension reduction selecting the number of components such that
the condition number of the mean covariance matrix is lower than
``val``.
the condition number of the mean matrix is lower than ``val``.
This threshold has a physiological interpretation, because it can
be viewed as the ratio between the power of the strongest component
(usually, eye-blink source) and the power of the lowest component
Expand All @@ -49,9 +48,9 @@ class Whitening(BaseEstimator, TransformerMixin):
n_components_ : int
If fit, the number of components after dimension reduction.
filters_ : ndarray, shape ``(n_channels_, n_components_)``
If fit, the spatial filters to whiten covariance matrices.
If fit, the spatial filters to whiten SPD matrices.
inv_filters_ : ndarray, shape ``(n_components_, n_channels_)``
If fit, the spatial filters to unwhiten covariance matrices.
If fit, the spatial filters to unwhiten SPD matrices.

Notes
-----
Expand All @@ -75,16 +74,15 @@ def fit(self, X, y=None, sample_weight=None):
y : None
Ignored as unsupervised.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weight of each matrix, to compute the weighted mean covariance
matrix used for whitening and dimension reduction. If None, it uses
equal weights.
Weight of each matrix, to compute the weighted mean matrix used for
whitening and dimension reduction. If None, it uses equal weights.

Returns
-------
self : Whitening instance
The Whitening instance.
"""
# weighted mean of input covariance matrices
# weighted mean of input SPD matrices
Xm = mean_covariance(
X,
metric=self.metric,
Expand Down
58 changes: 49 additions & 9 deletions pyriemann/spatialfilters.py
@@ -1,4 +1,6 @@
"""Spatial filtering function."""
import warnings

import numpy as np
from scipy.linalg import eigh, inv
from sklearn.base import BaseEstimator, TransformerMixin
Expand Down Expand Up @@ -342,12 +344,12 @@ def fit(self, X, y):
ix = np.argsort(np.abs(evals - 0.5))[::-1]
elif len(classes) > 2:
evecs, D = ajd_pham(C)
Ctot = np.array(mean_covariance(C, self.metric))
Ctot = mean_covariance(C, self.metric)
evecs = evecs.T

# normalize
for i in range(evecs.shape[1]):
tmp = np.dot(np.dot(evecs[:, i].T, Ctot), evecs[:, i])
tmp = evecs[:, i].T @ Ctot @ evecs[:, i]
evecs[:, i] /= np.sqrt(tmp)

mutual_info = []
Expand All @@ -357,8 +359,7 @@ def fit(self, X, y):
a = 0
b = 0
for i, c in enumerate(classes):
tmp = np.dot(np.dot(evecs[:, j].T, C[i]),
evecs[:, j])
tmp = evecs[:, j].T @ C[i] @ evecs[:, j]
a += Pc[i] * np.log(np.sqrt(tmp))
b += Pc[i] * (tmp ** 2 - 1)
mi = - (a + (3.0 / 16) * (b ** 2))
Expand Down Expand Up @@ -510,8 +511,30 @@ class AJDC(BaseEstimator, TransformerMixin):
The sampling frequency of the signal.
dim_red : None | dict, default=None
Parameter for dimension reduction of cospectra, because Pham's AJD is
sensitive to matrices conditioning. For more details, see parameter
``dim_red`` of :class:`pyriemann.preprocessing.Whitening`.
sensitive to matrices conditioning.

If ``None`` :
no dimension reduction during whitening.
If ``{'n_components': val}`` :
dimension reduction defining the number of components;
``val`` must be an integer superior to 1.
If ``{'expl_var': val}`` :
dimension reduction selecting the number of components such that
the amount of variance that needs to be explained is greater than
the percentage specified by ``val``.
``val`` must be a float in (0,1], typically ``0.99``.
If ``{'max_cond': val}`` :
dimension reduction selecting the number of components such that
the condition number of the mean matrix is lower than ``val``.
This threshold has a physiological interpretation, because it can
be viewed as the ratio between the power of the strongest component
(usually, eye-blink source) and the power of the lowest component
you don't want to keep (acquisition sensor noise).
``val`` must be a float strictly superior to 1, typically 100.
If ``{'warm_restart': val}`` :
dimension reduction defining the number of components from an
initial joint diagonalizer, and then run AJD from this solution.
``val`` must be a square ndarray.
verbose : bool, default=True
Verbose flag.

Expand All @@ -523,6 +546,8 @@ class AJDC(BaseEstimator, TransformerMixin):
If fit, the frequencies associated to cospectra.
n_sources_ : int
If fit, the number of components of the source space.
diag_filters_ : ndarray, shape ``(n_sources_, n_sources_)``
If fit, the diagonalization filters, also called joint diagonalizer.
forward_filters_ : ndarray, shape ``(n_sources_, n_channels_)``
If fit, the spatial filters used to transform signal into source,
also called deximing or separating matrix.
Expand Down Expand Up @@ -619,6 +644,20 @@ def fit(self, X, y=None):
# estimation of non-diagonality weights, Eq(B.1) in [1]
weights = get_nondiag_weight(self._cosp_channels)

# initial diagonalizer: if warm restart, dimension reduction defined by
# the size of the initial diag filters
init = None
if self.dim_red is None:
warnings.warn('Parameter dim_red should not be let to None')
elif isinstance(self.dim_red, dict) and len(self.dim_red) == 1 \
and next(iter(self.dim_red)) == 'warm_restart':
init = self.dim_red['warm_restart']
if init.ndim != 2 or init.shape[0] != init.shape[1]:
raise ValueError(
'Initial diagonalizer defined in dim_red is not a 2D '
'square matrix (Got shape = %s).' % (init.shape,))
self.dim_red = {'n_components': init.shape[0]}

# dimension reduction and whitening, Eq.(8) in [2], computed on the
# weighted mean of cospectra across frequencies (and conditions)
whit = Whitening(
Expand All @@ -629,14 +668,15 @@ def fit(self, X, y=None):
self.n_sources_ = whit.n_components_

# approximate joint diagonalization, currently by Pham's algorithm [3]
diag_filters, self._cosp_sources = ajd_pham(
self.diag_filters_, self._cosp_sources = ajd_pham(
cosp_rw,
init=init,
n_iter_max=100,
sample_weight=weights)

# computation of forward and backward filters, Eq.(9) and (10) in [2]
self.forward_filters_ = diag_filters @ whit.filters_.T
self.backward_filters_ = whit.inv_filters_.T @ inv(diag_filters)
self.forward_filters_ = self.diag_filters_ @ whit.filters_.T
self.backward_filters_ = whit.inv_filters_.T @ inv(self.diag_filters_)
return self

def transform(self, X):
Expand Down
42 changes: 30 additions & 12 deletions pyriemann/utils/ajd.py
Expand Up @@ -20,7 +20,15 @@ def _get_normalized_weight(sample_weight, data):
return normalized_weight


def rjd(X, eps=1e-8, n_iter_max=1000):
def _check_init_diag(init, n):
if init.shape != (n, n):
raise ValueError(
'Initial diagonalizer shape must be %d x % d (Got %s).'
% (n, n, init.shape,))
return init


def rjd(X, *, init=None, eps=1e-8, n_iter_max=1000):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is technically an API change but I think it's save to do.

"""Approximate joint diagonalization based on Jacobi angles.

This is a direct implementation of the AJD algorithm by Cardoso and
Expand All @@ -31,6 +39,8 @@ def rjd(X, eps=1e-8, n_iter_max=1000):
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of symmetric matrices to diagonalize.
init : None | ndarray, shape (n_channels, n_channels), default=None
Initialization for the diagonalizer.
eps : float, default=1e-8
Tolerance for stopping criterion.
n_iter_max : int, default=1000
Expand Down Expand Up @@ -63,7 +73,10 @@ def rjd(X, eps=1e-8, n_iter_max=1000):

# init variables
m, nm = A.shape # n_channels, n_matrices_x_channels
V = np.eye(m)
if init is None:
V = np.eye(m)
else:
V = _check_init_diag(init, m)
encore = True
k = 0

Expand Down Expand Up @@ -105,7 +118,7 @@ def rjd(X, eps=1e-8, n_iter_max=1000):
return V, D


def ajd_pham(X, eps=1e-6, n_iter_max=15, sample_weight=None):
def ajd_pham(X, *, init=None, eps=1e-6, n_iter_max=15, sample_weight=None):
"""Approximate joint diagonalization based on Pham's algorithm.

This is a direct implementation of the Pham's AJD algorithm [1]_.
Expand All @@ -114,6 +127,8 @@ def ajd_pham(X, eps=1e-6, n_iter_max=15, sample_weight=None):
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices to diagonalize.
init : None | ndarray, shape (n_channels, n_channels), default=None
Initialization for the diagonalizer.
eps : float, default=1e-6
Tolerance for stoping criterion.
n_iter_max : int, default=15
Expand Down Expand Up @@ -151,7 +166,10 @@ def ajd_pham(X, eps=1e-6, n_iter_max=15, sample_weight=None):

# Init variables
n_channels, n_matrices_x_channels = A.shape
V = np.eye(n_channels)
if init is None:
V = np.eye(n_channels)
else:
qbarthelemy marked this conversation as resolved.
Show resolved Hide resolved
V = _check_init_diag(init, n_channels)
epsilon = n_channels * (n_channels - 1) * eps

for it in range(n_iter_max):
Expand Down Expand Up @@ -199,7 +217,7 @@ def ajd_pham(X, eps=1e-6, n_iter_max=15, sample_weight=None):
return V, D


def uwedge(X, init=None, eps=1e-7, n_iter_max=100):
def uwedge(X, *, init=None, eps=1e-7, n_iter_max=100):
"""Approximate joint diagonalization based on UWEDGE.

Implementation of the AJD algorithm by Tichavsky and Yeredor [1]_ [2]_:
Expand Down Expand Up @@ -253,9 +271,9 @@ def uwedge(X, init=None, eps=1e-7, n_iter_max=100):

if init is None:
E, H = np.linalg.eig(M[:, 0:d])
W_est = np.dot(np.diag(1. / np.sqrt(np.abs(E))), H.T)
W_est = H.T / np.sqrt(np.abs(E))[:, np.newaxis]
else:
W_est = init
W_est = _check_init_diag(init, d)

Ms = np.array(M)
Rs = np.zeros((d, n_matrices))
Expand All @@ -269,19 +287,19 @@ def uwedge(X, init=None, eps=1e-7, n_iter_max=100):

crit = np.sum(Ms**2) - np.sum(Rs**2)
while (improve > eps) & (iteration < n_iter_max):
B = np.dot(Rs, Rs.T)
B = Rs @ Rs.T
C1 = np.zeros((d, d))
for i in range(d):
C1[:, i] = np.sum(Ms[:, i:Md:d]*Rs, axis=1)

D0 = B*B.T - np.outer(np.diag(B), np.diag(B))
A0 = (C1 * B - np.dot(np.diag(np.diag(B)), C1.T)) / (D0 + np.eye(d))
D0 = B * B.T - np.outer(np.diag(B), np.diag(B))
A0 = (C1 * B - np.diag(np.diag(B)) @ C1.T) / (D0 + np.eye(d))
A0 += np.eye(d)
W_est = np.linalg.solve(A0, W_est)

Raux = np.dot(np.dot(W_est, M[:, 0:d]), W_est.T)
aux = 1./np.sqrt(np.abs(np.diag(Raux)))
W_est = np.dot(np.diag(aux), W_est)
aux = 1. / np.sqrt(np.abs(np.diag(Raux)))
W_est = np.diag(aux) @ W_est
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use broadcasting if you can. Every time I see a np.diag in numpy my brain produces a strong ERP ;)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I try to remove np.dot and np.diag when I can.
But I only improve the code when I’m sure to break nothing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uwedge seems deep, I couldn't think of a way to get rid of the np.diag nicely.


for k in range(n_matrices):
ini = k*d
Expand Down
56 changes: 32 additions & 24 deletions tests/test_ajd.py
Expand Up @@ -31,17 +31,38 @@ def test_get_normalized_weight_pos(get_covmats):
_get_normalized_weight(w, covmats)


@pytest.mark.parametrize("ajd", [rjd, ajd_pham])
def test_ajd_shape(ajd, get_covmats):
@pytest.mark.parametrize("ajd", [rjd, ajd_pham, uwedge])
@pytest.mark.parametrize("init", [True, False])
def test_ajd(ajd, init, get_covmats_params):
"""Test ajd algos"""
n_matrices, n_channels = 5, 3
covmats = get_covmats(n_matrices, n_channels)
V, D = rjd(covmats)
covmats, _, A = get_covmats_params(n_matrices, n_channels)
if init:
V, D = ajd(covmats)
else:
V, D = ajd(covmats, init=A)
assert V.shape == (n_channels, n_channels)
assert D.shape == (n_matrices, n_channels, n_channels)

if ajd is rjd:
assert V.T @ V == approx(np.eye(n_channels)) # check orthogonality


@pytest.mark.parametrize("ajd", [rjd, ajd_pham, uwedge])
def test_ajd_init_error(ajd, get_covmats):
"""Test init for ajd algos"""
n_matrices, n_channels = 5, 3
covmats = get_covmats(n_matrices, n_channels)
with pytest.raises(ValueError): # not 2D array
ajd(covmats, init=np.ones((3, 2, 2)))
with pytest.raises(ValueError): # not square array
ajd(covmats, init=np.ones((3, 2)))
with pytest.raises(ValueError): # shape not equal to n_channels
ajd(covmats, init=np.ones((2, 2)))


def test_pham(get_covmats):
"""Test pham's ajd"""
def test_pham_weight_none_equivalent_uniform(get_covmats):
"""Test pham's ajd weights: none is equivalent to uniform values"""
n_matrices, n_channels, w_val = 5, 3, 2
covmats = get_covmats(n_matrices, n_channels)
V, D = ajd_pham(covmats)
Expand All @@ -53,8 +74,8 @@ def test_pham(get_covmats):
assert_array_equal(D, Dw)


def test_pham_pos_weight(get_covmats):
# Test that weight must be strictly positive
def test_pham_weight_positive(get_covmats):
"""Test pham's ajd weights: must be strictly positive"""
n_matrices, n_channels, w_val = 5, 3, 2
covmats = get_covmats(n_matrices, n_channels)
w = w_val * np.ones(n_matrices)
Expand All @@ -63,9 +84,9 @@ def test_pham_pos_weight(get_covmats):
ajd_pham(covmats, sample_weight=w)


def test_pham_zero_weight(get_covmats):
# now test that setting one weight to almost zero it's almost
# like not passing the matrix
def test_pham_weight_zero(get_covmats):
"""Test pham's ajd weights: setting one weight to almost zero it's almost
like not passing the matrix"""
n_matrices, n_channels, w_val = 5, 3, 2
covmats = get_covmats(n_matrices, n_channels)
w = w_val * np.ones(n_matrices)
Expand All @@ -75,16 +96,3 @@ def test_pham_zero_weight(get_covmats):
Vw, Dw = ajd_pham(covmats, sample_weight=w)
assert V == approx(Vw, rel=1e-4, abs=1e-8)
assert D == approx(Dw[1:], rel=1e-4, abs=1e-8)


@pytest.mark.parametrize("init", [True, False])
def test_uwedge(init, get_covmats_params):
"""Test uwedge."""
n_matrices, n_channels = 5, 3
covmats, _, A = get_covmats_params(n_matrices, n_channels)
if init:
V, D = uwedge(covmats)
else:
V, D = uwedge(covmats, init=A)
assert V.shape == (n_channels, n_channels)
assert D.shape == (n_matrices, n_channels, n_channels)