Skip to content

Commit

Permalink
Add Circuit.zip method (#3075)
Browse files Browse the repository at this point in the history
Handy to have around when building circuits up in tiled pieces and wanting to guarantee the moment structure comes out right.
  • Loading branch information
Strilanc committed Jun 8, 2020
1 parent 2b8b085 commit 8cc7d93
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 0 deletions.
67 changes: 67 additions & 0 deletions cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,6 +1235,73 @@ def _insert_operations(self, operations: Sequence['cirq.Operation'],
self._moments[moment_index] = ops.Moment(
self._moments[moment_index].operations + tuple(new_ops))

def zip(*circuits):
"""Combines operations from circuits in a moment-by-moment fashion.
Moment k of the resulting circuit will have all operations from moment
k of each of the given circuits.
When the given circuits have different lengths, the shorter circuits are
implicitly padded with empty moments. This differs from the behavior of
python's built-in zip function, which would instead truncate the longer
circuits.
The zipped circuits can't have overlapping operations occurring at the
same moment index.
Args:
circuits: The circuits to merge together.
Returns:
The merged circuit.
Raises:
ValueError:
The zipped circuits have overlapping operations occurring at the
same moment index.
Examples:
>>> import cirq
>>> a, b, c, d = cirq.LineQubit.range(4)
>>> circuit1 = cirq.Circuit(cirq.H(a), cirq.CNOT(a, b))
>>> circuit2 = cirq.Circuit(cirq.X(c), cirq.Y(c), cirq.Z(c))
>>> circuit3 = cirq.Circuit(cirq.Moment(), cirq.Moment(cirq.S(d)))
>>> print(circuit1.zip(circuit2))
0: ───H───@───────
1: ───────X───────
<BLANKLINE>
2: ───X───Y───Z───
>>> print(circuit1.zip(circuit2, circuit3))
0: ───H───@───────
1: ───────X───────
<BLANKLINE>
2: ───X───Y───Z───
<BLANKLINE>
3: ───────S───────
>>> print(cirq.Circuit.zip(circuit3, circuit2, circuit1))
0: ───H───@───────
1: ───────X───────
<BLANKLINE>
2: ───X───Y───Z───
<BLANKLINE>
3: ───────S───────
"""
circuits = list(circuits)
n = max([len(c) for c in circuits], default=0)

result = cirq.Circuit()
for k in range(n):
try:
result.append(cirq.Moment(c[k] for c in circuits if k < len(c)))
except ValueError as ex:
raise ValueError(
f"Overlapping operations between zipped circuits "
f"at moment index {k}.\n{ex}") from ex
return result

def insert_at_frontier(self,
operations: 'cirq.OP_TREE',
start: int,
Expand Down
53 changes: 53 additions & 0 deletions cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3764,3 +3764,56 @@ def test_deprecated():
circuit = cirq.Circuit([cirq.H(q)])
with cirq.testing.assert_logs('final_state_vector', 'deprecated'):
_ = circuit.final_wavefunction()


def test_zip():
a, b, c, d = cirq.LineQubit.range(4)

circuit1 = cirq.Circuit(cirq.H(a), cirq.CNOT(a, b))
circuit2 = cirq.Circuit(cirq.X(c), cirq.Y(c), cirq.Z(c))
circuit3 = cirq.Circuit(cirq.Moment(), cirq.Moment(cirq.S(d)))

# Calling works both static-style and instance-style.
assert circuit1.zip(circuit2) == cirq.Circuit.zip(circuit1, circuit2)

# Empty cases.
assert cirq.Circuit.zip() == cirq.Circuit()
assert cirq.Circuit.zip(cirq.Circuit()) == cirq.Circuit()
assert cirq.Circuit().zip(cirq.Circuit()) == cirq.Circuit()
assert circuit1.zip(cirq.Circuit()) == circuit1
assert cirq.Circuit(cirq.Moment()).zip(cirq.Circuit()) == cirq.Circuit(
cirq.Moment())
assert cirq.Circuit().zip(cirq.Circuit(cirq.Moment())) == cirq.Circuit(
cirq.Moment())

# Small cases.
assert circuit1.zip(circuit2) == circuit2.zip(circuit1) == cirq.Circuit(
cirq.Moment(
cirq.H(a),
cirq.X(c),
),
cirq.Moment(
cirq.CNOT(a, b),
cirq.Y(c),
),
cirq.Moment(cirq.Z(c),),
)
assert circuit1.zip(circuit2, circuit3) == cirq.Circuit(
cirq.Moment(
cirq.H(a),
cirq.X(c),
),
cirq.Moment(
cirq.CNOT(a, b),
cirq.Y(c),
cirq.S(d),
),
cirq.Moment(cirq.Z(c),),
)

# Overlapping operations.
with pytest.raises(ValueError, match="moment index 1.*\n.*CNOT"):
_ = cirq.Circuit.zip(
cirq.Circuit(cirq.X(a), cirq.CNOT(a, b)),
cirq.Circuit(cirq.X(b), cirq.Z(b)),
)

0 comments on commit 8cc7d93

Please sign in to comment.