Skip to content

Commit

Permalink
Add support for indexing circuits and moments by qubits (#2773)
Browse files Browse the repository at this point in the history
Fixes #2762
  • Loading branch information
fedimser committed Mar 3, 2020
1 parent 5d9a464 commit cc457d3
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 5 deletions.
49 changes: 45 additions & 4 deletions cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
Operations. Each Operation is a Gate that acts on some Qubits, for a given
Moment the Operations must all act on distinct Qubits.
"""

from collections import defaultdict
from fractions import Fraction
from itertools import groupby
Expand Down Expand Up @@ -85,6 +84,13 @@ class Circuit:
and sliced,
circuit[1:3] is a new Circuit made up of two moments, the first being
circuit[1] and the second being circuit[2];
circuit[:, qubit] is a new Circuit with the same moments, but with only
those operations which act on the given Qubit;
circuit[:, qubits], where 'qubits' is list of Qubits, is a new Circuit
with the same moments, but only with those operations which touch
any of the given qubits;
circuit[1:3, qubit] is equivalent to circuit[1:3][:, qubit];
circuit[1:3, qubits] is equivalent to circuit[1:3][:, qubits];
and concatenated,
circuit1 + circuit2 is a new Circuit made up of the moments in circuit1
followed by the moments in circuit2;
Expand Down Expand Up @@ -211,15 +217,50 @@ def __getitem__(self, key: slice) -> 'Circuit':
def __getitem__(self, key: int) -> 'cirq.Moment':
pass

@overload
def __getitem__(self, key: Tuple[int, 'cirq.Qid']) -> 'cirq.Operation':
pass

@overload
def __getitem__(self,
key: Tuple[int, Iterable['cirq.Qid']]) -> 'cirq.Moment':
pass

@overload
def __getitem__(self, key: Tuple[slice, 'cirq.Qid']) -> 'cirq.Circuit':
pass

@overload
def __getitem__(self,
key: Tuple[slice, Iterable['cirq.Qid']]) -> 'cirq.Circuit':
pass

def __getitem__(self, key):
if isinstance(key, slice):
sliced_circuit = Circuit(device=self.device)
sliced_circuit._moments = self._moments[key]
return sliced_circuit
if isinstance(key, int):
if hasattr(key, '__index__'):
return self._moments[key]

raise TypeError('__getitem__ called with key not of type slice or int.')
if isinstance(key, tuple):
if len(key) != 2:
raise ValueError('If key is tuple, it must be a pair.')
moment_idx, qubit_idx = key
# moment_idx - int or slice; qubit_idx - Qid or Iterable[Qid].
selected_moments = self._moments[moment_idx]
# selected_moments - Moment or list[Moment].
if isinstance(selected_moments, list):
if isinstance(qubit_idx, cirq.Qid):
qubit_idx = [qubit_idx]
new_circuit = Circuit(device=self.device)
new_circuit._moments = [
moment[qubit_idx] for moment in selected_moments
]
return new_circuit
return selected_moments[qubit_idx]

raise TypeError(
'__getitem__ called with key not of type slice, int or tuple.')

@overload
def __setitem__(self, key: int, value: 'cirq.Moment'):
Expand Down
109 changes: 109 additions & 0 deletions cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3615,3 +3615,112 @@ def test_transform_qubits():
assert c.transform_qubits(lambda q: q).device is cg.Foxtail
assert c.transform_qubits(lambda q: q, new_device=cg.Bristlecone
).device is cg.Bristlecone


def test_indexing_by_pair():
# 0: ───H───@───X───@───
# │ │
# 1: ───────H───@───@───
# │ │
# 2: ───────────H───X───
q = cirq.LineQubit.range(3)
c = cirq.Circuit([
cirq.H(q[0]),
cirq.H(q[1]).controlled_by(q[0]),
cirq.H(q[2]).controlled_by(q[1]),
cirq.X(q[0]),
cirq.CCNOT(*q),
])

# Indexing by single moment and qubit.
assert c[0, q[0]] == c[0][q[0]] == cirq.H(q[0])
assert c[1, q[0]] == c[1, q[1]] == cirq.H(q[1]).controlled_by(q[0])
assert c[2, q[0]] == c[2][q[0]] == cirq.X(q[0])
assert c[2, q[1]] == c[2, q[2]] == cirq.H(q[2]).controlled_by(q[1])
assert c[3, q[0]] == c[3, q[1]] == c[3, q[2]] == cirq.CCNOT(*q)

# Indexing by moment and qubit - throws if there is no operation.
with pytest.raises(KeyError, match="Moment doesn't act on given qubit"):
_ = c[0, q[1]]

# Indexing by single moment and multiple qubits.
assert c[0, q] == c[0]
assert c[1, q] == c[1]
assert c[2, q] == c[2]
assert c[3, q] == c[3]
assert c[0, q[0:2]] == c[0]
assert c[0, q[1:3]] == cirq.Moment([])
assert c[1, q[1:2]] == c[1]
assert c[2, [q[0]]] == cirq.Moment([cirq.X(q[0])])
assert c[2, q[1:3]] == cirq.Moment([cirq.H(q[2]).controlled_by(q[1])])
assert c[np.int64(2), q[0:2]] == c[2]

# Indexing by single qubit.
assert c[:, q[0]] == cirq.Circuit([
cirq.Moment([cirq.H(q[0])]),
cirq.Moment([cirq.H(q[1]).controlled_by(q[0])]),
cirq.Moment([cirq.X(q[0])]),
cirq.Moment([cirq.CCNOT(q[0], q[1], q[2])]),
])
assert c[:, q[1]] == cirq.Circuit([
cirq.Moment([]),
cirq.Moment([cirq.H(q[1]).controlled_by(q[0])]),
cirq.Moment([cirq.H(q[2]).controlled_by(q[1])]),
cirq.Moment([cirq.CCNOT(q[0], q[1], q[2])]),
])
assert c[:, q[2]] == cirq.Circuit([
cirq.Moment([]),
cirq.Moment([]),
cirq.Moment([cirq.H(q[2]).controlled_by(q[1])]),
cirq.Moment([cirq.CCNOT(q[0], q[1], q[2])]),
])

# Indexing by several qubits.
assert c[:, q] == c[:, q[0:2]] == c[:, [q[0], q[2]]] == c
assert c[:, q[1:3]] == cirq.Circuit([
cirq.Moment([]),
cirq.Moment([cirq.H(q[1]).controlled_by(q[0])]),
cirq.Moment([cirq.H(q[2]).controlled_by(q[1])]),
cirq.Moment([cirq.CCNOT(q[0], q[1], q[2])]),
])

# Indexing by several moments and one qubit.
assert c[1:3, q[0]] == cirq.Circuit([
cirq.H(q[1]).controlled_by(q[0]),
cirq.X(q[0]),
])
assert c[1::2, q[2]] == cirq.Circuit([
cirq.Moment([]),
cirq.Moment([cirq.CCNOT(*q)]),
])

# Indexing by several moments and several qubits.
assert c[0:2, q[1:3]] == cirq.Circuit([
cirq.Moment([]),
cirq.Moment([cirq.H(q[1]).controlled_by(q[0])]),
])
assert c[::2, q[0:2]] == cirq.Circuit([
cirq.Moment([cirq.H(q[0])]),
cirq.Moment([cirq.H(q[2]).controlled_by(q[1]),
cirq.X(q[0])]),
])

# Equivalent ways of indexing.
assert c[0:2, q[1:3]] == c[0:2][:, q[1:3]] == c[:, q[1:3]][0:2]

# Passing more than 2 items is forbidden.
with pytest.raises(ValueError, match='If key is tuple, it must be a pair.'):
_ = c[0, q[1], 0]

# Can't swap indices.
with pytest.raises(TypeError,
match='list indices must be integers or slices'):
_ = c[q[1], 0]


def test_indexing_by_numpy_integer():
q = cirq.NamedQubit('q')
c = cirq.Circuit(cirq.X(q), cirq.Y(q))

assert c[np.int32(1)] == cirq.Moment([cirq.Y(q)])
assert c[np.int64(1)] == cirq.Moment([cirq.Y(q)])
40 changes: 39 additions & 1 deletion cirq/ops/moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""A simplified time-slice of operations within a sequenced circuit."""

from typing import (Any, Callable, Iterable, Sequence, TypeVar, Union, Tuple,
FrozenSet, TYPE_CHECKING, Iterator)
FrozenSet, TYPE_CHECKING, Iterator, overload)
from cirq import protocols
from cirq.ops import raw_types

Expand All @@ -33,6 +33,13 @@ class Moment:
moment should execute at the same time (to the extent possible; not all
operations have the same duration) and it is expected that all operations
in a moment should be completed before beginning the next moment.
Moment can be indexed by qubit or list of qubits:
moment[qubit] returns the Operation in the moment which touches the
given qubit, or throws KeyError if there is no such operation.
moment[qubits] returns another Moment which consists only of those
operations which touch at least one of the given qubits. If there
are no such operations, returns an empty Moment.
"""

def __init__(self, operations: Iterable[raw_types.Operation] = ()) -> None:
Expand Down Expand Up @@ -117,6 +124,18 @@ def without_operations_touching(self, qubits: Iterable[raw_types.Qid]):
operation for operation in self.operations
if qubits.isdisjoint(frozenset(operation.qubits)))

def _operation_touching(self, qubit: raw_types.Qid) -> 'cirq.Operation':
"""Returns the operation touching given qubit.
Args:
qubit: Operations that touch this qubit will be returned.
Returns:
The operation which touches `qubit`.
"""
for op in self.operations:
if qubit in op.qubits:
return op
raise KeyError("Moment doesn't act on given qubit")

def __copy__(self):
return type(self)(self.operations)

Expand Down Expand Up @@ -198,6 +217,25 @@ def __add__(self, other):
return self.with_operation(other)
return NotImplemented

# pylint: disable=function-redefined
@overload
def __getitem__(self, key: raw_types.Qid) -> 'cirq.Operation':
pass

@overload
def __getitem__(self, key: Iterable[raw_types.Qid]) -> 'cirq.Moment':
pass

def __getitem__(self, key):
if isinstance(key, raw_types.Qid):
return self._operation_touching(key)
elif isinstance(key, Iterable):
qubits_to_keep = frozenset(key)
ops_to_keep = tuple(
op for op in self.operations
if not qubits_to_keep.isdisjoint(frozenset(op.qubits)))
return Moment(ops_to_keep)


def _list_repr_with_indented_item_lines(items: Sequence[Any]) -> str:
block = '\n'.join([repr(op) + ',' for op in items])
Expand Down
32 changes: 32 additions & 0 deletions cirq/ops/moment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,35 @@ def test_add():
circuit2 = cirq.Circuit(cirq.CNOT(a, b), cirq.Y(b))
circuit2[1] += cirq.X(a)
assert circuit2 == expected_circuit


def test_indexes_by_qubit():
a, b, c = cirq.LineQubit.range(3)
moment = cirq.Moment([cirq.H(a), cirq.CNOT(b, c)])

assert moment[a] == cirq.H(a)
assert moment[b] == cirq.CNOT(b, c)
assert moment[c] == cirq.CNOT(b, c)


def test_throws_when_indexed_by_unused_qubit():
a, b = cirq.LineQubit.range(2)
moment = cirq.Moment([cirq.H(a)])

with pytest.raises(KeyError, match="Moment doesn't act on given qubit"):
_ = moment[b]


def test_indexes_by_list_of_qubits():
q = cirq.LineQubit.range(4)
moment = cirq.Moment([cirq.Z(q[0]), cirq.CNOT(q[1], q[2])])

assert moment[[q[0]]] == Moment([cirq.Z(q[0])])
assert moment[[q[1]]] == Moment([cirq.CNOT(q[1], q[2])])
assert moment[[q[2]]] == Moment([cirq.CNOT(q[1], q[2])])
assert moment[[q[3]]] == Moment([])
assert moment[q[0:2]] == moment
assert moment[q[1:3]] == Moment([cirq.CNOT(q[1], q[2])])
assert moment[q[2:4]] == Moment([cirq.CNOT(q[1], q[2])])
assert moment[[q[0], q[3]]] == Moment([cirq.Z(q[0])])
assert moment[q] == moment

0 comments on commit cc457d3

Please sign in to comment.