Skip to content

Commit

Permalink
Return same FrozenCircuit instance from Circuit.freeze() until mutated
Browse files Browse the repository at this point in the history
This simplifies working with frozen circuits and circuit identity by
storing the `FrozenCircuit` instance returned by `Circuit.freeze()` and
returning the same instance until it is "invalidated" by any mutations
of the Circuit itself. Note that if we make additional changes to the
Circuit implementation, such as adding new mutating methods, we will
need to remember to add `self._frozen = None` statements to invalidate
the frozen representation in those places as well. To reduce risk, I
have put these invalidation statements immediately after any mutation
operations on `self._moments`, even in places where some invalidations
could be elided or pushed to the end of a method; it seems safer to keep
these close together in the source code. I have also added an
implementation note calling out this detail and reminding future
developers to invalidate when needed if adding other `Circuit` mutations.
  • Loading branch information
maffoo committed Oct 20, 2023
1 parent 96b3842 commit 298750e
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 13 deletions.
52 changes: 39 additions & 13 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,28 +173,20 @@ def _from_moments(cls: Type[CIRCUIT_TYPE], moments: Iterable['cirq.Moment']) ->
def moments(self) -> Sequence['cirq.Moment']:
pass

@abc.abstractmethod
def freeze(self) -> 'cirq.FrozenCircuit':
"""Creates a FrozenCircuit from this circuit.
If 'self' is a FrozenCircuit, the original object is returned.
"""
from cirq.circuits import FrozenCircuit

if isinstance(self, FrozenCircuit):
return self

return FrozenCircuit(self, strategy=InsertStrategy.EARLIEST)

@abc.abstractmethod
def unfreeze(self, copy: bool = True) -> 'cirq.Circuit':
"""Creates a Circuit from this circuit.
Args:
copy: If True and 'self' is a Circuit, returns a copy that circuit.
"""
if isinstance(self, Circuit):
return Circuit.copy(self) if copy else self

return Circuit(self, strategy=InsertStrategy.EARLIEST)

def __bool__(self):
return bool(self.moments)
Expand Down Expand Up @@ -1743,7 +1735,10 @@ def __init__(
together. This option does not affect later insertions into the
circuit.
"""
# Implementation note: we set self._frozen = None any time self._moments
# is mutated, to "invalidate" the frozen instance.
self._moments: List['cirq.Moment'] = []
self._frozen: Optional['cirq.FrozenCircuit'] = None
flattened_contents = tuple(ops.flatten_to_ops_or_moments(contents))
if all(isinstance(c, Moment) for c in flattened_contents):
self._moments[:] = cast(Iterable[Moment], flattened_contents)
Expand Down Expand Up @@ -1810,12 +1805,29 @@ def _load_contents_with_earliest_strategy(self, contents: 'cirq.OP_TREE'):
for i in range(length):
if i in moments_by_index:
self._moments.append(moments_by_index[i].with_operations(op_lists_by_index[i]))
self._frozen = None
else:
self._moments.append(Moment(op_lists_by_index[i]))
self._frozen = None

def __copy__(self) -> 'cirq.Circuit':
return self.copy()

def freeze(self) -> 'cirq.FrozenCircuit':
"""Gets a frozen version of this circuit.
Repeated calls to `.freeze()` will return the same FrozenCircuit
instance as long as this circuit is not mutated.
"""
from cirq.circuits import FrozenCircuit

if self._frozen is None:
self._frozen = FrozenCircuit.from_moments(*self._moments)
return self._frozen

def unfreeze(self, copy: bool = True) -> 'cirq.Circuit':
return self.copy() if copy else self

def copy(self) -> 'Circuit':
"""Return a copy of this circuit."""
copied_circuit = Circuit()
Expand All @@ -1841,11 +1853,13 @@ def __setitem__(self, key, value):
raise TypeError('Can only assign Moments into Circuits.')

self._moments[key] = value
self._frozen = None

# pylint: enable=function-redefined

def __delitem__(self, key: Union[int, slice]):
del self._moments[key]
self._frozen = None

def __iadd__(self, other):
self.append(other)
Expand Down Expand Up @@ -1874,6 +1888,7 @@ def __imul__(self, repetitions: _INT_TYPE):
if not isinstance(repetitions, (int, np.integer)):
return NotImplemented
self._moments *= int(repetitions)
self._frozen = None
return self

def __mul__(self, repetitions: _INT_TYPE):
Expand Down Expand Up @@ -2017,6 +2032,7 @@ def _pick_or_create_inserted_op_moment_index(

if strategy is InsertStrategy.NEW or strategy is InsertStrategy.NEW_THEN_INLINE:
self._moments.insert(splitter_index, Moment())
self._frozen = None
return splitter_index

if strategy is InsertStrategy.INLINE:
Expand Down Expand Up @@ -2074,13 +2090,16 @@ def insert(
for moment_or_op in list(ops.flatten_to_ops_or_moments(moment_or_operation_tree)):
if isinstance(moment_or_op, Moment):
self._moments.insert(k, moment_or_op)
self._frozen = None
k += 1
else:
op = moment_or_op
p = self._pick_or_create_inserted_op_moment_index(k, op, strategy)
while p >= len(self._moments):
self._moments.append(Moment())
self._frozen = None
self._moments[p] = self._moments[p].with_operation(op)
self._frozen = None
k = max(k, p + 1)
if strategy is InsertStrategy.NEW_THEN_INLINE:
strategy = InsertStrategy.INLINE
Expand Down Expand Up @@ -2119,6 +2138,7 @@ def insert_into_range(self, operations: 'cirq.OP_TREE', start: int, end: int) ->
break

self._moments[i] = self._moments[i].with_operation(op)
self._frozen = None
op_index += 1

if op_index >= len(flat_ops):
Expand Down Expand Up @@ -2165,6 +2185,7 @@ def _push_frontier(
if n_new_moments > 0:
insert_index = min(late_frontier.values())
self._moments[insert_index:insert_index] = [Moment()] * n_new_moments
self._frozen = None
for q in update_qubits:
if early_frontier.get(q, 0) > insert_index:
early_frontier[q] += n_new_moments
Expand All @@ -2191,13 +2212,13 @@ def _insert_operations(
if len(operations) != len(insertion_indices):
raise ValueError('operations and insertion_indices must have the same length.')
self._moments += [Moment() for _ in range(1 + max(insertion_indices) - len(self))]
self._frozen = False
moment_to_ops: Dict[int, List['cirq.Operation']] = defaultdict(list)
for op_index, moment_index in enumerate(insertion_indices):
moment_to_ops[moment_index].append(operations[op_index])
for moment_index, new_ops in moment_to_ops.items():
self._moments[moment_index] = Moment(
self._moments[moment_index].operations + tuple(new_ops)
)
self._moments[moment_index] = self._moments[moment_index].with_operations(*new_ops)
self._frozen = None

def insert_at_frontier(
self,
Expand Down Expand Up @@ -2259,6 +2280,7 @@ def batch_remove(self, removals: Iterable[Tuple[int, 'cirq.Operation']]) -> None
old_op for old_op in copy._moments[i].operations if op != old_op
)
self._moments = copy._moments
self._frozen = None

def batch_replace(
self, replacements: Iterable[Tuple[int, 'cirq.Operation', 'cirq.Operation']]
Expand All @@ -2283,6 +2305,7 @@ def batch_replace(
old_op if old_op != op else new_op for old_op in copy._moments[i].operations
)
self._moments = copy._moments
self._frozen = None

def batch_insert_into(self, insert_intos: Iterable[Tuple[int, 'cirq.OP_TREE']]) -> None:
"""Inserts operations into empty spaces in existing moments.
Expand All @@ -2303,6 +2326,7 @@ def batch_insert_into(self, insert_intos: Iterable[Tuple[int, 'cirq.OP_TREE']])
for i, insertions in insert_intos:
copy._moments[i] = copy._moments[i].with_operations(insertions)
self._moments = copy._moments
self._frozen = None

def batch_insert(self, insertions: Iterable[Tuple[int, 'cirq.OP_TREE']]) -> None:
"""Applies a batched insert operation to the circuit.
Expand Down Expand Up @@ -2337,6 +2361,7 @@ def batch_insert(self, insertions: Iterable[Tuple[int, 'cirq.OP_TREE']]) -> None
if next_index > insert_index:
shift += next_index - insert_index
self._moments = copy._moments
self._frozen = None

def append(
self,
Expand Down Expand Up @@ -2367,6 +2392,7 @@ def clear_operations_touching(
for k in moment_indices:
if 0 <= k < len(self._moments):
self._moments[k] = self._moments[k].without_operations_touching(qubits)
self._frozen = None

@property
def moments(self) -> Sequence['cirq.Moment']:
Expand Down
20 changes: 20 additions & 0 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4521,6 +4521,26 @@ def test_freeze_not_relocate_moments():
assert [mc is fc for mc, fc in zip(c, f)] == [True, True]


def test_freeze_returns_same_instance_if_not_mutated():
q = cirq.q(0)
c = cirq.Circuit(cirq.X(q), cirq.measure(q))
f0 = c.freeze()
f1 = c.freeze()
assert f1 is f0

c.append(cirq.Y(q))
f2 = c.freeze()
f3 = c.freeze()
assert f2 is not f1
assert f3 is f2

c[-1] = cirq.Moment(cirq.Y(q))
f4 = c.freeze()
f5 = c.freeze()
assert f4 is not f3
assert f5 is f4


def test_factorize_one_factor():
circuit = cirq.Circuit()
q0, q1, q2 = cirq.LineQubit.range(3)
Expand Down
6 changes: 6 additions & 0 deletions cirq-core/cirq/circuits/frozen_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'FrozenCircuit':
def moments(self) -> Sequence['cirq.Moment']:
return self._moments

def freeze(self) -> 'cirq.FrozenCircuit':
return self

def unfreeze(self, copy: bool = True) -> 'cirq.Circuit':
return Circuit.from_moments(*self)

@property
def tags(self) -> Tuple[Hashable, ...]:
"""Returns a tuple of the Circuit's tags."""
Expand Down

0 comments on commit 298750e

Please sign in to comment.