Skip to content

Commit

Permalink
Create light_benchmark.py (#228)
Browse files Browse the repository at this point in the history
* Create light_benchmark.py

* push skeleton of light_benchmark.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* use pip to install dependencies

* use paradigm to get data

* fix bug with loop+dictionnary

* fix label encoding

* run benchmarl on main too

* syntax error

* missing `with`

* little clean up

* fix dockerfile

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
gcattan and pre-commit-ci[bot] committed Dec 16, 2023
1 parent a4f8cc6 commit 3eabe4e
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 7 deletions.
39 changes: 39 additions & 0 deletions .github/workflows/light_benchmark.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: Light Benchmark

on:
push:
paths:
- 'pyriemann_qiskit/**'
- 'examples/**'
- '.github/workflows/light_benchmark.yml'
pull_request:
paths:
- 'pyriemann_qiskit/**'
- 'examples/**'
- '.github/workflows/light_benchmark.yml'

jobs:
light_benchmark:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Cache dependencies
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: light_benchmark.yml
- name: Install dependencies
run: |
pip install .[docs]
- name: Run benchmark script (PR)
run: |
python benchmarks/light_benchmark.py
- uses: actions/checkout@v4
with:
ref: 'main'
- name: Install dependencies
run: |
pip install .[docs]
- name: Run benchmark script (main)
run: |
python benchmarks/light_benchmark.py
14 changes: 7 additions & 7 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ RUN mkdir /home/mne_data

## Workaround for firestore
RUN pip install protobuf==4.25.1
RUN pip install google_cloud_firestore==2.14.0rc1
RUN pip install google_cloud_firestore==2.14.0
### Missing __init__ file in protobuf
RUN touch /usr/local/lib/python3.9/site-packages/protobuf-4.25.1-py3.9-linux-x86_64.egg/google/__init__.py
## google.cloud.location is never used in these files, and is missing in path.
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0rc1-py3.9.egg/google/cloud/firestore_v1/services/firestore/client.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0rc1-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/base.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0rc1-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/grpc.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0rc1-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0rc1-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/rest.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0rc1-py3.9.egg/google/cloud/firestore_v1/services/firestore/async_client.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/client.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/base.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/grpc.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/rest.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/async_client.py'

ENTRYPOINT [ "python", "/examples/ERP/classify_P300_bi.py" ]
89 changes: 89 additions & 0 deletions benchmarks/light_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
====================================================================
Light Benchmark
====================================================================
This benchmark is a non-regression performance test, intended
to run on Ci with each PRs.
"""
# Author: Gregoire Cattan
# Modified from plot_classify_P300_bi.py of pyRiemann
# License: BSD (3-clause)

from pyriemann.estimation import XdawnCovariances
from pyriemann.tangentspace import TangentSpace
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score
from sklearn.preprocessing import LabelEncoder
import warnings
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from moabb import set_log_level
from moabb.datasets import bi2012
from moabb.paradigms import P300
from pyriemann_qiskit.pipelines import (
QuantumClassifierWithDefaultRiemannianPipeline,
)
from sklearn.decomposition import PCA

print(__doc__)

##############################################################################
# getting rid of the warnings about the future
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=RuntimeWarning)

warnings.filterwarnings("ignore")

set_log_level("info")

##############################################################################
# Create Pipelines
# ----------------
#
# Pipelines must be a dict of sklearn pipeline transformer.

##############################################################################

paradigm = P300(resample=128)

dataset = bi2012() # MOABB provides several other P300 datasets

X, y, _ = paradigm.get_data(dataset, subjects=[1])

y = LabelEncoder().fit_transform(y)

X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42
)

pipelines = {}

pipelines["RG+QuantumSVM"] = QuantumClassifierWithDefaultRiemannianPipeline(
shots=512,
nfilter=2,
dim_red=PCA(n_components=5),
)

pipelines["RG+LDA"] = make_pipeline(
XdawnCovariances(
nfilter=2,
estimator="lwf",
xdawn_estimator="scm",
),
TangentSpace(),
PCA(n_components=10),
LDA(solver="lsqr", shrinkage="auto"),
)

scores = {}

for key, pipeline in pipelines.items():
pipeline.fit(X_train, y_train)
y_pred = pipeline.predict(X_test)
score = balanced_accuracy_score(y_test, y_pred)
scores[key] = score


print("Scores: ", scores)

0 comments on commit 3eabe4e

Please sign in to comment.