Skip to content

Commit

Permalink
Add shots and feature map exposure (#14)
Browse files Browse the repository at this point in the history
* hyperparams object?

* Revert "hyperparams object?"

This reverts commit 2ad611f.

* expose `shots` parameter

* expose `feature map`

* push factory

* add pooch in setup requirements

* Revert "add pooch in setup requirements"

This reverts commit 6501bc9.

* update doc requirements according to a3e4e650b83a2a9ccfa5eb9a8b4234d99577d00d

* add documentation for factory

* flake8

* add gen_zz_feature_map to API

* move import of gen_zz_feature_map to the end of import section

* Improve documentation for gen_zz_feature_map

* Update pyriemann_qiskit/classification.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* install git in gh action

* approval missing

* Revert "approval missing"

This reverts commit 78b9677.

* Revert "install git in gh action"

This reverts commit f16bd5d.

* update pyriemann requirements?

* Revert "update pyriemann requirements?"

This reverts commit cd20619.

* upgrade pip in doc pipeline

* typo in ghpages

* Revert "typo in ghpages"

This reverts commit c771e53.

* improve ghpages pipeline

* install git

* fix requirement in setup.py
https://stackoverflow.com/questions/32688688/how-to-write-setup-py-to-include-a-git-repository-as-a-dependency

* missing -y option

* Update pyriemann_qiskit/classification.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann_qiskit/utils/hyper_params_factory.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann_qiskit/classification.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann_qiskit/utils/hyper_params_factory.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann_qiskit/utils/hyper_params_factory.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann_qiskit/utils/hyper_params_factory.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* fix typing of entanglement attribute

* - add tests
- correct and improve description of gen_zz_feature_map

* various refactoring and linting

* typo

* generalize -> string value -> value

* rename feature_dim -> n_features

* use pytest.raises(ValueError)

Co-authored-by: gcattan <gregoire.cattan@ibm.com>
Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>
  • Loading branch information
3 people authored Nov 25, 2021
1 parent 274d5b4 commit f945753
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 19 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/deploy_ghpages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ jobs:
with:
docs-folder: "doc/"
pre-build-command: |
python -m pip install --upgrade pip
apt-get update
apt-get -y install git-all
pip install -e .
pip install -r doc/doc-requirements.txt
pip install -r doc/requirements.txt
- name: Upload generated HTML as artifact
uses: actions/upload-artifact@v2
with:
Expand Down
18 changes: 17 additions & 1 deletion doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,20 @@ Classification

QuanticClassifierBase
QuanticSVM
QuanticVQC
QuanticVQC


Utils function
--------------

Utils functions are low level functions for the `classification` module.

Hyper-parameters generation
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. _hyper_params_factory_api:
.. currentmodule:: pyriemann_qiskit.utils.hyper_params_factory

.. autosummary::
:toctree: generated/

gen_zz_feature_map
2 changes: 1 addition & 1 deletion doc/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ sphinx-gallery
sphinx-bootstrap_theme
numpydoc
cython
mne
mne[data]>=0.24
seaborn
scikit-learn
joblib
Expand Down
37 changes: 22 additions & 15 deletions pyriemann_qiskit/classification.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""Module for classification function."""
import numpy as np

from sklearn.base import BaseEstimator, ClassifierMixin

from qiskit import BasicAer, IBMQ
from qiskit.circuit.library import ZZFeatureMap, TwoLocal
from qiskit.circuit.library import TwoLocal
from qiskit.aqua import QuantumInstance, aqua_globals
from qiskit.aqua.quantum_instance import logger
from qiskit.aqua.algorithms import QSVM, SklearnSVM, VQC
Expand All @@ -13,6 +11,8 @@
from qiskit.aqua.components.optimizers import SPSA
from datetime import datetime
import logging
from .utils.hyper_params_factory import gen_zz_feature_map

logger.level = logging.INFO


Expand Down Expand Up @@ -45,6 +45,10 @@ class QuanticClassifierBase(BaseEstimator, ClassifierMixin):
the classification task will be running on a IBM quantum backend
verbose : bool (default:True)
If true will output all intermediate results and logs
shots : int (default:1024)
Number of repetitions of each circuit, for sampling
gen_feature_map : Callable[int, QuantumCircuit | FeatureMap]
Function generating a feature map to encode data into a quantum state.
Notes
-----
Expand All @@ -68,11 +72,14 @@ class QuanticClassifierBase(BaseEstimator, ClassifierMixin):
"""

def __init__(self, quantum=True, q_account_token=None, verbose=True):
def __init__(self, quantum=True, q_account_token=None, verbose=True,
shots=1024, gen_feature_map=gen_zz_feature_map()):
self.verbose = verbose
self._log("Initializing Quantum Classifier")
self.q_account_token = q_account_token
self.quantum = quantum
self.shots = shots
self.gen_feature_map = gen_feature_map
# protected field for child classes
self._training_input = {}

Expand Down Expand Up @@ -146,15 +153,14 @@ def fit(self, X, y):
self._training_input[self.classes_[1]] = class1
self._training_input[self.classes_[0]] = class0

feature_dim = get_feature_dimension(self._training_input)
self._log("Feature dimension = ", feature_dim)
self._feature_map = ZZFeatureMap(feature_dimension=feature_dim, reps=2,
entanglement='linear')
n_features = get_feature_dimension(self._training_input)
self._log("Feature dimension = ", n_features)
self._feature_map = self.gen_feature_map(n_features)
if self.quantum:
if not hasattr(self, "_backend"):
def filters(device):
return (
device.configuration().n_qubits >= feature_dim
device.configuration().n_qubits >= n_features
and not device.configuration().simulator
and device.status().operational)
devices = self._provider.backends(filters=filters)
Expand All @@ -166,14 +172,15 @@ def filters(device):
self._log("Quantum backend = ", self._backend)
seed_sim = aqua_globals.random_seed
seed_trs = aqua_globals.random_seed
self._quantum_instance = QuantumInstance(self._backend, shots=1024,
self._quantum_instance = QuantumInstance(self._backend,
shots=self.shots,
seed_simulator=seed_sim,
seed_transpiler=seed_trs)
self._classifier = self._init_algo(feature_dim)
self._classifier = self._init_algo(n_features)
self._train(X, y)
return self

def _init_algo(self, feature_dim):
def _init_algo(self, n_features):
raise Exception("Init algo method was not implemented")

def _train(self, X, y):
Expand Down Expand Up @@ -240,7 +247,7 @@ class QuanticSVM(QuanticClassifierBase):
"""

def _init_algo(self, feature_dim):
def _init_algo(self, n_features):
# Although we do not train the classifier at this location
# training_input are required by Qiskit library.
self._log("SVM initiating algorithm")
Expand Down Expand Up @@ -333,10 +340,10 @@ def __init__(self, q_account_token=None,
q_account_token=q_account_token,
verbose=verbose)

def _init_algo(self, feature_dim):
def _init_algo(self, n_features):
self._log("VQC training...")
self._optimizer = SPSA(maxiter=40, c0=4.0, skip_calibration=True)
self._var_form = TwoLocal(feature_dim,
self._var_form = TwoLocal(n_features,
['ry', 'rz'], 'cz', reps=3)
# Although we do not train the classifier at this location
# training_input are required by Qiskit library.
Expand Down
5 changes: 5 additions & 0 deletions pyriemann_qiskit/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from . import hyper_params_factory

__all__ = [
'hyper_params_factory',
]
36 changes: 36 additions & 0 deletions pyriemann_qiskit/utils/hyper_params_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from qiskit.circuit.library import ZZFeatureMap


def gen_zz_feature_map(reps=2, entanglement='linear'):
"""Return a callable that generate a ZZFeatureMap.
A feature map encodes data into a quantum state.
A ZZFeatureMap is a second-order Pauli-Z evolution circuit.
Parameters
----------
reps : int (default 2)
The number of repeated circuits, greater or equal to 1.
entanglement : str | list[list[list[int]]] | \
Callable[int, list[list[list[int]]]]
Specifies the entanglement structure.
Entanglement structure can be provided with indices or string.
Possible string values are: 'full', 'linear', 'circular' and 'sca'.
Consult [1]_ for more details on entanglement structure.
Returns
-------
ret : ZZFeatureMap
An instance of ZZFeatureMap
References
----------
.. [1] \
https://qiskit.org/documentation/stable/0.19/stubs/qiskit.circuit.library.NLocal.html
"""
if reps < 1:
raise ValueError("Parameter reps must be superior \
or equal to 1 (Got %d)" % reps)

return lambda n_features: ZZFeatureMap(feature_dimension=n_features,
reps=reps,
entanglement=entanglement)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
},
platforms='any',
python_requires=">=3.6",
install_requires=['cython', 'pyriemann', 'qiskit==0.20.0', 'cvxpy==1.1.12'],
install_requires=['cython', 'pyriemann @ git+https://github.com/pyRiemann/pyRiemann#egg=pyriemann', 'qiskit==0.20.0', 'cvxpy==1.1.12'],
extras_require={'docs': ['sphinx-gallery', 'sphinx-bootstrap_theme', 'numpydoc', 'mne', 'seaborn'],
'tests': ['pytest', 'seaborn', 'flake8']},
zip_safe=False,
Expand Down
32 changes: 32 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,35 @@ def get_feats(rndstate):
def _gen_feat(n_samples, n_features):
return generate_feat(n_samples, n_features, rndstate)
return _gen_feat


def _get_linear_entanglement(n_qbits_in_block, n_features):
return [list(range(i, i + n_qbits_in_block))
for i in range(n_features - n_qbits_in_block + 1)]


def _get_pauli_z_rep_linear_entanglement(n_features):
num_qubits_by_block = [1, 2]
indices_by_block = []
for n in num_qubits_by_block:
linear = _get_linear_entanglement(n, n_features)
indices_by_block.append(linear)
return indices_by_block


@pytest.fixture
def get_pauli_z_linear_entangl_handle():
def _get_pauli_z_linear_entangl_handle(n_features):
indices = _get_pauli_z_rep_linear_entanglement(n_features)
return lambda _: [indices]

return _get_pauli_z_linear_entangl_handle


@pytest.fixture
def get_pauli_z_linear_entangl_idx():
def _get_pauli_z_linear_entangl_idx(reps, n_features):
indices = _get_pauli_z_rep_linear_entanglement(n_features)
return [indices for _ in range(reps)]

return _get_pauli_z_linear_entangl_idx
38 changes: 38 additions & 0 deletions tests/test_utils_hyper_params_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
from pyriemann_qiskit.utils.hyper_params_factory import gen_zz_feature_map


@pytest.mark.parametrize(
'entanglement', ['full', 'linear', 'circular', 'sca']
)
def test_gen_zz_feature_map_entangl_strings(entanglement):
"""Test gen_zz_feature_map with different string options of entanglement"""
n_features = 2
feature_map = gen_zz_feature_map(entanglement=entanglement)(n_features)
assert isinstance(feature_map.parameters, set)


def test_gen_zz_feature_map_entangl_idx(get_pauli_z_linear_entangl_idx):
"""Test gen_zz_feature_map with valid indices value"""
n_features = 2
reps = 2
indices = get_pauli_z_linear_entangl_idx(reps, n_features)
feature_map_handle = gen_zz_feature_map(reps=reps, entanglement=indices)
feature_map = feature_map_handle(n_features)
assert isinstance(feature_map.parameters, set)


def test_gen_zz_feature_map_entangl_handle(get_pauli_z_linear_entangl_handle):
"""Test gen_zz_feature_map with a valid callable"""
n_features = 2
indices = get_pauli_z_linear_entangl_handle(n_features)
feature_map = gen_zz_feature_map(entanglement=indices)(n_features)
assert isinstance(feature_map.parameters, set)


def test_gen_zz_feature_map_entangl_invalid_value():
"""Test gen_zz_feature_map with uncorrect value"""
n_features = 2
feature_map = gen_zz_feature_map(entanglement="invalid")(n_features)
with pytest.raises(ValueError):
feature_map.parameters

0 comments on commit f945753

Please sign in to comment.