Skip to content

Commit

Permalink
Rename measurement_key protocol to measurement_key_str (#4403)
Browse files Browse the repository at this point in the history
* Rename measurement_key protocol

* restrict changes to cirq-google

* Revert "restrict changes to cirq-google"

This reverts commit 4554226.

* Fix a wrong rename

* Rename protocols with _name instead of _str. Add tests and support backward compatible magic methods

* classes don't need to fallback to deprecated methods anymore

* Keep gate operation fallback and add warnings
  • Loading branch information
smitsanghavi committed Aug 16, 2021
1 parent 5a248eb commit 3690440
Show file tree
Hide file tree
Showing 33 changed files with 214 additions and 132 deletions.
2 changes: 2 additions & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,9 @@
json_serializable_dataclass,
kraus,
measurement_key,
measurement_key_name,
measurement_keys,
measurement_key_names,
mixture,
mul,
num_qubits,
Expand Down
8 changes: 4 additions & 4 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class AbstractCircuit(abc.ABC):
* to_text_diagram
* to_text_diagram_drawer
* qid_shape
* all_measurement_keys
* all_measurement_key_names
* to_quil
* to_qasm
* save_qasm
Expand Down Expand Up @@ -908,8 +908,8 @@ def qid_shape(
qids = ops.QubitOrder.as_qubit_order(qubit_order).order_for(self.all_qubits())
return protocols.qid_shape(qids)

def all_measurement_keys(self) -> AbstractSet[str]:
return protocols.measurement_keys(self)
def all_measurement_key_names(self) -> AbstractSet[str]:
return protocols.measurement_key_names(self)

def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
return self._with_sliced_moments(
Expand Down Expand Up @@ -1576,7 +1576,7 @@ class Circuit(AbstractCircuit):
* to_text_diagram
* to_text_diagram_drawer
* qid_shape
* all_measurement_keys
* all_measurement_key_names
* to_quil
* to_qasm
* save_qasm
Expand Down
8 changes: 4 additions & 4 deletions cirq-core/cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,10 @@ def _qid_shape_(self) -> Tuple[int, ...]:
def _is_measurement_(self) -> bool:
return self.circuit._is_measurement_()

def _measurement_keys_(self) -> AbstractSet[str]:
def _measurement_key_names_(self) -> AbstractSet[str]:
circuit_keys = [
value.MeasurementKey.parse_serialized(key_str)
for key_str in self.circuit.all_measurement_keys()
for key_str in self.circuit.all_measurement_key_names()
]
if self.repetition_ids is not None:
circuit_keys = [
Expand Down Expand Up @@ -509,14 +509,14 @@ def with_measurement_key_mapping(self, key_map: Dict[str, str]) -> 'CircuitOpera
keys than this operation.
"""
new_map = {}
for k in self.circuit.all_measurement_keys():
for k in self.circuit.all_measurement_key_names():
k = value.MeasurementKey.parse_serialized(k).name
k_new = self.measurement_key_map.get(k, k)
k_new = key_map.get(k_new, k_new)
if k_new != k:
new_map[k] = k_new
new_op = self.replace(measurement_key_map=new_map)
if len(new_op._measurement_keys_()) != len(self._measurement_keys_()):
if len(new_op._measurement_key_names_()) != len(self._measurement_key_names_()):
raise ValueError(
f'Collision in measurement key map composition. Original map:\n'
f'{self.measurement_key_map}\nApplied changes: {key_map}'
Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/circuits/circuit_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def test_with_measurement_keys():
op_with_keys = op_base.with_measurement_key_mapping({'ma': 'pa', 'x': 'z'})
assert op_with_keys.base_operation() == op_base
assert op_with_keys.measurement_key_map == {'ma': 'pa'}
assert cirq.measurement_keys(op_with_keys) == {'pa', 'mb'}
assert cirq.measurement_key_names(op_with_keys) == {'pa', 'mb'}

assert cirq.with_measurement_key_mapping(op_base, {'ma': 'pa'}) == op_with_keys

Expand Down Expand Up @@ -745,14 +745,14 @@ def test_decompose_repeated_nested_measurements():
'one:one:zero:D',
'one:one:one:D',
]
assert cirq.measurement_keys(op3) == set(expected_measurement_keys_in_order)
assert cirq.measurement_key_names(op3) == set(expected_measurement_keys_in_order)

expected_circuit = cirq.Circuit()
for key in expected_measurement_keys_in_order:
expected_circuit.append(cirq.measure(a, key=cirq.MeasurementKey.parse_serialized(key)))

assert cirq.Circuit(cirq.decompose(op3)) == expected_circuit
assert cirq.measurement_keys(expected_circuit) == set(expected_measurement_keys_in_order)
assert cirq.measurement_key_names(expected_circuit) == set(expected_measurement_keys_in_order)

# Verify that mapped_circuit gives the same operations.
assert op3.mapped_circuit(deep=True) == expected_circuit
Expand Down
25 changes: 14 additions & 11 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3861,16 +3861,19 @@ def test_measurement_key_mapping(circuit_cls):
cirq.measure(a, key='m1'),
cirq.measure(b, key='m2'),
)
assert c.all_measurement_keys() == {'m1', 'm2'}
assert c.all_measurement_key_names() == {'m1', 'm2'}

assert cirq.with_measurement_key_mapping(c, {'m1': 'p1'}).all_measurement_keys() == {'p1', 'm2'}
assert cirq.with_measurement_key_mapping(c, {'m1': 'p1'}).all_measurement_key_names() == {
'p1',
'm2',
}

assert cirq.with_measurement_key_mapping(
c, {'m1': 'p1', 'm2': 'p2'}
).all_measurement_keys() == {'p1', 'p2'}
).all_measurement_key_names() == {'p1', 'p2'}

c_swapped = cirq.with_measurement_key_mapping(c, {'m1': 'm2', 'm2': 'm1'})
assert c_swapped.all_measurement_keys() == {'m1', 'm2'}
assert c_swapped.all_measurement_key_names() == {'m1', 'm2'}

# Verify that the keys were actually swapped.
simulator = cirq.Simulator()
Expand All @@ -3883,7 +3886,7 @@ def test_measurement_key_mapping(circuit_cls):
{
'x': 'z',
},
).all_measurement_keys()
).all_measurement_key_names()
== {'m1', 'm2'}
)

Expand Down Expand Up @@ -4288,9 +4291,9 @@ def test_indexing_by_numpy_integer(circuit_cls):


@pytest.mark.parametrize('circuit_cls', [cirq.Circuit, cirq.FrozenCircuit])
def test_all_measurement_keys(circuit_cls):
def test_all_measurement_key_names(circuit_cls):
class Unknown(cirq.SingleQubitGate):
def _measurement_key_(self):
def _measurement_key_name_(self):
return 'test'

a, b = cirq.LineQubit.range(2)
Expand All @@ -4305,10 +4308,10 @@ def _measurement_key_(self):
)

# Big case.
assert c.all_measurement_keys() == {'x', 'y', 'xy', 'test'}
assert c.all_measurement_key_names() == {'x', 'y', 'xy', 'test'}

# Empty case.
assert circuit_cls().all_measurement_keys() == set()
assert circuit_cls().all_measurement_key_names() == set()

# Order does not matter.
assert (
Expand All @@ -4319,7 +4322,7 @@ def _measurement_key_(self):
cirq.measure(b, key='y'),
]
)
).all_measurement_keys()
).all_measurement_key_names()
== {'x', 'y'}
)
assert (
Expand All @@ -4330,7 +4333,7 @@ def _measurement_key_(self):
cirq.measure(a, key='x'),
]
)
).all_measurement_keys()
).all_measurement_key_names()
== {'x', 'y'}
)

Expand Down
10 changes: 5 additions & 5 deletions cirq-core/cirq/circuits/frozen_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
self._all_qubits: Optional[FrozenSet['cirq.Qid']] = None
self._all_operations: Optional[Tuple[ops.Operation, ...]] = None
self._has_measurements: Optional[bool] = None
self._all_measurement_keys: Optional[AbstractSet[str]] = None
self._all_measurement_key_names: Optional[AbstractSet[str]] = None
self._are_all_measurements_terminal: Optional[bool] = None

@property
Expand Down Expand Up @@ -130,10 +130,10 @@ def has_measurements(self) -> bool:
self._has_measurements = super().has_measurements()
return self._has_measurements

def all_measurement_keys(self) -> AbstractSet[str]:
if self._all_measurement_keys is None:
self._all_measurement_keys = super().all_measurement_keys()
return self._all_measurement_keys
def all_measurement_key_names(self) -> AbstractSet[str]:
if self._all_measurement_key_names is None:
self._all_measurement_key_names = super().all_measurement_key_names()
return self._all_measurement_key_names

def are_all_measurements_terminal(self) -> bool:
if self._are_all_measurements_terminal is None:
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/circuits/qasm_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def _generate_measurement_ids(self) -> Tuple[Dict[str, str], Dict[str, Optional[
meas_comments = {} # type: Dict[str, Optional[str]]
meas_i = 0
for meas in self.measurements:
key = protocols.measurement_key(meas)
key = protocols.measurement_key_name(meas)
if key in meas_key_id_map:
continue
meas_id = f'm_{key}'
Expand Down Expand Up @@ -274,7 +274,7 @@ def output(text):
# Pick an id for the creg that will store each measurement
already_output_keys: Set[str] = set()
for meas in self.measurements:
key = protocols.measurement_key(meas)
key = protocols.measurement_key_name(meas)
if key in already_output_keys:
continue
already_output_keys.add(key)
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/circuits/quil_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _generate_measurement_ids(self) -> Dict[str, str]:
measurement_id_map: Dict[str, str] = {}
for op in self.operations:
if isinstance(op.gate, ops.MeasurementGate):
key = protocols.measurement_key(op)
key = protocols.measurement_key_name(op)
if key in measurement_id_map:
continue
measurement_id_map[key] = f'm{index}'
Expand All @@ -152,7 +152,7 @@ def _write_quil(self, output_func: Callable[[str], None]) -> None:
if len(self.measurements) > 0:
measurements_declared: Set[str] = set()
for m in self.measurements:
key = protocols.measurement_key(m)
key = protocols.measurement_key_name(m)
if key in measurements_declared:
continue
measurements_declared.add(key)
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ion/ion_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _verify_unique_measurement_keys(operations: Iterable[ops.Operation]):
for op in operations:
if isinstance(op.gate, ops.MeasurementGate):
meas = op.gate
key = protocols.measurement_key(meas)
key = protocols.measurement_key_name(meas)
if key in seen:
raise ValueError(f'Measurement key {key} repeated')
seen.add(key)
4 changes: 2 additions & 2 deletions cirq-core/cirq/ops/gate_features_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def test_qasm_output_args_format():
assert args.format('_{0}_', a) == '_aaa[0]_'
assert args.format('_{0}_', b) == '_bbb[0]_'

assert args.format('_{0:meas}_', cirq.measurement_key(m_a)) == '_m_a_'
assert args.format('_{0:meas}_', cirq.measurement_key(m_b)) == '_m_b_'
assert args.format('_{0:meas}_', cirq.measurement_key_name(m_a)) == '_m_a_'
assert args.format('_{0:meas}_', cirq.measurement_key_name(m_b)) == '_m_b_'

assert args.format('_{0}_', 89.1234567) == '_89.1235_'
assert args.format('_{0}_', 1.23) == '_1.23_'
Expand Down
21 changes: 19 additions & 2 deletions cirq-core/cirq/ops/gate_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import numpy as np

from cirq import protocols, value
from cirq._compat import _warn_or_error
from cirq.ops import raw_types, gate_features
from cirq.type_workarounds import NotImplementedType

Expand Down Expand Up @@ -229,15 +230,31 @@ def _is_measurement_(self) -> Optional[bool]:
# Let the protocol handle the fallback.
return NotImplemented

def _measurement_key_(self) -> Optional[str]:
def _measurement_key_name_(self) -> Optional[str]:
getter = getattr(self.gate, '_measurement_key_name_', None)
if getter is not None:
return getter()
getter = getattr(self.gate, '_measurement_key_', None)
if getter is not None:
_warn_or_error(
f'_measurement_key_ was used but is deprecated.\n'
f'It will be removed in cirq v0.13.\n'
f'Use _measurement_key_name_ instead.\n'
)
return getter()
return NotImplemented

def _measurement_keys_(self) -> Optional[Iterable[str]]:
def _measurement_key_names_(self) -> Optional[Iterable[str]]:
getter = getattr(self.gate, '_measurement_key_names_', None)
if getter is not None:
return getter()
getter = getattr(self.gate, '_measurement_keys_', None)
if getter is not None:
_warn_or_error(
f'_measurement_keys_ was used but is deprecated.\n'
f'It will be removed in cirq v0.13.\n'
f'Use _measurement_key_names_ instead.\n'
)
return getter()
return NotImplemented

Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/ops/gate_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def test_channel():

def test_measurement_key():
a = cirq.NamedQubit('a')
assert cirq.measurement_key(cirq.measure(a, key='lock')) == 'lock'
assert cirq.measurement_key_name(cirq.measure(a, key='lock')) == 'lock'


def assert_mixtures_equal(actual, expected):
Expand Down Expand Up @@ -387,7 +387,7 @@ def test_with_measurement_key_mapping():
op = cirq.measure(a, key='m')

remap_op = cirq.with_measurement_key_mapping(op, {'m': 'k'})
assert cirq.measurement_keys(remap_op) == {'k'}
assert cirq.measurement_key_names(remap_op) == {'k'}
assert cirq.with_measurement_key_mapping(op, {'x': 'k'}) is op


Expand All @@ -396,7 +396,7 @@ def test_with_key_path():
op = cirq.measure(a, key='m')

remap_op = cirq.with_key_path(op, ('a', 'b'))
assert cirq.measurement_keys(remap_op) == {'a:b:m'}
assert cirq.measurement_key_names(remap_op) == {'a:b:m'}
assert cirq.with_key_path(remap_op, ('a', 'b')) is remap_op

assert cirq.with_key_path(op, tuple()) is op
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/kraus_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def num_qubits(self) -> int:
def _kraus_(self):
return self._kraus_ops

def _measurement_key_(self):
def _measurement_key_name_(self):
if self._key is None:
return NotImplemented
return self._key
Expand Down
8 changes: 4 additions & 4 deletions cirq-core/cirq/ops/kraus_channel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def test_kraus_channel_from_channel():
q0 = cirq.LineQubit(0)
dp = cirq.depolarize(0.1)
kc = cirq.KrausChannel.from_channel(dp, key='dp')
assert cirq.measurement_key(kc) == 'dp'
assert cirq.measurement_key_name(kc) == 'dp'

circuit = cirq.Circuit(kc.on(q0))
sim = cirq.Simulator(seed=0)
Expand Down Expand Up @@ -48,12 +48,12 @@ def test_kraus_channel_remap_keys():
dp = cirq.depolarize(0.1)
kc = cirq.KrausChannel.from_channel(dp)
with pytest.raises(TypeError):
_ = cirq.measurement_key(kc)
_ = cirq.measurement_key_name(kc)
assert cirq.with_measurement_key_mapping(kc, {'a': 'b'}) is NotImplemented

kc_x = cirq.KrausChannel.from_channel(dp, key='x')
assert cirq.with_measurement_key_mapping(kc_x, {'a': 'b'}) is kc_x
assert cirq.measurement_key(cirq.with_key_path(kc_x, ('path',))) == 'path:x'
assert cirq.measurement_key_name(cirq.with_key_path(kc_x, ('path',))) == 'path:x'

kc_a = cirq.KrausChannel.from_channel(dp, key='a')
kc_b = cirq.KrausChannel.from_channel(dp, key='b')
Expand All @@ -69,7 +69,7 @@ def test_kraus_channel_from_kraus():
np.array([[1, -1], [-1, 1]]) * 0.5,
]
x_meas = cirq.KrausChannel(ops, key='x_meas')
assert cirq.measurement_key(x_meas) == 'x_meas'
assert cirq.measurement_key_name(x_meas) == 'x_meas'

circuit = cirq.Circuit(cirq.H(q0), x_meas.on(q0))
sim = cirq.Simulator(seed=0)
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/measurement_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def full_invert_mask(self):
def _is_measurement_(self) -> bool:
return True

def _measurement_key_(self):
def _measurement_key_name_(self):
return self.key

def _kraus_(self):
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/mixed_unitary_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def num_qubits(self) -> int:
def _mixture_(self):
return self._mixture

def _measurement_key_(self):
def _measurement_key_name_(self):
if self._key is None:
return NotImplemented
return self._key
Expand Down

0 comments on commit 3690440

Please sign in to comment.