Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantic mdm #130

Merged
merged 43 commits into from
May 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
f2bd941
Update classification.py
gcattan Apr 20, 2023
a4a856b
add template for QMDM
gcattan Apr 20, 2023
d0fff94
add test template
gcattan Apr 20, 2023
c37d07c
add mdm docplex model
gcattan Apr 20, 2023
f26a902
flake8
gcattan Apr 20, 2023
169cdd8
add docplex mdm test
gcattan Apr 20, 2023
9cae55f
flake8
gcattan Apr 20, 2023
db10cc4
list to numpy array conversion
gcattan Apr 20, 2023
f7f9e7c
use matrices no vectors for MDM
gcattan Apr 20, 2023
4605fcd
get binary square matrices
gcattan Apr 20, 2023
0285db8
flake8
gcattan Apr 20, 2023
58358f5
use ndarray
gcattan Apr 20, 2023
42a2aa0
use covariance matrices for MDM
gcattan Apr 20, 2023
82de818
fix fixture being called from conftest
gcattan Apr 20, 2023
6ed3cd2
fix docplex test?
gcattan Apr 20, 2023
1f3c6dc
running mdm
gcattan Apr 21, 2023
1a6ce99
passing docplex test
gcattan Apr 21, 2023
559d3ad
flake8
gcattan Apr 21, 2023
5f90bec
test_classification pass
gcattan Apr 21, 2023
a68e5c9
cvx method used inside QuanticMDM. May be distance function is wrong …
gcattan Apr 21, 2023
1b85cdf
logeuclid distance
gcattan Apr 22, 2023
1540553
create distance package
gcattan Apr 22, 2023
57618d8
override _predict_distances
gcattan Apr 22, 2023
82650a2
create test_utils_distance
gcattan Apr 22, 2023
623d5b5
fix a couple of bugs with tests
gcattan May 16, 2023
6b6c82f
docplex doc
gcattan May 16, 2023
d10f79c
flake8 docplex
gcattan May 16, 2023
f612bb4
Merge branch 'pyRiemann:main' into quantic-mdm
gcattan May 17, 2023
8a178ea
Merge branch 'main' into quantic-mdm
gcattan May 25, 2023
28fd6eb
correct some typos. Reintroduce TestQuantumClassifierWithDefaultRiema…
gcattan May 26, 2023
cbc01e1
configure backend for QAOA
gcattan May 26, 2023
40de3cb
add global optimizer
gcattan May 26, 2023
e96b290
set the global optimizer in QuanticMDM
gcattan May 26, 2023
e2505e4
complete api
gcattan May 26, 2023
bdcbb56
flake8
gcattan May 26, 2023
447574e
flake8
gcattan May 26, 2023
ea5940c
flake8 ignore _global_optimizer no used
gcattan May 26, 2023
e89ff60
replace logeuclid by eudlic
gcattan May 26, 2023
f2c3544
use mne_sample
gcattan May 26, 2023
39b1ae4
change by logeuclid
gcattan May 26, 2023
b07cfdd
diminish number of split
gcattan May 26, 2023
8127a8e
update docker
gcattan May 26, 2023
f6650e4
fix distance_methods not found
gcattan May 26, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ RUN mkdir /root/mne_data
RUN mkdir /home/mne_data

## Workaround for firestore
RUN pip install protobuf==4.23.1
RUN pip install protobuf==4.23.2
RUN pip install google_cloud_firestore==2.11.1
### Missing __init__ file in protobuf
RUN touch /usr/local/lib/python3.8/site-packages/protobuf-4.23.1-py3.8.egg/google/__init__.py
RUN touch /usr/local/lib/python3.8/site-packages/protobuf-4.23.2-py3.8.egg/google/__init__.py
## google.cloud.location is never used in these files, and is missing in path.
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.11.1-py3.8.egg/google/cloud/firestore_v1/services/firestore/client.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.11.1-py3.8.egg/google/cloud/firestore_v1/services/firestore/transports/base.py'
Expand Down
3 changes: 3 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Classification
QuanticClassifierBase
QuanticSVM
QuanticVQC
QuanticMDM
QuantumClassifierWithDefaultRiemannianPipeline


Expand Down Expand Up @@ -83,6 +84,8 @@ Docplex
pyQiskitOptimizer
ClassicalOptimizer
NaiveQAOAOptimizer
set_global_optimizer
get_global_optimizer

Math
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion doc/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ qiskit-optimization==0.5.0
qiskit-aer==0.12.0
scipy==1.7.3
moabb>=0.4.6
git+https://github.com/pyRiemann/pyRiemann#egg=pyriemann
pyriemann==0.4
docplex>=2.21.207
firebase_admin==6.1.0
99 changes: 99 additions & 0 deletions pyriemann_qiskit/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
get_spsa)
from .utils import get_provider, get_devices, get_simulator
from pyriemann.estimation import XdawnCovariances
from pyriemann.classification import MDM
from pyriemann.tangentspace import TangentSpace
from pyriemann_qiskit.datasets import get_feature_dimension
from pyriemann_qiskit.utils import (ClassicalOptimizer,
NaiveQAOAOptimizer,
set_global_optimizer)

logger.level = logging.INFO

Expand Down Expand Up @@ -470,6 +474,101 @@ def predict(self, X):
return self._map_0_1_to_classes(labels)


class QuanticMDM(QuanticClassifierBase):

"""Quantum-enhanced MDM

# This class is a convex implementation of the MDM [1]_,
# that can runs with quantum optimization.
# Only log-euclidian distance between trial and class prototypes
# is supported at the moment, but any type of metric
# can be used for centroid estimation.

Notes
-----
.. versionadded:: 0.0.4

Parameters
----------
metric : string | dict, default={"mean": 'logeuclid', "distance": 'convex'}
The type of metric used for centroid and distance estimation.
see `mean_covariance` for the list of supported metric.
the metric could be a dict with two keys, `mean` and `distance` in
order to pass different metrics for the centroid estimation and the
distance estimation. Typical usecase is to pass 'logeuclid' metric for
the mean in order to boost the computional speed and 'riemann' for the
distance in order to keep the good sensitivity for the classification.

See Also
--------
QuanticClassifierBase
pyriemann.classification.MDM

References
----------
.. [1] `Multiclass Brain-Computer Interface Classification by Riemannian
Geometry
<https://hal.archives-ouvertes.fr/hal-00681328>`_
A. Barachant, S. Bonnet, M. Congedo, and C. Jutten. IEEE Transactions
on Biomedical Engineering, vol. 59, no. 4, p. 920-928, 2012.
.. [2] `Riemannian geometry applied to BCI classification
<https://hal.archives-ouvertes.fr/hal-00602700/>`_
A. Barachant, S. Bonnet, M. Congedo and C. Jutten. 9th International
Conference Latent Variable Analysis and Signal Separation
(LVA/ICA 2010), LNCS vol. 6365, 2010, p. 629-636.
"""

def __init__(self,
metric={"mean": 'logeuclid', "distance": 'convex'},
**parameters):
QuanticClassifierBase.__init__(self, **parameters)
self.metric = metric

def _init_algo(self, n_features):
self._log("Convex MDM initiating algorithm")
classifier = MDM(metric=self.metric)
if self.quantum:
self._optimizer = \
NaiveQAOAOptimizer(quantum_instance=self._quantum_instance)
else:
self._optimizer = ClassicalOptimizer()
set_global_optimizer(self._optimizer)
return classifier

def predict_proba(self, X):
"""Return the probabilities associated with predictions.

Parameters
----------
X : ndarray, shape (n_trials, n_channels, n_channels)
ndarray of trials.

Returns
-------
prob : ndarray, shape (n_samples, n_classes)
prob[n, 0] == True if the nth sample is assigned to 1st class;
prob[n, 1] == True if the nth sample is assigned to 2nd class.
"""
return self._classifier.predict_proba(X)

def predict(self, X):
"""Calculates the predictions.

Parameters
----------
X : ndarray, shape (n_samples, n_features)
Input vector, where `n_samples` is the number of samples and
`n_features` is the number of features.

Returns
-------
pred : array, shape (n_samples,)
Class labels for samples in X.
"""
labels = self._predict(X)
return self._map_0_1_to_classes(labels)


class QuantumClassifierWithDefaultRiemannianPipeline(BaseEstimator,
ClassifierMixin,
TransformerMixin):
Expand Down
6 changes: 5 additions & 1 deletion pyriemann_qiskit/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
square_int_mat_var,
square_bin_mat_var,
ClassicalOptimizer,
NaiveQAOAOptimizer)
NaiveQAOAOptimizer,
set_global_optimizer,
get_global_optimizer)
from .firebase_connector import (
FirebaseConnector,
Cache,
Expand All @@ -30,6 +32,8 @@
'square_bin_mat_var',
'ClassicalOptimizer',
'NaiveQAOAOptimizer',
'set_global_optimizer',
'get_global_optimizer',
'logeucl_dist_convex',
'FirebaseConnector',
'Cache',
Expand Down
31 changes: 29 additions & 2 deletions pyriemann_qiskit/utils/distance.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import numpy as np
from docplex.mp.model import Model
from pyriemann_qiskit.utils.docplex import ClassicalOptimizer
from pyriemann.utils.distance import distance_logeuclid
from pyriemann_qiskit.utils.docplex import (ClassicalOptimizer,
get_global_optimizer)
from pyriemann.classification import MDM
from pyriemann.utils.distance import (distance_logeuclid,
distance_methods)


def logeucl_dist_convex(X, y, optimizer=ClassicalOptimizer()):
Expand Down Expand Up @@ -33,6 +37,8 @@ def logeucl_dist_convex(X, y, optimizer=ClassicalOptimizer()):
http://ibmdecisionoptimization.github.io/docplex-doc/mp/_modules/docplex/mp/model.html#Model
"""

optimizer = get_global_optimizer(optimizer)

n_classes, _, _ = X.shape
classes = range(n_classes)

Expand All @@ -56,3 +62,24 @@ def dist(m1, m2):
result = optimizer.solve(prob, reshape=False)

return result


_mdm_predict_distances_original = MDM._predict_distances


def predict_distances(mdm, X):
if mdm.metric_dist == 'convex':
centroids = np.array(mdm.covmeans_)
return np.array([logeucl_dist_convex(centroids, x) for x in X])
else:
return _mdm_predict_distances_original(mdm, X)


MDM._predict_distances = predict_distances

# This is only for validation inside the MDM.
# In fact, we override the _predict_distances method
# inside MDM to directly use logeucl_dist_convex when the metric is "convex"
# This is due to the fact the the signature of this method is different from
# the usual distance functions.
distance_methods['convex'] = logeucl_dist_convex
52 changes: 49 additions & 3 deletions pyriemann_qiskit/utils/docplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,45 @@
from pyriemann_qiskit.utils import cov_to_corr_matrix, get_simulator


_global_optimizer = None


def set_global_optimizer(optimizer):
"""Set the value of the global optimizer

Parameters
----------
optimizer: pyQiskitOptimizer
An instance of pyQiskitOptimizer.

Notes
-----
.. versionadded:: 0.0.4
"""
_global_optimizer = optimizer # noqa


def get_global_optimizer(default):
"""Get the value of the global optimizer

Parameters
----------
default: pyQiskitOptimizer
An instance of pyQiskitOptimizer.
It will be returned by default if the global optimizer is None.

Returns
-------
optimizer : pyQiskitOptimizer
The global optimizer.

Notes
-----
.. versionadded:: 0.0.4
"""
return _global_optimizer if _global_optimizer is not None else default


def square_cont_mat_var(prob, channels,
name='cont_covmat'):
"""Creates a 2-dimensional dictionary of continuous decision variables,
Expand Down Expand Up @@ -360,6 +399,9 @@ class NaiveQAOAOptimizer(pyQiskitOptimizer):
----------
upper_bound : int (default: 7)
The maximum integer value for matrix normalization.
backend: QuantumInstance (default: None)
A quantum backend instance.
If None, AerSimulator will be used.

Notes
-----
Expand All @@ -370,9 +412,10 @@ class NaiveQAOAOptimizer(pyQiskitOptimizer):
--------
pyQiskitOptimizer
"""
def __init__(self, upper_bound=7):
def __init__(self, upper_bound=7, quantum_instance=None):
pyQiskitOptimizer.__init__(self)
self.upper_bound = upper_bound
self.quantum_instance = quantum_instance

"""Transform all values in the covariance matrix
to integers.
Expand Down Expand Up @@ -439,8 +482,11 @@ def covmat_var(self, prob, channels, name):
def _solve_qp(self, qp, reshape=True):
conv = IntegerToBinary()
qubo = conv.convert(qp)
backend = get_simulator()
quantum_instance = QuantumInstance(backend)
if self.quantum_instance is None:
backend = get_simulator()
quantum_instance = QuantumInstance(backend)
else:
quantum_instance = self.quantum_instance
qaoa_mes = QAOA(quantum_instance=quantum_instance,
initial_point=[0., 0.])
qaoa = MinimumEigenOptimizer(qaoa_mes)
Expand Down
5 changes: 4 additions & 1 deletion pyriemann_qiskit/utils/mean.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from docplex.mp.model import Model
from pyriemann.utils.mean import mean_methods
from pyriemann_qiskit.utils.docplex import ClassicalOptimizer
from pyriemann_qiskit.utils.docplex import (ClassicalOptimizer,
get_global_optimizer)


def fro_mean_convex(covmats, sample_weight=None,
Expand Down Expand Up @@ -32,6 +33,8 @@ def fro_mean_convex(covmats, sample_weight=None,
http://ibmdecisionoptimization.github.io/docplex-doc/mp/_modules/docplex/mp/model.html#Model
"""

optimizer = get_global_optimizer(optimizer)

n_trials, n_channels, _ = covmats.shape
channels = range(n_channels)
trials = range(n_trials)
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,16 @@ def _get_dataset(n_samples, n_features, n_classes, type="bin"):
elif type == "bin":
samples = _get_binary_feats(n_samples, n_features)
labels = _get_labels(n_samples, n_classes)
elif type == "rand_cov":
samples = make_covariances(n_samples, n_features, 0,
return_params=False)
labels = _get_labels(n_samples, n_classes)
elif type == "bin_cov":
samples_0 = make_covariances(n_samples // n_classes, n_features, 0,
return_params=False)
samples_1 = samples_0 * 2
samples = np.concatenate((samples_0, samples_1), axis=0)
labels = _get_labels(n_samples, n_classes)
else:
samples, labels = get_mne_sample()
return samples, labels
Expand Down
29 changes: 26 additions & 3 deletions tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pyriemann_qiskit.classification \
import (QuanticSVM,
QuanticVQC,
QuanticMDM,
QuantumClassifierWithDefaultRiemannianPipeline)
from pyriemann_qiskit.datasets import get_mne_sample
from pyriemann_qiskit.utils.filtering import NaiveDimRed
Expand All @@ -17,6 +18,8 @@
[make_pipeline(XdawnCovariances(nfilter=1),
TangentSpace(), NaiveDimRed(),
QuanticSVM(quantum=False)),
make_pipeline(XdawnCovariances(nfilter=1),
QuanticMDM(quantum=False)),
QuantumClassifierWithDefaultRiemannianPipeline(
nfilter=1,
shots=None)])
Expand Down Expand Up @@ -118,9 +121,7 @@ def additional_steps(self):


class TestClassicalSVM(BinaryFVT):
""" Perform standard SVC test
(canary test to assess pipeline correctness)
"""
""" Perform functional validation testing of Quantic SVM"""
def get_params(self):
quantum_instance = QuanticSVM(quantum=False, verbose=False)
return {
Expand Down Expand Up @@ -188,6 +189,28 @@ def check(self):
assert len(self.prediction) == len(self.labels)


class TestClassicalMDM(BinaryFVT):
"""Test the classical version of MDM is used
when quantum is false
https://quantum-computing.ibm.com/
Note that the "real quantum version" of this test may also take some time.
"""
def get_params(self):
quantum_instance = QuanticMDM(quantum=False, verbose=False)
return {
"n_samples": 100,
"n_features": 9,
"quantum_instance": quantum_instance,
"type": "bin_cov"
}

def check(self):
assert self.prediction[:self.class_len].all() == \
self.quantum_instance.classes_[0]
assert self.prediction[self.class_len:].all() == \
self.quantum_instance.classes_[1]


class TestQuantumClassifierWithDefaultRiemannianPipeline(BinaryFVT):
"""Functional testing for riemann quantum classifier."""
def get_params(self):
Expand Down
Loading