Skip to content

Commit

Permalink
Chores/plot distance visualization (#284)
Browse files Browse the repository at this point in the history
* plot distances_visualization on Ci

* fix hardcoded parameter in plot_financial_data

* fix label in plot_classifier_comparison

* apply noplot convention

* remove tight_layout (not necessary)

* add alias module

* add alias file

* missing line at end of file

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* fix flake8

* fix warning

* remove alias module

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* missing in previous commit

---------

Co-authored-by: Gregoire Cattan <gregoire.cattan@ibm.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 18, 2024
1 parent 0d8d189 commit 28686ff
Show file tree
Hide file tree
Showing 13 changed files with 33 additions and 19 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
5 changes: 5 additions & 0 deletions examples/MI/helpers/alias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from pyriemann.estimation import ERPCovariances
from pyriemann.classification import MDM
from sklearn.pipeline import make_pipeline

ERPCov_MDM = make_pipeline(ERPCovariances(estimator="lwf"), MDM())
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,12 @@
from moabb.paradigms import MotorImagery

# inject cpm distance and mean to pyriemann (if not done already)
from helpers.alias import ERPCov_MDM
from pyriemann_qiskit.utils import distance, mean # noqa
from pyriemann_qiskit.pipelines import (
QuantumMDMWithRiemannianPipeline,
)

from sklearn.pipeline import make_pipeline
from pyriemann.estimation import ERPCovariances
from pyriemann.classification import MDM

print(__doc__)

##############################################################################
Expand Down Expand Up @@ -73,7 +70,7 @@
)

# Classical baseline for evaluation
pipelines["R-MDM"] = make_pipeline(ERPCovariances(estimator="lwf"), MDM())
pipelines["R-MDM"] = ERPCov_MDM

##############################################################################
# Run evaluation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
# Modified from plot_classify_EEG_quantum_svm
# License: BSD (3-clause)

from helpers.alias import ERPCov_MDM
from pyriemann_qiskit.datasets import get_mne_sample
from pyriemann_qiskit.pipelines import (
QuantumClassifierWithDefaultRiemannianPipeline,
QuantumMDMWithRiemannianPipeline,
)
from pyriemann.estimation import ERPCovariances
from pyriemann.classification import MDM
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
confusion_matrix,
Expand Down Expand Up @@ -73,7 +71,7 @@
quantum_mdm = QuantumMDMWithRiemannianPipeline()

# Pipeline 5
mdm = make_pipeline(ERPCovariances(estimator="lwf"), MDM())
mdm = ERPCov_MDM

classifiers = [vqc, quantum_svm, classical_svm, quantum_mdm, mdm]

Expand Down
2 changes: 1 addition & 1 deletion examples/other_datasets/plot_financial_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def fit(self, X, y):
def transform(self, X):
if not self.process:
return X
return Whitening(dim_red={"n_components": 4}).fit_transform(X)
return Whitening(dim_red={"n_components": self.n_components}).fit_transform(X)


# Create a RandomForest for baseline comparison of direct classification:
Expand Down
16 changes: 13 additions & 3 deletions examples/toys_dataset/plot_classifier_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
generate_linearly_separable_dataset,
generate_qiskit_dataset,
)
from pyriemann_qiskit.classification import QuanticSVM
from pyriemann_qiskit.classification import (
QuanticSVM,
# uncomment to run comparison with QuanticVQC (disabled for CI/CD)
# QuanticVQC
)

# uncomment to run comparison with QuanticVQC (disabled for CI/CD)
# from pyriemann_qiskit.classification import QuanticVQC
Expand All @@ -36,14 +40,20 @@

h = 0.02 # step size in the mesh
labels = (0, 1)
names = ["Linear SVM", "RBF SVM", "VQC", "QSVM"]
names = [
"Linear SVM",
"RBF SVM",
# uncomment to run comparison with QuanticVQC (disabled for CI/CD)
# "VQC",
"QSVM",
]

classifiers = [
SVC(kernel="linear", C=0.025),
SVC(gamma="auto", C=0.001),
# uncomment to run comparison with QuanticVQC (disabled for CI/CD)
# QuanticVQC(),
QuanticSVM(quantum=False),
QuanticSVM(quantum=False), # quantum=False for CI
]

# Warning: There is a known convergence issue with QSVM
Expand Down
1 change: 0 additions & 1 deletion examples/toys_dataset/plot_learning_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,4 @@
axe.set_ylabel("Evaluated values (MDM)")
axe.set_xlabel("Evaluations")

plt.tight_layout()
plt.show()
8 changes: 7 additions & 1 deletion pyriemann_qiskit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from ._version import __version__
from . import classification, pipelines, ensemble, autoencoders

__all__ = ["__version__", "classification", "pipelines", "ensemble", "autoencoders"]
__all__ = [
"__version__",
"classification",
"pipelines",
"ensemble",
"autoencoders",
]
7 changes: 3 additions & 4 deletions pyriemann_qiskit/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def _create_pipe(self):
gen_feature_map=feature_map,
shots=self.shots,
quantum=is_quantum,
**self.params
**self.params,
)
else:
self._log("QuanticSVM chosen.")
Expand All @@ -296,7 +296,7 @@ def _create_pipe(self):
max_iter=self.max_iter,
gen_feature_map=feature_map,
shots=self.shots,
**self.params
**self.params,
)

return make_pipeline(
Expand Down Expand Up @@ -402,8 +402,7 @@ def __init__(
BasePipeline.__init__(self, "QuantumMDMWithRiemannianPipeline")

def _create_pipe(self):
print(self.metric)
print(self.metric["mean"])
self._log(f"Running QMDM with metric {self.metric}")
if is_qfunction(self.metric["mean"]):
if self.quantum:
covariances = XdawnCovariances(
Expand Down

0 comments on commit 28686ff

Please sign in to comment.