Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 68 additions & 2 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,64 @@ def zip(
) from ex
return result

def concat_ragged(
*circuits: 'cirq.AbstractCircuit', align: Union['cirq.Alignment', str] = Alignment.LEFT
) -> 'cirq.AbstractCircuit':
"""Concatenates circuits, overlapping them if possible due to ragged edges.

Starts with the first circuit (index 0), then iterates over the other
circuits while folding them in. To fold two circuits together, they
are placed one after the other and then moved inward until just before
their operations would collide. If any of the circuits do not share
qubits and so would not collide, the starts or ends of the circuits will
be aligned, acording to the given align parameter.

Beware that this method is *not* associative. For example:

>>> a, b = cirq.LineQubit.range(2)
>>> A = cirq.Circuit(cirq.H(a))
>>> B = cirq.Circuit(cirq.H(b))
>>> f = cirq.Circuit.concat_ragged
>>> f(f(A, B), A) == f(A, f(B, A))
False
>>> len(f(f(f(A, B), A), B)) == len(f(f(A, f(B, A)), B))
False

Args:
*circuits: The circuits to concatenate.
align: When to stop when sliding the circuits together.
'left': Stop when the starts of the circuits align.
'right': Stop when the ends of the circuits align.
'first': Stop the first time either the starts or the ends align. Circuits
are never overlapped more than needed to align their starts (in case
the left circuit is smaller) or to align their ends (in case the right
circuit is smaller)

Returns:
The concatenated and overlapped circuit.
"""
if len(circuits) == 0:
return Circuit()
n_acc = len(circuits[0])

if isinstance(align, str):
align = Alignment[align.upper()]

# Allocate a buffer large enough to append and prepend all the circuits.
pad_len = sum(len(c) for c in circuits) - n_acc
buffer = np.zeros(shape=pad_len * 2 + n_acc, dtype=object)

# Put the initial circuit in the center of the buffer.
offset = pad_len
buffer[offset : offset + n_acc] = circuits[0].moments

# Accumulate all the circuits into the buffer.
for k in range(1, len(circuits)):
offset, n_acc = _concat_ragged_helper(offset, n_acc, buffer, circuits[k].moments, align)

return cirq.Circuit(buffer[offset : offset + n_acc])

@_compat.deprecated(deadline='v0.16', fix='Renaming to concat_ragged')
def tetris_concat(
*circuits: 'cirq.AbstractCircuit', align: Union['cirq.Alignment', str] = Alignment.LEFT
) -> 'cirq.AbstractCircuit':
Expand Down Expand Up @@ -1503,7 +1561,7 @@ def tetris_concat(

# Accumulate all the circuits into the buffer.
for k in range(1, len(circuits)):
offset, n_acc = _tetris_concat_helper(offset, n_acc, buffer, circuits[k].moments, align)
offset, n_acc = _concat_ragged_helper(offset, n_acc, buffer, circuits[k].moments, align)

return cirq.Circuit(buffer[offset : offset + n_acc])

Expand Down Expand Up @@ -1633,7 +1691,7 @@ def _overlap_collision_time(
return upper_bound


def _tetris_concat_helper(
def _concat_ragged_helper(
c1_offset: int, n1: int, buf: np.ndarray, c2: Sequence['cirq.Moment'], align: 'cirq.Alignment'
) -> Tuple[int, int]:
n2 = len(c2)
Expand Down Expand Up @@ -1846,6 +1904,14 @@ def __pow__(self, exponent: int) -> 'cirq.Circuit':

__hash__ = None # type: ignore

def concat_ragged(
*circuits: 'cirq.AbstractCircuit', align: Union['cirq.Alignment', str] = Alignment.LEFT
) -> 'cirq.Circuit':
return AbstractCircuit.concat_ragged(*circuits, align=align).unfreeze(copy=False)

concat_ragged.__doc__ = AbstractCircuit.concat_ragged.__doc__

@_compat.deprecated(deadline='v0.16', fix='Renaming to concat_ragged')
def tetris_concat(
*circuits: 'cirq.AbstractCircuit', align: Union['cirq.Alignment', str] = Alignment.LEFT
) -> 'cirq.Circuit':
Expand Down
Loading