diff --git a/cirq-core/cirq/interop/quirk/cells/input_rotation_cells.py b/cirq-core/cirq/interop/quirk/cells/input_rotation_cells.py index fd80d1f7f5d..53a8c2bd41d 100644 --- a/cirq-core/cirq/interop/quirk/cells/input_rotation_cells.py +++ b/cirq-core/cirq/interop/quirk/cells/input_rotation_cells.py @@ -142,15 +142,10 @@ def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'): sign_char = '-' if self.exponent_sign == -1 else '' symbols = list(sub_result.wire_symbols) symbols.extend(f'A{i}' for i in range(len(self.register))) - qubit_index = ( - len(self.base_operation.controls) - if isinstance(self.base_operation, ops.ControlledOperation) - else 0 - ) return cirq.CircuitDiagramInfo( tuple(symbols), exponent=f'({sign_char}A/2^{len(self.register)})', - exponent_qubit_index=qubit_index, + exponent_qubit_index=sub_result.exponent_qubit_index or 0, auto_exponent_parens=False, ) diff --git a/cirq-core/cirq/interop/quirk/cells/input_rotation_cells_test.py b/cirq-core/cirq/interop/quirk/cells/input_rotation_cells_test.py index d0de73c06fd..226e595ac85 100644 --- a/cirq-core/cirq/interop/quirk/cells/input_rotation_cells_test.py +++ b/cirq-core/cirq/interop/quirk/cells/input_rotation_cells_test.py @@ -64,9 +64,9 @@ def test_input_rotation_cells(): assert_url_to_circuit_returns( '{"cols":[["•","Z^(A/2^n)","inputA2"]]}', diagram=""" -0: ───@─────────── +0: ───@^(A/2^2)─── │ -1: ───Z^(A/2^2)─── +1: ───@─────────── │ 2: ───A0────────── │ diff --git a/cirq-core/cirq/interop/quirk/url_to_circuit_test.py b/cirq-core/cirq/interop/quirk/url_to_circuit_test.py index 0bc51c099bf..996a5d1225e 100644 --- a/cirq-core/cirq/interop/quirk/url_to_circuit_test.py +++ b/cirq-core/cirq/interop/quirk/url_to_circuit_test.py @@ -371,21 +371,21 @@ def test_completes_weight_zero_billion_laughs(): def test_example_qft_circuit(): qft_example_diagram = """ -0: ───×───────────────H───S───────T───────────Z─────────────────────Z────────────────────────────────Z──────────────────────────────────────────Z────────────────────────────────────────────────────Z────────────────────────────────────────────────────────────── - │ │ │ │ │ │ │ │ -1: ───┼───×───────────────@───H───┼───S───────┼─────────T───────────┼──────────Z─────────────────────┼─────────Z────────────────────────────────┼─────────Z──────────────────────────────────────────┼─────────Z──────────────────────────────────────────────────── - │ │ │ │ │ │ │ │ │ │ │ │ │ │ -2: ───┼───┼───×───────────────────@───@───H───┼─────────┼───S───────┼──────────┼─────────T───────────┼─────────┼──────────Z─────────────────────┼─────────┼─────────Z────────────────────────────────┼─────────┼─────────Z────────────────────────────────────────── - │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ -3: ───┼───┼───┼───×───────────────────────────@^(1/8)───@───@───H───┼──────────┼─────────┼───S───────┼─────────┼──────────┼─────────T───────────┼─────────┼─────────┼──────────Z─────────────────────┼─────────┼─────────┼─────────Z──────────────────────────────── - │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ -4: ───┼───┼───┼───×─────────────────────────────────────────────────@^(1/16)───@^(1/8)───@───@───H───┼─────────┼──────────┼─────────┼───S───────┼─────────┼─────────┼──────────┼─────────T───────────┼─────────┼─────────┼─────────┼──────────Z───────────────────── - │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ -5: ───┼───┼───×──────────────────────────────────────────────────────────────────────────────────────@^0.031───@^(1/16)───@^(1/8)───@───@───H───┼─────────┼─────────┼──────────┼─────────┼───S───────┼─────────┼─────────┼─────────┼──────────┼─────────T─────────── - │ │ │ │ │ │ │ │ │ │ │ │ │ │ -6: ───┼───×─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────@^0.016───@^0.031───@^(1/16)───@^(1/8)───@───@───H───┼─────────┼─────────┼─────────┼──────────┼─────────┼───S─────── - │ │ │ │ │ │ │ │ -7: ───×──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────@^0.008───@^0.016───@^0.031───@^(1/16)───@^(1/8)───@───@───H─── +0: ───×───────────────H───@───────────@────────────────────@──────────────────────────────@─────────────────────────────────────────@───────────────────────────────────────────────────@─────────────────────────────────────────────────────────────@─────────────────────────────────────────────────────────────────────── + │ │ │ │ │ │ │ │ +1: ───┼───×───────────────@^0.5───H───┼────────@───────────┼─────────@────────────────────┼──────────@──────────────────────────────┼─────────@─────────────────────────────────────────┼─────────@───────────────────────────────────────────────────┼─────────@───────────────────────────────────────────────────────────── + │ │ │ │ │ │ │ │ │ │ │ │ │ │ +2: ───┼───┼───×───────────────────────@^0.25───@^0.5───H───┼─────────┼────────@───────────┼──────────┼─────────@────────────────────┼─────────┼──────────@──────────────────────────────┼─────────┼─────────@─────────────────────────────────────────┼─────────┼─────────@─────────────────────────────────────────────────── + │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ +3: ───┼───┼───┼───×────────────────────────────────────────@^(1/8)───@^0.25───@^0.5───H───┼──────────┼─────────┼────────@───────────┼─────────┼──────────┼─────────@────────────────────┼─────────┼─────────┼──────────@──────────────────────────────┼─────────┼─────────┼─────────@───────────────────────────────────────── + │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ +4: ───┼───┼───┼───×───────────────────────────────────────────────────────────────────────@^(1/16)───@^(1/8)───@^0.25───@^0.5───H───┼─────────┼──────────┼─────────┼────────@───────────┼─────────┼─────────┼──────────┼─────────@────────────────────┼─────────┼─────────┼─────────┼──────────@────────────────────────────── + │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ +5: ───┼───┼───×─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────@^0.031───@^(1/16)───@^(1/8)───@^0.25───@^0.5───H───┼─────────┼─────────┼──────────┼─────────┼────────@───────────┼─────────┼─────────┼─────────┼──────────┼─────────@──────────────────── + │ │ │ │ │ │ │ │ │ │ │ │ │ │ +6: ───┼───×─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────@^0.016───@^0.031───@^(1/16)───@^(1/8)───@^0.25───@^0.5───H───┼─────────┼─────────┼─────────┼──────────┼─────────┼────────@─────────── + │ │ │ │ │ │ │ │ +7: ───×───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────@^0.008───@^0.016───@^0.031───@^(1/16)───@^(1/8)───@^0.25───@^0.5───H─── """ qft_example_json = ( diff --git a/cirq-core/cirq/ops/controlled_gate_test.py b/cirq-core/cirq/ops/controlled_gate_test.py index e74cb35f308..596ebb8147d 100644 --- a/cirq-core/cirq/ops/controlled_gate_test.py +++ b/cirq-core/cirq/ops/controlled_gate_test.py @@ -485,7 +485,7 @@ def test_circuit_diagram(): class MockGate(cirq.testing.TwoQubitGate): def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: self.captured_diagram_args = args - return cirq.CircuitDiagramInfo(wire_symbols=tuple(['MOCK']), exponent=1, connected=True) + return cirq.CircuitDiagramInfo(wire_symbols=tuple(['M1', 'M2']), exponent=1, connected=True) def test_uninformed_circuit_diagram_info(): @@ -496,7 +496,7 @@ def test_uninformed_circuit_diagram_info(): args = cirq.CircuitDiagramInfoArgs.UNINFORMED_DEFAULT assert cirq.circuit_diagram_info(cgate, args) == cirq.CircuitDiagramInfo( - wire_symbols=('@', 'MOCK'), exponent=1, connected=True + wire_symbols=('@', 'M1', 'M2'), exponent=1, connected=True, exponent_qubit_index=1 ) assert mock_gate.captured_diagram_args == args diff --git a/cirq-core/cirq/ops/controlled_operation.py b/cirq-core/cirq/ops/controlled_operation.py index 6e2a87c2b03..8e311e0febd 100644 --- a/cirq-core/cirq/ops/controlled_operation.py +++ b/cirq-core/cirq/ops/controlled_operation.py @@ -237,12 +237,19 @@ def get_symbol(vals): return f"({','.join(map(str, vals))})" wire_symbols = (*(get_symbol(vals) for vals in self.control_values), *sub_info.wire_symbols) + exponent_qubit_index = None + if sub_info.exponent_qubit_index is not None: + exponent_qubit_index = sub_info.exponent_qubit_index + len(self.control_values) + elif sub_info.exponent is not None: + # For a multi-qubit `sub_operation`, if the `exponent_qubit_index` is None, the qubit + # on which the exponent gets drawn in the controlled case (smallest ordered qubit of + # sub_operation) can be different from the uncontrolled case (lexicographically largest + # qubit of sub_operation). See tests for example. + exponent_qubit_index = len(self.control_values) return protocols.CircuitDiagramInfo( wire_symbols=wire_symbols, exponent=sub_info.exponent, - exponent_qubit_index=None - if sub_info.exponent_qubit_index is None - else sub_info.exponent_qubit_index + 1, + exponent_qubit_index=exponent_qubit_index, ) def _json_dict_(self) -> Dict[str, Any]: diff --git a/cirq-core/cirq/ops/controlled_operation_test.py b/cirq-core/cirq/ops/controlled_operation_test.py index 7c0a02a4a93..8ce6be78dc5 100644 --- a/cirq-core/cirq/ops/controlled_operation_test.py +++ b/cirq-core/cirq/ops/controlled_operation_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Union, Tuple, cast +import itertools import numpy as np import pytest @@ -247,11 +248,27 @@ def test_circuit_diagram(): class MockGate(cirq.testing.TwoQubitGate): + def __init__(self, exponent_qubit_index=None): + self._exponent_qubit_index = exponent_qubit_index + def _circuit_diagram_info_( self, args: protocols.CircuitDiagramInfoArgs ) -> protocols.CircuitDiagramInfo: self.captured_diagram_args = args - return cirq.CircuitDiagramInfo(wire_symbols=tuple(['MOCK']), exponent=1, connected=True) + return cirq.CircuitDiagramInfo( + wire_symbols=tuple(['M1', 'M2']), + exponent=1, + exponent_qubit_index=self._exponent_qubit_index, + connected=True, + ) + + +def test_controlled_diagram_exponent(): + for q in itertools.permutations(cirq.LineQubit.range(5)): + for idx in [None, 0, 1]: + op = MockGate(idx)(*q[:2]).controlled_by(*q[2:]) + add = 0 if idx is None else idx + assert cirq.circuit_diagram_info(op).exponent_qubit_index == len(q[2:]) + add def test_uninformed_circuit_diagram_info(): @@ -262,7 +279,7 @@ def test_uninformed_circuit_diagram_info(): args = protocols.CircuitDiagramInfoArgs.UNINFORMED_DEFAULT assert cirq.circuit_diagram_info(c_op, args) == cirq.CircuitDiagramInfo( - wire_symbols=('@', 'MOCK'), exponent=1, connected=True + wire_symbols=('@', 'M1', 'M2'), exponent=1, connected=True, exponent_qubit_index=1 ) assert mock_gate.captured_diagram_args == args diff --git a/cirq-core/cirq/ops/gate_operation.py b/cirq-core/cirq/ops/gate_operation.py index 1fafa322fab..5c2d160ddbe 100644 --- a/cirq-core/cirq/ops/gate_operation.py +++ b/cirq-core/cirq/ops/gate_operation.py @@ -19,6 +19,7 @@ AbstractSet, Any, cast, + Collection, Dict, FrozenSet, Iterable, @@ -351,5 +352,19 @@ def _equal_up_to_global_phase_( return False return protocols.equal_up_to_global_phase(self.gate, other.gate, atol=atol) + def controlled_by( + self, + *control_qubits: 'cirq.Qid', + control_values: Optional[Sequence[Union[int, Collection[int]]]] = None, + ) -> 'cirq.Operation': + if len(control_qubits) == 0: + return self + qubits = tuple(control_qubits) + return self._gate.controlled( + num_controls=len(qubits), + control_values=control_values, + control_qid_shape=tuple(q.dimension for q in qubits), + ).on(*(qubits + self._qubits)) + TV = TypeVar('TV', bound=raw_types.Gate) diff --git a/cirq-core/cirq/ops/raw_types_test.py b/cirq-core/cirq/ops/raw_types_test.py index 98549bd51f5..7dd8b5eeb31 100644 --- a/cirq-core/cirq/ops/raw_types_test.py +++ b/cirq-core/cirq/ops/raw_types_test.py @@ -127,13 +127,13 @@ def test_gate(): def test_op(): - a, b, c = cirq.LineQubit.range(3) + a, b, c, d = cirq.LineQubit.range(4) g = ValiGate() - op = g(a) + op = g(a, b) assert op.controlled_by() is op - controlled_op = op.controlled_by(b, c) + controlled_op = op.controlled_by(c, d) assert controlled_op.sub_operation == op - assert controlled_op.controls == (b, c) + assert controlled_op.controls == (c, d) def test_op_validate(): diff --git a/cirq-core/cirq/ops/three_qubit_gates.py b/cirq-core/cirq/ops/three_qubit_gates.py index 12793a636e8..fc48391c4c8 100644 --- a/cirq-core/cirq/ops/three_qubit_gates.py +++ b/cirq-core/cirq/ops/three_qubit_gates.py @@ -14,7 +14,17 @@ """Common quantum gates that target three qubits.""" -from typing import AbstractSet, Any, List, Optional, Tuple, TYPE_CHECKING +from typing import ( + AbstractSet, + Any, + Collection, + List, + Optional, + Sequence, + Tuple, + TYPE_CHECKING, + Union, +) import numpy as np import sympy @@ -30,6 +40,7 @@ pauli_gates, raw_types, swap_gates, + raw_types, ) if TYPE_CHECKING: @@ -169,6 +180,31 @@ def __str__(self) -> str: def _num_qubits_(self) -> int: return 3 + def controlled( + self, + num_controls: int = None, + control_values: Optional[Sequence[Union[int, Collection[int]]]] = None, + control_qid_shape: Optional[Tuple[int, ...]] = None, + ) -> raw_types.Gate: + """Returns a controlled `ZPowGate` with two additional controls. + + The `controlled` method of the `Gate` class, of which this class is a + child, returns a `ControlledGate` with `sub_gate = self`. This method + overrides this behavior to return a `ControlledGate` with + `sub_gate = ZPowGate`. + """ + if num_controls == 0: + return self + return controlled_gate.ControlledGate( + controlled_gate.ControlledGate( + common_gates.ZPowGate(exponent=self._exponent, global_shift=self._global_shift), + num_controls=2, + ), + num_controls=num_controls, + control_values=control_values, + control_qid_shape=control_qid_shape, + ) + @value.value_equality() class ThreeQubitDiagonalGate(raw_types.Gate): @@ -432,6 +468,31 @@ def __str__(self) -> str: def _num_qubits_(self) -> int: return 3 + def controlled( + self, + num_controls: int = None, + control_values: Optional[Sequence[Union[int, Collection[int]]]] = None, + control_qid_shape: Optional[Tuple[int, ...]] = None, + ) -> raw_types.Gate: + """Returns a controlled `XPowGate` with two additional controls. + + The `controlled` method of the `Gate` class, of which this class is a + child, returns a `ControlledGate` with `sub_gate = self`. This method + overrides this behavior to return a `ControlledGate` with + `sub_gate = XPowGate`. + """ + if num_controls == 0: + return self + return controlled_gate.ControlledGate( + controlled_gate.ControlledGate( + common_gates.XPowGate(exponent=self._exponent, global_shift=self._global_shift), + num_controls=2, + ), + num_controls=num_controls, + control_values=control_values, + control_qid_shape=control_qid_shape, + ) + @value.value_equality() class CSwapGate(gate_features.InterchangeableQubitsGate, raw_types.Gate): @@ -580,6 +641,28 @@ def __repr__(self) -> str: def _num_qubits_(self) -> int: return 3 + def controlled( + self, + num_controls: int = None, + control_values: Optional[Sequence[Union[int, Collection[int]]]] = None, + control_qid_shape: Optional[Tuple[int, ...]] = None, + ) -> raw_types.Gate: + """Returns a controlled `SWAP` with one additional control. + + The `controlled` method of the `Gate` class, of which this class is a + child, returns a `ControlledGate` with `sub_gate = self`. This method + overrides this behavior to return a `ControlledGate` with + `sub_gate = SWAP`. + """ + if num_controls == 0: + return self + return controlled_gate.ControlledGate( + controlled_gate.ControlledGate(swap_gates.SWAP, num_controls=1), + num_controls=num_controls, + control_values=control_values, + control_qid_shape=control_qid_shape, + ) + CCZ = CCZPowGate() document( diff --git a/cirq-core/cirq/ops/three_qubit_gates_test.py b/cirq-core/cirq/ops/three_qubit_gates_test.py index 748bc9d301b..61d29a1a322 100644 --- a/cirq-core/cirq/ops/three_qubit_gates_test.py +++ b/cirq-core/cirq/ops/three_qubit_gates_test.py @@ -39,6 +39,8 @@ def test_eigen_gates_consistent_protocols(eigen_gate_type): (cirq.CSWAP, False), (cirq.ThreeQubitDiagonalGate([2, 3, 5, 7, 11, 13, 17, 19]), True), (cirq.ThreeQubitDiagonalGate([0, 0, 0, 0, 0, 0, 0, 0]), True), + (cirq.CCX, False), + (cirq.CCZ, False), ), ) def test_consistent_protocols(gate, ignoring_global_phase): @@ -320,3 +322,10 @@ def test_resolve(resolve_fn): diagonal_gate = resolve_fn(diagonal_gate, {'b': 19}) assert diagonal_gate == cirq.ThreeQubitDiagonalGate(diagonal_angles) assert not cirq.is_parameterized(diagonal_gate) + + +@pytest.mark.parametrize('gate', [cirq.CCX, cirq.CCZ, cirq.CSWAP]) +def test_controlled_ops_consistency(gate): + a, b, c, d = cirq.LineQubit.range(4) + assert gate.controlled(0) is gate + assert gate(a, b, c).controlled_by(d) == gate(d, b, c).controlled_by(a) diff --git a/cirq-core/cirq/testing/__init__.py b/cirq-core/cirq/testing/__init__.py index 387249245de..bb4bfe1278c 100644 --- a/cirq-core/cirq/testing/__init__.py +++ b/cirq-core/cirq/testing/__init__.py @@ -36,6 +36,10 @@ assert_phase_by_is_consistent_with_unitary, ) +from cirq.testing.consistent_controlled_gate_op import ( + assert_controlled_and_controlled_by_identical, +) + from cirq.testing.consistent_decomposition import ( assert_decompose_is_consistent_with_unitary, ) diff --git a/cirq-core/cirq/testing/consistent_controlled_gate_op.py b/cirq-core/cirq/testing/consistent_controlled_gate_op.py new file mode 100644 index 00000000000..1d63358d76e --- /dev/null +++ b/cirq-core/cirq/testing/consistent_controlled_gate_op.py @@ -0,0 +1,56 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence, Optional, Union, Collection + +from cirq import protocols, devices, ops + + +def assert_controlled_and_controlled_by_identical( + gate: ops.Gate, + *, + num_controls: Sequence[int] = (2, 1, 3, 10), + control_values: Optional[Sequence[Optional[Sequence[Union[int, Collection[int]]]]]] = None, +) -> None: + """Checks that gate.on().controlled_by() == gate.controlled().on()""" + if control_values is not None: + if len(num_controls) != len(control_values): + raise ValueError(f"len(num_controls) != len(control_values)") + for i, num_control in enumerate(num_controls): + control_value = control_values[i] if control_values else None + if control_value is not None and len(control_value) != num_control: + raise ValueError(f"len(control_values[{i}]) != num_controls[{i}]") + _assert_gate_consistent(gate, num_control, control_value) + + +def _assert_gate_consistent( + gate: ops.Gate, + num_controls: int, + control_values: Optional[Sequence[Union[int, Collection[int]]]], +) -> None: + if isinstance(gate, ops.DensePauliString) and protocols.is_parameterized(gate): + # Parameterized `DensePauliString`s cannot be applied to qubits to produce valid operations. + # TODO: This behavior should be fixed (https://github.com/quantumlib/Cirq/issues/4508) + return None + gate_controlled = gate.controlled(num_controls, control_values) + qubits = devices.LineQid.for_gate(gate_controlled) + control_qubits = qubits[:num_controls] + gate_qubits = qubits[num_controls:] + gate_controlled_on = gate_controlled.on(*control_qubits, *gate_qubits) + gate_on_controlled_by = gate.on(*gate_qubits).controlled_by( + *control_qubits, control_values=control_values + ) + assert ( + gate_controlled_on == gate_on_controlled_by + ), "gate.controlled().on() and gate.on().controlled() should return the same operations." diff --git a/cirq-core/cirq/testing/consistent_controlled_gate_op_test.py b/cirq-core/cirq/testing/consistent_controlled_gate_op_test.py new file mode 100644 index 00000000000..ade2a64c8b7 --- /dev/null +++ b/cirq-core/cirq/testing/consistent_controlled_gate_op_test.py @@ -0,0 +1,88 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Sequence, Union, Collection, Tuple, List + +import pytest + +import numpy as np + +import cirq + + +class GoodGate(cirq.EigenGate, cirq.SingleQubitGate): + def _eigen_components(self) -> List[Tuple[float, np.ndarray]]: + # coverage: ignore + return [ + (0, np.diag([1, 0])), + (1, np.diag([0, 1])), + ] + + +class BadGateOperation(cirq.GateOperation): + def controlled_by( + self, + *control_qubits: 'cirq.Qid', + control_values: Optional[Sequence[Union[int, Collection[int]]]] = None, + ) -> 'cirq.Operation': + return cirq.ControlledOperation(control_qubits, self, control_values) + + +class BadGate(cirq.EigenGate, cirq.SingleQubitGate): + def _eigen_components(self) -> List[Tuple[float, np.ndarray]]: + # coverage: ignore + return [ + (0, np.diag([1, 0])), + (1, np.diag([0, 1])), + ] + + def on(self, *qubits: 'cirq.Qid') -> 'cirq.Operation': + return BadGateOperation(self, list(qubits)) + + def controlled( + self, + num_controls: int = None, + control_values: Optional[Sequence[Union[int, Collection[int]]]] = None, + control_qid_shape: Optional[Tuple[int, ...]] = None, + ) -> 'cirq.Gate': + ret = super().controlled(num_controls, control_values, control_qid_shape) + if num_controls == 1 and control_values is None: + return cirq.CZPowGate(exponent=self._exponent, global_shift=self._global_shift) + return ret + + +def test_assert_controlled_and_controlled_by_identical(): + cirq.testing.assert_controlled_and_controlled_by_identical(GoodGate()) + + with pytest.raises(AssertionError): + cirq.testing.assert_controlled_and_controlled_by_identical(BadGate()) + + with pytest.raises(ValueError, match=r'len\(num_controls\) != len\(control_values\)'): + cirq.testing.assert_controlled_and_controlled_by_identical( + GoodGate(), num_controls=[1, 2], control_values=[(1,)] + ) + + with pytest.raises(ValueError, match=r'len\(control_values\[1\]\) != num_controls\[1\]'): + cirq.testing.assert_controlled_and_controlled_by_identical( + GoodGate(), + num_controls=[1, 2], + control_values=[ + (1,), + ( + 1, + 1, + 1, + ), + ], + ) diff --git a/cirq-core/cirq/testing/consistent_protocols.py b/cirq-core/cirq/testing/consistent_protocols.py index cfa208c5067..28ed4db47ec 100644 --- a/cirq-core/cirq/testing/consistent_protocols.py +++ b/cirq-core/cirq/testing/consistent_protocols.py @@ -43,6 +43,7 @@ assert_specifies_has_unitary_if_unitary, ) from cirq.testing.equivalent_repr_eval import assert_equivalent_repr +from cirq.testing.consistent_controlled_gate_op import assert_controlled_and_controlled_by_identical def assert_implements_consistent_protocols( @@ -160,6 +161,8 @@ def _assert_meets_standards_helper( ) if isinstance(val, ops.EigenGate): assert_eigen_shifts_is_consistent_with_eigen_components(val) + if isinstance(val, ops.Gate): + assert_controlled_and_controlled_by_identical(val) def assert_commutes_magic_method_consistent_with_unitaries( diff --git a/cirq-core/cirq/testing/consistent_protocols_test.py b/cirq-core/cirq/testing/consistent_protocols_test.py index fd561b1e9e3..a8ef6e8f655 100644 --- a/cirq-core/cirq/testing/consistent_protocols_test.py +++ b/cirq-core/cirq/testing/consistent_protocols_test.py @@ -22,6 +22,7 @@ import cirq from cirq._compat import proper_repr from cirq.type_workarounds import NotImplementedType +import cirq.testing.consistent_controlled_gate_op_test as controlled_gate_op_test class GoodGate(cirq.SingleQubitGate): @@ -253,6 +254,9 @@ def test_assert_implements_consistent_protocols(): BadGateRepr(phase_exponent=0.25), global_vals={'BadGateRepr': BadGateRepr} ) + with pytest.raises(AssertionError): + cirq.testing.assert_implements_consistent_protocols(controlled_gate_op_test.BadGate()) + def test_assert_eigengate_implements_consistent_protocols(): cirq.testing.assert_eigengate_implements_consistent_protocols(