Skip to content

Commit

Permalink
Bug/607 fidelity statevector kernel cannot be pickled (#778)
Browse files Browse the repository at this point in the history
* Made fidelity_statevector_kernel picklable

Added a new param to store cache size and a custom __getstate__  and __setstate__ to handle removing and re-initliasing the lru cache during pickle/unpickling respectively.

* updated notes

* name changes

* spell corrections

* Updated description

* Added unittest for pickling

* Spelling changes

* Making error messages clearer

* Spelling -_-

* Update releasenotes/notes/fix-fid_statevector_kernel-pickling-b7fa2b13a15ec9c6.yaml

Co-authored-by: Declan Millar <declan.millar@ibm.com>

* Update .gitignore

Co-authored-by: Declan Millar <declan.millar@ibm.com>

* Update test/kernels/test_fidelity_statevector_kernel.py

Co-authored-by: Declan Millar <declan.millar@ibm.com>

* Update qiskit_machine_learning/kernels/fidelity_statevector_kernel.py

Co-authored-by: Declan Millar <declan.millar@ibm.com>

* Update test/kernels/test_fidelity_statevector_kernel.py

Co-authored-by: Declan Millar <declan.millar@ibm.com>

* Update qiskit_machine_learning/kernels/fidelity_statevector_kernel.py

Co-authored-by: Declan Millar <declan.millar@ibm.com>

* Added Any class

---------

Co-authored-by: M. Emre Sahin <40424147+OkuyanBoga@users.noreply.github.com>
Co-authored-by: Declan Millar <declan.millar@ibm.com>
Co-authored-by: Anton Dekusar <62334182+adekusar-drl@users.noreply.github.com>
(cherry picked from commit c59063a)
  • Loading branch information
oscar-wallis authored and mergify[bot] committed Feb 29, 2024
1 parent 8e18006 commit 833f391
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 4 deletions.
15 changes: 12 additions & 3 deletions qiskit_machine_learning/kernels/fidelity_statevector_kernel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This code is part of a Qiskit project.
#
# (C) Copyright IBM 2023.
# (C) Copyright IBM 2023, 2024.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
Expand All @@ -14,7 +14,7 @@
from __future__ import annotations

from functools import lru_cache
from typing import Type, TypeVar
from typing import Type, TypeVar, Any

import numpy as np

Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(
self._auto_clear_cache = auto_clear_cache
self._shots = shots
self._enforce_psd = enforce_psd

self._cache_size = cache_size
# Create the statevector cache at the instance level.
self._get_statevector = lru_cache(maxsize=cache_size)(self._get_statevector_)

Expand Down Expand Up @@ -160,3 +160,12 @@ def clear_cache(self):
"""Clear the statevector cache."""
# pylint: disable=no-member
self._get_statevector.cache_clear()

def __getstate__(self) -> dict[str, Any]:
kernel = dict(self.__dict__)
kernel["_get_statevector"] = None
return kernel

def __setstate__(self, kernel: dict[str, Any]):
self.__dict__ = kernel
self._get_statevector = lru_cache(maxsize=self._cache_size)(self._get_statevector_)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
Fixed a bug where :class:`.FidelityStatevectorKernel` threw an error when pickled.
52 changes: 51 additions & 1 deletion test/kernels/test_fidelity_statevector_kernel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This code is part of a Qiskit project.
#
# (C) Copyright IBM 2023.
# (C) Copyright IBM 2023, 2024.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
Expand All @@ -15,6 +15,7 @@

import functools
import itertools
import pickle
import sys
import unittest

Expand Down Expand Up @@ -343,6 +344,55 @@ def test_properties(self):
self.assertEqual(qc, kernel.feature_map)
self.assertEqual(1, kernel.num_features)

def test_pickling(self):
"""Test that the kernel can be pickled correctly and without error."""
# Compares original kernel with copies made using pickle module and get & set state directly
qc = QuantumCircuit(1)
qc.ry(Parameter("w"), 0)
kernel1 = FidelityStatevectorKernel(feature_map=qc)

pickled_obj = pickle.dumps(kernel1)
kernel2 = pickle.loads(pickled_obj)

kernel3 = FidelityStatevectorKernel()
kernel3.__setstate__(kernel1.__getstate__())

with self.subTest("Pickle fail, kernels are not the same type"):
self.assertEqual(type(kernel1), type(kernel2))

with self.subTest("Pickle fail, kernels are not the same type"):
self.assertEqual(type(kernel1), type(kernel3))

with self.subTest("Pickle fail, kernels are not unique objects"):
self.assertNotEqual(kernel1, kernel2)

with self.subTest("Pickle fail, kernels are not unique objects"):
self.assertNotEqual(kernel1, kernel3)

with self.subTest("Pickle fail, caches are not the same type"):
self.assertEqual(type(kernel1._get_statevector), type(kernel2._get_statevector))

with self.subTest("Pickle fail, caches are not the same type"):
self.assertEqual(type(kernel1._get_statevector), type(kernel3._get_statevector))

# Remove cache to check dict properties are otherwise identical.
# - caches are never identical as they have different RAM locations.
kernel1.__dict__["_get_statevector"] = None
kernel2.__dict__["_get_statevector"] = None
kernel3.__dict__["_get_statevector"] = None

# Confirm changes were made.
with self.subTest("Pickle fail, caches have not been removed from kernels"):
self.assertEqual(kernel1._get_statevector, None)
self.assertEqual(kernel2._get_statevector, None)
self.assertEqual(kernel3._get_statevector, None)

with self.subTest("Pickle fail, properties of kernels (bar cache) are not identical"):
self.assertEqual(kernel1.__dict__, kernel2.__dict__)

with self.subTest("Pickle fail, properties of kernels (bar cache) are not identical"):
self.assertEqual(kernel1.__dict__, kernel3.__dict__)


@ddt
class TestStatevectorKernelDuplicates(QiskitMachineLearningTestCase):
Expand Down

0 comments on commit 833f391

Please sign in to comment.