diff --git a/.github/workflows/light_benchmark.yml b/.github/workflows/light_benchmark.yml new file mode 100644 index 00000000..3a5fbc06 --- /dev/null +++ b/.github/workflows/light_benchmark.yml @@ -0,0 +1,39 @@ +name: Light Benchmark + +on: + push: + paths: + - 'pyriemann_qiskit/**' + - 'examples/**' + - '.github/workflows/light_benchmark.yml' + pull_request: + paths: + - 'pyriemann_qiskit/**' + - 'examples/**' + - '.github/workflows/light_benchmark.yml' + +jobs: + light_benchmark: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Cache dependencies + uses: actions/cache@v3 + with: + path: ~/.cache/pip + key: light_benchmark.yml + - name: Install dependencies + run: | + pip install .[docs] + - name: Run benchmark script (PR) + run: | + python benchmarks/light_benchmark.py + - uses: actions/checkout@v4 + with: + ref: 'main' + - name: Install dependencies + run: | + pip install .[docs] + - name: Run benchmark script (main) + run: | + python benchmarks/light_benchmark.py diff --git a/Dockerfile b/Dockerfile index c34983f7..9d1466c7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -25,15 +25,15 @@ RUN mkdir /home/mne_data ## Workaround for firestore RUN pip install protobuf==4.25.1 -RUN pip install google_cloud_firestore==2.14.0rc1 +RUN pip install google_cloud_firestore==2.14.0 ### Missing __init__ file in protobuf RUN touch /usr/local/lib/python3.9/site-packages/protobuf-4.25.1-py3.9-linux-x86_64.egg/google/__init__.py ## google.cloud.location is never used in these files, and is missing in path. -RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0rc1-py3.9.egg/google/cloud/firestore_v1/services/firestore/client.py' -RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0rc1-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/base.py' -RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0rc1-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/grpc.py' -RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0rc1-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py' -RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0rc1-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/rest.py' -RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0rc1-py3.9.egg/google/cloud/firestore_v1/services/firestore/async_client.py' +RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/client.py' +RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/base.py' +RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/grpc.py' +RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py' +RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/rest.py' +RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.14.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/async_client.py' ENTRYPOINT [ "python", "/examples/ERP/classify_P300_bi.py" ] diff --git a/benchmarks/light_benchmark.py b/benchmarks/light_benchmark.py new file mode 100644 index 00000000..4addb7c2 --- /dev/null +++ b/benchmarks/light_benchmark.py @@ -0,0 +1,89 @@ +""" +==================================================================== +Light Benchmark +==================================================================== + +This benchmark is a non-regression performance test, intended +to run on Ci with each PRs. + +""" +# Author: Gregoire Cattan +# Modified from plot_classify_P300_bi.py of pyRiemann +# License: BSD (3-clause) + +from pyriemann.estimation import XdawnCovariances +from pyriemann.tangentspace import TangentSpace +from sklearn.pipeline import make_pipeline +from sklearn.model_selection import train_test_split +from sklearn.metrics import balanced_accuracy_score +from sklearn.preprocessing import LabelEncoder +import warnings +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA +from moabb import set_log_level +from moabb.datasets import bi2012 +from moabb.paradigms import P300 +from pyriemann_qiskit.pipelines import ( + QuantumClassifierWithDefaultRiemannianPipeline, +) +from sklearn.decomposition import PCA + +print(__doc__) + +############################################################################## +# getting rid of the warnings about the future +warnings.simplefilter(action="ignore", category=FutureWarning) +warnings.simplefilter(action="ignore", category=RuntimeWarning) + +warnings.filterwarnings("ignore") + +set_log_level("info") + +############################################################################## +# Create Pipelines +# ---------------- +# +# Pipelines must be a dict of sklearn pipeline transformer. + +############################################################################## + +paradigm = P300(resample=128) + +dataset = bi2012() # MOABB provides several other P300 datasets + +X, y, _ = paradigm.get_data(dataset, subjects=[1]) + +y = LabelEncoder().fit_transform(y) + +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.33, random_state=42 +) + +pipelines = {} + +pipelines["RG+QuantumSVM"] = QuantumClassifierWithDefaultRiemannianPipeline( + shots=512, + nfilter=2, + dim_red=PCA(n_components=5), +) + +pipelines["RG+LDA"] = make_pipeline( + XdawnCovariances( + nfilter=2, + estimator="lwf", + xdawn_estimator="scm", + ), + TangentSpace(), + PCA(n_components=10), + LDA(solver="lsqr", shrinkage="auto"), +) + +scores = {} + +for key, pipeline in pipelines.items(): + pipeline.fit(X_train, y_train) + y_pred = pipeline.predict(X_test) + score = balanced_accuracy_score(y_test, y_pred) + scores[key] = score + + +print("Scores: ", scores)