diff --git a/cirq-core/cirq/ops/common_channels_test.py b/cirq-core/cirq/ops/common_channels_test.py index f5e429def99..92eba4da851 100644 --- a/cirq-core/cirq/ops/common_channels_test.py +++ b/cirq-core/cirq/ops/common_channels_test.py @@ -241,11 +241,22 @@ def test_deprecated_on_each_for_depolarizing_channel_one_qubit(): def test_deprecated_on_each_for_depolarizing_channel_two_qubits(): - q0, q1 = cirq.LineQubit.range(2) + q0, q1, q2, q3, q4, q5 = cirq.LineQubit.range(6) op = cirq.DepolarizingChannel(p=0.1, n_qubits=2) - with pytest.raises(ValueError, match="one qubit"): + op.on_each([(q0, q1)]) + op.on_each([(q0, q1), (q2, q3)]) + op.on_each(zip([q0, q2, q4], [q1, q3, q5])) + op.on_each((q0, q1)) + op.on_each([q0, q1]) + with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): op.on_each(q0, q1) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + op.on_each([('bogus object 0', 'bogus object 1')]) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + op.on_each(['01']) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + op.on_each([(False, None)]) def test_depolarizing_channel_apply_two_qubits(): diff --git a/cirq-core/cirq/ops/gate_features.py b/cirq-core/cirq/ops/gate_features.py index e977d770ce3..4b94c9bfc5c 100644 --- a/cirq-core/cirq/ops/gate_features.py +++ b/cirq-core/cirq/ops/gate_features.py @@ -18,7 +18,7 @@ """ import abc -from typing import Union, Iterable, Any, List +from typing import Union, Iterable, Any, List, Sequence from cirq.ops import raw_types @@ -36,20 +36,39 @@ class SupportsOnEachGate(raw_types.Gate, metaclass=abc.ABCMeta): def on_each(self, *targets: Union[raw_types.Qid, Iterable[Any]]) -> List[raw_types.Operation]: """Returns a list of operations applying the gate to all targets. - Args: - *targets: The qubits to apply this gate to. - + *targets: The qubits to apply this gate to. For single-qubit gates + this can be provided as varargs or a combination of nested + iterables. For multi-qubit gates this must be provided as an + `Iterable[Sequence[Qid]]`, where each sequence has `num_qubits` + qubits. Returns: Operations applying this gate to the target qubits. - Raises: - ValueError if targets are not instances of Qid or List[Qid]. - ValueError if the gate operates on two or more Qids. + ValueError if targets are not instances of Qid or Iterable[Qid]. + ValueError if the gate qubit number is incompatible. """ + operations: List[raw_types.Operation] = [] if self._num_qubits_() > 1: - raise ValueError('This gate only supports on_each when it is a one qubit gate.') - operations = [] # type: List[raw_types.Operation] + iterator: Iterable = targets + if len(targets) == 1: + if not isinstance(targets[0], Iterable): + raise TypeError(f'{targets[0]} object is not iterable.') + t0 = list(targets[0]) + iterator = [t0] if t0 and isinstance(t0[0], raw_types.Qid) else t0 + for target in iterator: + if not isinstance(target, Sequence): + raise ValueError( + f'Inputs to multi-qubit gates must be Sequence[Qid].' + f' Type: {type(target)}' + ) + if not all(isinstance(x, raw_types.Qid) for x in target): + raise ValueError(f'All values in sequence should be Qids, but got {target}') + if len(target) != self._num_qubits_(): + raise ValueError(f'Expected {self._num_qubits_()} qubits, got {target}') + operations.append(self.on(*target)) + return operations + for target in targets: if isinstance(target, raw_types.Qid): operations.append(self.on(target)) diff --git a/cirq-core/cirq/ops/identity_test.py b/cirq-core/cirq/ops/identity_test.py index 2ec9191b741..9220d4dcb3c 100644 --- a/cirq-core/cirq/ops/identity_test.py +++ b/cirq-core/cirq/ops/identity_test.py @@ -66,8 +66,28 @@ def test_identity_on_each_only_single_qubit(): cirq.IdentityGate(1, (3,)).on(q0_3), cirq.IdentityGate(1, (3,)).on(q1_3), ] - with pytest.raises(ValueError, match='one qubit'): - cirq.IdentityGate(num_qubits=2).on_each(q0, q1) + + +def test_identity_on_each_two_qubits(): + q0, q1, q2, q3 = cirq.LineQubit.range(4) + q0_3, q1_3 = q0.with_dimension(3), q1.with_dimension(3) + assert cirq.IdentityGate(2).on_each([(q0, q1)]) == [cirq.IdentityGate(2)(q0, q1)] + assert cirq.IdentityGate(2).on_each([(q0, q1), (q2, q3)]) == [ + cirq.IdentityGate(2)(q0, q1), + cirq.IdentityGate(2)(q2, q3), + ] + assert cirq.IdentityGate(2, (3, 3)).on_each([(q0_3, q1_3)]) == [ + cirq.IdentityGate(2, (3, 3))(q0_3, q1_3), + ] + assert cirq.IdentityGate(2).on_each((q0, q1)) == [cirq.IdentityGate(2)(q0, q1)] + with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): + cirq.IdentityGate(2).on_each(q0, q1) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + cirq.IdentityGate(2).on_each([[(q0, q1)]]) + with pytest.raises(ValueError, match='Expected 2 qubits'): + cirq.IdentityGate(2).on_each([(q0,)]) + with pytest.raises(ValueError, match='Expected 2 qubits'): + cirq.IdentityGate(2).on_each([(q0, q1, q2)]) @pytest.mark.parametrize('num_qubits', [1, 2, 4]) diff --git a/cirq-core/cirq/ops/raw_types_test.py b/cirq-core/cirq/ops/raw_types_test.py index afb85664751..08cf6a9c910 100644 --- a/cirq-core/cirq/ops/raw_types_test.py +++ b/cirq-core/cirq/ops/raw_types_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import AbstractSet +from typing import AbstractSet, Iterator, Any import pytest import numpy as np @@ -739,3 +739,176 @@ def qubits(self): cirq.act_on(NoActOn()(q).with_tags("test"), args) with pytest.raises(TypeError, match="Failed to act"): cirq.act_on(MissingActOn().with_tags("test"), args) + + +def test_single_qubit_gate_validates_on_each(): + class Dummy(cirq.SingleQubitGate): + def matrix(self): + pass + + g = Dummy() + assert g.num_qubits() == 1 + + test_qubits = [cirq.NamedQubit(str(i)) for i in range(3)] + + _ = g.on_each(*test_qubits) + _ = g.on_each(test_qubits) + + test_non_qubits = [str(i) for i in range(3)] + with pytest.raises(ValueError): + _ = g.on_each(*test_non_qubits) + with pytest.raises(ValueError): + _ = g.on_each(*test_non_qubits) + + +def test_on_each(): + class CustomGate(cirq.SingleQubitGate): + pass + + a = cirq.NamedQubit('a') + b = cirq.NamedQubit('b') + c = CustomGate() + + assert c.on_each() == [] + assert c.on_each(a) == [c(a)] + assert c.on_each(a, b) == [c(a), c(b)] + assert c.on_each(b, a) == [c(b), c(a)] + + assert c.on_each([]) == [] + assert c.on_each([a]) == [c(a)] + assert c.on_each([a, b]) == [c(a), c(b)] + assert c.on_each([b, a]) == [c(b), c(a)] + assert c.on_each([a, [b, a], b]) == [c(a), c(b), c(a), c(b)] + + with pytest.raises(ValueError): + c.on_each('abcd') + with pytest.raises(ValueError): + c.on_each(['abcd']) + with pytest.raises(ValueError): + c.on_each([a, 'abcd']) + + qubit_iterator = (q for q in [a, b, a, b]) + assert isinstance(qubit_iterator, Iterator) + assert c.on_each(qubit_iterator) == [c(a), c(b), c(a), c(b)] + + +def test_on_each_two_qubits(): + class CustomGate(cirq.ops.gate_features.SupportsOnEachGate, cirq.TwoQubitGate): + pass + + a = cirq.NamedQubit('a') + b = cirq.NamedQubit('b') + g = CustomGate() + + assert g.on_each([]) == [] + assert g.on_each([(a, b)]) == [g(a, b)] + assert g.on_each([[a, b]]) == [g(a, b)] + assert g.on_each([(b, a)]) == [g(b, a)] + assert g.on_each([(a, b), (b, a)]) == [g(a, b), g(b, a)] + assert g.on_each(zip([a, b], [b, a])) == [g(a, b), g(b, a)] + assert g.on_each() == [] + assert g.on_each((b, a)) == [g(b, a)] + assert g.on_each((a, b), (a, b)) == [g(a, b), g(a, b)] + assert g.on_each(*zip([a, b], [b, a])) == [g(a, b), g(b, a)] + with pytest.raises(TypeError, match='object is not iterable'): + g.on_each(a) + with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): + g.on_each(a, b) + with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): + g.on_each([12]) + with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): + g.on_each([(a, b), 12]) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each([(a, b), [(a, b)]]) + with pytest.raises(ValueError, match='Expected 2 qubits'): + g.on_each([()]) + with pytest.raises(ValueError, match='Expected 2 qubits'): + g.on_each([(a,)]) + with pytest.raises(ValueError, match='Expected 2 qubits'): + g.on_each([(a, b, a)]) + with pytest.raises(ValueError, match='Expected 2 qubits'): + g.on_each(zip([a, a])) + with pytest.raises(ValueError, match='Expected 2 qubits'): + g.on_each(zip([a, a], [b, b], [a, a])) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each('ab') + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each(('ab',)) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each([('ab',)]) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each([(a, 'ab')]) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each([(a, 'b')]) + + qubit_iterator = (qs for qs in [[a, b], [a, b]]) + assert isinstance(qubit_iterator, Iterator) + assert g.on_each(qubit_iterator) == [g(a, b), g(a, b)] + + +def test_on_each_three_qubits(): + class CustomGate(cirq.ops.gate_features.SupportsOnEachGate, cirq.ThreeQubitGate): + pass + + a = cirq.NamedQubit('a') + b = cirq.NamedQubit('b') + c = cirq.NamedQubit('c') + g = CustomGate() + + assert g.on_each([]) == [] + assert g.on_each([(a, b, c)]) == [g(a, b, c)] + assert g.on_each([[a, b, c]]) == [g(a, b, c)] + assert g.on_each([(c, b, a)]) == [g(c, b, a)] + assert g.on_each([(a, b, c), (c, b, a)]) == [g(a, b, c), g(c, b, a)] + assert g.on_each(zip([a, c], [b, b], [c, a])) == [g(a, b, c), g(c, b, a)] + assert g.on_each() == [] + assert g.on_each((c, b, a)) == [g(c, b, a)] + assert g.on_each((a, b, c), (c, b, a)) == [g(a, b, c), g(c, b, a)] + assert g.on_each(*zip([a, c], [b, b], [c, a])) == [g(a, b, c), g(c, b, a)] + with pytest.raises(TypeError, match='object is not iterable'): + g.on_each(a) + with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): + g.on_each(a, b, c) + with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): + g.on_each([12]) + with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): + g.on_each([(a, b, c), 12]) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each([(a, b, c), [(a, b, c)]]) + with pytest.raises(ValueError, match='Expected 3 qubits'): + g.on_each([(a,)]) + with pytest.raises(ValueError, match='Expected 3 qubits'): + g.on_each([(a, b)]) + with pytest.raises(ValueError, match='Expected 3 qubits'): + g.on_each([(a, b, c, a)]) + with pytest.raises(ValueError, match='Expected 3 qubits'): + g.on_each(zip([a, a], [b, b])) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each('abc') + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each(('abc',)) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each([('abc',)]) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each([(a, 'abc')]) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each([(a, 'bc')]) + + qubit_iterator = (qs for qs in [[a, b, c], [a, b, c]]) + assert isinstance(qubit_iterator, Iterator) + assert g.on_each(qubit_iterator) == [g(a, b, c), g(a, b, c)] + + +def test_on_each_iterable_qid(): + class QidIter(cirq.Qid): + @property + def dimension(self) -> int: + return 2 + + def _comparison_key(self) -> Any: + return 1 + + def __iter__(self): + raise NotImplementedError() + + assert cirq.H.on_each(QidIter())[0] == cirq.H.on(QidIter())