-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat/Quantum art visualization (#183)
* add n-trials to get-qiskit-dataset * add example * spiral implementation * add visualization package * modify get_linearly_separable_dataset * improve example * display the number of parameters inside the circuit * comment code * just adapts parameters * [pre-commit.ci] auto fixes from pre-commit.com hooks * fix dockerfile * update version of google_cloud_firestore * fix dockerfile --------- Co-authored-by: Gregoire Cattan <gregoire.cattan@ibm.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
0d98b92
commit 823f494
Showing
6 changed files
with
188 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
""" | ||
==================================================================== | ||
Art visualization of Variational Quantum Classifier. | ||
==================================================================== | ||
Display the variability of the weights inside the variational quantum | ||
classifier. | ||
""" | ||
|
||
# Author: Gregoire Cattan | ||
# License: BSD (3-clause) | ||
|
||
from pyriemann_qiskit.utils.hyper_params_factory import gen_two_local | ||
import matplotlib.pyplot as plt | ||
from pyriemann_qiskit.datasets import get_linearly_separable_dataset | ||
from pyriemann_qiskit.classification import QuanticVQC | ||
from pyriemann_qiskit.visualization import weights_spiral | ||
|
||
|
||
print(__doc__) | ||
|
||
############################################################################### | ||
# In this example we will display weights variability of the parameter inside | ||
# the variational quantum circuit which is used by VQC. | ||
# | ||
# The idea is simple : | ||
# - We initialize a VQC with different number of parameters and number of samples | ||
# - We train the VQC a couple of time and we store the fitted weights. | ||
# - We compute variability of the weight and display it in a fashion way. | ||
|
||
# Let's start by defining some plot area. | ||
fig, axes = plt.subplots(2, 2) | ||
fig.suptitle("VQC weights variability") | ||
|
||
# We will compute weight variability for different number of samples | ||
for i, n_samples in enumerate([2, 20]): | ||
# ... and for differente number of parameters. | ||
# (n_reps controls the number of parameters inside the circuit) | ||
for j, n_reps in enumerate([1, 3]): | ||
# instanciate VQC. | ||
vqc = QuanticVQC(gen_var_form=gen_two_local(reps=n_reps)) | ||
|
||
# Get data. We will use a toy dataset here. | ||
X, y = get_linearly_separable_dataset(n_samples=n_samples) | ||
|
||
# Compute and display weight variability after training | ||
axe = axes[i, j] | ||
# ... This is all done in this method | ||
# It displays a spiral. Each "branch of the spiral" is a parameter inside VQC. | ||
# The larger is the branch, the higher is the parameter variability. | ||
weights_spiral(axe, vqc, X, y, n_trainings=5) | ||
n_params = vqc.parameter_count | ||
|
||
# Just improve display of the graphics. | ||
if j == 0: | ||
axe.set_ylabel(f"n_samples: {n_samples}") | ||
if i == 0: | ||
axe.set_xlabel(f"n_params: {n_params}") | ||
axe.xaxis.set_label_position("top") | ||
axe.set_xticks(()) | ||
axe.set_yticks(()) | ||
|
||
plt.tight_layout() | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .art import weights_spiral | ||
|
||
__all__ = ["weights_spiral"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import pandas as pd | ||
import numpy as np | ||
|
||
|
||
def weights_spiral(axe, vqc, X, y, n_trainings=5): | ||
"""Artistic representation of vqc training. | ||
Display a spiral. Each "branch" of the spiral corresponds to a parameter inside VQC. | ||
When the branch is "large" it means that the weight of the parameter varies a lot | ||
between different trainings. | ||
Notes | ||
----- | ||
.. versionadded:: 0.1.0 | ||
Parameters | ||
---------- | ||
axe: Axe | ||
Pointer to the matplotlib plot or subplot. | ||
vqc: QuanticVQC | ||
The instance of VQC to evaluate. | ||
X: ndarray, shape (n_samples, n_features) | ||
Input vector, where `n_samples` is the number of samples and | ||
`n_features` is the number of features. | ||
y: ndarray, shape (n_samples,) | ||
Predicted target vector relative to X. | ||
n_trainings: int (default: 5) | ||
Number of trainings to run, in order to evaluate the variability of the | ||
parameters' weights. | ||
Returns | ||
------- | ||
X: ndarray, shape (n_samples, n_features) | ||
Input vector, where `n_samples` is the number of samples and | ||
`n_features` is the number of features. | ||
y: ndarray, shape (n_samples,) | ||
Predicted target vector relative to X. | ||
""" | ||
|
||
weights = [] | ||
|
||
for i in range(5): | ||
vqc.fit(X, y) | ||
train_weights = vqc._classifier.weights | ||
weights.append(train_weights) | ||
|
||
df = pd.DataFrame(weights) | ||
|
||
theta = np.arange(0, 8 * np.pi, 0.1) | ||
a = 1 | ||
b = 0.2 | ||
|
||
n_params = len(df.columns) | ||
|
||
max_var = df.var().max() | ||
|
||
# https://matplotlib.org/3.1.1/gallery/misc/fill_spiral.html | ||
for i in range(n_params): | ||
dt = 2 * np.pi / n_params * i | ||
x = a * np.cos(theta + dt) * np.exp(b * theta) | ||
y = a * np.sin(theta + dt) * np.exp(b * theta) | ||
|
||
var = df[i].var() | ||
|
||
dt = dt + (var / max_var) * np.pi / 4.0 | ||
|
||
x2 = a * np.cos(theta + dt) * np.exp(b * theta) | ||
y2 = a * np.sin(theta + dt) * np.exp(b * theta) | ||
|
||
xf = np.concatenate((x, x2[::-1])) | ||
yf = np.concatenate((y, y2[::-1])) | ||
|
||
axe.fill(xf, yf) |