diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index c6b00726a2c..78a9765f8c7 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -1450,6 +1450,64 @@ def zip( ) from ex return result + def concat_ragged( + *circuits: 'cirq.AbstractCircuit', align: Union['cirq.Alignment', str] = Alignment.LEFT + ) -> 'cirq.AbstractCircuit': + """Concatenates circuits, overlapping them if possible due to ragged edges. + + Starts with the first circuit (index 0), then iterates over the other + circuits while folding them in. To fold two circuits together, they + are placed one after the other and then moved inward until just before + their operations would collide. If any of the circuits do not share + qubits and so would not collide, the starts or ends of the circuits will + be aligned, acording to the given align parameter. + + Beware that this method is *not* associative. For example: + + >>> a, b = cirq.LineQubit.range(2) + >>> A = cirq.Circuit(cirq.H(a)) + >>> B = cirq.Circuit(cirq.H(b)) + >>> f = cirq.Circuit.concat_ragged + >>> f(f(A, B), A) == f(A, f(B, A)) + False + >>> len(f(f(f(A, B), A), B)) == len(f(f(A, f(B, A)), B)) + False + + Args: + *circuits: The circuits to concatenate. + align: When to stop when sliding the circuits together. + 'left': Stop when the starts of the circuits align. + 'right': Stop when the ends of the circuits align. + 'first': Stop the first time either the starts or the ends align. Circuits + are never overlapped more than needed to align their starts (in case + the left circuit is smaller) or to align their ends (in case the right + circuit is smaller) + + Returns: + The concatenated and overlapped circuit. + """ + if len(circuits) == 0: + return Circuit() + n_acc = len(circuits[0]) + + if isinstance(align, str): + align = Alignment[align.upper()] + + # Allocate a buffer large enough to append and prepend all the circuits. + pad_len = sum(len(c) for c in circuits) - n_acc + buffer = np.zeros(shape=pad_len * 2 + n_acc, dtype=object) + + # Put the initial circuit in the center of the buffer. + offset = pad_len + buffer[offset : offset + n_acc] = circuits[0].moments + + # Accumulate all the circuits into the buffer. + for k in range(1, len(circuits)): + offset, n_acc = _concat_ragged_helper(offset, n_acc, buffer, circuits[k].moments, align) + + return cirq.Circuit(buffer[offset : offset + n_acc]) + + @_compat.deprecated(deadline='v0.16', fix='Renaming to concat_ragged') def tetris_concat( *circuits: 'cirq.AbstractCircuit', align: Union['cirq.Alignment', str] = Alignment.LEFT ) -> 'cirq.AbstractCircuit': @@ -1503,7 +1561,7 @@ def tetris_concat( # Accumulate all the circuits into the buffer. for k in range(1, len(circuits)): - offset, n_acc = _tetris_concat_helper(offset, n_acc, buffer, circuits[k].moments, align) + offset, n_acc = _concat_ragged_helper(offset, n_acc, buffer, circuits[k].moments, align) return cirq.Circuit(buffer[offset : offset + n_acc]) @@ -1633,7 +1691,7 @@ def _overlap_collision_time( return upper_bound -def _tetris_concat_helper( +def _concat_ragged_helper( c1_offset: int, n1: int, buf: np.ndarray, c2: Sequence['cirq.Moment'], align: 'cirq.Alignment' ) -> Tuple[int, int]: n2 = len(c2) @@ -1846,6 +1904,14 @@ def __pow__(self, exponent: int) -> 'cirq.Circuit': __hash__ = None # type: ignore + def concat_ragged( + *circuits: 'cirq.AbstractCircuit', align: Union['cirq.Alignment', str] = Alignment.LEFT + ) -> 'cirq.Circuit': + return AbstractCircuit.concat_ragged(*circuits, align=align).unfreeze(copy=False) + + concat_ragged.__doc__ = AbstractCircuit.concat_ragged.__doc__ + + @_compat.deprecated(deadline='v0.16', fix='Renaming to concat_ragged') def tetris_concat( *circuits: 'cirq.AbstractCircuit', align: Union['cirq.Alignment', str] = Alignment.LEFT ) -> 'cirq.Circuit': diff --git a/cirq-core/cirq/circuits/circuit_test.py b/cirq-core/cirq/circuits/circuit_test.py index 0ea762d732a..0f9c32fe332 100644 --- a/cirq-core/cirq/circuits/circuit_test.py +++ b/cirq-core/cirq/circuits/circuit_test.py @@ -4317,33 +4317,240 @@ def _circuit_diagram_info_(self, args): assert '|c>' in circuit._repr_html_() -def test_tetris_concat(): +def test_tetris_concat_deprecated(): a, b = cirq.LineQubit.range(2) empty = cirq.Circuit() - assert cirq.Circuit.tetris_concat(empty, empty) == empty - assert cirq.Circuit.tetris_concat() == empty - assert empty.tetris_concat(empty) == empty - assert empty.tetris_concat(empty, empty) == empty + with cirq.testing.assert_deprecated('ragged', deadline='v0.16', count=None): + assert cirq.Circuit.tetris_concat(empty, empty) == empty + assert cirq.Circuit.tetris_concat() == empty + assert empty.tetris_concat(empty) == empty + assert empty.tetris_concat(empty, empty) == empty + + ha = cirq.Circuit(cirq.H(a)) + hb = cirq.Circuit(cirq.H(b)) + assert ha.tetris_concat(hb) == ha.zip(hb) + + assert ha.tetris_concat(empty) == ha + assert empty.tetris_concat(ha) == ha + + hac = cirq.Circuit(cirq.H(a), cirq.CNOT(a, b)) + assert hac.tetris_concat(hb) == hac + hb + assert hb.tetris_concat(hac) == hb.zip(hac) + + zig = cirq.Circuit(cirq.H(a), cirq.CNOT(a, b), cirq.H(b)) + assert zig.tetris_concat(zig) == cirq.Circuit( + cirq.H(a), + cirq.CNOT(a, b), + cirq.Moment(cirq.H(a), cirq.H(b)), + cirq.CNOT(a, b), + cirq.H(b), + ) + + zag = cirq.Circuit(cirq.H(a), cirq.H(a), cirq.CNOT(a, b), cirq.H(b), cirq.H(b)) + assert zag.tetris_concat(zag) == cirq.Circuit( + cirq.H(a), + cirq.H(a), + cirq.CNOT(a, b), + cirq.Moment(cirq.H(a), cirq.H(b)), + cirq.Moment(cirq.H(a), cirq.H(b)), + cirq.CNOT(a, b), + cirq.H(b), + cirq.H(b), + ) + + space = cirq.Circuit(cirq.Moment()) * 10 + f = cirq.Circuit.tetris_concat + assert len(f(space, ha)) == 10 + assert len(f(space, ha, ha, ha)) == 10 + assert len(f(space, f(ha, ha, ha))) == 10 + assert len(f(space, ha, align='LEFT')) == 10 + assert len(f(space, ha, ha, ha, align='RIGHT')) == 12 + assert len(f(space, f(ha, ha, ha, align='LEFT'))) == 10 + assert len(f(space, f(ha, ha, ha, align='RIGHT'))) == 10 + assert len(f(space, f(ha, ha, ha), align='LEFT')) == 10 + assert len(f(space, f(ha, ha, ha), align='RIGHT')) == 10 + + # L shape overlap (vary c1). + assert 7 == len( + f( + cirq.Circuit(cirq.CZ(a, b), [cirq.H(a)] * 5), + cirq.Circuit([cirq.H(b)] * 5, cirq.CZ(a, b)), + ) + ) + assert 7 == len( + f( + cirq.Circuit(cirq.CZ(a, b), [cirq.H(a)] * 4), + cirq.Circuit([cirq.H(b)] * 5, cirq.CZ(a, b)), + ) + ) + assert 7 == len( + f( + cirq.Circuit(cirq.CZ(a, b), [cirq.H(a)] * 1), + cirq.Circuit([cirq.H(b)] * 5, cirq.CZ(a, b)), + ) + ) + assert 8 == len( + f( + cirq.Circuit(cirq.CZ(a, b), [cirq.H(a)] * 6), + cirq.Circuit([cirq.H(b)] * 5, cirq.CZ(a, b)), + ) + ) + assert 9 == len( + f( + cirq.Circuit(cirq.CZ(a, b), [cirq.H(a)] * 7), + cirq.Circuit([cirq.H(b)] * 5, cirq.CZ(a, b)), + ) + ) + + # L shape overlap (vary c2). + assert 7 == len( + f( + cirq.Circuit(cirq.CZ(a, b), [cirq.H(a)] * 5), + cirq.Circuit([cirq.H(b)] * 5, cirq.CZ(a, b)), + ) + ) + assert 7 == len( + f( + cirq.Circuit(cirq.CZ(a, b), [cirq.H(a)] * 5), + cirq.Circuit([cirq.H(b)] * 4, cirq.CZ(a, b)), + ) + ) + assert 7 == len( + f( + cirq.Circuit(cirq.CZ(a, b), [cirq.H(a)] * 5), + cirq.Circuit([cirq.H(b)] * 1, cirq.CZ(a, b)), + ) + ) + assert 8 == len( + f( + cirq.Circuit(cirq.CZ(a, b), [cirq.H(a)] * 5), + cirq.Circuit([cirq.H(b)] * 6, cirq.CZ(a, b)), + ) + ) + assert 9 == len( + f( + cirq.Circuit(cirq.CZ(a, b), [cirq.H(a)] * 5), + cirq.Circuit([cirq.H(b)] * 7, cirq.CZ(a, b)), + ) + ) + + # When scanning sees a possible hit, continues scanning for earlier hit. + assert 10 == len( + f( + cirq.Circuit( + cirq.Moment(), + cirq.Moment(), + cirq.Moment(), + cirq.Moment(), + cirq.Moment(), + cirq.Moment(cirq.H(a)), + cirq.Moment(), + cirq.Moment(), + cirq.Moment(cirq.H(b)), + ), + cirq.Circuit( + cirq.Moment(), + cirq.Moment(), + cirq.Moment(), + cirq.Moment(cirq.H(a)), + cirq.Moment(), + cirq.Moment(cirq.H(b)), + ), + ) + ) + # Correct tie breaker when one operation sees two possible hits. + for cz_order in [cirq.CZ(a, b), cirq.CZ(b, a)]: + assert 3 == len( + f( + cirq.Circuit(cirq.Moment(cz_order), cirq.Moment(), cirq.Moment()), + cirq.Circuit(cirq.Moment(cirq.H(a)), cirq.Moment(cirq.H(b))), + ) + ) + + # Types. + v = ha.freeze().tetris_concat(empty) + assert type(v) is cirq.FrozenCircuit and v == ha.freeze() + v = ha.tetris_concat(empty.freeze()) + assert type(v) is cirq.Circuit and v == ha + v = ha.freeze().tetris_concat(empty) + assert type(v) is cirq.FrozenCircuit and v == ha.freeze() + v = cirq.Circuit.tetris_concat(ha, empty) + assert type(v) is cirq.Circuit and v == ha + v = cirq.FrozenCircuit.tetris_concat(ha, empty) + assert type(v) is cirq.FrozenCircuit and v == ha.freeze() + + +def test_tetris_concat_alignment_deprecated(): + a, b = cirq.LineQubit.range(2) + + with cirq.testing.assert_deprecated('ragged', deadline='v0.16', count=None): + + assert cirq.Circuit.tetris_concat( + cirq.Circuit(cirq.X(a)), + cirq.Circuit(cirq.Y(b)) * 4, + cirq.Circuit(cirq.Z(a)), + align='first', + ) == cirq.Circuit( + cirq.Moment(cirq.X(a), cirq.Y(b)), + cirq.Moment(cirq.Y(b)), + cirq.Moment(cirq.Y(b)), + cirq.Moment(cirq.Z(a), cirq.Y(b)), + ) + + assert cirq.Circuit.tetris_concat( + cirq.Circuit(cirq.X(a)), + cirq.Circuit(cirq.Y(b)) * 4, + cirq.Circuit(cirq.Z(a)), + align='left', + ) == cirq.Circuit( + cirq.Moment(cirq.X(a), cirq.Y(b)), + cirq.Moment(cirq.Z(a), cirq.Y(b)), + cirq.Moment(cirq.Y(b)), + cirq.Moment(cirq.Y(b)), + ) + + assert cirq.Circuit.tetris_concat( + cirq.Circuit(cirq.X(a)), + cirq.Circuit(cirq.Y(b)) * 4, + cirq.Circuit(cirq.Z(a)), + align='right', + ) == cirq.Circuit( + cirq.Moment(cirq.Y(b)), + cirq.Moment(cirq.Y(b)), + cirq.Moment(cirq.Y(b)), + cirq.Moment(cirq.X(a), cirq.Y(b)), + cirq.Moment(cirq.Z(a)), + ) + + +def test_concat_ragged(): + a, b = cirq.LineQubit.range(2) + empty = cirq.Circuit() + + assert cirq.Circuit.concat_ragged(empty, empty) == empty + assert cirq.Circuit.concat_ragged() == empty + assert empty.concat_ragged(empty) == empty + assert empty.concat_ragged(empty, empty) == empty ha = cirq.Circuit(cirq.H(a)) hb = cirq.Circuit(cirq.H(b)) - assert ha.tetris_concat(hb) == ha.zip(hb) + assert ha.concat_ragged(hb) == ha.zip(hb) - assert ha.tetris_concat(empty) == ha - assert empty.tetris_concat(ha) == ha + assert ha.concat_ragged(empty) == ha + assert empty.concat_ragged(ha) == ha hac = cirq.Circuit(cirq.H(a), cirq.CNOT(a, b)) - assert hac.tetris_concat(hb) == hac + hb - assert hb.tetris_concat(hac) == hb.zip(hac) + assert hac.concat_ragged(hb) == hac + hb + assert hb.concat_ragged(hac) == hb.zip(hac) zig = cirq.Circuit(cirq.H(a), cirq.CNOT(a, b), cirq.H(b)) - assert zig.tetris_concat(zig) == cirq.Circuit( + assert zig.concat_ragged(zig) == cirq.Circuit( cirq.H(a), cirq.CNOT(a, b), cirq.Moment(cirq.H(a), cirq.H(b)), cirq.CNOT(a, b), cirq.H(b) ) zag = cirq.Circuit(cirq.H(a), cirq.H(a), cirq.CNOT(a, b), cirq.H(b), cirq.H(b)) - assert zag.tetris_concat(zag) == cirq.Circuit( + assert zag.concat_ragged(zag) == cirq.Circuit( cirq.H(a), cirq.H(a), cirq.CNOT(a, b), @@ -4355,7 +4562,7 @@ def test_tetris_concat(): ) space = cirq.Circuit(cirq.Moment()) * 10 - f = cirq.Circuit.tetris_concat + f = cirq.Circuit.concat_ragged assert len(f(space, ha)) == 10 assert len(f(space, ha, ha, ha)) == 10 assert len(f(space, f(ha, ha, ha))) == 10 @@ -4464,22 +4671,22 @@ def test_tetris_concat(): ) # Types. - v = ha.freeze().tetris_concat(empty) + v = ha.freeze().concat_ragged(empty) assert type(v) is cirq.FrozenCircuit and v == ha.freeze() - v = ha.tetris_concat(empty.freeze()) + v = ha.concat_ragged(empty.freeze()) assert type(v) is cirq.Circuit and v == ha - v = ha.freeze().tetris_concat(empty) + v = ha.freeze().concat_ragged(empty) assert type(v) is cirq.FrozenCircuit and v == ha.freeze() - v = cirq.Circuit.tetris_concat(ha, empty) + v = cirq.Circuit.concat_ragged(ha, empty) assert type(v) is cirq.Circuit and v == ha - v = cirq.FrozenCircuit.tetris_concat(ha, empty) + v = cirq.FrozenCircuit.concat_ragged(ha, empty) assert type(v) is cirq.FrozenCircuit and v == ha.freeze() -def test_tetris_concat_alignment(): +def test_concat_ragged_alignment(): a, b = cirq.LineQubit.range(2) - assert cirq.Circuit.tetris_concat( + assert cirq.Circuit.concat_ragged( cirq.Circuit(cirq.X(a)), cirq.Circuit(cirq.Y(b)) * 4, cirq.Circuit(cirq.Z(a)), align='first' ) == cirq.Circuit( cirq.Moment(cirq.X(a), cirq.Y(b)), @@ -4488,7 +4695,7 @@ def test_tetris_concat_alignment(): cirq.Moment(cirq.Z(a), cirq.Y(b)), ) - assert cirq.Circuit.tetris_concat( + assert cirq.Circuit.concat_ragged( cirq.Circuit(cirq.X(a)), cirq.Circuit(cirq.Y(b)) * 4, cirq.Circuit(cirq.Z(a)), align='left' ) == cirq.Circuit( cirq.Moment(cirq.X(a), cirq.Y(b)), @@ -4497,7 +4704,7 @@ def test_tetris_concat_alignment(): cirq.Moment(cirq.Y(b)), ) - assert cirq.Circuit.tetris_concat( + assert cirq.Circuit.concat_ragged( cirq.Circuit(cirq.X(a)), cirq.Circuit(cirq.Y(b)) * 4, cirq.Circuit(cirq.Z(a)), align='right' ) == cirq.Circuit( cirq.Moment(cirq.Y(b)), diff --git a/cirq-core/cirq/circuits/frozen_circuit.py b/cirq-core/cirq/circuits/frozen_circuit.py index 13548c2b51e..b6cb08bada0 100644 --- a/cirq-core/cirq/circuits/frozen_circuit.py +++ b/cirq-core/cirq/circuits/frozen_circuit.py @@ -30,7 +30,7 @@ import numpy as np -from cirq import ops, protocols +from cirq import ops, protocols, _compat if TYPE_CHECKING: @@ -176,6 +176,14 @@ def _resolve_parameters_( ) -> 'cirq.FrozenCircuit': return self.unfreeze()._resolve_parameters_(resolver, recursive).freeze() + def concat_ragged( + *circuits: 'cirq.AbstractCircuit', align: Union['cirq.Alignment', str] = Alignment.LEFT + ) -> 'cirq.FrozenCircuit': + return AbstractCircuit.concat_ragged(*circuits, align=align).freeze() + + concat_ragged.__doc__ = AbstractCircuit.concat_ragged.__doc__ + + @_compat.deprecated(deadline='v0.16', fix='Renaming to concat_ragged') def tetris_concat( *circuits: 'cirq.AbstractCircuit', align: Union['cirq.Alignment', str] = Alignment.LEFT ) -> 'cirq.FrozenCircuit':