Skip to content

Commit

Permalink
Exposure of Gamma, optimizer and variational form (#16)
Browse files Browse the repository at this point in the history
* expose gamma

* expose optimizer

* gen_two_local factory and tests

* expose two local parameter

* code smells: exceptions

* complete api.rst

* typo

* echo git location

* apt-get update before pip upgrade

* just left pip install

* just left pip install -e

* no with more clear log, retry sudo apt install git

* missing -y option

* apt-get update

* retry fix-missing

* check if git in path

* just left  pip install -r doc/requirements.txt

* try upgrading pip...

* revert changes

* typo

* liste usr/bin

* try git-core

* typo

* fix-missing options...

* Update pyriemann_qiskit/classification.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann_qiskit/classification.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann_qiskit/classification.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann_qiskit/classification.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann_qiskit/utils/hyper_params_factory.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann_qiskit/utils/hyper_params_factory.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* - fix pipeline
https://stackoverflow.com/questions/68802802/repository-http-security-debian-org-debian-security-buster-updates-inrelease
- fix docstring in section raises

* flake8

Co-authored-by: gcattan <gregoire.cattan@ibm.com>
Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>
  • Loading branch information
3 people authored Dec 3, 2021
1 parent f945753 commit 769605b
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 24 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/deploy_ghpages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@ jobs:
with:
docs-folder: "doc/"
pre-build-command: |
apt-get --allow-releaseinfo-change update
python -m pip install --upgrade pip
apt-get update
apt-get -y install git-all
pip install -e .
apt-get -y install --fix-missing git-core
pip install -r doc/requirements.txt
- name: Upload generated HTML as artifact
uses: actions/upload-artifact@v2
Expand Down
4 changes: 3 additions & 1 deletion doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,6 @@ Hyper-parameters generation
.. autosummary::
:toctree: generated/

gen_zz_feature_map
gen_zz_feature_map
gen_two_local
get_spsa
67 changes: 48 additions & 19 deletions pyriemann_qiskit/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from qiskit import BasicAer, IBMQ
from qiskit.circuit.library import TwoLocal
from qiskit.aqua import QuantumInstance, aqua_globals
from qiskit.aqua.quantum_instance import logger
from qiskit.aqua.algorithms import QSVM, SklearnSVM, VQC
from qiskit.aqua.utils import get_feature_dimension
from qiskit.providers.ibmq import least_busy
from qiskit.aqua.components.optimizers import SPSA
from datetime import datetime
import logging
from .utils.hyper_params_factory import gen_zz_feature_map
from .utils.hyper_params_factory import (gen_zz_feature_map,
gen_two_local,
get_spsa)

logger.level = logging.INFO

Expand Down Expand Up @@ -47,7 +47,8 @@ class QuanticClassifierBase(BaseEstimator, ClassifierMixin):
If true will output all intermediate results and logs
shots : int (default:1024)
Number of repetitions of each circuit, for sampling
gen_feature_map : Callable[int, QuantumCircuit | FeatureMap]
gen_feature_map : Callable[int, QuantumCircuit | FeatureMap] \
(default : Callable[int, ZZFeatureMap])
Function generating a feature map to encode data into a quantum state.
Notes
Expand Down Expand Up @@ -134,6 +135,11 @@ def fit(self, X, y):
y : ndarray, shape (n_samples,)
Target vector relative to X.
Raises
------
Exception
Raised if the number of classes is different than 2
Returns
-------
self : QuanticClassifierBase instance
Expand Down Expand Up @@ -231,6 +237,12 @@ class QuanticSVM(QuanticClassifierBase):
-----
.. versionadded:: 0.0.1
Parameters
----------
gamma : float | None (default:None)
Used as input for sklearn rbf_kernel which is used internally.
See [3]_ for more information about gamma.
See Also
--------
QuanticClassifierBase
Expand All @@ -245,16 +257,23 @@ class QuanticSVM(QuanticClassifierBase):
Nature, vol. 567, no. 7747, pp. 209–212, Mar. 2019,
doi: 10.1038/s41586-019-0980-2.
.. [3] Available from: \
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.pairwise.rbf_kernel.html
"""

def __init__(self, gamma=None, **parameters):
QuanticClassifierBase.__init__(self, **parameters)
self.gamma = gamma

def _init_algo(self, n_features):
# Although we do not train the classifier at this location
# training_input are required by Qiskit library.
self._log("SVM initiating algorithm")
if self.quantum:
classifier = QSVM(self._feature_map, self._training_input)
else:
classifier = SklearnSVM(self._training_input)
classifier = SklearnSVM(self._training_input, gamma=self.gamma)
return classifier

def predict_proba(self, X):
Expand Down Expand Up @@ -307,11 +326,12 @@ class QuanticVQC(QuanticClassifierBase):
Parameters
----------
q_account_token : string (default:None)
If quantum==True and q_account_token provided,
the classification task will be running on a IBM quantum backend
verbose : bool (default:True)
If true will output all intermediate results and logs
optimizer : Optimizer (default:SPSA)
The classical optimizer to use.
See [3] for details.
gen_var_form : Callable[int, QuantumCircuit | VariationalForm] \
(default: Callable[int, TwoLocal])
Function generating a variational form instance.
Notes
-----
Expand All @@ -321,6 +341,11 @@ class QuanticVQC(QuanticClassifierBase):
--------
QuanticClassifierBase
Raises
------
ValueError
Raised if ``quantum`` is False
References
----------
.. [1] H. Abraham et al., Qiskit:
Expand All @@ -332,22 +357,26 @@ class QuanticVQC(QuanticClassifierBase):
Nature, vol. 567, no. 7747, pp. 209–212, Mar. 2019,
doi: 10.1038/s41586-019-0980-2.
.. [3] \
https://qiskit.org/documentation/stable/0.19/stubs/qiskit.aqua.algorithms.VQC.html
"""

def __init__(self, q_account_token=None,
verbose=True, **parameters):
QuanticClassifierBase.__init__(self,
q_account_token=q_account_token,
verbose=verbose)
def __init__(self, optimizer=get_spsa(), gen_var_form=gen_two_local(),
**parameters):
if "quantum" in parameters and not parameters["quantum"]:
raise ValueError("VQC can only run on a quantum \
computer or simulator.")
QuanticClassifierBase.__init__(self, **parameters)
self.optimizer = optimizer
self.gen_var_form = gen_var_form

def _init_algo(self, n_features):
self._log("VQC training...")
self._optimizer = SPSA(maxiter=40, c0=4.0, skip_calibration=True)
self._var_form = TwoLocal(n_features,
['ry', 'rz'], 'cz', reps=3)
var_form = self.gen_var_form(n_features)
# Although we do not train the classifier at this location
# training_input are required by Qiskit library.
vqc = VQC(self._optimizer, self._feature_map, self._var_form,
vqc = VQC(self.optimizer, self._feature_map, var_form,
self._training_input)
return vqc

Expand Down
123 changes: 123 additions & 0 deletions pyriemann_qiskit/utils/hyper_params_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from qiskit.circuit.library import ZZFeatureMap
from qiskit.aqua.components.optimizers import SPSA
from qiskit.circuit.library import TwoLocal


def gen_zz_feature_map(reps=2, entanglement='linear'):
Expand All @@ -22,6 +24,11 @@ def gen_zz_feature_map(reps=2, entanglement='linear'):
ret : ZZFeatureMap
An instance of ZZFeatureMap
Raises
------
ValueError
Raised if ``reps`` is lower than 1.
References
----------
.. [1] \
Expand All @@ -34,3 +41,119 @@ def gen_zz_feature_map(reps=2, entanglement='linear'):
return lambda n_features: ZZFeatureMap(feature_dimension=n_features,
reps=reps,
entanglement=entanglement)


# Valid gates for two local circuits
gates = ['ch', 'cx', 'cy', 'cz', 'crx', 'cry', 'crz',
'h', 'i', 'id', 'iden',
'rx', 'rxx', 'ry', 'ryy', 'rz', 'rzx', 'rzz',
's', 'sdg', 'swap',
'x', 'y', 'z', 't', 'tdg']


def _check_gates_in_blocks(blocks):
if isinstance(blocks, list):
for gate in blocks:
if gate not in gates:
raise ValueError("Gate %s is not a valid gate" % gate)
else:
if blocks not in gates:
raise ValueError("Gate %s is not a valid gate"
% blocks)


def gen_two_local(reps=3, rotation_blocks=['ry', 'rz'],
entanglement_blocks='cz'):
"""Return a callable that generate a TwoLocal circuit.
The two-local circuit is a parameterized circuit consisting
of alternating rotation layers and entanglement layers [1]_.
Parameters
----------
reps : int (default 3)
Specifies how often a block consisting of a rotation layer
and entanglement layer is repeated.
rotation_blocks : str | list[str]
The gates used in the rotation layer.
Valid string values are defined in `gates`.
entanglement_blocks : str | list[str]
The gates used in the entanglement layer.
Valid string values are defined in `gates`.
Returns
-------
ret : TwoLocal
An instance of a TwoLocal circuit
Raises
------
ValueError
Raised if ``rotation_blocks`` or ``entanglement_blocks`` contain
a non valid gate
References
----------
.. [1] \
https://qiskit.org/documentation/stable/0.19/stubs/qiskit.circuit.library.TwoLocal.html
"""
if reps < 1:
raise ValueError("Parameter reps must be superior \
or equal to 1 (Got %d)" % reps)

_check_gates_in_blocks(rotation_blocks)

_check_gates_in_blocks(entanglement_blocks)

return lambda n_features: TwoLocal(n_features,
rotation_blocks,
entanglement_blocks, reps=reps)


def get_spsa(max_trials=40, c=(None, None, None, None, 4.0)):
"""Return an instance of SPSA.
SPSA [1, 2]_ is an algorithmic method for optimizing systems
with multiple unknown parameters.
For more details, see [3]_ and [4]_.
Parameters
----------
max_trials : int (default:40)
Maximum number of iterations to perform.
c : tuple[float | None] (default:(None, None, None, None, 4.0))
The 5 control parameters for SPSA algorithms.
See [3]_ for implementation details.
Auto calibration of SPSA will be skiped if one
of the parameters is different from None.
Returns
-------
ret : SPSA
An instance of SPSA
References
----------
.. [1] Spall, J. C. (2012), “Stochastic Optimization,”
in Handbook of Computational Statistics:
Concepts and Methods (2nd ed.)
(J. Gentle, W. Härdle, and Y. Mori, eds.),
Springer−Verlag, Heidelberg, Chapter 7, pp. 173–201.
dx.doi.org/10.1007/978-3-642-21551-3_7
.. [2] Spall, J. C. (1999), "Stochastic Optimization:
Stochastic Approximation and Simulated Annealing,"
in Encyclopedia of Electrical and Electronics Engineering
(J. G. Webster, ed.),
Wiley, New York, vol. 20, pp. 529–542
.. [3] \
https://qiskit.org/documentation/stable/0.19/stubs/qiskit.aqua.components.optimizers.SPSA.html
.. [4] https://www.jhuapl.edu/SPSA/#Overview
"""
params = {}
for i in range(5):
if c[i] is not None:
params["c" + str(i)] = c[i]
if len(params) > 0:
params["skip_calibration"] = True
return SPSA(max_trials=max_trials, **params)
6 changes: 6 additions & 0 deletions tests/test_classification.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import numpy as np
from pyriemann.classification import TangentSpace
from pyriemann.estimation import XdawnCovariances
Expand All @@ -17,6 +18,11 @@ def test_params(get_covmats, get_labels):
cross_val_score(clf, covset, labels, cv=skf, scoring='roc_auc')


def test_vqc_classical_should_return_value_error():
with pytest.raises(ValueError):
QuanticVQC(quantum=False)


def test_qsvm_init():
"""Test init of quantum classifiers"""
# if "classical" computation enable,
Expand Down
Loading

0 comments on commit 769605b

Please sign in to comment.