Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/logeuclid mean and logeuclid distance to convex hull #244

Merged
merged 90 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
9ecc93c
- change `convex` by `cpm` (stands for constraint programming model)
gcattan Feb 18, 2024
f65c4b2
- parametetrize tests
gcattan Feb 18, 2024
fc4a323
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 18, 2024
3a1eff9
fix qiskit version
gcattan Feb 18, 2024
38d4ca9
Update doc/api.rst
gcattan Feb 19, 2024
8f657ea
Update pyriemann_qiskit/utils/mean.py
gcattan Feb 19, 2024
c396820
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 19, 2024
fcff079
Update pyriemann_qiskit/utils/mean.py
gcattan Feb 19, 2024
9be8e98
Update pyriemann_qiskit/utils/mean.py
gcattan Feb 19, 2024
99a130a
Update pyriemann_qiskit/classification.py
gcattan Feb 19, 2024
e82d404
- Rename cpm to cpm-le in some places
gcattan Feb 19, 2024
650e987
rename fro_mean_cpm -> mean_euclid_cpm
gcattan Feb 19, 2024
39d6395
rename cpm_metric -> metric
gcattan Feb 19, 2024
80c145d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 19, 2024
e2d654b
Update pyriemann_qiskit/utils/mean.py
gcattan Feb 20, 2024
663c863
- remove shrink
gcattan Feb 20, 2024
297117a
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 20, 2024
d15a905
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2024
f9fedfa
add missing reference to distance
gcattan Feb 20, 2024
865a0e5
Update pyriemann_qiskit/utils/mean.py
gcattan Feb 20, 2024
0630778
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2024
b843c41
Update pyriemann_qiskit/utils/mean.py
gcattan Feb 20, 2024
5d9bc28
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2024
803d76c
add references to cpm_le
gcattan Feb 20, 2024
aba9638
fix references
gcattan Feb 20, 2024
1c88275
Update pyriemann_qiskit/utils/mean.py
gcattan Feb 21, 2024
5c019dc
Update pyriemann_qiskit/utils/mean.py
gcattan Feb 21, 2024
d00ad94
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 21, 2024
fabc426
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 21, 2024
5eab650
Update pyriemann_qiskit/utils/mean.py
gcattan Feb 21, 2024
7080dce
- use logeuclid_cpm everywhere
gcattan Feb 21, 2024
c051659
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 21, 2024
8e7c898
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 21, 2024
c43d50a
add test
gcattan Feb 21, 2024
3549dc0
- remove shrinkage
gcattan Feb 21, 2024
94c6435
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 21, 2024
4eec384
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 21, 2024
c7ff8ab
rename X, y -> A, B
gcattan Feb 21, 2024
1f7d490
fix failure on test due to regularization
gcattan Feb 21, 2024
fc5dd3c
Tentative to improve tests on Ci
gcattan Feb 21, 2024
e84652e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 21, 2024
749d057
improvement of tests
gcattan Feb 21, 2024
7c72b13
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 21, 2024
13ebae6
fix firestore
gcattan Feb 21, 2024
7cdd80b
add regularization
gcattan Feb 21, 2024
ecc0351
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 21, 2024
f1f0c3e
add missing regularization
gcattan Feb 21, 2024
c8923bb
fix tests
gcattan Feb 21, 2024
9ee3473
just remove logeuclid_cpm for the moment for test performance.
gcattan Feb 21, 2024
284dc86
- change the behavior of the metric parameter in pipeline so it is th…
gcattan Feb 22, 2024
716aad6
add regularization to light_benchmark
gcattan Feb 22, 2024
8280e93
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 22, 2024
0c0ea99
fix import in test
gcattan Feb 22, 2024
5cf8ff0
lint
gcattan Feb 22, 2024
c8079f7
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 22, 2024
2d0e8ab
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 22, 2024
548e32d
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 22, 2024
8cde304
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 22, 2024
1be5a3d
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 22, 2024
b69854e
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 22, 2024
b082823
Improve doc
gcattan Feb 22, 2024
633a6bd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 22, 2024
8f5aeab
diminish number of trial for SPSA in benchmark script
gcattan Feb 23, 2024
b6057f3
improve mean
qbarthelemy Feb 23, 2024
2145c49
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 23, 2024
53bfcea
add distance output
gcattan Feb 23, 2024
22abf09
fix lint
gcattan Feb 23, 2024
725da6b
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 24, 2024
8e8eda3
add missing imports
gcattan Feb 24, 2024
0f7fdd8
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 24, 2024
31fee4f
fix test
gcattan Feb 24, 2024
0e3eca3
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 26, 2024
f8c3deb
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 26, 2024
f27b796
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 26, 2024
e4a5fcc
add constraint
gcattan Feb 26, 2024
6cf239d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2024
f404dee
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 26, 2024
659217b
add regularization based on GH console error
gcattan Feb 26, 2024
0fb54a7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2024
473d822
fix lint
gcattan Feb 26, 2024
34a9267
applu suggestion on method predict_distances
gcattan Feb 26, 2024
aebf84f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2024
0c6785d
add more regularization for GH CI
gcattan Feb 26, 2024
18122dd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2024
2eb9141
wrong file: revert changes to test_utils_distance,
gcattan Feb 26, 2024
9f2be5d
add check_weights
qbarthelemy Feb 27, 2024
dccfd51
replace make_covariances by make_matrices in conftest.py
gcattan Feb 27, 2024
b2753cf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 27, 2024
f4bf157
fix test by using mne data
gcattan Feb 27, 2024
f6a432d
remove line with check_weights
gcattan Feb 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/light_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@
)

pipelines["QMDM_mean"] = QuantumMDMWithRiemannianPipeline(
convex_metric="mean", quantum=True
cpm_metric="mean", quantum=True
)

pipelines["QMDM_dist"] = QuantumMDMWithRiemannianPipeline(
convex_metric="distance", quantum=True
cpm_metric="distance", quantum=True
)

pipelines["RG_LDA"] = make_pipeline(
Expand Down
5 changes: 3 additions & 2 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ Mean
.. autosummary::
:toctree: generated/

fro_mean_convex
fro_mean_cpm
le_mean_cpm
gcattan marked this conversation as resolved.
Show resolved Hide resolved

Distance
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -106,7 +107,7 @@ Distance
.. autosummary::
:toctree: generated/

logeucl_dist_convex
logeucl_dist_cpm
qbarthelemy marked this conversation as resolved.
Show resolved Hide resolved

Docplex
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
1 change: 1 addition & 0 deletions doc/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ imbalanced-learn==0.11.0
joblib
pandas
cvxpy==1.4.1
qiskit==0.45.0
qiskit_machine_learning==0.6.1
qiskit-ibm-provider==0.7.3
qiskit-optimization==0.5.0
Expand Down
16 changes: 8 additions & 8 deletions examples/ERP/classify_P300_bi_quantum_mdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
The mean and the distance in MDM algorithm are formulated as
optimization problems. These optimization problems are translated
to Qiskit using Docplex and additional glue code. These optimizations
are enabled when we use convex mean or convex distance. This is set
using the 'convex_metric' parameter of the QuantumMDMWithRiemannianPipeline.
are enabled when we use cpm mean or cpm distance. This is set
using the 'cpm_metric' parameter of the QuantumMDMWithRiemannianPipeline.

Classification can be run either on emulation or real quantum computer.

Expand Down Expand Up @@ -43,7 +43,7 @@
from moabb.evaluations import WithinSessionEvaluation
from moabb.paradigms import P300

# inject convex distance and mean to pyriemann (if not done already)
# inject cpm distance and mean to pyriemann (if not done already)
from pyriemann_qiskit.utils import distance, mean # noqa
from pyriemann_qiskit.pipelines import (
QuantumMDMVotingClassifier,
Expand Down Expand Up @@ -107,15 +107,15 @@

pipelines = {}

pipelines["mean=convex/distance=euclid"] = QuantumMDMWithRiemannianPipeline(
convex_metric="mean", quantum=quantum
pipelines["mean=cpm/distance=euclid"] = QuantumMDMWithRiemannianPipeline(
cpm_metric="mean", quantum=quantum
)

pipelines["mean=logeuclid/distance=convex"] = QuantumMDMWithRiemannianPipeline(
convex_metric="distance", quantum=quantum
pipelines["mean=logeuclid/distance=cpm"] = QuantumMDMWithRiemannianPipeline(
cpm_metric="distance", quantum=quantum
)

pipelines["Voting convex"] = QuantumMDMVotingClassifier(quantum=quantum)
pipelines["Voting cpm"] = QuantumMDMVotingClassifier(quantum=quantum)

##############################################################################
# Run evaluation
Expand Down
6 changes: 3 additions & 3 deletions examples/MI/classify_alexmi_with_quantum_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from moabb.evaluations import WithinSessionEvaluation
from moabb.paradigms import MotorImagery

# inject convex distance and mean to pyriemann (if not done already)
# inject cpm distance and mean to pyriemann (if not done already)
from pyriemann_qiskit.utils import distance, mean # noqa
from pyriemann_qiskit.pipelines import (
QuantumMDMWithRiemannianPipeline,
Expand Down Expand Up @@ -68,8 +68,8 @@
pipelines = {}

# Will run QAOA under the hood
pipelines["mean=logeuclid/distance=convex"] = QuantumMDMWithRiemannianPipeline(
convex_metric="distance", quantum=True
pipelines["mean=logeuclid/distance=cpm"] = QuantumMDMWithRiemannianPipeline(
cpm_metric="distance", quantum=True
)

# Classical baseline for evaluation
Expand Down
10 changes: 5 additions & 5 deletions pyriemann_qiskit/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ class QuanticMDM(QuanticClassifierBase):

"""Quantum-enhanced MDM classifier

This class is a convex implementation of the Minimum Distance to Mean (MDM)
This class is a cpm implementation of the Minimum Distance to Mean (MDM)
gcattan marked this conversation as resolved.
Show resolved Hide resolved
[1]_, which can run with quantum optimization.
Only log-Euclidean distance between trial and class prototypes is supported
at the moment, but any type of metric can be used for centroid estimation.
Expand All @@ -597,7 +597,7 @@ class QuanticMDM(QuanticClassifierBase):

Parameters
----------
metric : string | dict, default={"mean": 'logeuclid', "distance": 'convex'}
metric : string | dict, default={"mean": 'logeuclid', "distance": 'cpm'}
The type of metric used for centroid and distance estimation.
see `mean_covariance` for the list of supported metric.
the metric could be a dict with two keys, `mean` and `distance` in
Expand All @@ -606,7 +606,7 @@ class QuanticMDM(QuanticClassifierBase):
the mean in order to boost the computional speed and 'riemann' for the
distance in order to keep the good sensitivity for the classification.
quantum : bool (default: True)
Only applies if `metric` contains a convex distance or mean.
Only applies if `metric` contains a cpm distance or mean.

- If true will run on local or remote backend
(depending on q_account_token value),
Expand Down Expand Up @@ -646,7 +646,7 @@ class QuanticMDM(QuanticClassifierBase):

def __init__(
self,
metric={"mean": "logeuclid", "distance": "convex"},
metric={"mean": "logeuclid", "distance": "cpm_le"},
quantum=True,
q_account_token=None,
verbose=True,
Expand All @@ -661,7 +661,7 @@ def __init__(
self.upper_bound = upper_bound

def _init_algo(self, n_features):
self._log("Convex MDM initiating algorithm")
self._log("cpm MDM initiating algorithm")
classifier = MDM(metric=self.metric)
if self.quantum:
self._log("Using NaiveQAOAOptimizer")
Expand Down
40 changes: 20 additions & 20 deletions pyriemann_qiskit/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,19 +304,19 @@ def _create_pipe(self):

class QuantumMDMWithRiemannianPipeline(BasePipeline):

"""MDM with Riemannian pipeline adapted for convex metrics.
"""MDM with Riemannian pipeline adapted for cpm metrics.

It can run on classical or quantum optimizer.

Parameters
----------
convex_metric : string (default: "distance")
cpm_metric : string (default: "distance")
gcattan marked this conversation as resolved.
Show resolved Hide resolved
`metric` passed to the inner QuanticMDM depends on the
`convex_metric` as follows (convex_metric => metric):
`cpm_metric` as follows (cpm_metric => metric):

- "distance" => {mean=logeuclid, distance=convex},
- "mean" => {mean=convex, distance=euclid},
- "both" => {mean=convex, distance=convex},
- "distance" => {mean=logeuclid, distance=cpm},
- "mean" => {mean=cpm, distance=euclid},
- "both" => {mean=cpm, distance=cpm},
qbarthelemy marked this conversation as resolved.
Show resolved Hide resolved
- other => same as "distance".
quantum : bool (default: True)
- If true will run on local or remote backend
Expand Down Expand Up @@ -351,14 +351,14 @@ class QuantumMDMWithRiemannianPipeline(BasePipeline):

def __init__(
self,
convex_metric="distance",
cpm_metric="distance",
quantum=True,
q_account_token=None,
verbose=True,
shots=1024,
upper_bound=7,
):
self.convex_metric = convex_metric
self.cpm_metric = cpm_metric
self.quantum = quantum
self.q_account_token = q_account_token
self.verbose = verbose
Expand All @@ -368,14 +368,14 @@ def __init__(
BasePipeline.__init__(self, "QuantumMDMWithRiemannianPipeline")

def _create_pipe(self):
if self.convex_metric == "both":
metric = {"mean": "convex", "distance": "convex"}
elif self.convex_metric == "mean":
metric = {"mean": "convex", "distance": "euclid"}
if self.cpm_metric == "both":
metric = {"mean": "cpm_le", "distance": "cpm_le"}
elif self.cpm_metric == "mean":
metric = {"mean": "cpm_le", "distance": "logeuclid"}
qbarthelemy marked this conversation as resolved.
Show resolved Hide resolved
else:
metric = {"mean": "logeuclid", "distance": "convex"}
metric = {"mean": "logeuclid", "distance": "cpm_le"}

if metric["mean"] == "convex":
if metric["mean"] == "cpm_le":
if self.quantum:
covariances = XdawnCovariances(
nfilter=1, estimator="scm", xdawn_estimator="lwf"
Expand Down Expand Up @@ -407,8 +407,8 @@ class QuantumMDMVotingClassifier(BasePipeline):
Voting classifier with two configurations of
QuantumMDMWithRiemannianPipeline:

- with mean = convex and distance = euclid,
- with mean = logeuclid and distance = convex.
- with mean = cpm and distance = euclid,
qbarthelemy marked this conversation as resolved.
Show resolved Hide resolved
- with mean = logeuclid and distance = cpm.

Parameters
----------
Expand Down Expand Up @@ -460,15 +460,15 @@ def __init__(
BasePipeline.__init__(self, "QuantumMDMVotingClassifier")

def _create_pipe(self):
clf_mean_logeuclid_dist_convex = QuantumMDMWithRiemannianPipeline(
clf_mean_logeuclid_dist_cpm = QuantumMDMWithRiemannianPipeline(
"distance",
self.quantum,
self.q_account_token,
self.verbose,
self.shots,
self.upper_bound,
)
clf_mean_convex_dist_euclid = QuantumMDMWithRiemannianPipeline(
clf_mean_cpm_dist_euclid = QuantumMDMWithRiemannianPipeline(
"mean",
self.quantum,
self.q_account_token,
Expand All @@ -480,8 +480,8 @@ def _create_pipe(self):
return make_pipeline(
VotingClassifier(
[
("mean_logeuclid_dist_convex", clf_mean_logeuclid_dist_convex),
("mean_convex_dist_euclid ", clf_mean_convex_dist_euclid),
("mean_logeuclid_dist_cpm", clf_mean_logeuclid_dist_cpm),
("mean_cpm_dist_euclid ", clf_mean_cpm_dist_euclid),
],
voting="soft",
)
Expand Down
4 changes: 2 additions & 2 deletions pyriemann_qiskit/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
add_moabb_dataframe_results_to_caches,
convert_caches_to_dataframes,
)
from .distance import logeucl_dist_convex
from .distance import logeucl_dist_cpm

__all__ = [
"hyper_params_factory",
Expand All @@ -36,7 +36,7 @@
"NaiveQAOAOptimizer",
"set_global_optimizer",
"get_global_optimizer",
"logeucl_dist_convex",
"logeucl_dist_cpm",
"FirebaseConnector",
"Cache",
"generate_caches",
Expand Down
12 changes: 6 additions & 6 deletions pyriemann_qiskit/utils/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from pyriemann.utils.base import logm

gcattan marked this conversation as resolved.
Show resolved Hide resolved

gcattan marked this conversation as resolved.
Show resolved Hide resolved
def logeucl_dist_convex(X, y, optimizer=ClassicalOptimizer()):
"""Convex formulation of the MDM algorithm with log-Euclidean metric.
def logeucl_dist_cpm(X, y, optimizer=ClassicalOptimizer()):
"""Constraint Programming Model (CPM) formulation of the MDM algorithm with log-Euclidean metric.
gcattan marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
qbarthelemy marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -67,9 +67,9 @@ def log_prod(m1, m2):


def predict_distances(mdm, X):
if mdm.metric_dist == "convex":
if mdm.metric_dist == "cpm_le":
centroids = np.array(mdm.covmeans_)
return np.array([logeucl_dist_convex(centroids, x) for x in X])
return np.array([logeucl_dist_cpm(centroids, x) for x in X])
else:
return _mdm_predict_distances_original(mdm, X)

Expand All @@ -78,7 +78,7 @@ def predict_distances(mdm, X):

# This is only for validation inside the MDM.
# In fact, we override the _predict_distances method
# inside MDM to directly use logeucl_dist_convex when the metric is "convex"
# inside MDM to directly use logeucl_dist_cpm when the metric is "cpm_le"
# This is due to the fact the the signature of this method is different from
# the usual distance functions.
distance_functions["convex"] = logeucl_dist_convex
distance_functions["cpm_le"] = logeucl_dist_cpm
45 changes: 41 additions & 4 deletions pyriemann_qiskit/utils/mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
from pyriemann.utils.mean import mean_functions
from pyriemann_qiskit.utils.docplex import ClassicalOptimizer, get_global_optimizer
from pyriemann.estimation import Shrinkage
from pyriemann.utils.base import logm, expm
import numpy as np


def fro_mean_convex(
def fro_mean_cpm(
gcattan marked this conversation as resolved.
Show resolved Hide resolved
covmats, sample_weight=None, optimizer=ClassicalOptimizer(), shrink=True
qbarthelemy marked this conversation as resolved.
Show resolved Hide resolved
):
"""Convex formulation of the mean with Frobenius distance.
"""Constraint Programm Model (CPM) formulation of the mean with Frobenius distance.
gcattan marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
Expand All @@ -25,7 +27,7 @@ def fro_mean_convex(
Returns
-------
mean : ndarray, shape (n_channels, n_channels)
Convex-optimized Frobenius mean.
CPM-optimized Frobenius mean.
gcattan marked this conversation as resolved.
Show resolved Hide resolved

Notes
-----
Expand Down Expand Up @@ -66,4 +68,39 @@ def _fro_dist(A, B):
return result


mean_functions["convex"] = fro_mean_convex
def le_mean_cpm(
covmats, sample_weight=None, optimizer=ClassicalOptimizer(), shrink=True
gcattan marked this conversation as resolved.
Show resolved Hide resolved
):
"""Constraint Programm Model (CPM) formulation of the mean with log-euclidian distance.
gcattan marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
covmats: ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices.
sample_weights: None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. Never used in practice.
It is kept only for standardization with pyRiemann.
optimizer: pyQiskitOptimizer
An instance of pyQiskitOptimizer.
shrink: boolean (default: true)
If True, it applies shrinkage regularization [2]_
of the resulting covariance matrix.
qbarthelemy marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
mean : ndarray, shape (n_channels, n_channels)
CPM-optimized Frobenius mean.
gcattan marked this conversation as resolved.
Show resolved Hide resolved

Notes
-----
.. versionadded:: 0.2.0

"""

log_covmats = logm(covmats)
result = fro_mean_cpm(log_covmats, sample_weight, optimizer, shrink)
return expm(result)


mean_functions["cpm_fro"] = fro_mean_cpm
mean_functions["cpm_le"] = le_mean_cpm
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ numpy<1.27
cython
scikit-learn==1.3.2
git+https://github.com/pyRiemann/pyRiemann#egg=pyriemann
qiskit==0.45.0
qbarthelemy marked this conversation as resolved.
Show resolved Hide resolved
qiskit_machine_learning==0.6.1
qiskit-ibm-provider==0.7.3
qiskit-optimization==0.5.0
Expand Down
Loading
Loading