Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exposure of Gamma, optimizer and variational form #16

Merged
merged 32 commits into from
Dec 3, 2021
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
a514d1e
expose gamma
gcattan Nov 26, 2021
da5d917
expose optimizer
gcattan Nov 26, 2021
c71aaa5
gen_two_local factory and tests
gcattan Nov 26, 2021
8c77b87
expose two local parameter
gcattan Nov 26, 2021
19d660f
code smells: exceptions
gcattan Nov 26, 2021
db897d5
complete api.rst
gcattan Nov 26, 2021
a39f2a7
typo
gcattan Dec 1, 2021
53eb71c
echo git location
gcattan Dec 1, 2021
8ef5af7
apt-get update before pip upgrade
gcattan Dec 1, 2021
db97e8f
just left pip install
gcattan Dec 1, 2021
c8d964c
just left pip install -e
gcattan Dec 1, 2021
94a21a7
no with more clear log, retry sudo apt install git
gcattan Dec 1, 2021
6471310
missing -y option
gcattan Dec 1, 2021
42862b7
apt-get update
gcattan Dec 1, 2021
72ca6b4
retry fix-missing
gcattan Dec 1, 2021
0df4472
check if git in path
gcattan Dec 1, 2021
0eab8a1
just left pip install -r doc/requirements.txt
gcattan Dec 1, 2021
1ee0547
try upgrading pip...
gcattan Dec 1, 2021
a483a97
revert changes
gcattan Dec 1, 2021
7508721
typo
gcattan Dec 1, 2021
a8bb4dd
liste usr/bin
gcattan Dec 1, 2021
07201fe
try git-core
gcattan Dec 1, 2021
62c56ca
typo
gcattan Dec 1, 2021
5912048
fix-missing options...
gcattan Dec 1, 2021
f1617f9
Update pyriemann_qiskit/classification.py
gcattan Dec 1, 2021
de212f1
Update pyriemann_qiskit/classification.py
gcattan Dec 1, 2021
01a6f7d
Update pyriemann_qiskit/classification.py
gcattan Dec 1, 2021
979faaa
Update pyriemann_qiskit/classification.py
gcattan Dec 1, 2021
a8a5d79
Update pyriemann_qiskit/utils/hyper_params_factory.py
gcattan Dec 1, 2021
a9a29c8
Update pyriemann_qiskit/utils/hyper_params_factory.py
gcattan Dec 1, 2021
54acb79
- fix pipeline
gcattan Dec 1, 2021
7264519
flake8
gcattan Dec 1, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/deploy_ghpages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@ jobs:
with:
docs-folder: "doc/"
pre-build-command: |
apt-get update --fix-missing
python -m pip install --upgrade pip
apt-get update
apt-get -y install git-all
pip install -e .
apt-get -y install git
pip install -r doc/requirements.txt
- name: Upload generated HTML as artifact
uses: actions/upload-artifact@v2
Expand Down
4 changes: 3 additions & 1 deletion doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,6 @@ Hyper-parameters generation
.. autosummary::
:toctree: generated/

gen_zz_feature_map
gen_zz_feature_map
gen_two_local
get_spsa
62 changes: 42 additions & 20 deletions pyriemann_qiskit/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from qiskit import BasicAer, IBMQ
from qiskit.circuit.library import TwoLocal
from qiskit.aqua import QuantumInstance, aqua_globals
from qiskit.aqua.quantum_instance import logger
from qiskit.aqua.algorithms import QSVM, SklearnSVM, VQC
from qiskit.aqua.utils import get_feature_dimension
from qiskit.providers.ibmq import least_busy
from qiskit.aqua.components.optimizers import SPSA
from datetime import datetime
import logging
from .utils.hyper_params_factory import gen_zz_feature_map
from .utils.hyper_params_factory import (gen_zz_feature_map,
gen_two_local,
get_spsa)

logger.level = logging.INFO

Expand Down Expand Up @@ -47,7 +47,8 @@ class QuanticClassifierBase(BaseEstimator, ClassifierMixin):
If true will output all intermediate results and logs
shots : int (default:1024)
Number of repetitions of each circuit, for sampling
gen_feature_map : Callable[int, QuantumCircuit | FeatureMap]
gen_feature_map : Callable[int, QuantumCircuit | FeatureMap] \
(default : Callable[int, ZZFeatureMap])
Function generating a feature map to encode data into a quantum state.

Notes
Expand Down Expand Up @@ -134,6 +135,8 @@ def fit(self, X, y):
y : ndarray, shape (n_samples,)
Target vector relative to X.

:raises Exception: if the number of classes different than 2

Returns
-------
self : QuanticClassifierBase instance
Expand Down Expand Up @@ -231,6 +234,12 @@ class QuanticSVM(QuanticClassifierBase):
-----
.. versionadded:: 0.0.1

Attributes
----------
gcattan marked this conversation as resolved.
Show resolved Hide resolved
gamma : float | None (default:None)
Used as input for sklearn rbf_kernel which is used internally.
See [3]_ for more information about gamma.

See Also
--------
QuanticClassifierBase
Expand All @@ -244,17 +253,23 @@ class QuanticSVM(QuanticClassifierBase):
‘Supervised learning with quantum-enhanced feature spaces’,
Nature, vol. 567, no. 7747, pp. 209–212, Mar. 2019,
doi: 10.1038/s41586-019-0980-2.
.. [3] Available from: \
gcattan marked this conversation as resolved.
Show resolved Hide resolved
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.pairwise.rbf_kernel.html

"""

def __init__(self, gamma=None, **parameters):
QuanticClassifierBase.__init__(self, **parameters)
self.gamma = gamma

def _init_algo(self, n_features):
# Although we do not train the classifier at this location
# training_input are required by Qiskit library.
self._log("SVM initiating algorithm")
if self.quantum:
classifier = QSVM(self._feature_map, self._training_input)
else:
classifier = SklearnSVM(self._training_input)
classifier = SklearnSVM(self._training_input, gamma=self.gamma)
return classifier

def predict_proba(self, X):
Expand Down Expand Up @@ -305,13 +320,14 @@ class QuanticVQC(QuanticClassifierBase):
Note there is no classical version of this algorithm.
This will always run on a quantum computer (simulated or not)

Parameters
Attributes
gcattan marked this conversation as resolved.
Show resolved Hide resolved
----------
q_account_token : string (default:None)
If quantum==True and q_account_token provided,
the classification task will be running on a IBM quantum backend
verbose : bool (default:True)
If true will output all intermediate results and logs
optimizer : Optimizer (default:SPSA)
The classical optimizer to use.
See [3] for details.
gen_var_form : Callable[int, QuantumCircuit | VariationalForm] \
(default: Callable[int, TwoLocal])
Function generating a variational form instance.

Notes
-----
Expand All @@ -321,6 +337,8 @@ class QuanticVQC(QuanticClassifierBase):
--------
QuanticClassifierBase

:raises ValueError: if `quantum` is False

References
----------
.. [1] H. Abraham et al., Qiskit:
Expand All @@ -332,22 +350,26 @@ class QuanticVQC(QuanticClassifierBase):
Nature, vol. 567, no. 7747, pp. 209–212, Mar. 2019,
doi: 10.1038/s41586-019-0980-2.

.. [3] \
https://qiskit.org/documentation/stable/0.19/stubs/qiskit.aqua.algorithms.VQC.html?highlight=vqc#qiskit.aqua.algorithms.VQC
gcattan marked this conversation as resolved.
Show resolved Hide resolved

"""

def __init__(self, q_account_token=None,
verbose=True, **parameters):
QuanticClassifierBase.__init__(self,
q_account_token=q_account_token,
verbose=verbose)
def __init__(self, optimizer=get_spsa(), gen_var_form=gen_two_local(),
**parameters):
if "quantum" in parameters and not parameters["quantum"]:
raise ValueError("VQC can only run on a quantum \
computer or simulator.")
QuanticClassifierBase.__init__(self, **parameters)
self.optimizer = optimizer
self.gen_var_form = gen_var_form

def _init_algo(self, n_features):
self._log("VQC training...")
self._optimizer = SPSA(maxiter=40, c0=4.0, skip_calibration=True)
self._var_form = TwoLocal(n_features,
['ry', 'rz'], 'cz', reps=3)
var_form = self.gen_var_form(n_features)
# Although we do not train the classifier at this location
# training_input are required by Qiskit library.
vqc = VQC(self._optimizer, self._feature_map, self._var_form,
vqc = VQC(self.optimizer, self._feature_map, var_form,
self._training_input)
return vqc

Expand Down
117 changes: 117 additions & 0 deletions pyriemann_qiskit/utils/hyper_params_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from qiskit.circuit.library import ZZFeatureMap
from qiskit.aqua.components.optimizers import SPSA
from qiskit.circuit.library import TwoLocal


def gen_zz_feature_map(reps=2, entanglement='linear'):
Expand All @@ -22,6 +24,8 @@ def gen_zz_feature_map(reps=2, entanglement='linear'):
ret : ZZFeatureMap
An instance of ZZFeatureMap

:raises ValueError: if `reps` lower than 1.

References
----------
.. [1] \
Expand All @@ -34,3 +38,116 @@ def gen_zz_feature_map(reps=2, entanglement='linear'):
return lambda n_features: ZZFeatureMap(feature_dimension=n_features,
reps=reps,
entanglement=entanglement)


# Valid gates for two local circuits
gates = ['ch', 'cx', 'cy', 'cz', 'crx', 'cry', 'crz',
'h', 'i', 'id', 'iden',
'rx', 'rxx', 'ry', 'ryy', 'rz', 'rzx', 'rzz',
's', 'sdg', 'swap',
'x', 'y', 'z', 't', 'tdg']


def _check_gates_in_blocks(blocks):
if isinstance(blocks, list):
for gate in blocks:
if gate not in gates:
raise ValueError("Gate %s is not a valid gate" % gate)
else:
if blocks not in gates:
raise ValueError("Gate %s is not a valid gate"
% blocks)


def gen_two_local(reps=3, rotation_blocks=['ry', 'rz'],
entanglement_blocks='cz'):
"""Return a callable that generate a TwoLocal circuit.
The two-local circuit is a parameterized circuit consisting
of alternating rotation layers and entanglement layers [1]_.

Parameters
----------
reps : int (default 3)
Specifies how often a block consisting of a rotation layer
and entanglement layer is repeated.
rotation_blocks : str | list[str]
The gates used in the rotation layer.
Valid string values are defined in `gates`.
entanglement_blocks : str | list[str]
The gates used in the entanglement layer.
Valid string values are defined in `gates`.

Returns
-------
ret : TwoLocal
An instance of a TwoLocal circuit

:raises ValueError: if `rotation_blocks` or `entanglement_blocks` contains
a non valid gate

References
----------
.. [1] \
https://qiskit.org/documentation/stable/0.19/stubs/qiskit.circuit.library.TwoLocal.html
"""
if reps < 1:
raise ValueError("Parameter reps must be superior \
or equal to 1 (Got %d)" % reps)

_check_gates_in_blocks(rotation_blocks)

_check_gates_in_blocks(entanglement_blocks)

return lambda n_features: TwoLocal(n_features,
rotation_blocks,
entanglement_blocks, reps=reps)


def get_spsa(max_trials=40, c=(None, None, None, None, 4.0)):
"""Return an instance of SPSA.
SPSA [1, 2]_ is an algorithmic method for optimizing systems
with multiple unknown parameters.
For more details, see [3] and [4].
gcattan marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
max_trials : int (default:40)
Maximum number of iterations to perform.
c : tuple[float | None] (default:(None, None, None, None, 4.0))
The 5 control parameters for SPSA algorithms.
See [3] for implementation details.
gcattan marked this conversation as resolved.
Show resolved Hide resolved
Auto calibration of SPSA will be skiped if one
of the parameters is different from None.

Returns
-------
ret : SPSA
An instance of SPSA

References
----------
.. [1] Spall, J. C. (2012), “Stochastic Optimization,”
in Handbook of Computational Statistics:
Concepts and Methods (2nd ed.)
(J. Gentle, W. Härdle, and Y. Mori, eds.),
Springer−Verlag, Heidelberg, Chapter 7, pp. 173–201.
dx.doi.org/10.1007/978-3-642-21551-3_7

.. [2] Spall, J. C. (1999), "Stochastic Optimization:
Stochastic Approximation and Simulated Annealing,"
in Encyclopedia of Electrical and Electronics Engineering
(J. G. Webster, ed.),
Wiley, New York, vol. 20, pp. 529–542

.. [3] \
https://qiskit.org/documentation/stable/0.19/stubs/qiskit.aqua.components.optimizers.SPSA.html

.. [4] https://www.jhuapl.edu/SPSA/#Overview
"""
params = {}
for i in range(5):
if c[i] is not None:
params["c" + str(i)] = c[i]
if len(params) > 0:
params["skip_calibration"] = True
return SPSA(max_trials=max_trials, **params)
6 changes: 6 additions & 0 deletions tests/test_classification.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import numpy as np
from pyriemann.classification import TangentSpace
from pyriemann.estimation import XdawnCovariances
Expand All @@ -17,6 +18,11 @@ def test_params(get_covmats, get_labels):
cross_val_score(clf, covset, labels, cv=skf, scoring='roc_auc')


def test_vqc_classical_should_return_value_error():
with pytest.raises(ValueError):
QuanticVQC(quantum=False)


def test_qsvm_init():
"""Test init of quantum classifiers"""
# if "classical" computation enable,
Expand Down
Loading