Skip to content

Commit

Permalink
Use MeasurementKey in CircuitOperation (#4086)
Browse files Browse the repository at this point in the history
* Add the concept of measurement key path and use it for nested/repeated CircuitOperations. Also add `with_key_path` protocol.

* Format and docstrings

* json and other fixes

* Change to immutable default for pylint

* Make full_join_string_lists private
  • Loading branch information
smitsanghavi committed May 11, 2021
1 parent 6f9eedc commit b4c445f
Show file tree
Hide file tree
Showing 26 changed files with 620 additions and 135 deletions.
2 changes: 2 additions & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@
Duration,
DURATION_LIKE,
LinearDict,
MEASUREMENT_KEY_SEPARATOR,
MeasurementKey,
PeriodicValue,
RANDOM_STATE_OR_SEED_LIKE,
Expand Down Expand Up @@ -531,6 +532,7 @@
trace_distance_from_angle_list,
unitary,
validate_mixture,
with_key_path,
with_measurement_key_mapping,
)

Expand Down
5 changes: 5 additions & 0 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,11 @@ def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
[protocols.with_measurement_key_mapping(moment, key_map) for moment in self.moments]
)

def _with_key_path_(self, path: Tuple[str, ...]):
return self._with_sliced_moments(
[protocols.with_key_path(moment, path) for moment in self.moments]
)

def _qid_shape_(self) -> Tuple[int, ...]:
return self.qid_shape()

Expand Down
117 changes: 41 additions & 76 deletions cirq-core/cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@
import dataclasses
import numpy as np

from cirq import circuits, ops, protocols, study
from cirq import circuits, ops, protocols, value, study
from cirq._compat import proper_repr

if TYPE_CHECKING:
import cirq


INT_TYPE = Union[int, np.integer]
MEASUREMENT_KEY_SEPARATOR = ':'
REPETITION_ID_SEPARATOR = '-'


def default_repetition_ids(repetitions: int) -> Optional[List[str]]:
Expand All @@ -40,40 +40,18 @@ def default_repetition_ids(repetitions: int) -> Optional[List[str]]:
return None


def cartesian_product_of_string_lists(list1: Optional[List[str]], list2: Optional[List[str]]):
def _full_join_string_lists(list1: Optional[List[str]], list2: Optional[List[str]]):
if list1 is None and list2 is None:
return None # coverage: ignore
if list1 is None:
return list2 # coverage: ignore
if list2 is None:
return list1
return [
f'{MEASUREMENT_KEY_SEPARATOR.join([first, second])}' for first in list1 for second in list2
f'{REPETITION_ID_SEPARATOR.join([first, second])}' for first in list1 for second in list2
]


def split_maybe_indexed_key(maybe_indexed_key: str) -> List[str]:
"""Given a measurement_key, splits into index (series of repetition_ids) and unindexed key
parts. For a key without index, returns the unaltered key in a list. Assumes that the
unindexed measurement key does not contain the MEASUREMENT_KEY_SEPARATOR. This is validated by
the `CircuitOperation` constructor."""
return maybe_indexed_key.rsplit(MEASUREMENT_KEY_SEPARATOR, maxsplit=1)


def get_unindexed_key(maybe_indexed_key: str) -> str:
"""Given a measurement_key, returns the unindexed key part (without the series of prefixed
repetition_ids). For an already unindexed key, returns the unaltered key."""
return split_maybe_indexed_key(maybe_indexed_key)[-1]


def remap_maybe_indexed_key(key_map: Dict[str, str], key: str) -> str:
"""Given a key map and a measurement_key (indexed or unindexed), returns the remapped key in
the same format. Does not modify the index (series of repetition_ids) part, if it exists."""
split_key = split_maybe_indexed_key(key)
split_key[-1] = key_map.get(split_key[-1], split_key[-1])
return MEASUREMENT_KEY_SEPARATOR.join(split_key)


@dataclasses.dataclass(frozen=True)
class CircuitOperation(ops.Operation):
"""An operation that encapsulates a circuit.
Expand All @@ -90,6 +68,7 @@ class CircuitOperation(ops.Operation):
The keys and values should be unindexed (i.e. without repetition_ids).
The values cannot contain the `MEASUREMENT_KEY_SEPARATOR`.
param_resolver: Resolved values for parameters in the circuit.
parent_path: A tuple of identifiers for any parent CircuitOperations containing this one.
repetition_ids: List of identifiers for each repetition of the
CircuitOperation. If populated, the length should be equal to the
repetitions. If not populated and abs(`repetitions`) > 1, it is
Expand All @@ -104,6 +83,7 @@ class CircuitOperation(ops.Operation):
measurement_key_map: Dict[str, str] = dataclasses.field(default_factory=dict)
param_resolver: study.ParamResolver = study.ParamResolver()
repetition_ids: Optional[List[str]] = dataclasses.field(default=None)
parent_path: Tuple[str, ...] = dataclasses.field(default_factory=tuple)

def __post_init__(self):
if not isinstance(self.circuit, circuits.FrozenCircuit):
Expand All @@ -128,27 +108,12 @@ def __post_init__(self):

# Disallow mapping to keys containing the `MEASUREMENT_KEY_SEPARATOR`
for mapped_key in self.measurement_key_map.values():
if MEASUREMENT_KEY_SEPARATOR in mapped_key:
if value.MEASUREMENT_KEY_SEPARATOR in mapped_key:
raise ValueError(
f'Mapping to invalid key: {mapped_key}. "{MEASUREMENT_KEY_SEPARATOR}" '
f'Mapping to invalid key: {mapped_key}. "{value.MEASUREMENT_KEY_SEPARATOR}" '
'is not allowed for measurement keys in a CircuitOperation'
)

# Validate the keys for all direct child measurements. They are not allowed to contain
# `MEASUREMENT_KEY_SEPARATOR`
for _, op in self.circuit.findall_operations(
lambda op: not isinstance(op, CircuitOperation) and protocols.is_measurement(op)
):
for key in protocols.measurement_keys(op):
key = self.measurement_key_map.get(key, key)
if MEASUREMENT_KEY_SEPARATOR in key:
raise ValueError(
f'Measurement {op} found to have invalid key: {key}. '
f'"{MEASUREMENT_KEY_SEPARATOR}" is not allowed for measurement keys '
'in a CircuitOperation. Consider remapping the key using '
'`measurement_key_map` in the CircuitOperation constructor.'
)

# Disallow qid mapping dimension conflicts.
for q, q_new in self.qubit_map.items():
if q_new.dimension != q.dimension:
Expand Down Expand Up @@ -178,6 +143,7 @@ def __eq__(self, other) -> bool:
and self.param_resolver == other.param_resolver
and self.repetitions == other.repetitions
and self.repetition_ids == other.repetition_ids
and self.parent_path == other.parent_path
)

# Methods for getting post-mapping properties of the contained circuit.
Expand All @@ -195,12 +161,20 @@ def _qid_shape_(self) -> Tuple[int, ...]:
return tuple(q.dimension for q in self.qubits)

def _measurement_keys_(self) -> AbstractSet[str]:
circuit_keys = self.circuit.all_measurement_keys()
circuit_keys = [
value.MeasurementKey.parse_serialized(key_str)
for key_str in self.circuit.all_measurement_keys()
]
if self.repetition_ids is not None:
circuit_keys = cartesian_product_of_string_lists(
self.repetition_ids, list(circuit_keys)
)
return {remap_maybe_indexed_key(self.measurement_key_map, key) for key in circuit_keys}
circuit_keys = [
key.with_key_path_prefix(repetition_id)
for repetition_id in self.repetition_ids
for key in circuit_keys
]
return {
str(protocols.with_measurement_key_mapping(key, self.measurement_key_map))
for key in circuit_keys
}

def _parameter_names_(self) -> AbstractSet[str]:
return {
Expand All @@ -225,32 +199,9 @@ def _decompose_(self) -> 'cirq.OP_TREE':
# If it's a measurement circuit with repetitions/repetition_ids, prefix the repetition_ids
# to measurements. Details at https://tinyurl.com/measurement-repeated-circuitop.
ops = [] # type: List[cirq.Operation]
for parent_id in self.repetition_ids:
for op in result.all_operations():
if isinstance(op, CircuitOperation):
# For a CircuitOperation, prefix the current repetition_id to the children
# repetition_ids.
ops.append(
op.with_repetition_ids(
# If `op.repetition_ids` is None, this will return `[parent_id]`.
cartesian_product_of_string_lists([parent_id], op.repetition_ids)
)
)
elif protocols.is_measurement(op):
# For a non-CircuitOperation measurement, prefix the current repetition_id
# to the children measurement keys. Implemented by creating a mapping and
# using the with_measurement_key_mapping protocol.
ops.append(
protocols.with_measurement_key_mapping(
op,
key_map={
key: f'{MEASUREMENT_KEY_SEPARATOR.join([parent_id, key])}'
for key in protocols.measurement_keys(op)
},
)
)
else:
ops.append(op)
for repetition_id in self.repetition_ids:
path = self.parent_path + (repetition_id,)
ops += protocols.with_key_path(result, path).all_operations()
return ops

# Methods for string representation of the operation.
Expand All @@ -265,6 +216,8 @@ def __repr__(self):
args += f'measurement_key_map={proper_repr(self.measurement_key_map)},\n'
if self.param_resolver:
args += f'param_resolver={proper_repr(self.param_resolver)},\n'
if self.parent_path:
args += f'parent_path={proper_repr(self.parent_path)},\n'
if self.repetition_ids != self._default_repetition_ids():
# Default repetition_ids need not be specified.
args += f'repetition_ids={proper_repr(self.repetition_ids)},\n'
Expand All @@ -291,6 +244,8 @@ def dict_str(d: Dict) -> str:
args.append(f'key_map={dict_str(self.measurement_key_map)}')
if self.param_resolver:
args.append(f'params={self.param_resolver.param_dict}')
if self.parent_path:
args.append(f'parent_path={self.parent_path}')
if self.repetition_ids != self._default_repetition_ids():
# Default repetition_ids need not be specified.
args.append(f'repetition_ids={self.repetition_ids}')
Expand All @@ -313,6 +268,7 @@ def __hash__(self):
frozenset(self.qubit_map.items()),
frozenset(self.measurement_key_map.items()),
self.param_resolver,
self.parent_path,
tuple([] if self.repetition_ids is None else self.repetition_ids),
)
),
Expand All @@ -330,6 +286,7 @@ def _json_dict_(self):
'measurement_key_map': self.measurement_key_map,
'param_resolver': self.param_resolver,
'repetition_ids': self.repetition_ids,
'parent_path': self.parent_path,
}

@classmethod
Expand All @@ -341,13 +298,15 @@ def _from_json_dict_(
measurement_key_map,
param_resolver,
repetition_ids,
parent_path=(),
**kwargs,
):
return (
cls(circuit)
.with_qubit_mapping(dict(qubit_map))
.with_measurement_key_mapping(measurement_key_map)
.with_params(param_resolver)
.with_key_path(tuple(parent_path))
.repeat(repetitions, repetition_ids)
)

Expand Down Expand Up @@ -408,13 +367,19 @@ def repeat(
)

# If `self.repetition_ids` is None, this will just return `repetition_ids`.
repetition_ids = cartesian_product_of_string_lists(repetition_ids, self.repetition_ids)
repetition_ids = _full_join_string_lists(repetition_ids, self.repetition_ids)

return self.replace(repetitions=final_repetitions, repetition_ids=repetition_ids)

def __pow__(self, power: int) -> 'CircuitOperation':
return self.repeat(power)

def _with_key_path_(self, path: Tuple[str, ...]):
return dataclasses.replace(self, parent_path=path)

def with_key_path(self, path: Tuple[str, ...]):
return self._with_key_path_(path)

def with_repetition_ids(self, repetition_ids: List[str]) -> 'CircuitOperation':
return self.replace(repetition_ids=repetition_ids)

Expand Down Expand Up @@ -501,7 +466,7 @@ def with_measurement_key_mapping(self, key_map: Dict[str, str]) -> 'CircuitOpera
"""
new_map = {}
for k in self.circuit.all_measurement_keys():
k = get_unindexed_key(k)
k = value.MeasurementKey.parse_serialized(k).name
k_new = self.measurement_key_map.get(k, k)
k_new = key_map.get(k_new, k_new)
if k_new != k:
Expand Down

0 comments on commit b4c445f

Please sign in to comment.