Skip to content

Commit

Permalink
Allow on_each for multi-qubit gates (#4281)
Browse files Browse the repository at this point in the history
Adds on_each support to the SupportsOnEachGate mixin for multi-qubit gates.

The handling here is not as flexible as for single-qubit gates, which allows any tree of gates and applies them depth-first. This allows the following two options for multi-qubit gates:

```
A: varargs form gate.on_each([q1, q2], [q3, q4])
B: explicit form gate.on_each([[q1, q2], [q3, q4]])
```


Discussion here, #4034 (comment). Part of #4236.
  • Loading branch information
daxfohl committed Jul 9, 2021
1 parent e2b4477 commit 4ec906d
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 14 deletions.
15 changes: 13 additions & 2 deletions cirq-core/cirq/ops/common_channels_test.py
Expand Up @@ -241,11 +241,22 @@ def test_deprecated_on_each_for_depolarizing_channel_one_qubit():


def test_deprecated_on_each_for_depolarizing_channel_two_qubits():
q0, q1 = cirq.LineQubit.range(2)
q0, q1, q2, q3, q4, q5 = cirq.LineQubit.range(6)
op = cirq.DepolarizingChannel(p=0.1, n_qubits=2)

with pytest.raises(ValueError, match="one qubit"):
op.on_each([(q0, q1)])
op.on_each([(q0, q1), (q2, q3)])
op.on_each(zip([q0, q2, q4], [q1, q3, q5]))
op.on_each((q0, q1))
op.on_each([q0, q1])
with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
op.on_each(q0, q1)
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
op.on_each([('bogus object 0', 'bogus object 1')])
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
op.on_each(['01'])
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
op.on_each([(False, None)])


def test_depolarizing_channel_apply_two_qubits():
Expand Down
37 changes: 28 additions & 9 deletions cirq-core/cirq/ops/gate_features.py
Expand Up @@ -18,7 +18,7 @@
"""

import abc
from typing import Union, Iterable, Any, List
from typing import Union, Iterable, Any, List, Sequence

from cirq.ops import raw_types

Expand All @@ -36,20 +36,39 @@ class SupportsOnEachGate(raw_types.Gate, metaclass=abc.ABCMeta):

def on_each(self, *targets: Union[raw_types.Qid, Iterable[Any]]) -> List[raw_types.Operation]:
"""Returns a list of operations applying the gate to all targets.
Args:
*targets: The qubits to apply this gate to.
*targets: The qubits to apply this gate to. For single-qubit gates
this can be provided as varargs or a combination of nested
iterables. For multi-qubit gates this must be provided as an
`Iterable[Sequence[Qid]]`, where each sequence has `num_qubits`
qubits.
Returns:
Operations applying this gate to the target qubits.
Raises:
ValueError if targets are not instances of Qid or List[Qid].
ValueError if the gate operates on two or more Qids.
ValueError if targets are not instances of Qid or Iterable[Qid].
ValueError if the gate qubit number is incompatible.
"""
operations: List[raw_types.Operation] = []
if self._num_qubits_() > 1:
raise ValueError('This gate only supports on_each when it is a one qubit gate.')
operations = [] # type: List[raw_types.Operation]
iterator: Iterable = targets
if len(targets) == 1:
if not isinstance(targets[0], Iterable):
raise TypeError(f'{targets[0]} object is not iterable.')
t0 = list(targets[0])
iterator = [t0] if t0 and isinstance(t0[0], raw_types.Qid) else t0
for target in iterator:
if not isinstance(target, Sequence):
raise ValueError(
f'Inputs to multi-qubit gates must be Sequence[Qid].'
f' Type: {type(target)}'
)
if not all(isinstance(x, raw_types.Qid) for x in target):
raise ValueError(f'All values in sequence should be Qids, but got {target}')
if len(target) != self._num_qubits_():
raise ValueError(f'Expected {self._num_qubits_()} qubits, got {target}')
operations.append(self.on(*target))
return operations

for target in targets:
if isinstance(target, raw_types.Qid):
operations.append(self.on(target))
Expand Down
24 changes: 22 additions & 2 deletions cirq-core/cirq/ops/identity_test.py
Expand Up @@ -66,8 +66,28 @@ def test_identity_on_each_only_single_qubit():
cirq.IdentityGate(1, (3,)).on(q0_3),
cirq.IdentityGate(1, (3,)).on(q1_3),
]
with pytest.raises(ValueError, match='one qubit'):
cirq.IdentityGate(num_qubits=2).on_each(q0, q1)


def test_identity_on_each_two_qubits():
q0, q1, q2, q3 = cirq.LineQubit.range(4)
q0_3, q1_3 = q0.with_dimension(3), q1.with_dimension(3)
assert cirq.IdentityGate(2).on_each([(q0, q1)]) == [cirq.IdentityGate(2)(q0, q1)]
assert cirq.IdentityGate(2).on_each([(q0, q1), (q2, q3)]) == [
cirq.IdentityGate(2)(q0, q1),
cirq.IdentityGate(2)(q2, q3),
]
assert cirq.IdentityGate(2, (3, 3)).on_each([(q0_3, q1_3)]) == [
cirq.IdentityGate(2, (3, 3))(q0_3, q1_3),
]
assert cirq.IdentityGate(2).on_each((q0, q1)) == [cirq.IdentityGate(2)(q0, q1)]
with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
cirq.IdentityGate(2).on_each(q0, q1)
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
cirq.IdentityGate(2).on_each([[(q0, q1)]])
with pytest.raises(ValueError, match='Expected 2 qubits'):
cirq.IdentityGate(2).on_each([(q0,)])
with pytest.raises(ValueError, match='Expected 2 qubits'):
cirq.IdentityGate(2).on_each([(q0, q1, q2)])


@pytest.mark.parametrize('num_qubits', [1, 2, 4])
Expand Down
175 changes: 174 additions & 1 deletion cirq-core/cirq/ops/raw_types_test.py
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import AbstractSet
from typing import AbstractSet, Iterator, Any

import pytest
import numpy as np
Expand Down Expand Up @@ -739,3 +739,176 @@ def qubits(self):
cirq.act_on(NoActOn()(q).with_tags("test"), args)
with pytest.raises(TypeError, match="Failed to act"):
cirq.act_on(MissingActOn().with_tags("test"), args)


def test_single_qubit_gate_validates_on_each():
class Dummy(cirq.SingleQubitGate):
def matrix(self):
pass

g = Dummy()
assert g.num_qubits() == 1

test_qubits = [cirq.NamedQubit(str(i)) for i in range(3)]

_ = g.on_each(*test_qubits)
_ = g.on_each(test_qubits)

test_non_qubits = [str(i) for i in range(3)]
with pytest.raises(ValueError):
_ = g.on_each(*test_non_qubits)
with pytest.raises(ValueError):
_ = g.on_each(*test_non_qubits)


def test_on_each():
class CustomGate(cirq.SingleQubitGate):
pass

a = cirq.NamedQubit('a')
b = cirq.NamedQubit('b')
c = CustomGate()

assert c.on_each() == []
assert c.on_each(a) == [c(a)]
assert c.on_each(a, b) == [c(a), c(b)]
assert c.on_each(b, a) == [c(b), c(a)]

assert c.on_each([]) == []
assert c.on_each([a]) == [c(a)]
assert c.on_each([a, b]) == [c(a), c(b)]
assert c.on_each([b, a]) == [c(b), c(a)]
assert c.on_each([a, [b, a], b]) == [c(a), c(b), c(a), c(b)]

with pytest.raises(ValueError):
c.on_each('abcd')
with pytest.raises(ValueError):
c.on_each(['abcd'])
with pytest.raises(ValueError):
c.on_each([a, 'abcd'])

qubit_iterator = (q for q in [a, b, a, b])
assert isinstance(qubit_iterator, Iterator)
assert c.on_each(qubit_iterator) == [c(a), c(b), c(a), c(b)]


def test_on_each_two_qubits():
class CustomGate(cirq.ops.gate_features.SupportsOnEachGate, cirq.TwoQubitGate):
pass

a = cirq.NamedQubit('a')
b = cirq.NamedQubit('b')
g = CustomGate()

assert g.on_each([]) == []
assert g.on_each([(a, b)]) == [g(a, b)]
assert g.on_each([[a, b]]) == [g(a, b)]
assert g.on_each([(b, a)]) == [g(b, a)]
assert g.on_each([(a, b), (b, a)]) == [g(a, b), g(b, a)]
assert g.on_each(zip([a, b], [b, a])) == [g(a, b), g(b, a)]
assert g.on_each() == []
assert g.on_each((b, a)) == [g(b, a)]
assert g.on_each((a, b), (a, b)) == [g(a, b), g(a, b)]
assert g.on_each(*zip([a, b], [b, a])) == [g(a, b), g(b, a)]
with pytest.raises(TypeError, match='object is not iterable'):
g.on_each(a)
with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
g.on_each(a, b)
with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
g.on_each([12])
with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
g.on_each([(a, b), 12])
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
g.on_each([(a, b), [(a, b)]])
with pytest.raises(ValueError, match='Expected 2 qubits'):
g.on_each([()])
with pytest.raises(ValueError, match='Expected 2 qubits'):
g.on_each([(a,)])
with pytest.raises(ValueError, match='Expected 2 qubits'):
g.on_each([(a, b, a)])
with pytest.raises(ValueError, match='Expected 2 qubits'):
g.on_each(zip([a, a]))
with pytest.raises(ValueError, match='Expected 2 qubits'):
g.on_each(zip([a, a], [b, b], [a, a]))
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
g.on_each('ab')
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
g.on_each(('ab',))
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
g.on_each([('ab',)])
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
g.on_each([(a, 'ab')])
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
g.on_each([(a, 'b')])

qubit_iterator = (qs for qs in [[a, b], [a, b]])
assert isinstance(qubit_iterator, Iterator)
assert g.on_each(qubit_iterator) == [g(a, b), g(a, b)]


def test_on_each_three_qubits():
class CustomGate(cirq.ops.gate_features.SupportsOnEachGate, cirq.ThreeQubitGate):
pass

a = cirq.NamedQubit('a')
b = cirq.NamedQubit('b')
c = cirq.NamedQubit('c')
g = CustomGate()

assert g.on_each([]) == []
assert g.on_each([(a, b, c)]) == [g(a, b, c)]
assert g.on_each([[a, b, c]]) == [g(a, b, c)]
assert g.on_each([(c, b, a)]) == [g(c, b, a)]
assert g.on_each([(a, b, c), (c, b, a)]) == [g(a, b, c), g(c, b, a)]
assert g.on_each(zip([a, c], [b, b], [c, a])) == [g(a, b, c), g(c, b, a)]
assert g.on_each() == []
assert g.on_each((c, b, a)) == [g(c, b, a)]
assert g.on_each((a, b, c), (c, b, a)) == [g(a, b, c), g(c, b, a)]
assert g.on_each(*zip([a, c], [b, b], [c, a])) == [g(a, b, c), g(c, b, a)]
with pytest.raises(TypeError, match='object is not iterable'):
g.on_each(a)
with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
g.on_each(a, b, c)
with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
g.on_each([12])
with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
g.on_each([(a, b, c), 12])
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
g.on_each([(a, b, c), [(a, b, c)]])
with pytest.raises(ValueError, match='Expected 3 qubits'):
g.on_each([(a,)])
with pytest.raises(ValueError, match='Expected 3 qubits'):
g.on_each([(a, b)])
with pytest.raises(ValueError, match='Expected 3 qubits'):
g.on_each([(a, b, c, a)])
with pytest.raises(ValueError, match='Expected 3 qubits'):
g.on_each(zip([a, a], [b, b]))
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
g.on_each('abc')
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
g.on_each(('abc',))
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
g.on_each([('abc',)])
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
g.on_each([(a, 'abc')])
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
g.on_each([(a, 'bc')])

qubit_iterator = (qs for qs in [[a, b, c], [a, b, c]])
assert isinstance(qubit_iterator, Iterator)
assert g.on_each(qubit_iterator) == [g(a, b, c), g(a, b, c)]


def test_on_each_iterable_qid():
class QidIter(cirq.Qid):
@property
def dimension(self) -> int:
return 2

def _comparison_key(self) -> Any:
return 1

def __iter__(self):
raise NotImplementedError()

assert cirq.H.on_each(QidIter())[0] == cirq.H.on(QidIter())

0 comments on commit 4ec906d

Please sign in to comment.