-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Create light_benchmark.py * push skeleton of light_benchmark.py * [pre-commit.ci] auto fixes from pre-commit.com hooks * use pip to install dependencies * use paradigm to get data * fix bug with loop+dictionnary * fix label encoding * run benchmarl on main too * syntax error * missing `with` * little clean up * fix dockerfile --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
a4f8cc6
commit 3eabe4e
Showing
3 changed files
with
135 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
name: Light Benchmark | ||
|
||
on: | ||
push: | ||
paths: | ||
- 'pyriemann_qiskit/**' | ||
- 'examples/**' | ||
- '.github/workflows/light_benchmark.yml' | ||
pull_request: | ||
paths: | ||
- 'pyriemann_qiskit/**' | ||
- 'examples/**' | ||
- '.github/workflows/light_benchmark.yml' | ||
|
||
jobs: | ||
light_benchmark: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- name: Cache dependencies | ||
uses: actions/cache@v3 | ||
with: | ||
path: ~/.cache/pip | ||
key: light_benchmark.yml | ||
- name: Install dependencies | ||
run: | | ||
pip install .[docs] | ||
- name: Run benchmark script (PR) | ||
run: | | ||
python benchmarks/light_benchmark.py | ||
- uses: actions/checkout@v4 | ||
with: | ||
ref: 'main' | ||
- name: Install dependencies | ||
run: | | ||
pip install .[docs] | ||
- name: Run benchmark script (main) | ||
run: | | ||
python benchmarks/light_benchmark.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
""" | ||
==================================================================== | ||
Light Benchmark | ||
==================================================================== | ||
This benchmark is a non-regression performance test, intended | ||
to run on Ci with each PRs. | ||
""" | ||
# Author: Gregoire Cattan | ||
# Modified from plot_classify_P300_bi.py of pyRiemann | ||
# License: BSD (3-clause) | ||
|
||
from pyriemann.estimation import XdawnCovariances | ||
from pyriemann.tangentspace import TangentSpace | ||
from sklearn.pipeline import make_pipeline | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.metrics import balanced_accuracy_score | ||
from sklearn.preprocessing import LabelEncoder | ||
import warnings | ||
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA | ||
from moabb import set_log_level | ||
from moabb.datasets import bi2012 | ||
from moabb.paradigms import P300 | ||
from pyriemann_qiskit.pipelines import ( | ||
QuantumClassifierWithDefaultRiemannianPipeline, | ||
) | ||
from sklearn.decomposition import PCA | ||
|
||
print(__doc__) | ||
|
||
############################################################################## | ||
# getting rid of the warnings about the future | ||
warnings.simplefilter(action="ignore", category=FutureWarning) | ||
warnings.simplefilter(action="ignore", category=RuntimeWarning) | ||
|
||
warnings.filterwarnings("ignore") | ||
|
||
set_log_level("info") | ||
|
||
############################################################################## | ||
# Create Pipelines | ||
# ---------------- | ||
# | ||
# Pipelines must be a dict of sklearn pipeline transformer. | ||
|
||
############################################################################## | ||
|
||
paradigm = P300(resample=128) | ||
|
||
dataset = bi2012() # MOABB provides several other P300 datasets | ||
|
||
X, y, _ = paradigm.get_data(dataset, subjects=[1]) | ||
|
||
y = LabelEncoder().fit_transform(y) | ||
|
||
X_train, X_test, y_train, y_test = train_test_split( | ||
X, y, test_size=0.33, random_state=42 | ||
) | ||
|
||
pipelines = {} | ||
|
||
pipelines["RG+QuantumSVM"] = QuantumClassifierWithDefaultRiemannianPipeline( | ||
shots=512, | ||
nfilter=2, | ||
dim_red=PCA(n_components=5), | ||
) | ||
|
||
pipelines["RG+LDA"] = make_pipeline( | ||
XdawnCovariances( | ||
nfilter=2, | ||
estimator="lwf", | ||
xdawn_estimator="scm", | ||
), | ||
TangentSpace(), | ||
PCA(n_components=10), | ||
LDA(solver="lsqr", shrinkage="auto"), | ||
) | ||
|
||
scores = {} | ||
|
||
for key, pipeline in pipelines.items(): | ||
pipeline.fit(X_train, y_train) | ||
y_pred = pipeline.predict(X_test) | ||
score = balanced_accuracy_score(y_test, y_pred) | ||
scores[key] = score | ||
|
||
|
||
print("Scores: ", scores) |