Skip to content

Commit

Permalink
Quantic mdm (#130)
Browse files Browse the repository at this point in the history
* Update classification.py

* add template for QMDM

* add test template

* add mdm docplex model

* flake8

* add docplex mdm test

* flake8

* list to numpy array conversion

* use matrices no vectors for MDM

* get binary square matrices

* flake8

* use ndarray

* use covariance matrices for MDM

* fix fixture being called from conftest

* fix docplex test?

* running mdm

* passing docplex test

* flake8

* test_classification pass

* cvx method used inside QuanticMDM. May be distance function is wrong inside cvx method?

* logeuclid distance

* create distance package

* override _predict_distances

* create test_utils_distance

* fix a couple of bugs with tests

* docplex doc

* flake8 docplex

* correct some typos. Reintroduce TestQuantumClassifierWithDefaultRiemannianPipeline

* configure backend for QAOA

* add global optimizer

* set the global optimizer in QuanticMDM

* complete api

* flake8

* flake8

* flake8 ignore _global_optimizer no used

* replace logeuclid by eudlic

* use mne_sample

* change by logeuclid

* diminish number of split

* update docker

* fix distance_methods not found

---------

Co-authored-by: Gregoire Cattan <gregoire.cattan@ibm.com>
  • Loading branch information
gcattan and gcattan authored May 27, 2023
1 parent 7b5cf5b commit 05f4043
Show file tree
Hide file tree
Showing 12 changed files with 247 additions and 17 deletions.
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ RUN mkdir /root/mne_data
RUN mkdir /home/mne_data

## Workaround for firestore
RUN pip install protobuf==4.23.1
RUN pip install protobuf==4.23.2
RUN pip install google_cloud_firestore==2.11.1
### Missing __init__ file in protobuf
RUN touch /usr/local/lib/python3.8/site-packages/protobuf-4.23.1-py3.8.egg/google/__init__.py
RUN touch /usr/local/lib/python3.8/site-packages/protobuf-4.23.2-py3.8.egg/google/__init__.py
## google.cloud.location is never used in these files, and is missing in path.
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.11.1-py3.8.egg/google/cloud/firestore_v1/services/firestore/client.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.11.1-py3.8.egg/google/cloud/firestore_v1/services/firestore/transports/base.py'
Expand Down
3 changes: 3 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Classification
QuanticClassifierBase
QuanticSVM
QuanticVQC
QuanticMDM
QuantumClassifierWithDefaultRiemannianPipeline


Expand Down Expand Up @@ -83,6 +84,8 @@ Docplex
pyQiskitOptimizer
ClassicalOptimizer
NaiveQAOAOptimizer
set_global_optimizer
get_global_optimizer

Math
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion doc/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ qiskit-optimization==0.5.0
qiskit-aer==0.12.0
scipy==1.7.3
moabb>=0.4.6
git+https://github.com/pyRiemann/pyRiemann#egg=pyriemann
pyriemann==0.4
docplex>=2.21.207
firebase_admin==6.1.0
99 changes: 99 additions & 0 deletions pyriemann_qiskit/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
get_spsa)
from .utils import get_provider, get_devices, get_simulator
from pyriemann.estimation import XdawnCovariances
from pyriemann.classification import MDM
from pyriemann.tangentspace import TangentSpace
from pyriemann_qiskit.datasets import get_feature_dimension
from pyriemann_qiskit.utils import (ClassicalOptimizer,
NaiveQAOAOptimizer,
set_global_optimizer)

logger.level = logging.INFO

Expand Down Expand Up @@ -470,6 +474,101 @@ def predict(self, X):
return self._map_0_1_to_classes(labels)


class QuanticMDM(QuanticClassifierBase):

"""Quantum-enhanced MDM
# This class is a convex implementation of the MDM [1]_,
# that can runs with quantum optimization.
# Only log-euclidian distance between trial and class prototypes
# is supported at the moment, but any type of metric
# can be used for centroid estimation.
Notes
-----
.. versionadded:: 0.0.4
Parameters
----------
metric : string | dict, default={"mean": 'logeuclid', "distance": 'convex'}
The type of metric used for centroid and distance estimation.
see `mean_covariance` for the list of supported metric.
the metric could be a dict with two keys, `mean` and `distance` in
order to pass different metrics for the centroid estimation and the
distance estimation. Typical usecase is to pass 'logeuclid' metric for
the mean in order to boost the computional speed and 'riemann' for the
distance in order to keep the good sensitivity for the classification.
See Also
--------
QuanticClassifierBase
pyriemann.classification.MDM
References
----------
.. [1] `Multiclass Brain-Computer Interface Classification by Riemannian
Geometry
<https://hal.archives-ouvertes.fr/hal-00681328>`_
A. Barachant, S. Bonnet, M. Congedo, and C. Jutten. IEEE Transactions
on Biomedical Engineering, vol. 59, no. 4, p. 920-928, 2012.
.. [2] `Riemannian geometry applied to BCI classification
<https://hal.archives-ouvertes.fr/hal-00602700/>`_
A. Barachant, S. Bonnet, M. Congedo and C. Jutten. 9th International
Conference Latent Variable Analysis and Signal Separation
(LVA/ICA 2010), LNCS vol. 6365, 2010, p. 629-636.
"""

def __init__(self,
metric={"mean": 'logeuclid', "distance": 'convex'},
**parameters):
QuanticClassifierBase.__init__(self, **parameters)
self.metric = metric

def _init_algo(self, n_features):
self._log("Convex MDM initiating algorithm")
classifier = MDM(metric=self.metric)
if self.quantum:
self._optimizer = \
NaiveQAOAOptimizer(quantum_instance=self._quantum_instance)
else:
self._optimizer = ClassicalOptimizer()
set_global_optimizer(self._optimizer)
return classifier

def predict_proba(self, X):
"""Return the probabilities associated with predictions.
Parameters
----------
X : ndarray, shape (n_trials, n_channels, n_channels)
ndarray of trials.
Returns
-------
prob : ndarray, shape (n_samples, n_classes)
prob[n, 0] == True if the nth sample is assigned to 1st class;
prob[n, 1] == True if the nth sample is assigned to 2nd class.
"""
return self._classifier.predict_proba(X)

def predict(self, X):
"""Calculates the predictions.
Parameters
----------
X : ndarray, shape (n_samples, n_features)
Input vector, where `n_samples` is the number of samples and
`n_features` is the number of features.
Returns
-------
pred : array, shape (n_samples,)
Class labels for samples in X.
"""
labels = self._predict(X)
return self._map_0_1_to_classes(labels)


class QuantumClassifierWithDefaultRiemannianPipeline(BaseEstimator,
ClassifierMixin,
TransformerMixin):
Expand Down
6 changes: 5 additions & 1 deletion pyriemann_qiskit/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
square_int_mat_var,
square_bin_mat_var,
ClassicalOptimizer,
NaiveQAOAOptimizer)
NaiveQAOAOptimizer,
set_global_optimizer,
get_global_optimizer)
from .firebase_connector import (
FirebaseConnector,
Cache,
Expand All @@ -30,6 +32,8 @@
'square_bin_mat_var',
'ClassicalOptimizer',
'NaiveQAOAOptimizer',
'set_global_optimizer',
'get_global_optimizer',
'logeucl_dist_convex',
'FirebaseConnector',
'Cache',
Expand Down
31 changes: 29 additions & 2 deletions pyriemann_qiskit/utils/distance.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import numpy as np
from docplex.mp.model import Model
from pyriemann_qiskit.utils.docplex import ClassicalOptimizer
from pyriemann.utils.distance import distance_logeuclid
from pyriemann_qiskit.utils.docplex import (ClassicalOptimizer,
get_global_optimizer)
from pyriemann.classification import MDM
from pyriemann.utils.distance import (distance_logeuclid,
distance_methods)


def logeucl_dist_convex(X, y, optimizer=ClassicalOptimizer()):
Expand Down Expand Up @@ -33,6 +37,8 @@ def logeucl_dist_convex(X, y, optimizer=ClassicalOptimizer()):
http://ibmdecisionoptimization.github.io/docplex-doc/mp/_modules/docplex/mp/model.html#Model
"""

optimizer = get_global_optimizer(optimizer)

n_classes, _, _ = X.shape
classes = range(n_classes)

Expand All @@ -56,3 +62,24 @@ def dist(m1, m2):
result = optimizer.solve(prob, reshape=False)

return result


_mdm_predict_distances_original = MDM._predict_distances


def predict_distances(mdm, X):
if mdm.metric_dist == 'convex':
centroids = np.array(mdm.covmeans_)
return np.array([logeucl_dist_convex(centroids, x) for x in X])
else:
return _mdm_predict_distances_original(mdm, X)


MDM._predict_distances = predict_distances

# This is only for validation inside the MDM.
# In fact, we override the _predict_distances method
# inside MDM to directly use logeucl_dist_convex when the metric is "convex"
# This is due to the fact the the signature of this method is different from
# the usual distance functions.
distance_methods['convex'] = logeucl_dist_convex
52 changes: 49 additions & 3 deletions pyriemann_qiskit/utils/docplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,45 @@
from pyriemann_qiskit.utils import cov_to_corr_matrix, get_simulator


_global_optimizer = None


def set_global_optimizer(optimizer):
"""Set the value of the global optimizer
Parameters
----------
optimizer: pyQiskitOptimizer
An instance of pyQiskitOptimizer.
Notes
-----
.. versionadded:: 0.0.4
"""
_global_optimizer = optimizer # noqa


def get_global_optimizer(default):
"""Get the value of the global optimizer
Parameters
----------
default: pyQiskitOptimizer
An instance of pyQiskitOptimizer.
It will be returned by default if the global optimizer is None.
Returns
-------
optimizer : pyQiskitOptimizer
The global optimizer.
Notes
-----
.. versionadded:: 0.0.4
"""
return _global_optimizer if _global_optimizer is not None else default


def square_cont_mat_var(prob, channels,
name='cont_covmat'):
"""Creates a 2-dimensional dictionary of continuous decision variables,
Expand Down Expand Up @@ -360,6 +399,9 @@ class NaiveQAOAOptimizer(pyQiskitOptimizer):
----------
upper_bound : int (default: 7)
The maximum integer value for matrix normalization.
backend: QuantumInstance (default: None)
A quantum backend instance.
If None, AerSimulator will be used.
Notes
-----
Expand All @@ -370,9 +412,10 @@ class NaiveQAOAOptimizer(pyQiskitOptimizer):
--------
pyQiskitOptimizer
"""
def __init__(self, upper_bound=7):
def __init__(self, upper_bound=7, quantum_instance=None):
pyQiskitOptimizer.__init__(self)
self.upper_bound = upper_bound
self.quantum_instance = quantum_instance

"""Transform all values in the covariance matrix
to integers.
Expand Down Expand Up @@ -439,8 +482,11 @@ def covmat_var(self, prob, channels, name):
def _solve_qp(self, qp, reshape=True):
conv = IntegerToBinary()
qubo = conv.convert(qp)
backend = get_simulator()
quantum_instance = QuantumInstance(backend)
if self.quantum_instance is None:
backend = get_simulator()
quantum_instance = QuantumInstance(backend)
else:
quantum_instance = self.quantum_instance
qaoa_mes = QAOA(quantum_instance=quantum_instance,
initial_point=[0., 0.])
qaoa = MinimumEigenOptimizer(qaoa_mes)
Expand Down
5 changes: 4 additions & 1 deletion pyriemann_qiskit/utils/mean.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from docplex.mp.model import Model
from pyriemann.utils.mean import mean_methods
from pyriemann_qiskit.utils.docplex import ClassicalOptimizer
from pyriemann_qiskit.utils.docplex import (ClassicalOptimizer,
get_global_optimizer)


def fro_mean_convex(covmats, sample_weight=None,
Expand Down Expand Up @@ -32,6 +33,8 @@ def fro_mean_convex(covmats, sample_weight=None,
http://ibmdecisionoptimization.github.io/docplex-doc/mp/_modules/docplex/mp/model.html#Model
"""

optimizer = get_global_optimizer(optimizer)

n_trials, n_channels, _ = covmats.shape
channels = range(n_channels)
trials = range(n_trials)
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,16 @@ def _get_dataset(n_samples, n_features, n_classes, type="bin"):
elif type == "bin":
samples = _get_binary_feats(n_samples, n_features)
labels = _get_labels(n_samples, n_classes)
elif type == "rand_cov":
samples = make_covariances(n_samples, n_features, 0,
return_params=False)
labels = _get_labels(n_samples, n_classes)
elif type == "bin_cov":
samples_0 = make_covariances(n_samples // n_classes, n_features, 0,
return_params=False)
samples_1 = samples_0 * 2
samples = np.concatenate((samples_0, samples_1), axis=0)
labels = _get_labels(n_samples, n_classes)
else:
samples, labels = get_mne_sample()
return samples, labels
Expand Down
29 changes: 26 additions & 3 deletions tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pyriemann_qiskit.classification \
import (QuanticSVM,
QuanticVQC,
QuanticMDM,
QuantumClassifierWithDefaultRiemannianPipeline)
from pyriemann_qiskit.datasets import get_mne_sample
from pyriemann_qiskit.utils.filtering import NaiveDimRed
Expand All @@ -17,6 +18,8 @@
[make_pipeline(XdawnCovariances(nfilter=1),
TangentSpace(), NaiveDimRed(),
QuanticSVM(quantum=False)),
make_pipeline(XdawnCovariances(nfilter=1),
QuanticMDM(quantum=False)),
QuantumClassifierWithDefaultRiemannianPipeline(
nfilter=1,
shots=None)])
Expand Down Expand Up @@ -118,9 +121,7 @@ def additional_steps(self):


class TestClassicalSVM(BinaryFVT):
""" Perform standard SVC test
(canary test to assess pipeline correctness)
"""
""" Perform functional validation testing of Quantic SVM"""
def get_params(self):
quantum_instance = QuanticSVM(quantum=False, verbose=False)
return {
Expand Down Expand Up @@ -188,6 +189,28 @@ def check(self):
assert len(self.prediction) == len(self.labels)


class TestClassicalMDM(BinaryFVT):
"""Test the classical version of MDM is used
when quantum is false
https://quantum-computing.ibm.com/
Note that the "real quantum version" of this test may also take some time.
"""
def get_params(self):
quantum_instance = QuanticMDM(quantum=False, verbose=False)
return {
"n_samples": 100,
"n_features": 9,
"quantum_instance": quantum_instance,
"type": "bin_cov"
}

def check(self):
assert self.prediction[:self.class_len].all() == \
self.quantum_instance.classes_[0]
assert self.prediction[self.class_len:].all() == \
self.quantum_instance.classes_[1]


class TestQuantumClassifierWithDefaultRiemannianPipeline(BinaryFVT):
"""Functional testing for riemann quantum classifier."""
def get_params(self):
Expand Down
Loading

0 comments on commit 05f4043

Please sign in to comment.