From d136c765ba9fc426aa9a08806f26b08e6e7b9a3c Mon Sep 17 00:00:00 2001 From: Dave Bacon Date: Thu, 12 May 2022 04:15:43 +0000 Subject: [PATCH] Add json serialization to diagonal gates --- cirq-core/cirq/json_resolver_cache.py | 5 ++++- cirq-core/cirq/ops/diagonal_gate.py | 20 ++++++++++++++++++- cirq-core/cirq/ops/diagonal_gate_test.py | 4 ++++ cirq-core/cirq/ops/three_qubit_gates.py | 12 +++++++++-- cirq-core/cirq/ops/three_qubit_gates_test.py | 6 ++++++ cirq-core/cirq/ops/two_qubit_diagonal_gate.py | 9 ++++++++- .../cirq/ops/two_qubit_diagonal_gate_test.py | 4 ++++ .../json_test_data/DiagonalGate.json | 9 +++++++++ .../json_test_data/DiagonalGate.repr | 1 + .../ThreeQubitDiagonalGate.json | 13 ++++++++++++ .../ThreeQubitDiagonalGate.repr | 1 + .../json_test_data/TwoQubitDiagonalGate.json | 9 +++++++++ .../json_test_data/TwoQubitDiagonalGate.repr | 1 + .../cirq/protocols/json_test_data/spec.py | 3 --- 14 files changed, 89 insertions(+), 8 deletions(-) create mode 100644 cirq-core/cirq/protocols/json_test_data/DiagonalGate.json create mode 100644 cirq-core/cirq/protocols/json_test_data/DiagonalGate.repr create mode 100644 cirq-core/cirq/protocols/json_test_data/ThreeQubitDiagonalGate.json create mode 100644 cirq-core/cirq/protocols/json_test_data/ThreeQubitDiagonalGate.repr create mode 100644 cirq-core/cirq/protocols/json_test_data/TwoQubitDiagonalGate.json create mode 100644 cirq-core/cirq/protocols/json_test_data/TwoQubitDiagonalGate.repr diff --git a/cirq-core/cirq/json_resolver_cache.py b/cirq-core/cirq/json_resolver_cache.py index c4a6765c764..6b6ece07f5c 100644 --- a/cirq-core/cirq/json_resolver_cache.py +++ b/cirq-core/cirq/json_resolver_cache.py @@ -97,6 +97,7 @@ def _symmetricalqidpair(qids): 'CXPowGate': cirq.CXPowGate, 'CZPowGate': cirq.CZPowGate, 'CZTargetGateset': cirq.CZTargetGateset, + 'DiagonalGate': cirq.DiagonalGate, 'DensePauliString': cirq.DensePauliString, 'DepolarizingChannel': cirq.DepolarizingChannel, 'DeviceMetadata': cirq.DeviceMetadata, @@ -164,7 +165,6 @@ def _symmetricalqidpair(qids): 'QuantumFourierTransformGate': cirq.QuantumFourierTransformGate, 'QubitPermutationGate': cirq.QubitPermutationGate, 'RandomGateChannel': cirq.RandomGateChannel, - 'TensoredConfusionMatrices': cirq.TensoredConfusionMatrices, 'RepetitionsStoppingCriteria': cirq.work.RepetitionsStoppingCriteria, 'ResetChannel': cirq.ResetChannel, 'Result': cirq.ResultDict, # Keep support for Cirq < 0.14. @@ -181,8 +181,11 @@ def _symmetricalqidpair(qids): 'SwapPowGate': cirq.SwapPowGate, 'SympyCondition': cirq.SympyCondition, 'TaggedOperation': cirq.TaggedOperation, + 'TensoredConfusionMatrices': cirq.TensoredConfusionMatrices, 'TiltedSquareLattice': cirq.TiltedSquareLattice, + 'ThreeQubitDiagonalGate': cirq.ThreeQubitDiagonalGate, 'TrialResult': cirq.ResultDict, # keep support for Cirq < 0.11. + 'TwoQubitDiagonalGate': cirq.TwoQubitDiagonalGate, 'TwoQubitGateTabulation': cirq.TwoQubitGateTabulation, '_UnconstrainedDevice': cirq.devices.unconstrained_device._UnconstrainedDevice, 'VarianceStoppingCriteria': cirq.work.VarianceStoppingCriteria, diff --git a/cirq-core/cirq/ops/diagonal_gate.py b/cirq-core/cirq/ops/diagonal_gate.py index b01bf84c754..ab92eea80bb 100644 --- a/cirq-core/cirq/ops/diagonal_gate.py +++ b/cirq-core/cirq/ops/diagonal_gate.py @@ -18,7 +18,18 @@ passed as a list. """ -from typing import AbstractSet, Any, Iterator, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union +from typing import ( + AbstractSet, + Any, + Dict, + Iterator, + List, + Optional, + Sequence, + Tuple, + TYPE_CHECKING, + Union, +) import numpy as np import sympy @@ -79,6 +90,10 @@ def __init__(self, diag_angles_radians: Sequence['cirq.TParamVal']) -> None: """ self._diag_angles_radians: Tuple['cirq.TParamVal', ...] = tuple(diag_angles_radians) + @property + def diag_angles_radians(self) -> Tuple['cirq.TParamVal', ...]: + return self._diag_angles_radians + def _num_qubits_(self): return int(np.log2(len(self._diag_angles_radians))) @@ -190,6 +205,9 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE': decomposed_circ.extend(self._decompose_for_basis(i, bit_flip, -hat_angles[i], qubits)) return decomposed_circ + def _json_dict_(self) -> Dict[str, Any]: + return protocols.obj_to_dict_helper(self, attribute_names=["diag_angles_radians"]) + def __repr__(self) -> str: return 'cirq.DiagonalGate([{}])'.format( ','.join(proper_repr(angle) for angle in self._diag_angles_radians) diff --git a/cirq-core/cirq/ops/diagonal_gate_test.py b/cirq-core/cirq/ops/diagonal_gate_test.py index 4d9d599b6a5..9bb987dab62 100644 --- a/cirq-core/cirq/ops/diagonal_gate_test.py +++ b/cirq-core/cirq/ops/diagonal_gate_test.py @@ -39,6 +39,10 @@ def test_consistent_protocols(gate): cirq.testing.assert_implements_consistent_protocols(gate) +def test_property(): + assert cirq.DiagonalGate([2, 3, 5, 7]).diag_angles_radians == (2, 3, 5, 7) + + @pytest.mark.parametrize('n', [1, 2, 3, 4, 5, 6, 7, 8, 9]) def test_decomposition_unitary(n): diagonal_angles = np.random.randn(2**n) diff --git a/cirq-core/cirq/ops/three_qubit_gates.py b/cirq-core/cirq/ops/three_qubit_gates.py index 2be5ef92acb..16da10c20ad 100644 --- a/cirq-core/cirq/ops/three_qubit_gates.py +++ b/cirq-core/cirq/ops/three_qubit_gates.py @@ -18,6 +18,7 @@ AbstractSet, Any, Collection, + Dict, List, Optional, Sequence, @@ -204,7 +205,7 @@ def controlled( class ThreeQubitDiagonalGate(raw_types.Gate): """A gate given by a diagonal 8x8 matrix.""" - def __init__(self, diag_angles_radians: List[value.TParamVal]) -> None: + def __init__(self, diag_angles_radians: Sequence[value.TParamVal]) -> None: r"""A three qubit gate with only diagonal elements. This gate's off-diagonal elements are zero and it's on diagonal @@ -215,7 +216,11 @@ def __init__(self, diag_angles_radians: List[value.TParamVal]) -> None: If these values are $(x_0, x_1, \ldots , x_7)$ then the unitary has diagonal values $(e^{i x_0}, e^{i x_1}, \ldots, e^{i x_7})$. """ - self._diag_angles_radians: List[value.TParamVal] = diag_angles_radians + self._diag_angles_radians: Tuple[value.TParamVal, ...] = tuple(diag_angles_radians) + + @property + def diag_angles_radians(self) -> Tuple[value.TParamVal, ...]: + return self._diag_angles_radians def _is_parameterized_(self) -> bool: return any(protocols.is_parameterized(angle) for angle in self._diag_angles_radians) @@ -350,6 +355,9 @@ def _pauli_expansion_(self) -> value.LinearDict[str]: } ) + def _json_dict_(self) -> Dict[str, Any]: + return protocols.obj_to_dict_helper(self, attribute_names=["diag_angles_radians"]) + def __repr__(self) -> str: return 'cirq.ThreeQubitDiagonalGate([{}])'.format( ','.join(proper_repr(angle) for angle in self._diag_angles_radians) diff --git a/cirq-core/cirq/ops/three_qubit_gates_test.py b/cirq-core/cirq/ops/three_qubit_gates_test.py index b900a7fb599..c8660cb2f26 100644 --- a/cirq-core/cirq/ops/three_qubit_gates_test.py +++ b/cirq-core/cirq/ops/three_qubit_gates_test.py @@ -206,6 +206,12 @@ def test_decomposition_cost(op: cirq.Operation, max_two_cost: int): assert two_cost == max_two_cost +def test_diagonal_gate_property(): + assert cirq.ThreeQubitDiagonalGate([2, 3, 5, 7, 0, 0, 0, 1]).diag_angles_radians == ( + (2, 3, 5, 7, 0, 0, 0, 1) + ) + + @pytest.mark.parametrize( 'gate', [cirq.CCX, cirq.CSWAP, cirq.CCZ, cirq.ThreeQubitDiagonalGate([2, 3, 5, 7, 11, 13, 17, 19])], diff --git a/cirq-core/cirq/ops/two_qubit_diagonal_gate.py b/cirq-core/cirq/ops/two_qubit_diagonal_gate.py index 8220d05c0d0..75c74655a93 100644 --- a/cirq-core/cirq/ops/two_qubit_diagonal_gate.py +++ b/cirq-core/cirq/ops/two_qubit_diagonal_gate.py @@ -17,7 +17,7 @@ passed as a list. """ -from typing import AbstractSet, Any, Tuple, Optional, Sequence, TYPE_CHECKING +from typing import AbstractSet, Any, Dict, Tuple, Optional, Sequence, TYPE_CHECKING import numpy as np import sympy @@ -46,6 +46,10 @@ def __init__(self, diag_angles_radians: Sequence[value.TParamVal]) -> None: """ self._diag_angles_radians: Tuple[value.TParamVal, ...] = tuple(diag_angles_radians) + @property + def diag_angles_radians(self) -> Tuple[value.TParamVal, ...]: + return self._diag_angles_radians + def _num_qubits_(self) -> int: return 2 @@ -118,6 +122,9 @@ def __repr__(self) -> str: ','.join(proper_repr(angle) for angle in self._diag_angles_radians) ) + def _json_dict_(self) -> Dict[str, Any]: + return protocols.obj_to_dict_helper(self, attribute_names=["diag_angles_radians"]) + def _quil_( self, qubits: Tuple['cirq.Qid', ...], formatter: 'cirq.QuilFormatter' ) -> Optional[str]: diff --git a/cirq-core/cirq/ops/two_qubit_diagonal_gate_test.py b/cirq-core/cirq/ops/two_qubit_diagonal_gate_test.py index f72550174de..405b0fcd029 100644 --- a/cirq-core/cirq/ops/two_qubit_diagonal_gate_test.py +++ b/cirq-core/cirq/ops/two_qubit_diagonal_gate_test.py @@ -34,6 +34,10 @@ def test_consistent_protocols(gate): cirq.testing.assert_implements_consistent_protocols(gate) +def test_property(): + assert cirq.TwoQubitDiagonalGate([2, 3, 5, 7]).diag_angles_radians == (2, 3, 5, 7) + + def test_parameterized_decompose(): angles = sympy.symbols('x0, x1, x2, x3') parameterized_op = cirq.TwoQubitDiagonalGate(angles).on(*cirq.LineQubit.range(2)) diff --git a/cirq-core/cirq/protocols/json_test_data/DiagonalGate.json b/cirq-core/cirq/protocols/json_test_data/DiagonalGate.json new file mode 100644 index 00000000000..96ab1e498c4 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/DiagonalGate.json @@ -0,0 +1,9 @@ +{ + "cirq_type": "DiagonalGate", + "diag_angles_radians": [ + 0.0, + 1.0, + -1.0, + 0.0 + ] +} \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/DiagonalGate.repr b/cirq-core/cirq/protocols/json_test_data/DiagonalGate.repr new file mode 100644 index 00000000000..1cedb5597ca --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/DiagonalGate.repr @@ -0,0 +1 @@ +cirq.DiagonalGate(diag_angles_radians=[0.0, 1.0, -1.0, 0.0]) \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/ThreeQubitDiagonalGate.json b/cirq-core/cirq/protocols/json_test_data/ThreeQubitDiagonalGate.json new file mode 100644 index 00000000000..24f07e9af08 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/ThreeQubitDiagonalGate.json @@ -0,0 +1,13 @@ +{ + "cirq_type": "ThreeQubitDiagonalGate", + "diag_angles_radians": [ + 0.0, + 1.0, + -1.0, + 0.0, + 0.5, + 0.5, + 0.5, + 0.5 + ] +} \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/ThreeQubitDiagonalGate.repr b/cirq-core/cirq/protocols/json_test_data/ThreeQubitDiagonalGate.repr new file mode 100644 index 00000000000..8af9dec393d --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/ThreeQubitDiagonalGate.repr @@ -0,0 +1 @@ +cirq.ThreeQubitDiagonalGate(diag_angles_radians=[0.0, 1.0, -1.0, 0.0, 0.5, 0.5, 0.5, 0.5]) \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/TwoQubitDiagonalGate.json b/cirq-core/cirq/protocols/json_test_data/TwoQubitDiagonalGate.json new file mode 100644 index 00000000000..5130515f1c4 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/TwoQubitDiagonalGate.json @@ -0,0 +1,9 @@ +{ + "cirq_type": "TwoQubitDiagonalGate", + "diag_angles_radians": [ + 0.0, + 1.0, + -1.0, + 0.0 + ] +} \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/TwoQubitDiagonalGate.repr b/cirq-core/cirq/protocols/json_test_data/TwoQubitDiagonalGate.repr new file mode 100644 index 00000000000..1cf76594eb5 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/TwoQubitDiagonalGate.repr @@ -0,0 +1 @@ +cirq.TwoQubitDiagonalGate(diag_angles_radians=[0.0, 1.0, -1.0, 0.0]) \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/spec.py b/cirq-core/cirq/protocols/json_test_data/spec.py index ba93126fb14..a1dee019724 100644 --- a/cirq-core/cirq/protocols/json_test_data/spec.py +++ b/cirq-core/cirq/protocols/json_test_data/spec.py @@ -44,7 +44,6 @@ 'LinearCombinationOfOperations', 'Linspace', 'ListSweep', - 'DiagonalGate', 'NeutralAtomDevice', 'PauliInteractionGate', 'PauliSum', @@ -66,9 +65,7 @@ 'SparseSimulatorStep', 'StateVectorMixin', 'TextDiagramDrawer', - 'ThreeQubitDiagonalGate', 'Timestamp', - 'TwoQubitDiagonalGate', 'TwoQubitGateTabulationResult', 'UnitSweep', 'StateVectorSimulatorState',