Skip to content

Commit

Permalink
Refactor how we generate default json resolver (#3532)
Browse files Browse the repository at this point in the history
This came out of #3516; removing two levels of indentation makes the formatting change less drastic, but IMO also makes the code more understandable by removing the singleton `_ResolverCache` class.
  • Loading branch information
maffoo committed Nov 25, 2020
1 parent 92e928e commit 47d3418
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 170 deletions.
317 changes: 153 additions & 164 deletions cirq/protocols/json_serialization.py
Expand Up @@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import functools
import json
import numbers
import pathlib
from typing import (
Any,
Callable,
cast,
Dict,
IO,
Expand All @@ -36,183 +38,169 @@
from typing_extensions import Protocol

from cirq._doc import doc_private
from cirq.ops import raw_types # Tells mypy that the raw_types module exists
from cirq.ops import raw_types
from cirq.type_workarounds import NotImplementedType

if TYPE_CHECKING:
import cirq.ops.pauli_gates
import cirq.devices.unconstrained_device


class _ResolverCache:
"""Lazily import and build registry to avoid circular imports."""

def __init__(self):
self._crd = None

@property
def cirq_class_resolver_dictionary(self) -> Dict[str, Type]:
if self._crd is None:
import cirq
from cirq.devices.noise_model import _NoNoiseModel
from cirq.experiments import (CrossEntropyResult,
CrossEntropyResultDict,
GridInteractionLayer)
from cirq.experiments.grid_parallel_two_qubit_xeb import (
GridParallelXEBMetadata)
from cirq.google.devices.known_devices import (
_NamedConstantXmonDevice)

def _identity_operation_from_dict(qubits, **kwargs):
return cirq.identity_each(*qubits)

def single_qubit_matrix_gate(matrix):
if not isinstance(matrix, np.ndarray):
matrix = np.array(matrix, dtype=np.complex128)
return cirq.MatrixGate(matrix, qid_shape=(matrix.shape[0],))

def two_qubit_matrix_gate(matrix):
if not isinstance(matrix, np.ndarray):
matrix = np.array(matrix, dtype=np.complex128)
return cirq.MatrixGate(matrix, qid_shape=(2, 2))

self._crd = {
'AmplitudeDampingChannel': cirq.AmplitudeDampingChannel,
'AsymmetricDepolarizingChannel':
cirq.AsymmetricDepolarizingChannel,
'BitFlipChannel': cirq.BitFlipChannel,
'ProductState': cirq.ProductState,
'CCNotPowGate': cirq.CCNotPowGate,
'CCXPowGate': cirq.CCXPowGate,
'CCZPowGate': cirq.CCZPowGate,
'CNotPowGate': cirq.CNotPowGate,
'Calibration': cirq.google.Calibration,
'CalibrationLayer': cirq.google.CalibrationLayer,
'CalibrationResult': cirq.google.CalibrationResult,
'CalibrationTag': cirq.google.CalibrationTag,
'ControlledGate': cirq.ControlledGate,
'ControlledOperation': cirq.ControlledOperation,
'CSwapGate': cirq.CSwapGate,
'CXPowGate': cirq.CXPowGate,
'CZPowGate': cirq.CZPowGate,
'CrossEntropyResult': CrossEntropyResult,
'CrossEntropyResultDict': CrossEntropyResultDict,
'Circuit': cirq.Circuit,
'CliffordState': cirq.CliffordState,
'CliffordTableau': cirq.CliffordTableau,
'DepolarizingChannel': cirq.DepolarizingChannel,
'ConstantQubitNoiseModel': cirq.ConstantQubitNoiseModel,
'Duration': cirq.Duration,
'FrozenCircuit': cirq.FrozenCircuit,
'FSimGate': cirq.FSimGate,
'DensePauliString': cirq.DensePauliString,
'MutableDensePauliString': cirq.MutableDensePauliString,
'MutablePauliString': cirq.MutablePauliString,
'GateOperation': cirq.GateOperation,
'GateTabulation': cirq.google.GateTabulation,
'GeneralizedAmplitudeDampingChannel':
cirq.GeneralizedAmplitudeDampingChannel,
'GlobalPhaseOperation': cirq.GlobalPhaseOperation,
'GridInteractionLayer': GridInteractionLayer,
'GridParallelXEBMetadata': GridParallelXEBMetadata,
'GridQid': cirq.GridQid,
'GridQubit': cirq.GridQubit,
'HPowGate': cirq.HPowGate,
'ISwapPowGate': cirq.ISwapPowGate,
'IdentityGate': cirq.IdentityGate,
'IdentityOperation': _identity_operation_from_dict,
'InitObsSetting': cirq.work.InitObsSetting,
'LinearDict': cirq.LinearDict,
'LineQubit': cirq.LineQubit,
'LineQid': cirq.LineQid,
'MatrixGate': cirq.MatrixGate,
'MeasurementGate': cirq.MeasurementGate,
'_MeasurementSpec': cirq.work._MeasurementSpec,
'Moment': cirq.Moment,
'_XEigenState':
cirq.value.product_state._XEigenState, # type: ignore
'_YEigenState':
cirq.value.product_state._YEigenState, # type: ignore
'_ZEigenState':
cirq.value.product_state._ZEigenState, # type: ignore
'_NamedConstantXmonDevice': _NamedConstantXmonDevice,
'_NoNoiseModel': _NoNoiseModel,
'NamedQubit': cirq.NamedQubit,
'NamedQid': cirq.NamedQid,
'NoIdentifierQubit': cirq.testing.NoIdentifierQubit,
'_PauliX': cirq.ops.pauli_gates._PauliX,
'_PauliY': cirq.ops.pauli_gates._PauliY,
'_PauliZ': cirq.ops.pauli_gates._PauliZ,
'ParamResolver': cirq.ParamResolver,
'PasqalDevice': cirq.pasqal.PasqalDevice,
'PasqalVirtualDevice': cirq.pasqal.PasqalVirtualDevice,
'PauliString': cirq.PauliString,
'PhaseDampingChannel': cirq.PhaseDampingChannel,
'PhaseFlipChannel': cirq.PhaseFlipChannel,
'PhaseGradientGate': cirq.PhaseGradientGate,
'PhasedFSimGate': cirq.PhasedFSimGate,
'PhasedISwapPowGate': cirq.PhasedISwapPowGate,
'PhasedXPowGate': cirq.PhasedXPowGate,
'PhasedXZGate': cirq.PhasedXZGate,
'PhysicalZTag': cirq.google.PhysicalZTag,
'RandomGateChannel': cirq.RandomGateChannel,
'QuantumFourierTransformGate': cirq.QuantumFourierTransformGate,
'ResetChannel': cirq.ResetChannel,
'SingleQubitMatrixGate': single_qubit_matrix_gate,
'SingleQubitPauliStringGateOperation':
cirq.SingleQubitPauliStringGateOperation,
'SingleQubitReadoutCalibrationResult':
cirq.experiments.SingleQubitReadoutCalibrationResult,
'StabilizerStateChForm': cirq.StabilizerStateChForm,
'SwapPowGate': cirq.SwapPowGate,
'SycamoreGate': cirq.google.SycamoreGate,
'TaggedOperation': cirq.TaggedOperation,
'ThreeDQubit': cirq.pasqal.ThreeDQubit,
'Result': cirq.Result,
'TrialResult': cirq.TrialResult,
'TwoDQubit': cirq.pasqal.TwoDQubit,
'TwoQubitMatrixGate': two_qubit_matrix_gate,
'TwoQubitDiagonalGate': cirq.TwoQubitDiagonalGate,
'_UnconstrainedDevice':
cirq.devices.unconstrained_device._UnconstrainedDevice,
'VirtualTag': cirq.VirtualTag,
'WaitGate': cirq.WaitGate,
'_QubitAsQid': raw_types._QubitAsQid,
'XPowGate': cirq.XPowGate,
'XXPowGate': cirq.XXPowGate,
'YPowGate': cirq.YPowGate,
'YYPowGate': cirq.YYPowGate,
'ZPowGate': cirq.ZPowGate,
'ZZPowGate': cirq.ZZPowGate,

# not a cirq class, but treated as one:
'pandas.DataFrame': pd.DataFrame,
'pandas.Index': pd.Index,
'pandas.MultiIndex': pd.MultiIndex.from_tuples,
'sympy.Symbol': sympy.Symbol,
'sympy.Add': lambda args: sympy.Add(*args),
'sympy.Mul': lambda args: sympy.Mul(*args),
'sympy.Pow': lambda args: sympy.Pow(*args),
'sympy.Float': lambda approx: sympy.Float(approx),
'sympy.Integer': sympy.Integer,
'sympy.Rational': sympy.Rational,
'complex': complex,
}
return self._crd


RESOLVER_CACHE = _ResolverCache()
ObjectFactory = Union[Type, Callable[..., Any]]


@functools.lru_cache(maxsize=1)
def _cirq_class_resolver_dictionary() -> Dict[str, ObjectFactory]:
import cirq
from cirq.devices.noise_model import _NoNoiseModel
from cirq.experiments import (CrossEntropyResult, CrossEntropyResultDict,
GridInteractionLayer)
from cirq.experiments.grid_parallel_two_qubit_xeb import (
GridParallelXEBMetadata)
from cirq.google.devices.known_devices import _NamedConstantXmonDevice

def _identity_operation_from_dict(qubits, **kwargs):
return cirq.identity_each(*qubits)

def single_qubit_matrix_gate(matrix):
if not isinstance(matrix, np.ndarray):
matrix = np.array(matrix, dtype=np.complex128)
return cirq.MatrixGate(matrix, qid_shape=(matrix.shape[0],))

def two_qubit_matrix_gate(matrix):
if not isinstance(matrix, np.ndarray):
matrix = np.array(matrix, dtype=np.complex128)
return cirq.MatrixGate(matrix, qid_shape=(2, 2))

return {
'AmplitudeDampingChannel': cirq.AmplitudeDampingChannel,
'AsymmetricDepolarizingChannel': cirq.AsymmetricDepolarizingChannel,
'BitFlipChannel': cirq.BitFlipChannel,
'ProductState': cirq.ProductState,
'CCNotPowGate': cirq.CCNotPowGate,
'CCXPowGate': cirq.CCXPowGate,
'CCZPowGate': cirq.CCZPowGate,
'CNotPowGate': cirq.CNotPowGate,
'Calibration': cirq.google.Calibration,
'CalibrationLayer': cirq.google.CalibrationLayer,
'CalibrationResult': cirq.google.CalibrationResult,
'CalibrationTag': cirq.google.CalibrationTag,
'ControlledGate': cirq.ControlledGate,
'ControlledOperation': cirq.ControlledOperation,
'CSwapGate': cirq.CSwapGate,
'CXPowGate': cirq.CXPowGate,
'CZPowGate': cirq.CZPowGate,
'CrossEntropyResult': CrossEntropyResult,
'CrossEntropyResultDict': CrossEntropyResultDict,
'Circuit': cirq.Circuit,
'CliffordState': cirq.CliffordState,
'CliffordTableau': cirq.CliffordTableau,
'DepolarizingChannel': cirq.DepolarizingChannel,
'ConstantQubitNoiseModel': cirq.ConstantQubitNoiseModel,
'Duration': cirq.Duration,
'FrozenCircuit': cirq.FrozenCircuit,
'FSimGate': cirq.FSimGate,
'DensePauliString': cirq.DensePauliString,
'MutableDensePauliString': cirq.MutableDensePauliString,
'MutablePauliString': cirq.MutablePauliString,
'GateOperation': cirq.GateOperation,
'GateTabulation': cirq.google.GateTabulation,
'GeneralizedAmplitudeDampingChannel':
cirq.GeneralizedAmplitudeDampingChannel,
'GlobalPhaseOperation': cirq.GlobalPhaseOperation,
'GridInteractionLayer': GridInteractionLayer,
'GridParallelXEBMetadata': GridParallelXEBMetadata,
'GridQid': cirq.GridQid,
'GridQubit': cirq.GridQubit,
'HPowGate': cirq.HPowGate,
'ISwapPowGate': cirq.ISwapPowGate,
'IdentityGate': cirq.IdentityGate,
'IdentityOperation': _identity_operation_from_dict,
'InitObsSetting': cirq.work.InitObsSetting,
'LinearDict': cirq.LinearDict,
'LineQubit': cirq.LineQubit,
'LineQid': cirq.LineQid,
'MatrixGate': cirq.MatrixGate,
'MeasurementGate': cirq.MeasurementGate,
'_MeasurementSpec': cirq.work._MeasurementSpec,
'Moment': cirq.Moment,
'_XEigenState': cirq.value.product_state._XEigenState,
'_YEigenState': cirq.value.product_state._YEigenState,
'_ZEigenState': cirq.value.product_state._ZEigenState,
'_NamedConstantXmonDevice': _NamedConstantXmonDevice,
'_NoNoiseModel': _NoNoiseModel,
'NamedQubit': cirq.NamedQubit,
'NamedQid': cirq.NamedQid,
'NoIdentifierQubit': cirq.testing.NoIdentifierQubit,
'_PauliX': cirq.ops.pauli_gates._PauliX,
'_PauliY': cirq.ops.pauli_gates._PauliY,
'_PauliZ': cirq.ops.pauli_gates._PauliZ,
'ParamResolver': cirq.ParamResolver,
'PasqalDevice': cirq.pasqal.PasqalDevice,
'PasqalVirtualDevice': cirq.pasqal.PasqalVirtualDevice,
'PauliString': cirq.PauliString,
'PhaseDampingChannel': cirq.PhaseDampingChannel,
'PhaseFlipChannel': cirq.PhaseFlipChannel,
'PhaseGradientGate': cirq.PhaseGradientGate,
'PhasedFSimGate': cirq.PhasedFSimGate,
'PhasedISwapPowGate': cirq.PhasedISwapPowGate,
'PhasedXPowGate': cirq.PhasedXPowGate,
'PhasedXZGate': cirq.PhasedXZGate,
'PhysicalZTag': cirq.google.PhysicalZTag,
'RandomGateChannel': cirq.RandomGateChannel,
'QuantumFourierTransformGate': cirq.QuantumFourierTransformGate,
'ResetChannel': cirq.ResetChannel,
'SingleQubitMatrixGate': single_qubit_matrix_gate,
'SingleQubitPauliStringGateOperation':
cirq.SingleQubitPauliStringGateOperation,
'SingleQubitReadoutCalibrationResult':
cirq.experiments.SingleQubitReadoutCalibrationResult,
'StabilizerStateChForm': cirq.StabilizerStateChForm,
'SwapPowGate': cirq.SwapPowGate,
'SycamoreGate': cirq.google.SycamoreGate,
'TaggedOperation': cirq.TaggedOperation,
'ThreeDQubit': cirq.pasqal.ThreeDQubit,
'Result': cirq.Result,
'TrialResult': cirq.TrialResult,
'TwoDQubit': cirq.pasqal.TwoDQubit,
'TwoQubitMatrixGate': two_qubit_matrix_gate,
'TwoQubitDiagonalGate': cirq.TwoQubitDiagonalGate,
'_UnconstrainedDevice':
cirq.devices.unconstrained_device._UnconstrainedDevice,
'VirtualTag': cirq.VirtualTag,
'WaitGate': cirq.WaitGate,
'_QubitAsQid': raw_types._QubitAsQid,
'XPowGate': cirq.XPowGate,
'XXPowGate': cirq.XXPowGate,
'YPowGate': cirq.YPowGate,
'YYPowGate': cirq.YYPowGate,
'ZPowGate': cirq.ZPowGate,
'ZZPowGate': cirq.ZZPowGate,

# not a cirq class, but treated as one:
'pandas.DataFrame': pd.DataFrame,
'pandas.Index': pd.Index,
'pandas.MultiIndex': pd.MultiIndex.from_tuples,
'sympy.Symbol': sympy.Symbol,
'sympy.Add': lambda args: sympy.Add(*args),
'sympy.Mul': lambda args: sympy.Mul(*args),
'sympy.Pow': lambda args: sympy.Pow(*args),
'sympy.Float': lambda approx: sympy.Float(approx),
'sympy.Integer': sympy.Integer,
'sympy.Rational': sympy.Rational,
'complex': complex,
}


class JsonResolver(Protocol):
"""Protocol for json resolver functions passed to read_json."""

def __call__(self, cirq_type: str) -> Optional[Type]:
def __call__(self, cirq_type: str) -> Optional[ObjectFactory]:
...


def _cirq_class_resolver(cirq_type: str) -> Optional[Type]:
return RESOLVER_CACHE.cirq_class_resolver_dictionary.get(cirq_type, None)
def _cirq_class_resolver(cirq_type: str) -> Optional[ObjectFactory]:
return _cirq_class_resolver_dictionary().get(cirq_type, None)


DEFAULT_RESOLVERS: List[JsonResolver] = [
Expand Down Expand Up @@ -440,8 +428,9 @@ def _cirq_object_hook(d, resolvers: Sequence[JsonResolver]):
raise ValueError("Could not resolve type '{}' "
"during deserialization".format(d['cirq_type']))

if hasattr(cls, '_from_json_dict_'):
return cls._from_json_dict_(**d)
from_json_dict = getattr(cls, '_from_json_dict_', None)
if from_json_dict is not None:
return from_json_dict(**d)

del d['cirq_type']
return cls(**d)
Expand Down

0 comments on commit 47d3418

Please sign in to comment.