Skip to content

Commit

Permalink
Multiclass classification (#178)
Browse files Browse the repository at this point in the history
* Update conftest.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Update test_classification.py

add multilabel test for quantum vqc

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Update conftest.py

change get_binary_feat to get_separable_feat

* Update test_classification.py

* Update test_classification.py

* Update test_classification.py

missing return

* Update conftest.py

replace base by super

* Update conftest.py

typo

* Update classification.py

renamed:
- _map_0_1_to_classes
- _map_1_0_to_classes

remove exception for binary classification in :
- _split_classes and fit

* Update test_classification.py

* Update Dockerfile

* Update Dockerfile

* Update classification.py

change class to idx and split methods to support multi labels classification

* Update test_classification.py

add test for split classes with n_classes = 3

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Update test_classification.py

fix test on split_classes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Update test_classification.py

typo

* Update classification.py

fix bug with split_classes (index error)

* Update test_classification.py

typo

* Update test_classification.py

typo

* Update conftest.py

remove hardcoded n_classes=2

* Update test_classification.py

fix the number of sample

* Update conftest.py

fix class_len not updated after updating number of classes from 2 to 3

* Update classification.py

general one-hot encoded for multi-class

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Update classification.py

typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Update test_classification.py

flake8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Update utils.py

changed get mne dataset to return auditory stimulation if required

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Create plot_multiclass_classification.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Update plot_multiclass_classification.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Update plot_multiclass_classification.py

* Update plot_multiclass_classification.py

fix imports and display

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* - remove custom implementation of one-hot encoding (already done within Qiskit now)
- remove multilabels->multiclasses
- complete example

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* rename Multiclasses -> Multiclass

* inverted y_pred and y_true in balanced accuracy

* Testing the code. Comments have been updated.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Updated comments.

* Update conftest.py

* Update plot_classify_EEG_quantum_svm.py

* Update utils.py

* Update comments in classification.py

* Update multiclass_classification.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Gregoire Cattan <gregoire.cattan@ibm.com>
Co-authored-by: toncho11 <toncho11@users.noreply.github.com>
  • Loading branch information
4 people authored Sep 5, 2023
1 parent 35187e1 commit 0d98b92
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 78 deletions.
124 changes: 124 additions & 0 deletions examples/ERP/multiclass_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""
====================================================================
Multiclass EEG classification with Quantum Pipeline
====================================================================
This example demonstrates multiclass EEG classification with a quantum
classifier.
We will be comparing the performance of VQC vs Quantum SVM vs
Classical SVM vs Quantum MDM vs MDM.
Execution takes approximately 1h.
"""
# Author: Gregoire Cattan
# Modified from plot_classify_EEG_quantum_svm
# License: BSD (3-clause)

from pyriemann_qiskit.datasets import get_mne_sample
from pyriemann_qiskit.pipelines import (
QuantumClassifierWithDefaultRiemannianPipeline,
QuantumMDMWithRiemannianPipeline,
)
from pyriemann.estimation import ERPCovariances
from pyriemann.classification import MDM
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
confusion_matrix,
ConfusionMatrixDisplay,
balanced_accuracy_score,
)
from sklearn.decomposition import PCA
from matplotlib import pyplot as plt


print(__doc__)

###############################################################################
# Get the data

# Use MNE sample. The include_auditory parameter select 4 classes.
X, y = get_mne_sample(n_trials=-1, include_auditory=True)

# evaluation without k-fold cross-validation
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)

###############################################################################
# Decoding in tangent space with a quantum classifier

# Our helper class QuantumClassifierWithDefaultRiemannianPipeline allows to
# auto-configure the parameters of the pipelines.
# Warning: these are not optimal parameters

# Piepeline 1
quantum_svm = QuantumClassifierWithDefaultRiemannianPipeline(
dim_red=PCA(n_components=5),
)

# Piepeline 2
classical_svm = QuantumClassifierWithDefaultRiemannianPipeline(
shots=None, # 'None' forces classic SVM
dim_red=PCA(n_components=5),
)

# Piepeline 3
vqc = QuantumClassifierWithDefaultRiemannianPipeline(
dim_red=PCA(n_components=5),
# These parameters are specific to VQC.
# The pipeline will detect this and instantiate a VQC under the hood
spsa_trials=40,
two_local_reps=3,
)

# Piepeline 4
quantum_mdm = QuantumMDMWithRiemannianPipeline()

# Piepeline 5
mdm = make_pipeline(ERPCovariances(estimator="lwf"), MDM())

classifiers = [vqc, quantum_svm, classical_svm, quantum_mdm, mdm]

n_classifiers = len(classifiers)

# https://stackoverflow.com/questions/61825227/plotting-multiple-confusion-matrix-side-by-side
f, axes = plt.subplots(1, n_classifiers, sharey="row")

disp = None

# Compute results
for idx in range(n_classifiers):
# Training and classification
clf = classifiers[idx]
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)

# Printing the results
acc = balanced_accuracy_score(y_test, y_pred)
acc_str = "%0.2f" % acc

# Results visualization
# A confusion matrix is reported for each classifier. A perfectly performing
# classifier will have only its diagonal filled and the rest will be zeros.
names = ["aud left", "aud right", "vis left", "vis right"]
title = (
("VQC (" if idx == 0 else "Quantum SVM (" if idx == 1 else "Classical SVM (")
if idx == 2
else "Quantum MDM ("
if idx == 3
else "R-MDM (" + acc_str + ")"
)
axe = axes[idx]
cm = confusion_matrix(y_pred, y_test)
disp = ConfusionMatrixDisplay(cm, display_labels=names)
disp.plot(ax=axe, xticks_rotation=45)
disp.ax_.set_title(title)
disp.im_.colorbar.remove()
disp.ax_.set_xlabel("")
if idx > 0:
disp.ax_.set_ylabel("")

# Display all the confusion matrices
if disp:
f.text(0.4, 0.1, "Predicted label", ha="left")
plt.subplots_adjust(wspace=0.40, hspace=0.1)
f.colorbar(disp.im_, ax=axes)
plt.show()
4 changes: 2 additions & 2 deletions examples/ERP/plot_classify_EEG_quantum_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Decoding applied to EEG data in sensor space using RG.
Xdawn spatial filtering is applied on covariances matrices, which are
then projected in the tangent space and classified with a quantum SVM
classifier. It is compared to the classical SVM on binary classification.
classifier. It is compared to the classical SVM on binary classification.
"""
# Author: Gregoire Cattan
Expand Down Expand Up @@ -67,7 +67,7 @@
y_pred = clf.predict(X_test)

# Printing the results
acc = balanced_accuracy_score(y_pred, y_test)
acc = balanced_accuracy_score(y_test, y_pred)
acc_str = "%0.2f" % acc

names = ["vis left", "vis right"]
Expand Down
82 changes: 32 additions & 50 deletions pyriemann_qiskit/classification.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
"""Module for classification function."""
"""
Contains the base class for all quantum classifiers
as well as several quantum classifiers than can run
in several modes quantum/classical and simulated/real
quantum computer.
"""
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.svm import SVC
Expand Down Expand Up @@ -37,13 +42,10 @@ class QuanticClassifierBase(BaseEstimator, ClassifierMixin):
* tasks on a real quantum computer are assigned to a queue
before being executed on a back-end (delayed execution)
WARNING: At the moment this implementation only supports binary
classification.
Parameters
----------
quantum : bool (default: True)
- If true will run on local or remote backend
- If true will run on local or remote quantum backend
(depending on q_account_token value).
- If false, will perform classical computing instead
q_account_token : string (default:None)
Expand All @@ -62,6 +64,8 @@ class QuanticClassifierBase(BaseEstimator, ClassifierMixin):
Notes
-----
.. versionadded:: 0.0.1
.. versionchanged:: 0.1.0
Added support for multi-class classification.
Attributes
----------
Expand Down Expand Up @@ -120,24 +124,24 @@ def _log(self, *values):
print("[QClass] ", *values)

def _split_classes(self, X, y):
self._log(
"[Warning] Splitting first class from second class."
"Only binary classification is supported."
)
X_class1 = X[y == self.classes_[1]]
X_class0 = X[y == self.classes_[0]]
return (X_class1, X_class0)
n_classes = len(self.classes_)
X_classes = []
for idx in range(n_classes):
X_classes.append(X[y == self.classes_[idx]])
return X_classes

def _map_classes_to_0_1(self, y):
def _map_classes_to_indices(self, y):
y_copy = y.copy()
y_copy[y == self.classes_[0]] = 0
y_copy[y == self.classes_[1]] = 1
n_classes = len(self.classes_)
for idx in range(n_classes):
y_copy[y == self.classes_[idx]] = idx
return y_copy

def _map_0_1_to_classes(self, y):
def _map_indices_to_classes(self, y):
y_copy = y.copy()
y_copy[y == 0] = self.classes_[0]
y_copy[y == 1] = self.classes_[1]
n_classes = len(self.classes_)
for idx in range(n_classes):
y_copy[y == idx] = self.classes_[idx]
return y_copy

def fit(self, X, y):
Expand All @@ -151,11 +155,6 @@ def fit(self, X, y):
y : ndarray, shape (n_samples,)
Target vector relative to X.
Raises
------
Exception
Raised if the number of classes is different from 2
Returns
-------
self : QuanticClassifierBase instance
Expand All @@ -165,17 +164,13 @@ def fit(self, X, y):

self._log("Fitting: ", X.shape)
self.classes_ = np.unique(y)
if len(self.classes_) != 2:
raise Exception(
"Only binary classification \
is currently supported."
)

class1, class0 = self._split_classes(X, y)
y = self._map_classes_to_0_1(y)
X_classes = self._split_classes(X, y)
y = self._map_classes_to_indices(y)

self._training_input[self.classes_[1]] = class1
self._training_input[self.classes_[0]] = class0
n_classes = len(self.classes_)
for idx in range(n_classes):
self._training_input[self.classes_[idx]] = X_classes[idx]

n_features = get_feature_dimension(self._training_input)
self._log("Feature dimension = ", n_features)
Expand Down Expand Up @@ -226,7 +221,7 @@ def score(self, X, y):
accuracy : double
Accuracy of predictions from X with respect y.
"""
y = self._map_classes_to_0_1(y)
y = self._map_classes_to_indices(y)
self._log("Testing...")
return self._classifier.score(X, y)

Expand Down Expand Up @@ -399,7 +394,7 @@ def predict(self, X):
Class labels for samples in X.
"""
labels = self._predict(X)
return self._map_0_1_to_classes(labels)
return self._map_indices_to_classes(labels)


class QuanticVQC(QuanticClassifierBase):
Expand Down Expand Up @@ -439,6 +434,7 @@ class QuanticVQC(QuanticClassifierBase):
.. versionadded:: 0.0.1
.. versionchanged:: 0.1.0
Fix: copy estimator not keeping base class parameters.
Added support for multi-class classification.
See Also
--------
Expand Down Expand Up @@ -498,20 +494,6 @@ def _init_algo(self, n_features):
)
return vqc

def _map_classes_to_0_1(self, y):
# Label must be one-hot encoded for VQC
y_copy = np.ndarray((y.shape[0], 2))
y_copy[y == self.classes_[0]] = [1, 0]
y_copy[y == self.classes_[1]] = [0, 1]
return y_copy

def _map_0_1_to_classes(self, y):
# Decode one-hot encoded labels
y_copy = np.ndarray((y.shape[0], 1))
y_copy[(y == [1, 0]).all()] = self.classes_[0]
y_copy[(y == [0, 1]).all()] = self.classes_[1]
return y_copy

def predict_proba(self, X):
"""Returns the probabilities associated with predictions.
Expand Down Expand Up @@ -545,7 +527,7 @@ def predict(self, X):
Class labels for samples in X.
"""
labels = self._predict(X)
return self._map_0_1_to_classes(labels)
return self._map_indices_to_classes(labels)


class QuanticMDM(QuanticClassifierBase):
Expand Down Expand Up @@ -670,4 +652,4 @@ def predict(self, X):
Class labels for samples in X.
"""
labels = self._predict(X)
return self._map_0_1_to_classes(labels)
return self._map_indices_to_classes(labels)
15 changes: 13 additions & 2 deletions pyriemann_qiskit/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""
Contains helper methods and classes to manage datasets.
"""
from warnings import warn
import numpy as np

Expand All @@ -10,7 +13,7 @@
from sklearn.datasets import make_classification


def get_mne_sample(n_trials=10):
def get_mne_sample(n_trials=10, include_auditory=False):
"""Return sample data from the mne dataset.
```
Expand All @@ -31,6 +34,9 @@ def get_mne_sample(n_trials=10):
n_trials : int (default:10)
Number of trials to return.
If -1, then all trials are returned.
include_auditory : boolean (default:False)
If True, it returns also the auditory stimulation
in the MNE dataset.
Returns
-------
Expand All @@ -42,6 +48,8 @@ def get_mne_sample(n_trials=10):
Notes
-----
.. versionadded:: 0.0.1
.. versionchanged:: 0.1.0
Possibility to include auditory stimulation.
References
----------
Expand All @@ -55,7 +63,10 @@ def get_mne_sample(n_trials=10):
raw_fname = data_path + "/MEG/sample/sample_audvis_filt-0-40_raw.fif"
event_fname = data_path + "/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif"
tmin, tmax = -0.0, 1
event_id = dict(vis_l=3, vis_r=4) # select only two classes
if include_auditory:
event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4)
else:
event_id = dict(vis_l=3, vis_r=4)

# Setup for reading the raw data
raw = io.Raw(raw_fname, preload=True, verbose=False)
Expand Down
8 changes: 6 additions & 2 deletions pyriemann_qiskit/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,14 @@ class QuantumClassifierWithDefaultRiemannianPipeline(BasePipeline):
feature_reps : int (default: 2)
The number of repeated circuits for the ZZFeatureMap,
greater or equal to 1.
spsa_trials : int (default: 40)
spsa_trials : int (default: None)
Maximum number of iterations to perform using SPSA optimizer.
two_local_reps : int (default: 3)
For VQC, you can use 40 as a default.
VQC is only enabled if spsa_trials and two_local_reps are not None.
two_local_reps : int (default: None)
The number of repetition for the two-local cricuit.
VQC is only enabled if spsa_trials and two_local_reps are not None.
For VQC, you can use 3 as a default.
params: Dict (default: {})
Additional parameters to pass to the nested instance
of the quantum classifier.
Expand Down
Loading

0 comments on commit 0d98b92

Please sign in to comment.