Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiclass classification #178

Merged
merged 54 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
17ff0f8
Update conftest.py
gcattan Aug 25, 2023
cf579c6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 25, 2023
197cb41
Update test_classification.py
gcattan Aug 25, 2023
694330d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 25, 2023
b593f2f
Update conftest.py
gcattan Aug 25, 2023
6fa3efe
Update test_classification.py
gcattan Aug 25, 2023
79d22b1
Update test_classification.py
gcattan Aug 25, 2023
5e255d7
Update test_classification.py
gcattan Aug 25, 2023
25a359d
Update conftest.py
gcattan Aug 25, 2023
c9aa4b5
Update conftest.py
gcattan Aug 25, 2023
1440ecf
Update classification.py
gcattan Aug 26, 2023
83f7b1c
Update test_classification.py
gcattan Aug 26, 2023
bd18206
Update Dockerfile
gcattan Aug 26, 2023
3b2e23a
Update Dockerfile
gcattan Aug 26, 2023
fe2a7d6
Update classification.py
gcattan Aug 26, 2023
1afb8c3
Update test_classification.py
gcattan Aug 26, 2023
365bd37
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2023
dfdc359
Update test_classification.py
gcattan Aug 26, 2023
98b8878
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2023
a78fa9d
Update test_classification.py
gcattan Aug 26, 2023
5075593
Update classification.py
gcattan Aug 26, 2023
0dbfeca
Update test_classification.py
gcattan Aug 26, 2023
d2a4373
Update test_classification.py
gcattan Aug 26, 2023
fd5cb0c
Update conftest.py
gcattan Aug 26, 2023
d3e5133
Update test_classification.py
gcattan Aug 26, 2023
f8611a7
Update conftest.py
gcattan Aug 26, 2023
ed7908c
Update classification.py
gcattan Aug 26, 2023
fa78440
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2023
0b7511a
Update classification.py
gcattan Aug 26, 2023
ee56f57
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2023
8f3a326
Update test_classification.py
gcattan Aug 26, 2023
a347a7c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2023
de4a988
Update utils.py
gcattan Aug 26, 2023
fcab439
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2023
7cfb8cb
Create plot_multiclass_classification.py
gcattan Aug 26, 2023
2096b30
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2023
9cd7179
Update plot_multiclass_classification.py
gcattan Aug 26, 2023
8d4678b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2023
1365e11
Update plot_multiclass_classification.py
gcattan Aug 26, 2023
7b76245
Update plot_multiclass_classification.py
gcattan Aug 26, 2023
53973bb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2023
dfe66b8
- remove custom implementation of one-hot encoding (already done with…
gcattan Aug 26, 2023
42b25e6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2023
b562fb3
rename Multiclasses -> Multiclass
gcattan Aug 26, 2023
26e58b4
inverted y_pred and y_true in balanced accuracy
gcattan Aug 26, 2023
a8a9f85
Testing the code. Comments have been updated.
toncho11 Sep 5, 2023
a60245d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2023
48341e9
Updated comments.
toncho11 Sep 5, 2023
1f4d788
Update conftest.py
toncho11 Sep 5, 2023
4637bbc
Update plot_classify_EEG_quantum_svm.py
toncho11 Sep 5, 2023
07a1701
Update utils.py
toncho11 Sep 5, 2023
2bc685b
Update comments in classification.py
toncho11 Sep 5, 2023
593f2b7
Update multiclass_classification.py
toncho11 Sep 5, 2023
9961dea
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ RUN mkdir /root/mne_data
RUN mkdir /home/mne_data

## Workaround for firestore
RUN pip install protobuf==4.24.0rc2
RUN pip install protobuf==4.24.2
RUN pip install google_cloud_firestore==2.11.1
### Missing __init__ file in protobuf
RUN touch /usr/local/lib/python3.8/site-packages/protobuf-4.24.0rc2-py3.8.egg/google/__init__.py
RUN touch /usr/local/lib/python3.8/site-packages/protobuf-4.24.2-py3.8-linux-x86_64.egg/google/__init__.py
## google.cloud.location is never used in these files, and is missing in path.
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.11.1-py3.8.egg/google/cloud/firestore_v1/services/firestore/client.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.11.1-py3.8.egg/google/cloud/firestore_v1/services/firestore/transports/base.py'
Expand Down
124 changes: 124 additions & 0 deletions examples/ERP/multiclass_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""
====================================================================
Multiclass EEG classification with Quantum Pipeline
====================================================================

This example demonstrates multiclass EEG classification with a quantum
classifier.
We will be comparing the performance of VQC vs Quantum SVM vs
Classical SVM vs Quantum MDM vs MDM.
Execution takes approximately 1h.
"""
# Author: Gregoire Cattan
# Modified from plot_classify_EEG_quantum_svm
# License: BSD (3-clause)

from pyriemann_qiskit.datasets import get_mne_sample
from pyriemann_qiskit.pipelines import (
QuantumClassifierWithDefaultRiemannianPipeline,
QuantumMDMWithRiemannianPipeline,
)
from pyriemann.estimation import ERPCovariances
from pyriemann.classification import MDM
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
confusion_matrix,
ConfusionMatrixDisplay,
balanced_accuracy_score,
)
from sklearn.decomposition import PCA
from matplotlib import pyplot as plt


print(__doc__)

###############################################################################
# Get the data

# Use MNE sample. The include_auditory parameter select 4 classes.
X, y = get_mne_sample(n_trials=-1, include_auditory=True)

# evaluation without k-fold cross-validation
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)

###############################################################################
# Decoding in tangent space with a quantum classifier

# Our helper class QuantumClassifierWithDefaultRiemannianPipeline allows to
# auto-configure the parameters of the pipelines.
# Warning: these are not optimal parameters

# Piepeline 1
quantum_svm = QuantumClassifierWithDefaultRiemannianPipeline(
dim_red=PCA(n_components=5),
)

# Piepeline 2
classical_svm = QuantumClassifierWithDefaultRiemannianPipeline(
shots=None, # 'None' forces classic SVM
dim_red=PCA(n_components=5),
)

# Piepeline 3
vqc = QuantumClassifierWithDefaultRiemannianPipeline(
dim_red=PCA(n_components=5),
# These parameters are specific to VQC.
# The pipeline will detect this and instantiate a VQC under the hood
spsa_trials=40,
two_local_reps=3,
)

# Piepeline 4
quantum_mdm = QuantumMDMWithRiemannianPipeline()

# Piepeline 5
mdm = make_pipeline(ERPCovariances(estimator="lwf"), MDM())

classifiers = [vqc, quantum_svm, classical_svm, quantum_mdm, mdm]

n_classifiers = len(classifiers)

# https://stackoverflow.com/questions/61825227/plotting-multiple-confusion-matrix-side-by-side
f, axes = plt.subplots(1, n_classifiers, sharey="row")

disp = None

# Compute results
for idx in range(n_classifiers):
# Training and classification
clf = classifiers[idx]
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)

# Printing the results
acc = balanced_accuracy_score(y_test, y_pred)
acc_str = "%0.2f" % acc

# Results visualization
# A confusion matrix is reported for each classifier. A perfectly performing
# classifier will have only its diagonal filled and the rest will be zeros.
names = ["aud left", "aud right", "vis left", "vis right"]
title = (
("VQC (" if idx == 0 else "Quantum SVM (" if idx == 1 else "Classical SVM (")
if idx == 2
else "Quantum MDM ("
if idx == 3
else "R-MDM (" + acc_str + ")"
)
axe = axes[idx]
cm = confusion_matrix(y_pred, y_test)
disp = ConfusionMatrixDisplay(cm, display_labels=names)
disp.plot(ax=axe, xticks_rotation=45)
disp.ax_.set_title(title)
disp.im_.colorbar.remove()
disp.ax_.set_xlabel("")
if idx > 0:
disp.ax_.set_ylabel("")

# Display all the confusion matrices
if disp:
f.text(0.4, 0.1, "Predicted label", ha="left")
plt.subplots_adjust(wspace=0.40, hspace=0.1)
f.colorbar(disp.im_, ax=axes)
plt.show()
4 changes: 2 additions & 2 deletions examples/ERP/plot_classify_EEG_quantum_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Decoding applied to EEG data in sensor space using RG.
Xdawn spatial filtering is applied on covariances matrices, which are
then projected in the tangent space and classified with a quantum SVM
classifier. It is compared to the classical SVM on binary classification.
classifier. It is compared to the classical SVM on binary classification.

"""
# Author: Gregoire Cattan
Expand Down Expand Up @@ -67,7 +67,7 @@
y_pred = clf.predict(X_test)

# Printing the results
acc = balanced_accuracy_score(y_pred, y_test)
acc = balanced_accuracy_score(y_test, y_pred)
acc_str = "%0.2f" % acc

names = ["vis left", "vis right"]
Expand Down
82 changes: 32 additions & 50 deletions pyriemann_qiskit/classification.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
"""Module for classification function."""
"""
Contains the base class for all quantum classifiers
as well as several quantum classifiers than can run
in several modes quantum/classical and simulated/real
quantum computer.
"""
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.svm import SVC
Expand Down Expand Up @@ -37,13 +42,10 @@ class QuanticClassifierBase(BaseEstimator, ClassifierMixin):
* tasks on a real quantum computer are assigned to a queue
before being executed on a back-end (delayed execution)

WARNING: At the moment this implementation only supports binary
classification.

Parameters
----------
quantum : bool (default: True)
- If true will run on local or remote backend
- If true will run on local or remote quantum backend
(depending on q_account_token value).
- If false, will perform classical computing instead
q_account_token : string (default:None)
Expand All @@ -62,6 +64,8 @@ class QuanticClassifierBase(BaseEstimator, ClassifierMixin):
Notes
-----
.. versionadded:: 0.0.1
.. versionchanged:: 0.1.0
Added support for multi-class classification.

Attributes
----------
Expand Down Expand Up @@ -120,24 +124,24 @@ def _log(self, *values):
print("[QClass] ", *values)

def _split_classes(self, X, y):
self._log(
"[Warning] Splitting first class from second class."
"Only binary classification is supported."
)
X_class1 = X[y == self.classes_[1]]
X_class0 = X[y == self.classes_[0]]
return (X_class1, X_class0)
n_classes = len(self.classes_)
X_classes = []
for idx in range(n_classes):
X_classes.append(X[y == self.classes_[idx]])
return X_classes

def _map_classes_to_0_1(self, y):
def _map_classes_to_indices(self, y):
y_copy = y.copy()
y_copy[y == self.classes_[0]] = 0
y_copy[y == self.classes_[1]] = 1
n_classes = len(self.classes_)
for idx in range(n_classes):
y_copy[y == self.classes_[idx]] = idx
return y_copy

def _map_0_1_to_classes(self, y):
def _map_indices_to_classes(self, y):
y_copy = y.copy()
y_copy[y == 0] = self.classes_[0]
y_copy[y == 1] = self.classes_[1]
n_classes = len(self.classes_)
for idx in range(n_classes):
y_copy[y == idx] = self.classes_[idx]
return y_copy

def fit(self, X, y):
Expand All @@ -151,11 +155,6 @@ def fit(self, X, y):
y : ndarray, shape (n_samples,)
Target vector relative to X.

Raises
------
Exception
Raised if the number of classes is different from 2

Returns
-------
self : QuanticClassifierBase instance
Expand All @@ -165,17 +164,13 @@ def fit(self, X, y):

self._log("Fitting: ", X.shape)
self.classes_ = np.unique(y)
if len(self.classes_) != 2:
raise Exception(
"Only binary classification \
is currently supported."
)

class1, class0 = self._split_classes(X, y)
y = self._map_classes_to_0_1(y)
X_classes = self._split_classes(X, y)
y = self._map_classes_to_indices(y)

self._training_input[self.classes_[1]] = class1
self._training_input[self.classes_[0]] = class0
n_classes = len(self.classes_)
for idx in range(n_classes):
self._training_input[self.classes_[idx]] = X_classes[idx]

n_features = get_feature_dimension(self._training_input)
self._log("Feature dimension = ", n_features)
Expand Down Expand Up @@ -226,7 +221,7 @@ def score(self, X, y):
accuracy : double
Accuracy of predictions from X with respect y.
"""
y = self._map_classes_to_0_1(y)
y = self._map_classes_to_indices(y)
self._log("Testing...")
return self._classifier.score(X, y)

Expand Down Expand Up @@ -399,7 +394,7 @@ def predict(self, X):
Class labels for samples in X.
"""
labels = self._predict(X)
return self._map_0_1_to_classes(labels)
return self._map_indices_to_classes(labels)


class QuanticVQC(QuanticClassifierBase):
Expand Down Expand Up @@ -439,6 +434,7 @@ class QuanticVQC(QuanticClassifierBase):
.. versionadded:: 0.0.1
.. versionchanged:: 0.1.0
Fix: copy estimator not keeping base class parameters.
Added support for multi-class classification.

See Also
--------
Expand Down Expand Up @@ -498,20 +494,6 @@ def _init_algo(self, n_features):
)
return vqc

def _map_classes_to_0_1(self, y):
# Label must be one-hot encoded for VQC
y_copy = np.ndarray((y.shape[0], 2))
y_copy[y == self.classes_[0]] = [1, 0]
y_copy[y == self.classes_[1]] = [0, 1]
return y_copy

def _map_0_1_to_classes(self, y):
# Decode one-hot encoded labels
y_copy = np.ndarray((y.shape[0], 1))
y_copy[(y == [1, 0]).all()] = self.classes_[0]
y_copy[(y == [0, 1]).all()] = self.classes_[1]
return y_copy

def predict_proba(self, X):
"""Returns the probabilities associated with predictions.

Expand Down Expand Up @@ -545,7 +527,7 @@ def predict(self, X):
Class labels for samples in X.
"""
labels = self._predict(X)
return self._map_0_1_to_classes(labels)
return self._map_indices_to_classes(labels)


class QuanticMDM(QuanticClassifierBase):
Expand Down Expand Up @@ -670,4 +652,4 @@ def predict(self, X):
Class labels for samples in X.
"""
labels = self._predict(X)
return self._map_0_1_to_classes(labels)
return self._map_indices_to_classes(labels)
15 changes: 13 additions & 2 deletions pyriemann_qiskit/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""
Contains helper methods and classes to manage datasets.
"""
from warnings import warn
import numpy as np

Expand All @@ -10,7 +13,7 @@
from sklearn.datasets import make_classification


def get_mne_sample(n_trials=10):
def get_mne_sample(n_trials=10, include_auditory=False):
"""Return sample data from the mne dataset.

```
Expand All @@ -31,6 +34,9 @@ def get_mne_sample(n_trials=10):
n_trials : int (default:10)
Number of trials to return.
If -1, then all trials are returned.
include_auditory : boolean (default:False)
If True, it returns also the auditory stimulation
in the MNE dataset.

Returns
-------
Expand All @@ -42,6 +48,8 @@ def get_mne_sample(n_trials=10):
Notes
-----
.. versionadded:: 0.0.1
.. versionchanged:: 0.1.0
Possibility to include auditory stimulation.

References
----------
Expand All @@ -55,7 +63,10 @@ def get_mne_sample(n_trials=10):
raw_fname = data_path + "/MEG/sample/sample_audvis_filt-0-40_raw.fif"
event_fname = data_path + "/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif"
tmin, tmax = -0.0, 1
event_id = dict(vis_l=3, vis_r=4) # select only two classes
if include_auditory:
event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4)
else:
event_id = dict(vis_l=3, vis_r=4)

# Setup for reading the raw data
raw = io.Raw(raw_fname, preload=True, verbose=False)
Expand Down
8 changes: 6 additions & 2 deletions pyriemann_qiskit/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,14 @@ class QuantumClassifierWithDefaultRiemannianPipeline(BasePipeline):
feature_reps : int (default: 2)
The number of repeated circuits for the ZZFeatureMap,
greater or equal to 1.
spsa_trials : int (default: 40)
spsa_trials : int (default: None)
Maximum number of iterations to perform using SPSA optimizer.
two_local_reps : int (default: 3)
For VQC, you can use 40 as a default.
VQC is only enabled if spsa_trials and two_local_reps are not None.
two_local_reps : int (default: None)
The number of repetition for the two-local cricuit.
VQC is only enabled if spsa_trials and two_local_reps are not None.
For VQC, you can use 3 as a default.
params: Dict (default: {})
Additional parameters to pass to the nested instance
of the quantum classifier.
Expand Down
Loading