Skip to content

Commit

Permalink
Revert "Support qubits in MeasurementKey and update JSON serializat…
Browse files Browse the repository at this point in the history
…ion of keys" (quantumlib#4277)

* Revert "Support qubits in `MeasurementKey` and update JSON serialization of keys (quantumlib#4123)"

This reverts commit 81436fe.

* Safeguard against repr-inequality.

* assert_equivalent_repr
  • Loading branch information
95-martin-orion authored and rht committed May 1, 2023
1 parent f9bb8e0 commit bdff14f
Show file tree
Hide file tree
Showing 26 changed files with 83 additions and 889 deletions.
1 change: 0 additions & 1 deletion cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,6 @@
canonicalize_half_turns,
chosen_angle_to_canonical_half_turns,
chosen_angle_to_half_turns,
default_measurement_key_str,
Duration,
DURATION_LIKE,
GenericMetaImplementAnyOneOf,
Expand Down
12 changes: 2 additions & 10 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,16 +1367,8 @@ def test_findall_operations_with_gate(circuit_cls):
(3, cirq.CZ(a, b), cirq.CZ),
]
assert list(c.findall_operations_with_gate_type(cirq.MeasurementGate)) == [
(
4,
cirq.MeasurementGate(1).on(a),
cirq.MeasurementGate(1, key=cirq.MeasurementKey(qubits=(a,))),
),
(
4,
cirq.MeasurementGate(1).on(b),
cirq.MeasurementGate(1, key=cirq.MeasurementKey(qubits=(b,))),
),
(4, cirq.MeasurementGate(1, key='a').on(a), cirq.MeasurementGate(1, key='a')),
(4, cirq.MeasurementGate(1, key='b').on(b), cirq.MeasurementGate(1, key='b')),
]


Expand Down
4 changes: 0 additions & 4 deletions cirq-core/cirq/ops/gate_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,6 @@ def __str__(self) -> str:
def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['gate', 'qubits'])

@classmethod
def _from_json_dict_(cls, gate, qubits, **kwargs):
return gate.on(*qubits)

def _group_interchangeable_qubits(
self,
) -> Tuple[Union['cirq.Qid', Tuple[int, FrozenSet['cirq.Qid']]], ...]:
Expand Down
18 changes: 11 additions & 7 deletions cirq-core/cirq/ops/measure_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, List, Tuple, TYPE_CHECKING, Union
from typing import Callable, Iterable, List, Optional, Tuple, TYPE_CHECKING, Union

import numpy as np

Expand All @@ -24,9 +24,13 @@
import cirq


def _default_measurement_key(qubits: Iterable[raw_types.Qid]) -> str:
return ','.join(str(q) for q in qubits)


def measure(
*target: 'cirq.Qid',
key: Union[str, value.MeasurementKey] = '',
key: Optional[Union[str, value.MeasurementKey]] = None,
invert_mask: Tuple[bool, ...] = (),
) -> raw_types.Operation:
"""Returns a single MeasurementGate applied to all the given qubits.
Expand All @@ -35,7 +39,7 @@ def measure(
Args:
*target: The qubits that the measurement gate should measure.
key: The string key of the measurement. If this is empty, defaults
key: The string key of the measurement. If this is None, it defaults
to a comma-separated list of the target qubits' str values.
invert_mask: A list of Truthy or Falsey values indicating whether
the corresponding qubits should be flipped. None indicates no
Expand All @@ -56,12 +60,14 @@ def measure(
elif not isinstance(qubit, raw_types.Qid):
raise ValueError('measure() was called with type different than Qid.')

if key is None:
key = _default_measurement_key(target)
qid_shape = protocols.qid_shape(target)
return MeasurementGate(len(target), key, invert_mask, qid_shape).on(*target)


def measure_each(
*qubits: 'cirq.Qid', key_func: Callable[[raw_types.Qid], str] = lambda x: ''
*qubits: 'cirq.Qid', key_func: Callable[[raw_types.Qid], str] = str
) -> List[raw_types.Operation]:
"""Returns a list of operations individually measuring the given qubits.
Expand All @@ -70,9 +76,7 @@ def measure_each(
Args:
*qubits: The qubits to measure.
key_func: Determines the key of the measurements of each qubit. Takes
the qubit and returns the key for that qubit. Defaults to empty string
key which would result in a comma-separated list of the target qubits'
str values.
the qubit and returns the key for that qubit. Defaults to str.
Returns:
A list of operations individually measuring the given qubits.
Expand Down
18 changes: 4 additions & 14 deletions cirq-core/cirq/ops/measure_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,12 @@ def test_measure_qubits():
with pytest.raises(ValueError, match='empty set of qubits'):
_ = cirq.measure()

assert cirq.measure(a) == cirq.MeasurementGate(num_qubits=1, key='').on(a)
assert cirq.measure(a) != cirq.measure(a, key='a')
assert cirq.measure(a) != cirq.MeasurementGate(num_qubits=1, key='').on(b)
assert cirq.measure(a) != cirq.MeasurementGate(num_qubits=1, key='a').on(a)
assert cirq.measure(a) == cirq.MeasurementGate(
num_qubits=1, key=cirq.MeasurementKey(qubits=(a,))
).on(a)
assert cirq.measurement_key(cirq.measure(a)) == 'a'
assert cirq.measure(a, b) == cirq.MeasurementGate(num_qubits=2, key='').on(a, b)
assert cirq.measurement_key(cirq.measure(a, b)) == 'a,b'
assert cirq.measure(b, a) == cirq.MeasurementGate(num_qubits=2, key='').on(b, a)
assert cirq.measurement_key(cirq.measure(b, a)) == 'b,a'
assert cirq.measure(a) == cirq.MeasurementGate(num_qubits=1, key='a').on(a)
assert cirq.measure(a, b) == cirq.MeasurementGate(num_qubits=2, key='a,b').on(a, b)
assert cirq.measure(b, a) == cirq.MeasurementGate(num_qubits=2, key='b,a').on(b, a)
assert cirq.measure(a, key='b') == cirq.MeasurementGate(num_qubits=1, key='b').on(a)
assert cirq.measure(a, key='b') != cirq.MeasurementGate(num_qubits=1, key='b').on(b)
assert cirq.measure(a, invert_mask=(True,)) == cirq.MeasurementGate(
num_qubits=1, invert_mask=(True,)
num_qubits=1, key='a', invert_mask=(True,)
).on(a)
assert cirq.measure(*cirq.LineQid.for_qid_shape((1, 2, 3)), key='a') == cirq.MeasurementGate(
num_qubits=3, key='a', qid_shape=(1, 2, 3)
Expand Down
42 changes: 16 additions & 26 deletions cirq-core/cirq/ops/measurement_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional, Tuple, Sequence, TYPE_CHECKING, Union
from typing import Any, Dict, Iterable, Optional, Tuple, Sequence, TYPE_CHECKING, Union

import numpy as np

from cirq import protocols, value
from cirq.ops import raw_types, gate_operation
from cirq.ops import raw_types

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -80,21 +80,12 @@ def key(self, key: Union[str, value.MeasurementKey]):
else:
self.mkey = value.MeasurementKey(name=key)

def on(self, *qubits: raw_types.Qid) -> raw_types.Operation:
"""Returns an application of this gate to the given qubits.
Args:
*qubits: The collection of qubits to potentially apply the gate to.
"""
maybe_rekeyed_gate = self.with_key(self.mkey.with_qubits(qubits))
return gate_operation.GateOperation(maybe_rekeyed_gate, list(qubits))

def _qid_shape_(self) -> Tuple[int, ...]:
return self._qid_shape

def with_key(self, key: Union[str, value.MeasurementKey]) -> 'MeasurementGate':
"""Creates a measurement gate with a new key but otherwise identical."""
if isinstance(key, value.MeasurementKey) and key == self.mkey:
if key == self.key:
return self
return MeasurementGate(
self.num_qubits(), key=key, invert_mask=self.invert_mask, qid_shape=self._qid_shape
Expand All @@ -114,7 +105,7 @@ def with_bits_flipped(self, *bit_positions: int) -> 'MeasurementGate':
for b in bit_positions:
new_mask[b] = not new_mask[b]
return MeasurementGate(
self.num_qubits(), key=self.mkey, invert_mask=tuple(new_mask), qid_shape=self._qid_shape
self.num_qubits(), key=self.key, invert_mask=tuple(new_mask), qid_shape=self._qid_shape
)

def full_invert_mask(self):
Expand Down Expand Up @@ -161,8 +152,8 @@ def _circuit_diagram_info_(
if b:
symbols[i] = '!M'

# Mention the measurement key if it is non-trivial or there are no known qubits.
if self.mkey.name or self.mkey.path or not args.known_qubits:
# Mention the measurement key.
if not args.known_qubits or self.key != _default_measurement_key(args.known_qubits):
symbols[0] += f"('{self.key}')"

return protocols.CircuitDiagramInfo(tuple(symbols))
Expand Down Expand Up @@ -200,13 +191,8 @@ def _quil_(

def _op_repr_(self, qubits: Sequence['cirq.Qid']) -> str:
args = list(repr(q) for q in qubits)
if self.mkey.name or self.mkey.path:
if self.mkey == self.mkey.name:
args.append(f'key={self.mkey.name!r}')
else:
# Remove qubits from the `MeasurementKey` representation since we already have
# qubits from the op.
args.append(f'key={self.mkey.with_qubits(tuple())!r}')
if self.key != _default_measurement_key(qubits):
args.append(f'key={self.key!r}')
if self.invert_mask:
args.append(f'invert_mask={self.invert_mask!r}')
arg_list = ', '.join(args)
Expand All @@ -219,13 +205,13 @@ def __repr__(self):
return (
f'cirq.MeasurementGate('
f'{self.num_qubits()!r}, '
f'{self.mkey.name if self.mkey == self.mkey.name else self.mkey!r}, '
f'{self.key!r}, '
f'{self.invert_mask}'
f'{qid_shape_arg})'
)

def _value_equality_values_(self) -> Any:
return self.mkey, self.invert_mask, self._qid_shape
return self.key, self.invert_mask, self._qid_shape

def _json_dict_(self) -> Dict[str, Any]:
other = {}
Expand All @@ -234,7 +220,7 @@ def _json_dict_(self) -> Dict[str, Any]:
return {
'cirq_type': self.__class__.__name__,
'num_qubits': len(self._qid_shape),
'key': self.mkey,
'key': self.key,
'invert_mask': self.invert_mask,
**other,
}
Expand All @@ -243,7 +229,7 @@ def _json_dict_(self) -> Dict[str, Any]:
def _from_json_dict_(cls, num_qubits, key, invert_mask, qid_shape=None, **kwargs):
return cls(
num_qubits=num_qubits,
key=value.MeasurementKey.parse_serialized(key) if isinstance(key, str) else key,
key=value.MeasurementKey.parse_serialized(key),
invert_mask=tuple(invert_mask),
qid_shape=None if qid_shape is None else tuple(qid_shape),
)
Expand All @@ -254,3 +240,7 @@ def _has_stabilizer_effect_(self) -> Optional[bool]:
def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']) -> bool:
args.measure(qubits, self.key, self.full_invert_mask())
return True


def _default_measurement_key(qubits: Iterable[raw_types.Qid]) -> str:
return ','.join(str(q) for q in qubits)
18 changes: 9 additions & 9 deletions cirq-core/cirq/ops/measurement_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@
import cirq


def test_eval_repr():
# Basic safeguard against repr-inequality.
op = cirq.GateOperation(
gate=cirq.MeasurementGate(1, cirq.MeasurementKey(path=(), name='q0_1_0'), ()),
qubits=[cirq.GridQubit(0, 1)],
)
cirq.testing.assert_equivalent_repr(op)


@pytest.mark.parametrize('num_qubits', [1, 2, 4])
def test_measure_init(num_qubits):
assert cirq.MeasurementGate(num_qubits).num_qubits() == num_qubits
Expand Down Expand Up @@ -266,15 +275,6 @@ def test_op_repr():
"key='out', "
"invert_mask=(False, True))"
)
assert repr(
cirq.measure(
a, b, key=cirq.MeasurementKey(name='out', path=('a', 'b')), invert_mask=(False, True)
)
) == (
"cirq.measure(cirq.LineQubit(0), cirq.LineQubit(1), "
"key=cirq.MeasurementKey(path=('a', 'b'), name='out'), "
"invert_mask=(False, True))"
)


def test_act_on_state_vector():
Expand Down
28 changes: 1 addition & 27 deletions cirq-core/cirq/protocols/json_test_data/Circuit.json
Original file line number Diff line number Diff line change
Expand Up @@ -85,33 +85,7 @@
"gate": {
"cirq_type": "MeasurementGate",
"num_qubits": 5,
"key": {
"cirq_type": "MeasurementKey",
"name": "",
"path": [],
"qubits": [
{
"cirq_type": "LineQubit",
"x": 0
},
{
"cirq_type": "LineQubit",
"x": 1
},
{
"cirq_type": "LineQubit",
"x": 2
},
{
"cirq_type": "LineQubit",
"x": 3
},
{
"cirq_type": "LineQubit",
"x": 4
}
]
},
"key": "0,1,2,3,4",
"invert_mask": []
},
"qubits": [
Expand Down

0 comments on commit bdff14f

Please sign in to comment.