Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create light_benchmark.py #228

Merged
merged 12 commits into from
Dec 16, 2023
Merged
39 changes: 39 additions & 0 deletions .github/workflows/light_benchmark.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: Light Benchmark

on:
push:
paths:
- 'pyriemann_qiskit/**'
- 'examples/**'
- '.github/workflows/light_benchmark.yml'
pull_request:
paths:
- 'pyriemann_qiskit/**'
- 'examples/**'
- '.github/workflows/light_benchmark.yml'

jobs:
light_benchmark:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Cache dependencies
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: light_benchmark.yml
- name: Install dependencies
run: |
pip install .[docs]
- name: Run benchmark script (PR)
run: |
python benchmarks/light_benchmark.py
- uses: actions/checkout@v4
with:
ref: 'main'
- name: Install dependencies
run: |
pip install .[docs]
- name: Run benchmark script (main)
run: |
python benchmarks/light_benchmark.py
14 changes: 7 additions & 7 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ RUN mkdir /home/mne_data

## Workaround for firestore
RUN pip install protobuf==4.25.1
RUN pip install google_cloud_firestore==2.14.0rc1
RUN pip install google_cloud_firestore==2.14.0
### Missing __init__ file in protobuf
RUN touch /usr/local/lib/python3.9/site-packages/protobuf-4.25.1-py3.9-linux-x86_64.egg/google/__init__.py
## google.cloud.location is never used in these files, and is missing in path.
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0rc1-py3.9.egg/google/cloud/firestore_v1/services/firestore/client.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0rc1-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/base.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0rc1-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/grpc.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0rc1-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0rc1-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/rest.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0rc1-py3.9.egg/google/cloud/firestore_v1/services/firestore/async_client.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/client.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/base.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/grpc.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/rest.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/async_client.py'

ENTRYPOINT [ "python", "/examples/ERP/classify_P300_bi.py" ]
89 changes: 89 additions & 0 deletions benchmarks/light_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
====================================================================
Light Benchmark
====================================================================

This benchmark is a non-regression performance test, intended
to run on Ci with each PRs.

"""
# Author: Gregoire Cattan
# Modified from plot_classify_P300_bi.py of pyRiemann
# License: BSD (3-clause)

from pyriemann.estimation import XdawnCovariances
from pyriemann.tangentspace import TangentSpace
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score
from sklearn.preprocessing import LabelEncoder
import warnings
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from moabb import set_log_level
from moabb.datasets import bi2012
from moabb.paradigms import P300
from pyriemann_qiskit.pipelines import (
QuantumClassifierWithDefaultRiemannianPipeline,
)
from sklearn.decomposition import PCA

print(__doc__)

##############################################################################
# getting rid of the warnings about the future
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=RuntimeWarning)

warnings.filterwarnings("ignore")

set_log_level("info")

##############################################################################
# Create Pipelines
# ----------------
#
# Pipelines must be a dict of sklearn pipeline transformer.

##############################################################################

paradigm = P300(resample=128)

dataset = bi2012() # MOABB provides several other P300 datasets

X, y, _ = paradigm.get_data(dataset, subjects=[1])

y = LabelEncoder().fit_transform(y)

X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42
)

pipelines = {}

pipelines["RG+QuantumSVM"] = QuantumClassifierWithDefaultRiemannianPipeline(
shots=512,
nfilter=2,
dim_red=PCA(n_components=5),
)

pipelines["RG+LDA"] = make_pipeline(
XdawnCovariances(
nfilter=2,
estimator="lwf",
xdawn_estimator="scm",
),
TangentSpace(),
PCA(n_components=10),
LDA(solver="lsqr", shrinkage="auto"),
)

scores = {}

for key, pipeline in pipelines.items():
pipeline.fit(X_train, y_train)
y_pred = pipeline.predict(X_test)
score = balanced_accuracy_score(y_test, y_pred)
scores[key] = score


print("Scores: ", scores)
Loading