Skip to content

Commit

Permalink
feat/Quantum art visualization (#183)
Browse files Browse the repository at this point in the history
* add n-trials to get-qiskit-dataset

* add example

* spiral implementation

* add visualization package

* modify get_linearly_separable_dataset

* improve example

* display the number of parameters inside the circuit

* comment code

* just adapts parameters

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* fix dockerfile

* update version of google_cloud_firestore

* fix dockerfile

---------

Co-authored-by: Gregoire Cattan <gregoire.cattan@ibm.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 29, 2023
1 parent 0d98b92 commit 823f494
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 13 deletions.
18 changes: 9 additions & 9 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ RUN mkdir /root/mne_data
RUN mkdir /home/mne_data

## Workaround for firestore
RUN pip install protobuf==4.24.2
RUN pip install google_cloud_firestore==2.11.1
RUN pip install protobuf==4.24.3
RUN pip install google_cloud_firestore==2.12.0
### Missing __init__ file in protobuf
RUN touch /usr/local/lib/python3.8/site-packages/protobuf-4.24.2-py3.8-linux-x86_64.egg/google/__init__.py
RUN touch /usr/local/lib/python3.8/site-packages/protobuf-4.24.3-py3.8.egg/google/__init__.py
## google.cloud.location is never used in these files, and is missing in path.
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.11.1-py3.8.egg/google/cloud/firestore_v1/services/firestore/client.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.11.1-py3.8.egg/google/cloud/firestore_v1/services/firestore/transports/base.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.11.1-py3.8.egg/google/cloud/firestore_v1/services/firestore/transports/grpc.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.11.1-py3.8.egg/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.11.1-py3.8.egg/google/cloud/firestore_v1/services/firestore/transports/rest.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.11.1-py3.8.egg/google/cloud/firestore_v1/services/firestore/async_client.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.12.0-py3.8.egg/google/cloud/firestore_v1/services/firestore/client.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.12.0-py3.8.egg/google/cloud/firestore_v1/services/firestore/transports/base.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.12.0-py3.8.egg/google/cloud/firestore_v1/services/firestore/transports/grpc.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.12.0-py3.8.egg/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.12.0-py3.8.egg/google/cloud/firestore_v1/services/firestore/transports/rest.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.12.0-py3.8.egg/google/cloud/firestore_v1/services/firestore/async_client.py'

ENTRYPOINT [ "python", "/examples/ERP/classify_P300_bi.py" ]
65 changes: 65 additions & 0 deletions examples/toys_dataset/plot_quantum_art_vqc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
====================================================================
Art visualization of Variational Quantum Classifier.
====================================================================
Display the variability of the weights inside the variational quantum
classifier.
"""

# Author: Gregoire Cattan
# License: BSD (3-clause)

from pyriemann_qiskit.utils.hyper_params_factory import gen_two_local
import matplotlib.pyplot as plt
from pyriemann_qiskit.datasets import get_linearly_separable_dataset
from pyriemann_qiskit.classification import QuanticVQC
from pyriemann_qiskit.visualization import weights_spiral


print(__doc__)

###############################################################################
# In this example we will display weights variability of the parameter inside
# the variational quantum circuit which is used by VQC.
#
# The idea is simple :
# - We initialize a VQC with different number of parameters and number of samples
# - We train the VQC a couple of time and we store the fitted weights.
# - We compute variability of the weight and display it in a fashion way.

# Let's start by defining some plot area.
fig, axes = plt.subplots(2, 2)
fig.suptitle("VQC weights variability")

# We will compute weight variability for different number of samples
for i, n_samples in enumerate([2, 20]):
# ... and for differente number of parameters.
# (n_reps controls the number of parameters inside the circuit)
for j, n_reps in enumerate([1, 3]):
# instanciate VQC.
vqc = QuanticVQC(gen_var_form=gen_two_local(reps=n_reps))

# Get data. We will use a toy dataset here.
X, y = get_linearly_separable_dataset(n_samples=n_samples)

# Compute and display weight variability after training
axe = axes[i, j]
# ... This is all done in this method
# It displays a spiral. Each "branch of the spiral" is a parameter inside VQC.
# The larger is the branch, the higher is the parameter variability.
weights_spiral(axe, vqc, X, y, n_trainings=5)
n_params = vqc.parameter_count

# Just improve display of the graphics.
if j == 0:
axe.set_ylabel(f"n_samples: {n_samples}")
if i == 0:
axe.set_xlabel(f"n_params: {n_params}")
axe.xaxis.set_label_position("top")
axe.set_xticks(())
axe.set_yticks(())

plt.tight_layout()
plt.show()
18 changes: 18 additions & 0 deletions pyriemann_qiskit/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,24 @@ def predict(self, X):
labels = self._predict(X)
return self._map_indices_to_classes(labels)

@property
def parameter_count(self):
"""Returns the number of parameters inside the variational circuit.
This is determined by the `gen_var_form` attribute of this instance.
Returns
-------
n_params : int
The number of parameters in the variational circuit.
Returns 0 if the instance is not fit yet.
"""

if hasattr(self, "_classifier"):
return len(self._classifier.ansatz.parameters)

self._log("Instance not initialized. Parameter count is 0.")
return 0


class QuanticMDM(QuanticClassifierBase):

Expand Down
23 changes: 19 additions & 4 deletions pyriemann_qiskit/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,19 @@ def get_mne_sample(n_trials=10, include_auditory=False):
return X, y


def get_qiskit_dataset():
def get_qiskit_dataset(n_samples=30):
"""Return qiskit dataset.
Notes
-----
.. versionadded:: 0.0.1
.. versionchanged:: 0.1.0
Added `n_samples` parameter.
Parameters
----------
n_samples : int (default: 30)
Number of trials to return.
Returns
-------
Expand All @@ -117,20 +124,27 @@ def get_qiskit_dataset():

feature_dim = 2
X, _, _, _ = ad_hoc_data(
training_size=30, test_size=0, n=feature_dim, gap=0.3, plot_data=False
training_size=n_samples, test_size=0, n=feature_dim, gap=0.3, plot_data=False
)

y = np.concatenate(([0] * 30, [1] * 30))
y = np.concatenate(([0] * n_samples, [1] * n_samples))

return (X, y)


def get_linearly_separable_dataset():
def get_linearly_separable_dataset(n_samples=100):
"""Return a linearly separable dataset.
Notes
-----
.. versionadded:: 0.0.1
.. versionchanged:: 0.1.0
Added `n_samples` parameter.
Parameters
----------
n_samples : int (default: 100)
Number of trials to return.
Returns
-------
Expand All @@ -142,6 +156,7 @@ def get_linearly_separable_dataset():
"""
X, y = make_classification(
n_samples=n_samples,
n_features=2,
n_redundant=0,
n_informative=2,
Expand Down
3 changes: 3 additions & 0 deletions pyriemann_qiskit/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .art import weights_spiral

__all__ = ["weights_spiral"]
74 changes: 74 additions & 0 deletions pyriemann_qiskit/visualization/art.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pandas as pd
import numpy as np


def weights_spiral(axe, vqc, X, y, n_trainings=5):
"""Artistic representation of vqc training.
Display a spiral. Each "branch" of the spiral corresponds to a parameter inside VQC.
When the branch is "large" it means that the weight of the parameter varies a lot
between different trainings.
Notes
-----
.. versionadded:: 0.1.0
Parameters
----------
axe: Axe
Pointer to the matplotlib plot or subplot.
vqc: QuanticVQC
The instance of VQC to evaluate.
X: ndarray, shape (n_samples, n_features)
Input vector, where `n_samples` is the number of samples and
`n_features` is the number of features.
y: ndarray, shape (n_samples,)
Predicted target vector relative to X.
n_trainings: int (default: 5)
Number of trainings to run, in order to evaluate the variability of the
parameters' weights.
Returns
-------
X: ndarray, shape (n_samples, n_features)
Input vector, where `n_samples` is the number of samples and
`n_features` is the number of features.
y: ndarray, shape (n_samples,)
Predicted target vector relative to X.
"""

weights = []

for i in range(5):
vqc.fit(X, y)
train_weights = vqc._classifier.weights
weights.append(train_weights)

df = pd.DataFrame(weights)

theta = np.arange(0, 8 * np.pi, 0.1)
a = 1
b = 0.2

n_params = len(df.columns)

max_var = df.var().max()

# https://matplotlib.org/3.1.1/gallery/misc/fill_spiral.html
for i in range(n_params):
dt = 2 * np.pi / n_params * i
x = a * np.cos(theta + dt) * np.exp(b * theta)
y = a * np.sin(theta + dt) * np.exp(b * theta)

var = df[i].var()

dt = dt + (var / max_var) * np.pi / 4.0

x2 = a * np.cos(theta + dt) * np.exp(b * theta)
y2 = a * np.sin(theta + dt) * np.exp(b * theta)

xf = np.concatenate((x, x2[::-1]))
yf = np.concatenate((y, y2[::-1]))

axe.fill(xf, yf)

0 comments on commit 823f494

Please sign in to comment.