Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/logeuclid mean and logeuclid distance to convex hull #244

Merged
merged 90 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
9ecc93c
- change `convex` by `cpm` (stands for constraint programming model)
Feb 18, 2024
f65c4b2
- parametetrize tests
Feb 18, 2024
fc4a323
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 18, 2024
3a1eff9
fix qiskit version
Feb 18, 2024
38d4ca9
Update doc/api.rst
gcattan Feb 19, 2024
8f657ea
Update pyriemann_qiskit/utils/mean.py
gcattan Feb 19, 2024
c396820
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 19, 2024
fcff079
Update pyriemann_qiskit/utils/mean.py
gcattan Feb 19, 2024
9be8e98
Update pyriemann_qiskit/utils/mean.py
gcattan Feb 19, 2024
99a130a
Update pyriemann_qiskit/classification.py
gcattan Feb 19, 2024
e82d404
- Rename cpm to cpm-le in some places
Feb 19, 2024
650e987
rename fro_mean_cpm -> mean_euclid_cpm
Feb 19, 2024
39d6395
rename cpm_metric -> metric
Feb 19, 2024
80c145d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 19, 2024
e2d654b
Update pyriemann_qiskit/utils/mean.py
gcattan Feb 20, 2024
663c863
- remove shrink
Feb 20, 2024
297117a
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 20, 2024
d15a905
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2024
f9fedfa
add missing reference to distance
Feb 20, 2024
865a0e5
Update pyriemann_qiskit/utils/mean.py
gcattan Feb 20, 2024
0630778
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2024
b843c41
Update pyriemann_qiskit/utils/mean.py
gcattan Feb 20, 2024
5d9bc28
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2024
803d76c
add references to cpm_le
Feb 20, 2024
aba9638
fix references
Feb 20, 2024
1c88275
Update pyriemann_qiskit/utils/mean.py
gcattan Feb 21, 2024
5c019dc
Update pyriemann_qiskit/utils/mean.py
gcattan Feb 21, 2024
d00ad94
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 21, 2024
fabc426
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 21, 2024
5eab650
Update pyriemann_qiskit/utils/mean.py
gcattan Feb 21, 2024
7080dce
- use logeuclid_cpm everywhere
Feb 21, 2024
c051659
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 21, 2024
8e7c898
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 21, 2024
c43d50a
add test
Feb 21, 2024
3549dc0
- remove shrinkage
Feb 21, 2024
94c6435
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 21, 2024
4eec384
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 21, 2024
c7ff8ab
rename X, y -> A, B
Feb 21, 2024
1f7d490
fix failure on test due to regularization
Feb 21, 2024
fc5dd3c
Tentative to improve tests on Ci
Feb 21, 2024
e84652e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 21, 2024
749d057
improvement of tests
Feb 21, 2024
7c72b13
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 21, 2024
13ebae6
fix firestore
Feb 21, 2024
7cdd80b
add regularization
Feb 21, 2024
ecc0351
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 21, 2024
f1f0c3e
add missing regularization
Feb 21, 2024
c8923bb
fix tests
Feb 21, 2024
9ee3473
just remove logeuclid_cpm for the moment for test performance.
Feb 21, 2024
284dc86
- change the behavior of the metric parameter in pipeline so it is th…
Feb 22, 2024
716aad6
add regularization to light_benchmark
Feb 22, 2024
8280e93
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 22, 2024
0c0ea99
fix import in test
Feb 22, 2024
5cf8ff0
lint
Feb 22, 2024
c8079f7
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 22, 2024
2d0e8ab
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 22, 2024
548e32d
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 22, 2024
8cde304
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 22, 2024
1be5a3d
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 22, 2024
b69854e
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 22, 2024
b082823
Improve doc
Feb 22, 2024
633a6bd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 22, 2024
8f5aeab
diminish number of trial for SPSA in benchmark script
Feb 23, 2024
b6057f3
improve mean
qbarthelemy Feb 23, 2024
2145c49
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 23, 2024
53bfcea
add distance output
Feb 23, 2024
22abf09
fix lint
Feb 23, 2024
725da6b
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 24, 2024
8e8eda3
add missing imports
Feb 24, 2024
0f7fdd8
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 24, 2024
31fee4f
fix test
Feb 24, 2024
0e3eca3
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 26, 2024
f8c3deb
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 26, 2024
f27b796
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 26, 2024
e4a5fcc
add constraint
Feb 26, 2024
6cf239d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2024
f404dee
Update pyriemann_qiskit/utils/distance.py
gcattan Feb 26, 2024
659217b
add regularization based on GH console error
Feb 26, 2024
0fb54a7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2024
473d822
fix lint
Feb 26, 2024
34a9267
applu suggestion on method predict_distances
Feb 26, 2024
aebf84f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2024
0c6785d
add more regularization for GH CI
Feb 26, 2024
18122dd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2024
2eb9141
wrong file: revert changes to test_utils_distance,
Feb 26, 2024
9f2be5d
add check_weights
qbarthelemy Feb 27, 2024
dccfd51
replace make_covariances by make_matrices in conftest.py
Feb 27, 2024
b2753cf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 27, 2024
f4bf157
fix test by using mne data
Feb 27, 2024
f6a432d
remove line with check_weights
Feb 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ RUN mkdir /root/mne_data
RUN mkdir /home/mne_data

## Workaround for firestore
RUN pip install protobuf==4.25.2
RUN pip install google_cloud_firestore==2.14.0
RUN pip install protobuf==4.25.3
RUN pip install google_cloud_firestore==2.15.0
### Missing __init__ file in protobuf
RUN touch /usr/local/lib/python3.9/site-packages/protobuf-4.25.2-py3.9.egg/google/__init__.py
RUN touch /usr/local/lib/python3.9/site-packages/protobuf-4.25.3-py3.9.egg/google/__init__.py
## google.cloud.location is never used in these files, and is missing in path.
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/client.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/base.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/grpc.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/rest.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/async_client.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.15.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/client.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.15.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/base.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.15.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/grpc.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.15.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.15.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/rest.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.15.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/async_client.py'

ENTRYPOINT [ "python", "/examples/ERP/classify_P300_bi.py" ]
10 changes: 6 additions & 4 deletions benchmarks/light_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# Modified from plot_classify_P300_bi.py of pyRiemann
# License: BSD (3-clause)

from pyriemann.estimation import XdawnCovariances
from pyriemann.estimation import XdawnCovariances, Shrinkage
from pyriemann.tangentspace import TangentSpace
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import train_test_split
Expand Down Expand Up @@ -78,15 +78,17 @@
)

pipelines["RG_VQC"] = QuantumClassifierWithDefaultRiemannianPipeline(
shots=100, spsa_trials=5, two_local_reps=2, params={"seed": 42}
shots=100, spsa_trials=1, two_local_reps=2, params={"seed": 42}
)

pipelines["QMDM_mean"] = QuantumMDMWithRiemannianPipeline(
convex_metric="mean", quantum=True
metric={"mean": "euclid_cpm", "distance": "euclid"},
quantum=True,
regularization=Shrinkage(shrinkage=0.9),
)

pipelines["QMDM_dist"] = QuantumMDMWithRiemannianPipeline(
convex_metric="distance", quantum=True
metric={"mean": "logeuclid", "distance": "logeuclid_cpm"}, quantum=True
)

pipelines["RG_LDA"] = make_pipeline(
Expand Down
5 changes: 3 additions & 2 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ Mean
.. autosummary::
:toctree: generated/

fro_mean_convex
mean_euclid_cpm
mean_logeuclid_cpm

Distance
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -106,7 +107,7 @@ Distance
.. autosummary::
:toctree: generated/

logeucl_dist_convex
distance_logeuclid_cpm

Docplex
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
1 change: 1 addition & 0 deletions doc/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ imbalanced-learn==0.11.0
joblib
pandas
cvxpy==1.4.1
qiskit==0.45.0
qiskit_machine_learning==0.6.1
qiskit-ibm-provider==0.7.3
qiskit-optimization==0.5.0
Expand Down
6 changes: 3 additions & 3 deletions examples/ERP/classify_P300_bi_illiteracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,19 +172,19 @@ def placeholder(key):
placeholder(PIP.xDAWN_LDA.value)

pipelines[PIP.ERPCov_CvxMDM_Dist.value] = QuantumMDMWithRiemannianPipeline(
convex_metric="distance", quantum=False
metric="distance", quantum=False
)
placeholder(PIP.ERPCov_CvxMDM_Dist.value)

# Quantum Pipelines

pipelines[PIP.ERPCov_QMDM_Dist.value] = QuantumMDMWithRiemannianPipeline(
convex_metric="distance", quantum=True
metric="distance", quantum=True
)
placeholder(PIP.ERPCov_QMDM_Dist.value)

pipelines[PIP.ERPCov_QMDM_Dist.value] = QuantumMDMWithRiemannianPipeline(
convex_metric="distance", quantum=True
metric="distance", quantum=True
)
placeholder(PIP.ERPCov_QMDM_Dist.value)

Expand Down
16 changes: 8 additions & 8 deletions examples/ERP/classify_P300_bi_quantum_mdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
The mean and the distance in MDM algorithm are formulated as
optimization problems. These optimization problems are translated
to Qiskit using Docplex and additional glue code. These optimizations
are enabled when we use convex mean or convex distance. This is set
using the 'convex_metric' parameter of the QuantumMDMWithRiemannianPipeline.
are enabled when we use cpm mean or cpm distance. This is set
using the 'metric' parameter of the QuantumMDMWithRiemannianPipeline.

Classification can be run either on emulation or real quantum computer.

Expand Down Expand Up @@ -43,7 +43,7 @@
from moabb.evaluations import WithinSessionEvaluation
from moabb.paradigms import P300

# inject convex distance and mean to pyriemann (if not done already)
# inject cpm distance and mean to pyriemann (if not done already)
from pyriemann_qiskit.utils import distance, mean # noqa
from pyriemann_qiskit.pipelines import (
QuantumMDMVotingClassifier,
Expand Down Expand Up @@ -107,15 +107,15 @@

pipelines = {}

pipelines["mean=convex/distance=euclid"] = QuantumMDMWithRiemannianPipeline(
convex_metric="mean", quantum=quantum
pipelines["mean=logeuclid_cpm/distance=logeuclid"] = QuantumMDMWithRiemannianPipeline(
metric="mean", quantum=quantum
)

pipelines["mean=logeuclid/distance=convex"] = QuantumMDMWithRiemannianPipeline(
convex_metric="distance", quantum=quantum
pipelines["mean=logeuclid/distance=logeuclid_cpm"] = QuantumMDMWithRiemannianPipeline(
metric="distance", quantum=quantum
)

pipelines["Voting convex"] = QuantumMDMVotingClassifier(quantum=quantum)
pipelines["Voting logeuclid_cpm"] = QuantumMDMVotingClassifier(quantum=quantum)

##############################################################################
# Run evaluation
Expand Down
6 changes: 3 additions & 3 deletions examples/MI/classify_alexmi_with_quantum_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from moabb.evaluations import WithinSessionEvaluation
from moabb.paradigms import MotorImagery

# inject convex distance and mean to pyriemann (if not done already)
# inject cpm distance and mean to pyriemann (if not done already)
from pyriemann_qiskit.utils import distance, mean # noqa
from pyriemann_qiskit.pipelines import (
QuantumMDMWithRiemannianPipeline,
Expand Down Expand Up @@ -68,8 +68,8 @@
pipelines = {}

# Will run QAOA under the hood
pipelines["mean=logeuclid/distance=convex"] = QuantumMDMWithRiemannianPipeline(
convex_metric="distance", quantum=True
pipelines["mean=logeuclid/distance=cpm"] = QuantumMDMWithRiemannianPipeline(
metric="distance", quantum=True
)

# Classical baseline for evaluation
Expand Down
34 changes: 27 additions & 7 deletions pyriemann_qiskit/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from qiskit_ibm_provider import IBMProvider, least_busy
from qiskit_machine_learning.algorithms import QSVC, VQC, PegasosQSVC
from qiskit_machine_learning.kernels.quantum_kernel import QuantumKernel
from qiskit_optimization.algorithms import CobylaOptimizer
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.svm import SVC

Expand Down Expand Up @@ -582,7 +583,7 @@ class QuanticMDM(QuanticClassifierBase):

"""Quantum-enhanced MDM classifier

This class is a convex implementation of the Minimum Distance to Mean (MDM)
This class is a quantic implementation of the Minimum Distance to Mean (MDM)
[1]_, which can run with quantum optimization.
Only log-Euclidean distance between trial and class prototypes is supported
at the moment, but any type of metric can be used for centroid estimation.
Expand All @@ -593,11 +594,13 @@ class QuanticMDM(QuanticClassifierBase):
.. versionchanged:: 0.1.0
Fix: copy estimator not keeping base class parameters.
.. versionchanged:: 0.2.0
Add seed parameter
Add seed parameter.
Add regularization parameter.
Add classical_optimizer parameter.

Parameters
----------
metric : string | dict, default={"mean": 'logeuclid', "distance": 'convex'}
metric : string | dict, default={"mean": 'logeuclid', "distance": 'cpm'}
The type of metric used for centroid and distance estimation.
see `mean_covariance` for the list of supported metric.
the metric could be a dict with two keys, `mean` and `distance` in
Expand All @@ -606,7 +609,7 @@ class QuanticMDM(QuanticClassifierBase):
the mean in order to boost the computional speed and 'riemann' for the
distance in order to keep the good sensitivity for the classification.
quantum : bool (default: True)
Only applies if `metric` contains a convex distance or mean.
Only applies if `metric` contains a cpm distance or mean.

- If true will run on local or remote backend
(depending on q_account_token value),
Expand All @@ -624,6 +627,10 @@ class QuanticMDM(QuanticClassifierBase):
Random seed for the simulation
upper_bound : int (default: 7)
The maximum integer value for matrix normalization.
regularization: MixinTransformer (defulat: None)
Additional post-processing to regularize means.
classical_optimizer : OptimizationAlgorithm
An instance of OptimizationAlgorithm [3]_

See Also
--------
Expand All @@ -642,26 +649,32 @@ class QuanticMDM(QuanticClassifierBase):
A. Barachant, S. Bonnet, M. Congedo and C. Jutten. 9th International
Conference Latent Variable Analysis and Signal Separation
(LVA/ICA 2010), LNCS vol. 6365, 2010, p. 629-636.
.. [3] \
https://qiskit-community.github.io/qiskit-optimization/stubs/qiskit_optimization.algorithms.OptimizationAlgorithm.html#optimizationalgorithm
"""

def __init__(
self,
metric={"mean": "logeuclid", "distance": "convex"},
metric={"mean": "logeuclid", "distance": "logeuclid_cpm"},
quantum=True,
q_account_token=None,
verbose=True,
shots=1024,
seed=None,
upper_bound=7,
regularization=None,
classical_optimizer=CobylaOptimizer(rhobeg=2.1, rhoend=0.000001),
):
QuanticClassifierBase.__init__(
self, quantum, q_account_token, verbose, shots, None, seed
)
self.metric = metric
self.upper_bound = upper_bound
self.regularization = regularization
self.classical_optimizer = classical_optimizer

def _init_algo(self, n_features):
self._log("Convex MDM initiating algorithm")
self._log("Quantic MDM initiating algorithm")
classifier = MDM(metric=self.metric)
if self.quantum:
self._log("Using NaiveQAOAOptimizer")
Expand All @@ -670,10 +683,17 @@ def _init_algo(self, n_features):
)
else:
self._log("Using ClassicalOptimizer (COBYLA)")
self._optimizer = ClassicalOptimizer()
self._optimizer = ClassicalOptimizer(self.classical_optimizer)
set_global_optimizer(self._optimizer)
return classifier

def _train(self, X, y):
QuanticClassifierBase._train(self, X, y)
if self.regularization is not None:
self._classifier.covmeans_ = self.regularization.fit_transform(
self._classifier.covmeans_
)

def predict(self, X):
"""Calculates the predictions.

Expand Down
Loading
Loading