Skip to content

Commit

Permalink
Add tags to cirq.FrozenCircuit (#6266)
Browse files Browse the repository at this point in the history
* Add tags to FrozenCircuit

* Address comments and fix tests

* Address maffoo's comments
  • Loading branch information
tanujkhattar committed Aug 29, 2023
1 parent 83609eb commit 041ce5d
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 10 deletions.
11 changes: 6 additions & 5 deletions cirq-core/cirq/circuits/circuit.py
Expand Up @@ -272,12 +272,15 @@ def __getitem__(self, key):
def __str__(self) -> str:
return self.to_text_diagram()

def __repr__(self) -> str:
cls_name = self.__class__.__name__
def _repr_args(self) -> str:
args = []
if self.moments:
args.append(_list_repr_with_indented_item_lines(self.moments))
return f'cirq.{cls_name}({", ".join(args)})'
return f'{", ".join(args)}'

def __repr__(self) -> str:
cls_name = self.__class__.__name__
return f'cirq.{cls_name}({self._repr_args()})'

def _repr_pretty_(self, p: Any, cycle: bool) -> None:
"""Print ASCII diagram in Jupyter."""
Expand Down Expand Up @@ -1791,7 +1794,6 @@ def _load_contents_with_earliest_strategy(self, contents: 'cirq.OP_TREE'):

# "mop" means current moment-or-operation
for mop in ops.flatten_to_ops_or_moments(contents):

# Identify the index of the moment to place this `mop` into.
placement_index = get_earliest_accommodating_moment_index(
mop, qubit_indices, mkey_indices, ckey_indices, length
Expand Down Expand Up @@ -2450,7 +2452,6 @@ def _draw_moment_annotations(
first_annotation_row: int,
transpose: bool,
):

for k, annotation in enumerate(_get_moment_annotations(moment)):
args = protocols.CircuitDiagramInfoArgs(
known_qubits=(),
Expand Down
79 changes: 74 additions & 5 deletions cirq-core/cirq/circuits/frozen_circuit.py
Expand Up @@ -12,7 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""An immutable version of the Circuit data structure."""
from typing import AbstractSet, FrozenSet, Iterable, Iterator, Sequence, Tuple, TYPE_CHECKING, Union
from typing import (
AbstractSet,
FrozenSet,
Hashable,
Iterable,
Iterator,
Sequence,
Tuple,
TYPE_CHECKING,
Union,
)

import numpy as np

Expand All @@ -34,7 +44,10 @@ class FrozenCircuit(AbstractCircuit, protocols.SerializableByKey):
"""

def __init__(
self, *contents: 'cirq.OP_TREE', strategy: 'cirq.InsertStrategy' = InsertStrategy.EARLIEST
self,
*contents: 'cirq.OP_TREE',
strategy: 'cirq.InsertStrategy' = InsertStrategy.EARLIEST,
tags: Sequence[Hashable] = (),
) -> None:
"""Initializes a frozen circuit.
Expand All @@ -47,9 +60,14 @@ def __init__(
strategy: When initializing the circuit with operations and moments
from `contents`, this determines how the operations are packed
together.
tags: A sequence of any type of object that is useful to attach metadata
to this circuit as long as the type is hashable. If you wish the
resulting circuit to be eventually serialized into JSON, you should
also restrict the tags to be JSON serializable.
"""
base = Circuit(contents, strategy=strategy)
self._moments = tuple(base.moments)
self._tags = tuple(tags)

@classmethod
def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'FrozenCircuit':
Expand All @@ -61,10 +79,35 @@ def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'FrozenCircuit':
def moments(self) -> Sequence['cirq.Moment']:
return self._moments

@property
def tags(self) -> Tuple[Hashable, ...]:
"""Returns a tuple of the Circuit's tags."""
return self._tags

@_compat.cached_property
def untagged(self) -> 'cirq.FrozenCircuit':
"""Returns the underlying FrozenCircuit without any tags."""
return self._from_moments(self._moments) if self.tags else self

def with_tags(self, *new_tags: Hashable) -> 'cirq.FrozenCircuit':
"""Creates a new tagged `FrozenCircuit` with `self.tags` and `new_tags` combined."""
if not new_tags:
return self
new_circuit = FrozenCircuit(tags=self.tags + new_tags)
new_circuit._moments = self._moments
return new_circuit

@_compat.cached_method
def __hash__(self) -> int:
# Explicitly cached for performance
return hash((self.moments,))
return hash((self.moments, self.tags))

def __eq__(self, other):
super_eq = super().__eq__(other)
if super_eq is not True:
return super_eq
other_tags = other.tags if isinstance(other, FrozenCircuit) else ()
return self.tags == other_tags

def __getstate__(self):
# Don't save hash when pickling; see #3777.
Expand Down Expand Up @@ -130,11 +173,23 @@ def all_measurement_key_names(self) -> FrozenSet[str]:

@_compat.cached_method
def _is_parameterized_(self) -> bool:
return super()._is_parameterized_()
return super()._is_parameterized_() or any(
protocols.is_parameterized(tag) for tag in self.tags
)

@_compat.cached_method
def _parameter_names_(self) -> AbstractSet[str]:
return super()._parameter_names_()
tag_params = {name for tag in self.tags for name in protocols.parameter_names(tag)}
return super()._parameter_names_() | tag_params

def _resolve_parameters_(
self, resolver: 'cirq.ParamResolver', recursive: bool
) -> 'cirq.FrozenCircuit':
resolved_circuit = super()._resolve_parameters_(resolver, recursive)
resolved_tags = [
protocols.resolve_parameters(tag, resolver, recursive) for tag in self.tags
]
return resolved_circuit.with_tags(*resolved_tags)

def _measurement_key_names_(self) -> FrozenSet[str]:
return self.all_measurement_key_names()
Expand All @@ -161,6 +216,20 @@ def __pow__(self, other) -> 'cirq.FrozenCircuit':
except:
return NotImplemented

def _repr_args(self) -> str:
moments_repr = super()._repr_args()
tag_repr = ','.join(_compat.proper_repr(t) for t in self._tags)
return f'{moments_repr}, tags=[{tag_repr}]' if self.tags else moments_repr

def _json_dict_(self):
attribute_names = ['moments', 'tags'] if self.tags else ['moments']
ret = protocols.obj_to_dict_helper(self, attribute_names)
return ret

@classmethod
def _from_json_dict_(cls, moments, *, tags=(), **kwargs):
return cls(moments, strategy=InsertStrategy.EARLIEST, tags=tags)

def concat_ragged(
*circuits: 'cirq.AbstractCircuit', align: Union['cirq.Alignment', str] = Alignment.LEFT
) -> 'cirq.FrozenCircuit':
Expand Down
32 changes: 32 additions & 0 deletions cirq-core/cirq/circuits/frozen_circuit_test.py
Expand Up @@ -17,6 +17,7 @@
"""

import pytest
import sympy

import cirq

Expand Down Expand Up @@ -74,3 +75,34 @@ def test_immutable():
match="(can't set attribute)|(property 'moments' of 'FrozenCircuit' object has no setter)",
):
c.moments = (cirq.Moment(cirq.H(q)), cirq.Moment(cirq.X(q)))


def test_tagged_circuits():
q = cirq.LineQubit(0)
ops = [cirq.X(q), cirq.H(q)]
tags = [sympy.Symbol("a"), "b"]
circuit = cirq.Circuit(ops)
frozen_circuit = cirq.FrozenCircuit(ops)
tagged_circuit = cirq.FrozenCircuit(ops, tags=tags)
# Test equality
assert tagged_circuit.tags == tuple(tags)
assert circuit == frozen_circuit != tagged_circuit
assert cirq.approx_eq(circuit, frozen_circuit)
assert cirq.approx_eq(frozen_circuit, tagged_circuit)
# Test hash
assert hash(frozen_circuit) != hash(tagged_circuit)
# Test _repr_ and _json_ round trips.
cirq.testing.assert_equivalent_repr(tagged_circuit)
cirq.testing.assert_json_roundtrip_works(tagged_circuit)
# Test utility methods and constructors
assert frozen_circuit.with_tags() is frozen_circuit
assert frozen_circuit.with_tags(*tags) == tagged_circuit
assert tagged_circuit.with_tags("c") == cirq.FrozenCircuit(ops, tags=[*tags, "c"])
assert tagged_circuit.untagged == frozen_circuit
assert frozen_circuit.untagged is frozen_circuit
# Test parameterized protocols
assert cirq.is_parameterized(frozen_circuit) is False
assert cirq.is_parameterized(tagged_circuit) is True
assert cirq.parameter_names(tagged_circuit) == {"a"}
# Tags are not propagated to diagrams yet.
assert str(frozen_circuit) == str(tagged_circuit)

0 comments on commit 041ce5d

Please sign in to comment.