Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/Quantum art visualization #183

Merged
merged 13 commits into from
Sep 29, 2023
18 changes: 9 additions & 9 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ RUN mkdir /root/mne_data
RUN mkdir /home/mne_data

## Workaround for firestore
RUN pip install protobuf==4.24.2
RUN pip install google_cloud_firestore==2.11.1
RUN pip install protobuf==4.24.3
RUN pip install google_cloud_firestore==2.12.0
### Missing __init__ file in protobuf
RUN touch /usr/local/lib/python3.8/site-packages/protobuf-4.24.2-py3.8-linux-x86_64.egg/google/__init__.py
RUN touch /usr/local/lib/python3.8/site-packages/protobuf-4.24.3-py3.8.egg/google/__init__.py
## google.cloud.location is never used in these files, and is missing in path.
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.11.1-py3.8.egg/google/cloud/firestore_v1/services/firestore/client.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.11.1-py3.8.egg/google/cloud/firestore_v1/services/firestore/transports/base.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.11.1-py3.8.egg/google/cloud/firestore_v1/services/firestore/transports/grpc.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.11.1-py3.8.egg/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.11.1-py3.8.egg/google/cloud/firestore_v1/services/firestore/transports/rest.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.11.1-py3.8.egg/google/cloud/firestore_v1/services/firestore/async_client.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.12.0-py3.8.egg/google/cloud/firestore_v1/services/firestore/client.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.12.0-py3.8.egg/google/cloud/firestore_v1/services/firestore/transports/base.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.12.0-py3.8.egg/google/cloud/firestore_v1/services/firestore/transports/grpc.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.12.0-py3.8.egg/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.12.0-py3.8.egg/google/cloud/firestore_v1/services/firestore/transports/rest.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.12.0-py3.8.egg/google/cloud/firestore_v1/services/firestore/async_client.py'

ENTRYPOINT [ "python", "/examples/ERP/classify_P300_bi.py" ]
65 changes: 65 additions & 0 deletions examples/toys_dataset/plot_quantum_art_vqc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
====================================================================
Art visualization of Variational Quantum Classifier.
====================================================================

Display the variability of the weights inside the variational quantum
classifier.

"""

# Author: Gregoire Cattan
# License: BSD (3-clause)

from pyriemann_qiskit.utils.hyper_params_factory import gen_two_local
import matplotlib.pyplot as plt
from pyriemann_qiskit.datasets import get_linearly_separable_dataset
from pyriemann_qiskit.classification import QuanticVQC
from pyriemann_qiskit.visualization import weights_spiral


print(__doc__)

###############################################################################
# In this example we will display weights variability of the parameter inside
# the variational quantum circuit which is used by VQC.
#
# The idea is simple :
# - We initialize a VQC with different number of parameters and number of samples
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add . at the end of this sentence.

# - We train the VQC a couple of time and we store the fitted weights.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

time -> times

# - We compute variability of the weight and display it in a fashion way.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fashion -> fashionable


# Let's start by defining some plot area.
fig, axes = plt.subplots(2, 2)
fig.suptitle("VQC weights variability")

# We will compute weight variability for different number of samples
for i, n_samples in enumerate([2, 20]):
# ... and for differente number of parameters.
# (n_reps controls the number of parameters inside the circuit)
for j, n_reps in enumerate([1, 3]):
# instanciate VQC.
vqc = QuanticVQC(gen_var_form=gen_two_local(reps=n_reps))

# Get data. We will use a toy dataset here.
X, y = get_linearly_separable_dataset(n_samples=n_samples)

# Compute and display weight variability after training
axe = axes[i, j]
# ... This is all done in this method
# It displays a spiral. Each "branch of the spiral" is a parameter inside VQC.
# The larger is the branch, the higher is the parameter variability.
weights_spiral(axe, vqc, X, y, n_trainings=5)
n_params = vqc.parameter_count

# Just improve display of the graphics.
if j == 0:
axe.set_ylabel(f"n_samples: {n_samples}")
if i == 0:
axe.set_xlabel(f"n_params: {n_params}")
axe.xaxis.set_label_position("top")
axe.set_xticks(())
axe.set_yticks(())

plt.tight_layout()
plt.show()
18 changes: 18 additions & 0 deletions pyriemann_qiskit/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,24 @@ def predict(self, X):
labels = self._predict(X)
return self._map_indices_to_classes(labels)

@property
def parameter_count(self):
"""Returns the number of parameters inside the variational circuit.
This is determined by the `gen_var_form` attribute of this instance.

Returns
-------
n_params : int
The number of parameters in the variational circuit.
Returns 0 if the instance is not fit yet.
"""

if hasattr(self, "_classifier"):
return len(self._classifier.ansatz.parameters)

self._log("Instance not initialized. Parameter count is 0.")
return 0


class QuanticMDM(QuanticClassifierBase):

Expand Down
23 changes: 19 additions & 4 deletions pyriemann_qiskit/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,19 @@ def get_mne_sample(n_trials=10, include_auditory=False):
return X, y


def get_qiskit_dataset():
def get_qiskit_dataset(n_samples=30):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't be confusing for someone who wants the entire dataset to get only 30 samples when calling the function without parameters?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This dataset is generated. So you have to specify somehow the number of samples you want.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Then the name could be "generate_qiskit_dataset".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good point.

"""Return qiskit dataset.

Notes
-----
.. versionadded:: 0.0.1
.. versionchanged:: 0.1.0
Added `n_samples` parameter.

Parameters
----------
n_samples : int (default: 30)
Number of trials to return.

Returns
-------
Expand All @@ -117,20 +124,27 @@ def get_qiskit_dataset():

feature_dim = 2
X, _, _, _ = ad_hoc_data(
training_size=30, test_size=0, n=feature_dim, gap=0.3, plot_data=False
training_size=n_samples, test_size=0, n=feature_dim, gap=0.3, plot_data=False
)

y = np.concatenate(([0] * 30, [1] * 30))
y = np.concatenate(([0] * n_samples, [1] * n_samples))

return (X, y)


def get_linearly_separable_dataset():
def get_linearly_separable_dataset(n_samples=100):
"""Return a linearly separable dataset.

Notes
-----
.. versionadded:: 0.0.1
.. versionchanged:: 0.1.0
Added `n_samples` parameter.

Parameters
----------
n_samples : int (default: 100)
Number of trials to return.

Returns
-------
Expand All @@ -142,6 +156,7 @@ def get_linearly_separable_dataset():

"""
X, y = make_classification(
n_samples=n_samples,
n_features=2,
n_redundant=0,
n_informative=2,
Expand Down
3 changes: 3 additions & 0 deletions pyriemann_qiskit/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .art import weights_spiral

__all__ = ["weights_spiral"]
74 changes: 74 additions & 0 deletions pyriemann_qiskit/visualization/art.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pandas as pd
import numpy as np


def weights_spiral(axe, vqc, X, y, n_trainings=5):
"""Artistic representation of vqc training.

Display a spiral. Each "branch" of the spiral corresponds to a parameter inside VQC.
When the branch is "large" it means that the weight of the parameter varies a lot
between different trainings.

Notes
-----
.. versionadded:: 0.1.0

Parameters
----------
axe: Axe
Pointer to the matplotlib plot or subplot.
vqc: QuanticVQC
The instance of VQC to evaluate.
X: ndarray, shape (n_samples, n_features)
Input vector, where `n_samples` is the number of samples and
`n_features` is the number of features.
y: ndarray, shape (n_samples,)
Predicted target vector relative to X.
n_trainings: int (default: 5)
Number of trainings to run, in order to evaluate the variability of the
parameters' weights.

Returns
-------
X: ndarray, shape (n_samples, n_features)
Input vector, where `n_samples` is the number of samples and
`n_features` is the number of features.
y: ndarray, shape (n_samples,)
Predicted target vector relative to X.

"""

weights = []

for i in range(5):
vqc.fit(X, y)
train_weights = vqc._classifier.weights
weights.append(train_weights)

df = pd.DataFrame(weights)

theta = np.arange(0, 8 * np.pi, 0.1)
a = 1
b = 0.2

n_params = len(df.columns)

max_var = df.var().max()

# https://matplotlib.org/3.1.1/gallery/misc/fill_spiral.html
for i in range(n_params):
dt = 2 * np.pi / n_params * i
x = a * np.cos(theta + dt) * np.exp(b * theta)
y = a * np.sin(theta + dt) * np.exp(b * theta)

var = df[i].var()

dt = dt + (var / max_var) * np.pi / 4.0

x2 = a * np.cos(theta + dt) * np.exp(b * theta)
y2 = a * np.sin(theta + dt) * np.exp(b * theta)

xf = np.concatenate((x, x2[::-1]))
yf = np.concatenate((y, y2[::-1]))

axe.fill(xf, yf)