From 2ee8706159246f3439b86913302e61eb4660c170 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Sat, 18 Jun 2022 22:56:36 -0700 Subject: [PATCH 1/6] frozenset --- cirq-core/cirq/circuits/circuit.py | 16 ++++--- cirq-core/cirq/circuits/circuit_operation.py | 24 +++++----- cirq-core/cirq/circuits/frozen_circuit.py | 32 +++++--------- cirq-core/cirq/circuits/moment.py | 5 +-- cirq-core/cirq/ops/gate_operation.py | 4 +- cirq-core/cirq/ops/raw_types.py | 6 +-- .../cirq/protocols/control_key_protocol.py | 19 +++++--- .../protocols/control_key_protocol_test.py | 11 ++++- .../protocols/measurement_key_protocol.py | 44 ++++++++++++------- .../measurement_key_protocol_test.py | 25 ++++++++--- 10 files changed, 111 insertions(+), 75 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index c6b00726a2c..f3906e2da35 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -918,10 +918,12 @@ def qid_shape( qids = ops.QubitOrder.as_qubit_order(qubit_order).order_for(self.all_qubits()) return protocols.qid_shape(qids) - def all_measurement_key_objs(self) -> AbstractSet['cirq.MeasurementKey']: - return {key for op in self.all_operations() for key in protocols.measurement_key_objs(op)} + def all_measurement_key_objs(self) -> FrozenSet['cirq.MeasurementKey']: + return frozenset( + key for op in self.all_operations() for key in protocols.measurement_key_objs(op) + ) - def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: + def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']: """Returns the set of all measurement keys in this circuit. Returns: AbstractSet of `cirq.MeasurementKey` objects that are @@ -929,15 +931,17 @@ def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: """ return self.all_measurement_key_objs() - def all_measurement_key_names(self) -> AbstractSet[str]: + def all_measurement_key_names(self) -> FrozenSet[str]: """Returns the set of all measurement key names in this circuit. Returns: AbstractSet of strings that are the measurement key names in this circuit. """ - return {key for op in self.all_operations() for key in protocols.measurement_key_names(op)} + return frozenset( + key for op in self.all_operations() for key in protocols.measurement_key_names(op) + ) - def _measurement_key_names_(self) -> AbstractSet[str]: + def _measurement_key_names_(self) -> FrozenSet[str]: return self.all_measurement_key_names() def _with_measurement_key_mapping_(self, key_map: Dict[str, str]): diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index f4049685204..a9b9f2365c1 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -111,10 +111,10 @@ class CircuitOperation(ops.Operation): """ _hash: Optional[int] = dataclasses.field(default=None, init=False) - _cached_measurement_key_objs: Optional[AbstractSet['cirq.MeasurementKey']] = dataclasses.field( + _cached_measurement_key_objs: Optional[FrozenSet['cirq.MeasurementKey']] = dataclasses.field( default=None, init=False ) - _cached_control_keys: Optional[AbstractSet['cirq.MeasurementKey']] = dataclasses.field( + _cached_control_keys: Optional[FrozenSet['cirq.MeasurementKey']] = dataclasses.field( default=None, init=False ) _cached_mapped_single_loop: Optional['cirq.Circuit'] = dataclasses.field( @@ -243,32 +243,34 @@ def _ensure_deterministic_loop_count(self): if self.repeat_until or isinstance(self.repetitions, sympy.Expr): raise ValueError('Cannot unroll circuit due to nondeterministic repetitions') - def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: + def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']: if self._cached_measurement_key_objs is None: circuit_keys = protocols.measurement_key_objs(self.circuit) if circuit_keys and self.use_repetition_ids: self._ensure_deterministic_loop_count() if self.repetition_ids is not None: - circuit_keys = { + circuit_keys = frozenset( key.with_key_path_prefix(repetition_id) for repetition_id in self.repetition_ids for key in circuit_keys - } - circuit_keys = {key.with_key_path_prefix(*self.parent_path) for key in circuit_keys} + ) + circuit_keys = frozenset( + key.with_key_path_prefix(*self.parent_path) for key in circuit_keys + ) object.__setattr__( self, '_cached_measurement_key_objs', - { + frozenset( protocols.with_measurement_key_mapping(key, self.measurement_key_map) for key in circuit_keys - }, + ), ) return self._cached_measurement_key_objs # type: ignore - def _measurement_key_names_(self) -> AbstractSet[str]: - return {str(key) for key in self._measurement_key_objs_()} + def _measurement_key_names_(self) -> FrozenSet[str]: + return frozenset(str(key) for key in self._measurement_key_objs_()) - def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']: + def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']: if self._cached_control_keys is None: keys = ( frozenset() diff --git a/cirq-core/cirq/circuits/frozen_circuit.py b/cirq-core/cirq/circuits/frozen_circuit.py index 13548c2b51e..a354ff3a2d7 100644 --- a/cirq-core/cirq/circuits/frozen_circuit.py +++ b/cirq-core/cirq/circuits/frozen_circuit.py @@ -12,26 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """An immutable version of the Circuit data structure.""" -from typing import ( - TYPE_CHECKING, - AbstractSet, - FrozenSet, - Iterable, - Iterator, - Optional, - Sequence, - Tuple, - Union, -) - -from cirq.circuits import AbstractCircuit, Alignment, Circuit -from cirq.circuits.insert_strategy import InsertStrategy -from cirq.type_workarounds import NotImplementedType +from typing import TYPE_CHECKING, FrozenSet, Iterable, Iterator, Optional, Sequence, Tuple, Union import numpy as np from cirq import ops, protocols - +from cirq.circuits import AbstractCircuit, Alignment, Circuit +from cirq.circuits.insert_strategy import InsertStrategy +from cirq.type_workarounds import NotImplementedType if TYPE_CHECKING: import cirq @@ -70,7 +58,7 @@ def __init__( self._all_qubits: Optional[FrozenSet['cirq.Qid']] = None self._all_operations: Optional[Tuple[ops.Operation, ...]] = None self._has_measurements: Optional[bool] = None - self._all_measurement_key_objs: Optional[AbstractSet['cirq.MeasurementKey']] = None + self._all_measurement_key_objs: Optional[FrozenSet['cirq.MeasurementKey']] = None self._are_all_measurements_terminal: Optional[bool] = None self._control_keys: Optional[FrozenSet['cirq.MeasurementKey']] = None @@ -118,12 +106,12 @@ def has_measurements(self) -> bool: self._has_measurements = super().has_measurements() return self._has_measurements - def all_measurement_key_objs(self) -> AbstractSet['cirq.MeasurementKey']: + def all_measurement_key_objs(self) -> FrozenSet['cirq.MeasurementKey']: if self._all_measurement_key_objs is None: self._all_measurement_key_objs = super().all_measurement_key_objs() return self._all_measurement_key_objs - def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: + def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']: return self.all_measurement_key_objs() def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']: @@ -138,10 +126,10 @@ def are_all_measurements_terminal(self) -> bool: # End of memoized methods. - def all_measurement_key_names(self) -> AbstractSet[str]: - return {str(key) for key in self.all_measurement_key_objs()} + def all_measurement_key_names(self) -> FrozenSet[str]: + return frozenset(str(key) for key in self.all_measurement_key_objs()) - def _measurement_key_names_(self) -> AbstractSet[str]: + def _measurement_key_names_(self) -> FrozenSet[str]: return self.all_measurement_key_names() def __add__(self, other) -> 'cirq.FrozenCircuit': diff --git a/cirq-core/cirq/circuits/moment.py b/cirq-core/cirq/circuits/moment.py index f2179a7c635..8fe7cf35c91 100644 --- a/cirq-core/cirq/circuits/moment.py +++ b/cirq-core/cirq/circuits/moment.py @@ -16,7 +16,6 @@ import itertools from typing import ( - AbstractSet, Any, Callable, Dict, @@ -238,8 +237,8 @@ def _with_measurement_key_mapping_(self, key_map: Dict[str, str]): for op in self.operations ) - def _measurement_key_names_(self) -> AbstractSet[str]: - return {str(key) for key in self._measurement_key_objs_()} + def _measurement_key_names_(self) -> FrozenSet[str]: + return frozenset(str(key) for key in self._measurement_key_objs_()) def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']: if self._measurement_key_objs is None: diff --git a/cirq-core/cirq/ops/gate_operation.py b/cirq-core/cirq/ops/gate_operation.py index 95e69be6049..665636d5797 100644 --- a/cirq-core/cirq/ops/gate_operation.py +++ b/cirq-core/cirq/ops/gate_operation.py @@ -235,7 +235,7 @@ def _measurement_key_name_(self) -> Optional[str]: return getter() return NotImplemented - def _measurement_key_names_(self) -> Optional[AbstractSet[str]]: + def _measurement_key_names_(self) -> Optional[FrozenSet[str]]: getter = getattr(self.gate, '_measurement_key_names_', None) if getter is not None: return getter() @@ -247,7 +247,7 @@ def _measurement_key_obj_(self) -> Optional['cirq.MeasurementKey']: return getter() return NotImplemented - def _measurement_key_objs_(self) -> Optional[AbstractSet['cirq.MeasurementKey']]: + def _measurement_key_objs_(self) -> Optional[FrozenSet['cirq.MeasurementKey']]: getter = getattr(self.gate, '_measurement_key_objs_', None) if getter is not None: return getter() diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index 63955a36b86..22b1c6b4839 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -820,10 +820,10 @@ def _has_kraus_(self) -> bool: def _kraus_(self) -> Union[Tuple[np.ndarray], NotImplementedType]: return protocols.kraus(self.sub_operation, NotImplemented) - def _measurement_key_names_(self) -> AbstractSet[str]: + def _measurement_key_names_(self) -> FrozenSet[str]: return protocols.measurement_key_names(self.sub_operation) - def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: + def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']: return protocols.measurement_key_objs(self.sub_operation) def _is_measurement_(self) -> bool: @@ -905,7 +905,7 @@ def with_classical_controls( return self return self.sub_operation.with_classical_controls(*conditions) - def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']: + def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']: return protocols.control_keys(self.sub_operation) diff --git a/cirq-core/cirq/protocols/control_key_protocol.py b/cirq-core/cirq/protocols/control_key_protocol.py index ef39362eed7..5f8ffbce5f4 100644 --- a/cirq-core/cirq/protocols/control_key_protocol.py +++ b/cirq-core/cirq/protocols/control_key_protocol.py @@ -13,10 +13,11 @@ # limitations under the License. """Protocol for object that have control keys.""" -from typing import AbstractSet, Any, Iterable, TYPE_CHECKING +from typing import Any, FrozenSet, TYPE_CHECKING from typing_extensions import Protocol +from cirq import _compat from cirq._doc import doc_private from cirq.protocols import measurement_key_protocol @@ -34,7 +35,7 @@ class SupportsControlKey(Protocol): """ @doc_private - def _control_keys_(self) -> Iterable['cirq.MeasurementKey']: + def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']: """Return the keys for controls referenced by the receiving object. Returns: @@ -43,7 +44,7 @@ def _control_keys_(self) -> Iterable['cirq.MeasurementKey']: """ -def control_keys(val: Any) -> AbstractSet['cirq.MeasurementKey']: +def control_keys(val: Any) -> FrozenSet['cirq.MeasurementKey']: """Gets the keys that the value is classically controlled by. Args: @@ -56,12 +57,18 @@ def control_keys(val: Any) -> AbstractSet['cirq.MeasurementKey']: getter = getattr(val, '_control_keys_', None) result = NotImplemented if getter is None else getter() if result is not NotImplemented and result is not None: - return set(result) + if not isinstance(result, FrozenSet): + _compat._warn_or_error( + f'The _control_keys_ implementation of {type(val)} must return a' + f' frozenset instead of {type(result)} by v0.16.' + ) + return frozenset(result) + return result - return set() + return frozenset() -def measurement_keys_touched(val: Any) -> AbstractSet['cirq.MeasurementKey']: +def measurement_keys_touched(val: Any) -> FrozenSet['cirq.MeasurementKey']: """Returns all the measurement keys used by the value. This would be the case if the value is or contains a measurement gate, or diff --git a/cirq-core/cirq/protocols/control_key_protocol_test.py b/cirq-core/cirq/protocols/control_key_protocol_test.py index 7abee9cea42..4b72aecf277 100644 --- a/cirq-core/cirq/protocols/control_key_protocol_test.py +++ b/cirq-core/cirq/protocols/control_key_protocol_test.py @@ -18,7 +18,7 @@ def test_control_key(): class Named: def _control_keys_(self): - return [cirq.MeasurementKey('key')] + return frozenset([cirq.MeasurementKey('key')]) class NoImpl: def _control_keys_(self): @@ -27,3 +27,12 @@ def _control_keys_(self): assert cirq.control_keys(Named()) == {cirq.MeasurementKey('key')} assert not cirq.control_keys(NoImpl()) assert not cirq.control_keys(5) + + +def test_control_key_enumerable_deprecated(): + class Deprecated: + def _control_keys_(self): + return [cirq.MeasurementKey('key')] + + with cirq.testing.assert_deprecated('frozenset', deadline='v0.16'): + assert cirq.control_keys(Deprecated()) == {cirq.MeasurementKey('key')} diff --git a/cirq-core/cirq/protocols/measurement_key_protocol.py b/cirq-core/cirq/protocols/measurement_key_protocol.py index ac7fb637a13..26c2c36560d 100644 --- a/cirq-core/cirq/protocols/measurement_key_protocol.py +++ b/cirq-core/cirq/protocols/measurement_key_protocol.py @@ -13,11 +13,11 @@ # limitations under the License. """Protocol for object that have measurement keys.""" -from typing import AbstractSet, Any, Dict, FrozenSet, Optional, Tuple, TYPE_CHECKING +from typing import Any, Dict, FrozenSet, Optional, Tuple, TYPE_CHECKING from typing_extensions import Protocol -from cirq import value +from cirq import value, _compat from cirq._doc import doc_private if TYPE_CHECKING: @@ -68,7 +68,7 @@ def _measurement_key_obj_(self) -> 'cirq.MeasurementKey': """ @doc_private - def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: + def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']: """Return the key objects for measurements performed by the receiving object. When a measurement occurs, either on hardware, or in a simulation, @@ -86,7 +86,7 @@ def _measurement_key_name_(self) -> str: """ @doc_private - def _measurement_key_names_(self) -> AbstractSet[str]: + def _measurement_key_names_(self) -> FrozenSet[str]: """Return the string keys for measurements performed by the receiving object. When a measurement occurs, either on hardware, or in a simulation, @@ -172,39 +172,51 @@ def measurement_key_name(val: Any, default: Any = RaiseTypeErrorIfNotProvided): def _measurement_key_objs_from_magic_methods( val: Any, -) -> Optional[AbstractSet['cirq.MeasurementKey']]: +) -> Optional[FrozenSet['cirq.MeasurementKey']]: """Uses the measurement key related magic methods to get the `MeasurementKey`s for this object.""" getter = getattr(val, '_measurement_key_objs_', None) result = NotImplemented if getter is None else getter() if result is not NotImplemented and result is not None: - return set(result) + if not isinstance(result, FrozenSet): + _compat._warn_or_error( + f'The _control_keys_ implementation of {type(val)} must return a' + f' frozenset instead of {type(result)} by v0.16.' + ) + return frozenset(result) + return result getter = getattr(val, '_measurement_key_obj_', None) result = NotImplemented if getter is None else getter() if result is not NotImplemented and result is not None: - return {result} + return frozenset([result]) return result -def _measurement_key_names_from_magic_methods(val: Any) -> Optional[AbstractSet[str]]: +def _measurement_key_names_from_magic_methods(val: Any) -> Optional[FrozenSet[str]]: """Uses the measurement key related magic methods to get the key strings for this object.""" getter = getattr(val, '_measurement_key_names_', None) result = NotImplemented if getter is None else getter() if result is not NotImplemented and result is not None: - return set(result) + if not isinstance(result, FrozenSet): + _compat._warn_or_error( + f'The _control_keys_ implementation of {type(val)} must return a' + f' frozenset instead of {type(result)} by v0.16.' + ) + return frozenset(result) + return result getter = getattr(val, '_measurement_key_name_', None) result = NotImplemented if getter is None else getter() if result is not NotImplemented and result is not None: - return {result} + return frozenset([result]) return result -def measurement_key_objs(val: Any) -> AbstractSet['cirq.MeasurementKey']: +def measurement_key_objs(val: Any) -> FrozenSet['cirq.MeasurementKey']: """Gets the measurement key objects of measurements within the given value. Args: @@ -219,11 +231,11 @@ def measurement_key_objs(val: Any) -> AbstractSet['cirq.MeasurementKey']: return result key_strings = _measurement_key_names_from_magic_methods(val) if key_strings is not NotImplemented and key_strings is not None: - return {value.MeasurementKey.parse_serialized(key_str) for key_str in key_strings} - return set() + return frozenset(value.MeasurementKey.parse_serialized(key_str) for key_str in key_strings) + return frozenset() -def measurement_key_names(val: Any) -> AbstractSet[str]: +def measurement_key_names(val: Any) -> FrozenSet[str]: """Gets the measurement key strings of measurements within the given value. Args: @@ -244,8 +256,8 @@ def measurement_key_names(val: Any) -> AbstractSet[str]: return result key_objs = _measurement_key_objs_from_magic_methods(val) if key_objs is not NotImplemented and key_objs is not None: - return {str(key_obj) for key_obj in key_objs} - return set() + return frozenset(str(key_obj) for key_obj in key_objs) + return frozenset() def _is_measurement_from_magic_method(val: Any) -> Optional[bool]: diff --git a/cirq-core/cirq/protocols/measurement_key_protocol_test.py b/cirq-core/cirq/protocols/measurement_key_protocol_test.py index 38398e9b880..29731db5341 100644 --- a/cirq-core/cirq/protocols/measurement_key_protocol_test.py +++ b/cirq-core/cirq/protocols/measurement_key_protocol_test.py @@ -158,10 +158,10 @@ def num_qubits(self) -> int: def test_measurement_keys(key_method, keys): class MeasurementKeysGate(cirq.Gate): def _measurement_key_names_(self): - return ['a', 'b'] + return frozenset(['a', 'b']) def _measurement_key_objs_(self): - return [cirq.MeasurementKey('c'), cirq.MeasurementKey('d')] + return frozenset([cirq.MeasurementKey('c'), cirq.MeasurementKey('d')]) def num_qubits(self) -> int: return 1 @@ -183,7 +183,7 @@ def num_qubits(self) -> int: def test_measurement_key_mapping(): class MultiKeyGate: def __init__(self, keys): - self._keys = set(keys) + self._keys = frozenset(keys) def _measurement_key_names_(self): return self._keys @@ -220,10 +220,10 @@ def _with_measurement_key_mapping_(self, key_map): def test_measurement_key_path(): class MultiKeyGate: def __init__(self, keys): - self._keys = set([cirq.MeasurementKey.parse_serialized(key) for key in keys]) + self._keys = frozenset([cirq.MeasurementKey.parse_serialized(key) for key in keys]) def _measurement_key_names_(self): - return {str(key) for key in self._keys} + return frozenset([str(key) for key in self._keys]) def _with_key_path_(self, path): return MultiKeyGate([str(key._with_key_path_(path)) for key in self._keys]) @@ -238,3 +238,18 @@ def _with_key_path_(self, path): assert cirq.measurement_key_names(mkg_cd) == {'c:d:a', 'c:d:b'} assert cirq.with_key_path(cirq.X, ('c', 'd')) is NotImplemented + + +def test_measurement_key_enumerable_deprecated(): + class Deprecated: + def _measurement_key_objs_(self): + return [cirq.MeasurementKey('key')] + + def _measurement_key_names_(self): + return ['key'] + + with cirq.testing.assert_deprecated('frozenset', deadline='v0.16'): + assert cirq.measurement_key_objs(Deprecated()) == {cirq.MeasurementKey('key')} + + with cirq.testing.assert_deprecated('frozenset', deadline='v0.16'): + assert cirq.measurement_key_names(Deprecated()) == {'key'} From fb81f404875ddbb4d32c136c6649f1f272345c1b Mon Sep 17 00:00:00 2001 From: daxfohl Date: Sat, 18 Jun 2022 22:59:37 -0700 Subject: [PATCH 2/6] names --- cirq-core/cirq/protocols/measurement_key_protocol.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/protocols/measurement_key_protocol.py b/cirq-core/cirq/protocols/measurement_key_protocol.py index 26c2c36560d..6032827d291 100644 --- a/cirq-core/cirq/protocols/measurement_key_protocol.py +++ b/cirq-core/cirq/protocols/measurement_key_protocol.py @@ -181,7 +181,7 @@ def _measurement_key_objs_from_magic_methods( if result is not NotImplemented and result is not None: if not isinstance(result, FrozenSet): _compat._warn_or_error( - f'The _control_keys_ implementation of {type(val)} must return a' + f'The _measurement_key_objs_ implementation of {type(val)} must return a' f' frozenset instead of {type(result)} by v0.16.' ) return frozenset(result) @@ -202,7 +202,7 @@ def _measurement_key_names_from_magic_methods(val: Any) -> Optional[FrozenSet[st if result is not NotImplemented and result is not None: if not isinstance(result, FrozenSet): _compat._warn_or_error( - f'The _control_keys_ implementation of {type(val)} must return a' + f'The _measurement_key_names_ implementation of {type(val)} must return a' f' frozenset instead of {type(result)} by v0.16.' ) return frozenset(result) From 54b9a33bbe2a41eaad6832b466d51d37c432c4f2 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Tue, 21 Jun 2022 15:15:28 -0700 Subject: [PATCH 3/6] Update docstrings, remove intermediate lists --- cirq-core/cirq/circuits/circuit.py | 4 ++-- cirq-core/cirq/protocols/measurement_key_protocol_test.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index f3906e2da35..0d80d861719 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -926,7 +926,7 @@ def all_measurement_key_objs(self) -> FrozenSet['cirq.MeasurementKey']: def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']: """Returns the set of all measurement keys in this circuit. - Returns: AbstractSet of `cirq.MeasurementKey` objects that are + Returns: FrozenSet of `cirq.MeasurementKey` objects that are in this circuit. """ return self.all_measurement_key_objs() @@ -934,7 +934,7 @@ def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']: def all_measurement_key_names(self) -> FrozenSet[str]: """Returns the set of all measurement key names in this circuit. - Returns: AbstractSet of strings that are the measurement key + Returns: FrozenSet of strings that are the measurement key names in this circuit. """ return frozenset( diff --git a/cirq-core/cirq/protocols/measurement_key_protocol_test.py b/cirq-core/cirq/protocols/measurement_key_protocol_test.py index 29731db5341..9b24681ba26 100644 --- a/cirq-core/cirq/protocols/measurement_key_protocol_test.py +++ b/cirq-core/cirq/protocols/measurement_key_protocol_test.py @@ -220,10 +220,10 @@ def _with_measurement_key_mapping_(self, key_map): def test_measurement_key_path(): class MultiKeyGate: def __init__(self, keys): - self._keys = frozenset([cirq.MeasurementKey.parse_serialized(key) for key in keys]) + self._keys = frozenset(cirq.MeasurementKey.parse_serialized(key) for key in keys) def _measurement_key_names_(self): - return frozenset([str(key) for key in self._keys]) + return frozenset(str(key) for key in self._keys) def _with_key_path_(self, path): return MultiKeyGate([str(key._with_key_path_(path)) for key in self._keys]) From fce7980b115abd3866b5df2fecd880763eacfa4a Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 23 Jun 2022 22:16:41 -0700 Subject: [PATCH 4/6] Fix typings --- cirq-core/cirq/ops/gate_operation.py | 6 ++++-- cirq-core/cirq/protocols/control_key_protocol.py | 5 +++-- .../cirq/protocols/measurement_key_protocol.py | 15 ++++++++++----- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/cirq-core/cirq/ops/gate_operation.py b/cirq-core/cirq/ops/gate_operation.py index 665636d5797..f17612b2c49 100644 --- a/cirq-core/cirq/ops/gate_operation.py +++ b/cirq-core/cirq/ops/gate_operation.py @@ -235,7 +235,7 @@ def _measurement_key_name_(self) -> Optional[str]: return getter() return NotImplemented - def _measurement_key_names_(self) -> Optional[FrozenSet[str]]: + def _measurement_key_names_(self) -> Union[FrozenSet[str], NotImplementedType, None]: getter = getattr(self.gate, '_measurement_key_names_', None) if getter is not None: return getter() @@ -247,7 +247,9 @@ def _measurement_key_obj_(self) -> Optional['cirq.MeasurementKey']: return getter() return NotImplemented - def _measurement_key_objs_(self) -> Optional[FrozenSet['cirq.MeasurementKey']]: + def _measurement_key_objs_( + self, + ) -> Union[FrozenSet['cirq.MeasurementKey'], NotImplementedType, None]: getter = getattr(self.gate, '_measurement_key_objs_', None) if getter is not None: return getter() diff --git a/cirq-core/cirq/protocols/control_key_protocol.py b/cirq-core/cirq/protocols/control_key_protocol.py index 5f8ffbce5f4..f8897734918 100644 --- a/cirq-core/cirq/protocols/control_key_protocol.py +++ b/cirq-core/cirq/protocols/control_key_protocol.py @@ -13,13 +13,14 @@ # limitations under the License. """Protocol for object that have control keys.""" -from typing import Any, FrozenSet, TYPE_CHECKING +from typing import Any, FrozenSet, TYPE_CHECKING, Union from typing_extensions import Protocol from cirq import _compat from cirq._doc import doc_private from cirq.protocols import measurement_key_protocol +from cirq.type_workarounds import NotImplementedType if TYPE_CHECKING: import cirq @@ -35,7 +36,7 @@ class SupportsControlKey(Protocol): """ @doc_private - def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']: + def _control_keys_(self) -> Union[FrozenSet['cirq.MeasurementKey'], NotImplementedType, None]: """Return the keys for controls referenced by the receiving object. Returns: diff --git a/cirq-core/cirq/protocols/measurement_key_protocol.py b/cirq-core/cirq/protocols/measurement_key_protocol.py index 6032827d291..e1164298361 100644 --- a/cirq-core/cirq/protocols/measurement_key_protocol.py +++ b/cirq-core/cirq/protocols/measurement_key_protocol.py @@ -13,12 +13,13 @@ # limitations under the License. """Protocol for object that have measurement keys.""" -from typing import Any, Dict, FrozenSet, Optional, Tuple, TYPE_CHECKING +from typing import Any, Dict, FrozenSet, Optional, Tuple, TYPE_CHECKING, Union from typing_extensions import Protocol from cirq import value, _compat from cirq._doc import doc_private +from cirq.type_workarounds import NotImplementedType if TYPE_CHECKING: import cirq @@ -68,7 +69,9 @@ def _measurement_key_obj_(self) -> 'cirq.MeasurementKey': """ @doc_private - def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']: + def _measurement_key_objs_( + self, + ) -> Union[FrozenSet['cirq.MeasurementKey'], NotImplementedType, None]: """Return the key objects for measurements performed by the receiving object. When a measurement occurs, either on hardware, or in a simulation, @@ -86,7 +89,7 @@ def _measurement_key_name_(self) -> str: """ @doc_private - def _measurement_key_names_(self) -> FrozenSet[str]: + def _measurement_key_names_(self) -> Union[FrozenSet[str], NotImplementedType, None]: """Return the string keys for measurements performed by the receiving object. When a measurement occurs, either on hardware, or in a simulation, @@ -172,7 +175,7 @@ def measurement_key_name(val: Any, default: Any = RaiseTypeErrorIfNotProvided): def _measurement_key_objs_from_magic_methods( val: Any, -) -> Optional[FrozenSet['cirq.MeasurementKey']]: +) -> Union[FrozenSet['cirq.MeasurementKey'], NotImplementedType, None]: """Uses the measurement key related magic methods to get the `MeasurementKey`s for this object.""" @@ -194,7 +197,9 @@ def _measurement_key_objs_from_magic_methods( return result -def _measurement_key_names_from_magic_methods(val: Any) -> Optional[FrozenSet[str]]: +def _measurement_key_names_from_magic_methods( + val: Any, +) -> Union[FrozenSet[str], NotImplementedType, None]: """Uses the measurement key related magic methods to get the key strings for this object.""" getter = getattr(val, '_measurement_key_names_', None) From 631b25a039216c46dbc31879b68164582604b14d Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 23 Jun 2022 22:24:36 -0700 Subject: [PATCH 5/6] Fix circuitop merge --- cirq-core/cirq/circuits/circuit_operation.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index 853244eadd4..ddfad400912 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -19,7 +19,6 @@ """ import math from typing import ( - AbstractSet, Callable, Mapping, Sequence, @@ -309,7 +308,7 @@ def _ensure_deterministic_loop_count(self): raise ValueError('Cannot unroll circuit due to nondeterministic repetitions') @cached_property - def _measurement_key_objs(self) -> AbstractSet['cirq.MeasurementKey']: + def _measurement_key_objs(self) -> FrozenSet['cirq.MeasurementKey']: circuit_keys = protocols.measurement_key_objs(self.circuit) if circuit_keys and self.use_repetition_ids: self._ensure_deterministic_loop_count() @@ -320,19 +319,19 @@ def _measurement_key_objs(self) -> AbstractSet['cirq.MeasurementKey']: for key in circuit_keys } circuit_keys = {key.with_key_path_prefix(*self.parent_path) for key in circuit_keys} - return { + return frozenset( protocols.with_measurement_key_mapping(key, dict(self.measurement_key_map)) for key in circuit_keys - } + ) - def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: + def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']: return self._measurement_key_objs def _measurement_key_names_(self) -> FrozenSet[str]: return frozenset(str(key) for key in self._measurement_key_objs_()) @cached_property - def _control_keys(self) -> AbstractSet['cirq.MeasurementKey']: + def _control_keys(self) -> FrozenSet['cirq.MeasurementKey']: keys = ( frozenset() if not protocols.control_keys(self.circuit) @@ -342,13 +341,13 @@ def _control_keys(self) -> AbstractSet['cirq.MeasurementKey']: keys |= frozenset(self.repeat_until.keys) - self._measurement_key_objs_() return keys - def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']: + def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']: return self._control_keys def _is_parameterized_(self) -> bool: return any(self._parameter_names_generator()) - def _parameter_names_(self) -> AbstractSet[str]: + def _parameter_names_(self) -> FrozenSet[str]: return frozenset(self._parameter_names_generator()) def _parameter_names_generator(self) -> Iterator[str]: @@ -463,7 +462,7 @@ def __str__(self): ) args = [] - def dict_str(d: Dict) -> str: + def dict_str(d: Mapping) -> str: pairs = [f'{k}: {v}' for k, v in sorted(d.items())] return '{' + ', '.join(pairs) + '}' From 51f82ed01eb138f3e41e7a8415e827e58ee97919 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 23 Jun 2022 22:26:02 -0700 Subject: [PATCH 6/6] mypy --- cirq-core/cirq/circuits/circuit_operation.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index ddfad400912..46036798d9b 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -313,12 +313,14 @@ def _measurement_key_objs(self) -> FrozenSet['cirq.MeasurementKey']: if circuit_keys and self.use_repetition_ids: self._ensure_deterministic_loop_count() if self.repetition_ids is not None: - circuit_keys = { + circuit_keys = frozenset( key.with_key_path_prefix(repetition_id) for repetition_id in self.repetition_ids for key in circuit_keys - } - circuit_keys = {key.with_key_path_prefix(*self.parent_path) for key in circuit_keys} + ) + circuit_keys = frozenset( + key.with_key_path_prefix(*self.parent_path) for key in circuit_keys + ) return frozenset( protocols.with_measurement_key_mapping(key, dict(self.measurement_key_map)) for key in circuit_keys